(ns orcl.analyzer
  (:require [orcl.utils.cursor :as cursor]
            [orcl.utils :as utils]
            [clojure.set :as set]
            [orcl.analyzer.vars :as vars]
            [orcl.analyzer.patterns :as patterns]
    #?(:clj
            [orcl.analyzer.macro :as macro]))
  #?(:cljs (:require-macros [orcl.analyzer.macro :as macro])))

(defn primitive? [n] (#{:const :var} (:node n)))

(defn deflate-values
  [[c & cursors] node]
  (cond
    (nil? c) @node
    (primitive? @c) (recur cursors node)
    :else (macro/with-fresh fresh
                            (let [orig @c]
                              (cursor/reset! c {:node :var :var fresh})
                              {:node    :pruning
                               :pattern {:type :var :var fresh}
                               :left    (deflate-values cursors node)
                               :right   orig}))))

(defn bindings [ast]
  (case (:node ast)
    :var #{(:var ast)}
    :pruning (set/union (set/difference (bindings (:left ast)) (set/difference (:pattern ast))) (bindings (:right ast)))
    :sequential (set/union (bindings (:left ast)) (set/difference (bindings (:right ast)) (set/difference (:pattern ast))))))


;; TODO site imports & type datastructures
(defn normalize-declarations [{:keys [decls expr]}]
  (letfn [(flat-defs [defs]
            )]
    (loop [[d & decls] decls defs {}]
      (if d
        (case (:type d)
          :include (recur (concat (:decls d) decls) defs)
          :val (let [n {:node    :pruning
                        :pattern (:pattern d)
                        :left    (if (seq decls)
                                   {:node  :declarations
                                    :decls decls
                                    :expr  expr}
                                   expr)
                        :right   (:expr d)}]
                 (if (seq defs)
                   {:node :defs-group
                    :defs (vals defs)
                    :expr n}
                   n))
          :def (let [inst {:params (:params d)
                           :body   (:body d)
                           :guard  (:guard d)}
                     def (get defs (:name d)
                              {:name      (:name d)
                               :instances []})]
                 (recur decls (assoc defs (:name d) (update def :instances conj inst))))
          :def-sig (recur decls defs))
        {:node :defs-group
         :defs (vals defs)
         :expr expr}))))

(declare normalize)
(defn normalize* [ast]
  (case (:node ast)
    :declarations (normalize-declarations ast)
    :lambda (let [body (:body ast)
                  n    (str "__def_" (utils/sha body))]
              {:node :defs-group
               :defs [{:name      n
                       :instances [{:guard  (:guard ast)
                                    :params (:params ast)
                                    :body   body}]}]
               :expr {:node :var
                      :var  n}})
    (:list :tuple) (macro/as-cursor [c ast] (deflate-values (seq (:values c)) c))
    :record (macro/as-cursor [c ast] (deflate-values (map second (:pairs c)) c))
    :call (macro/as-cursor [c ast] (deflate-values (concat [(:target c)] (:args c)) c))
    :field-access (macro/as-cursor [c ast] (deflate-values [(:target c)] c))
    :conditional (let [v    (str "__v_" (utils/sha ast))
                       then (str "__then_" (utils/sha (:then ast)))
                       else (str "__else_" (utils/sha (:else ast)))]
                   {:node :defs-group
                    :defs [{:name      then
                            :instances [{:params []
                                         :body   (:then ast)}]}
                           {:name      else
                            :instances [{:params []
                                         :body   (:else ast)}]}]
                    :expr {:node    :sequential
                           :pattern {:type :var
                                     :var  v}
                           :left    (:if ast)
                           :right   {:node :normalized-conditional
                                     :if   {:node :var :var v}
                                     :then {:node :var :var then}
                                     :else {:node :var :var else}}}})
    ast))

(defn normalize [ast]
  (utils/ast-prewalk normalize* ast))

(defn with-sha [ast]
  (utils/ast-postwalk utils/with-sha ast))

(declare analyze-env)

(defn flatten-def
  ([ast] (flatten-def ast []))
  ([ast acc]
   (if (and (= :declaration (:node ast)) (= :def (:type (:decl ast))))
     (flatten-def (:expr ast) (conj acc (:decl ast)))
     [ast acc])))

;; Call in tail position if
;; - it is not in sequential,
;; - not in right branch of pruning,
;; - not in otherwise
;; - not in parallel
(def ^:dynamic *tail-pos*)

(defn analyze-instance [id instance]
  (let [argument-envs (fn [i p] (map (fn [binding] [binding {:type         :argument
                                                             :position     i
                                                             :instance-sha (:sha (:body instance))
                                                             :id           id}])
                                     (patterns/pattern-bindings p)))]
    (macro/with-envs (into {} (mapcat argument-envs (range) (:params instance)))
      (binding [*tail-pos* {:id id}]
        (update instance :body analyze-env)))))

(defn analyze-def [defs {:keys [name usages sha instances] :as node}]
  (macro/with-envs (into {} (for [{:keys [name usages node sha]} defs]
                              [name {:type   :def
                                     :id     sha
                                     :usages usages}]))
    (assoc node
      :arity (count (:params (first instances)))
      :instances (mapv (partial analyze-instance sha) instances))))

(defn analyze-defs [defs]
  (let [defs' (map #(assoc %
                      :usages (atom 0)
                      :sha (utils/sha %))
                   defs)]
    (mapv (partial analyze-def defs') defs')))

;; TODO check target & arity
(defn check-call! [call])

;:site (macro/with-env (:name decl) {:type       :site
;                                    :source     {:type :import :pos (:pos decl)}
;                                    :definition (:definition decl)}
;        (analyze-env (:expr ast)))

(defn analyze-env [ast]
  (case (:node ast)
    :pruning (assoc ast :left (macro/with-pattern (:pattern ast) {:type     :pruning
                                                                  :node-sha (:sha (:right ast))}
                                (analyze-env (:left ast)))
                        :right (binding [*tail-pos* nil]
                                 (analyze-env (:right ast))))
    :sequential (assoc ast :right (macro/with-pattern (:pattern ast) {:type     :sequential
                                                                      :node-sha (:sha (:left ast))}
                                    (binding [*tail-pos* nil]
                                      (analyze-env (:right ast))))
                           :left (binding [*tail-pos* nil]
                                   (analyze-env (:left ast))))
    :otherwise (assoc ast :right (binding [*tail-pos* nil]
                                   (analyze-env (:right ast)))
                          :left (binding [*tail-pos* nil]
                                  (analyze-env (:left ast))))
    :parallel (assoc ast :left (binding [*tail-pos* nil]
                                 (analyze-env (:left ast)))
                         :right (binding [*tail-pos* nil]
                                  (analyze-env (:right ast))))
    :defs-group (let [defs (analyze-defs (:defs ast))]
                  (macro/with-envs (into {} (for [{:keys [name sha usages]} defs]
                                              [name {:type   :def
                                                     :id     sha
                                                     :usages usages}]))
                    (assoc ast
                      :defs defs
                      :expr (analyze-env (:expr ast)))))
    :call (let [ast' (-> ast (update :target analyze-env) (assoc :args (mapv analyze-env (:args ast))))]
            (check-call! ast')
            (let [s (get-in ast' [:target :source])]
              (if (and *tail-pos* (= :def (:type s)) (= (:id s) (:id *tail-pos*)))
                (assoc ast'
                  :tail-pos *tail-pos*)
                ast')))
    :var (if-let [source (get vars/*env* (:var ast))]
           (do
             (when (= :def (:type source))
               (swap! (:usages source) inc))
             (assoc ast :source (dissoc source :usages)))
           ;; "Undefined variable" (:var ast)
           (throw (ex-info "Undefined variable" {:orcl/error-pos (:pos ast)
                                                 :orcl/error     "Undefined variable"
                                                 :variable       (:var ast)})))
    (utils/ast-walk analyze-env identity ast)))

(defn def-instance-locals [instance]
  (apply set/difference (:locals (:body instance))
         (map patterns/pattern-bindings (:params instance))))

(defn def-locals [def]
  (set (mapcat def-instance-locals (:instances def))))

(defn analyze-stage2 [ast]
  (let [set-locals    (fn [ast]
                        (if (= :defs-group (:node ast))
                          (do
                            (let [defs' (mapv #(assoc % :locals (def-locals %)) (:defs ast))]
                              (assoc ast
                                :defs defs'
                                :locals (set (mapcat :locals defs')))))
                          (assoc ast :locals
                                     (case (:node ast)
                                       (:otherwise :parallel) (set/union (:locals (:left ast))
                                                                         (:locals (:right ast)))
                                       (:sequential :pruning) (set/difference (set/union (:locals (:left ast))
                                                                                         (:locals (:right ast)))
                                                                              (patterns/pattern-bindings (:pattern ast)))
                                       :normalized-conditional (set/union (:locals (:var ast))
                                                                          (:locals (:then ast))
                                                                          (:locals (:else ast)))
                                       (:field-access :call) (apply set/union
                                                                    (get-in ast [:target :locals])
                                                                    (map :locals (:args ast)))
                                       (:tuple :list) (apply set/union (map :locals (:values ast)))
                                       :record (apply set/union (map (comp :locals second) (:pairs ast)))
                                       :var (if (#{:def :site} (get-in ast [:source :type]))
                                              #{}
                                              #{(:var ast)})
                                       (:const :stop) #{}))))
        remove-unused (fn [ast]
                        (case (:node ast)
                          :defs-group
                          (let [in-use (filterv #(pos? (:usages %)) (map #(update % :usages deref) (:defs ast)))]
                            (if (not-empty in-use)
                              (assoc ast :defs in-use)
                              (:expr ast)))
                          ast))]
    (utils/ast-postwalk (comp set-locals remove-unused) ast)))

(defn analyze-final
  ([ast] (analyze-final ast {}))
  ([ast env]
   (binding [vars/*env* env
             *tail-pos* nil]
     (let [ast' (analyze-env ast)]
       (analyze-stage2 ast')))))

(defn analyze
  ([ast] (analyze-final (with-sha (normalize ast))))
  ([ast env] (analyze-final (with-sha (normalize ast)) env)))