(ns orcl.analyzer
  (:require [orcl.utils.cursor :as cursor]
            [orcl.utils :as utils]
            [clojure.set :as set]
            [orcl.analyzer.vars :as vars]
            [orcl.typecheck :as typecheck]
            [orcl.analyzer.patterns :as patterns]
            [orcl.analyzer.convert :as convert]
    #?(:clj
            [orcl.analyzer.macro :as macro]))
  #?(:cljs (:require-macros [orcl.analyzer.macro :as macro])))

(defn primitive? [n] (#{:const :var} (:node n)))

(defn deflate-values
  [[c & cursors] node]
  (cond
    (nil? c) @node
    (primitive? @c) (recur cursors node)
    :else (macro/with-fresh fresh
            (let [orig @c]
              (cursor/reset! c {:node :var :var fresh})
              {:node    :pruning
               :pattern {:type :var :var fresh}
               :left    (deflate-values cursors node)
               :right   orig}))))

(defn bindings [ast]
  (case (:node ast)
    :var #{(:var ast)}
    :pruning (set/union (set/difference (bindings (:left ast)) (set/difference (:pattern ast))) (bindings (:right ast)))
    :sequential (set/union (bindings (:left ast)) (set/difference (bindings (:right ast)) (set/difference (:pattern ast))))))

(defn strict-pattern? [p]
  (case (:type p)
    (:wildcard :var) false
    true))

(defn site [n]
  {:node   :var
   :var    n
   :source {:type       :site
            :source     {:type :prelude}
            :definition n}})

(def site-wrap-some (site "_WrapSome"))
(def site-unwrap-some (site "_UnwrapSome"))
(def site-none (site "_None"))
(def site-is-none (site "_IsNone"))
(def site-list (site "_MakeList"))
(def site-tuple (site "_MakeTuple"))
(def site-record (site "_MakeRecord"))
(def site-ift (site "Ift"))
(def site-iff (site "Iff"))

(defn pattern-bindings [p]
  (case (:type p)
    :var [(:var p)]
    :wildcard []
    :const []
    :record (mapcat (comp pattern-bindings second) (:pairs p))
    (:list :tuple) (mapcat pattern-bindings (:patterns p))
    :as (cons (:alias p)
              (pattern-bindings (:pattern p)))
    :cons (concat (pattern-bindings (:head p)) (pattern-bindings (:tail p)))
    :call (mapcat pattern-bindings (:args p))))

(defn make-match [source target else]
  (macro/with-freshs 3 [res maybe-res target-binding]
    {:node    :sequential
     :left    {:node  :otherwise
               :left  {:node    :sequential
                       :left    source
                       :pattern {:type :var :var res}
                       :right   {:node   :call
                                 :target site-wrap-some
                                 :args   [{:node :var :var res}]}}
               :right {:node   :call
                       :target site-none
                       :args   []}}
     :pattern {:type :var :var maybe-res}
     :right   {:node  :parallel
               :left  {:node    :sequential
                       :left    {:node   :call
                                 :target site-unwrap-some
                                 :args   [{:node :var :var maybe-res}]}
                       :pattern {:type :var :var target-binding}
                       :right   (target target-binding)}
               :right {:node    :sequential
                       :left    {:node   :call
                                 :target site-is-none
                                 :args   [{:node :var :var maybe-res}]}
                       :pattern {:type :wildcard}
                       :right   else}}}))

(defn rebind-vars [vars expr]
  (reduce (fn [res [var-param var-binding]]
            {:node    :pruning
             :left    res
             :pattern {:type :var
                       :var  (:var var-param)}
             :right   {:node :var
                       :var  var-binding}})
          expr vars))

;; TODO guards support
(defn translate-clause* [vars strict else instance]
  ;; TODO optimization for special case: only one strict pattern
  (macro/with-fresh stricted-values
    (let [pattern {:type     :tuple
                   :patterns (mapv first strict)}
          values  {:node   :tuple
                   :values (mapv (fn [b] {:node :var :var b}) (map second strict))}
          [source target-f] (convert/convert-pattern pattern stricted-values)]
      {:node    :sequential
       :left    values
       :pattern {:type :var :var stricted-values}
       :right   (make-match
                  source
                  #(target-f % (rebind-vars vars (:body instance)))
                  else)})))

(defn pattern-type [p]
  (case (:type p)
    :var :var
    :wildcard :wildcard
    :strict))

(defn translate-clause [bindings else instance]
  (let [{:keys [var strict]} (group-by (fn [[pattern binding]] (pattern-type pattern))
                                       (map vector (:params instance) bindings))]
    (if (seq strict)
      (translate-clause* var strict else instance)
      (rebind-vars var (:body instance)))))

(defn translate-clauses [bindings instances]
  (reduce (partial translate-clause bindings) {:node :stop} (reverse instances)))

(defn translate-def [def]
  ;; TODO optimization for special case: one clause without strict patterns
  (let [arity (count (:params (first (:instances def))))]
    (macro/with-freshs arity bindings
      (assoc def :instances [{:params (mapv (fn [b] {:type :var :var b}) bindings)
                              :body   (translate-clauses bindings (:instances def))}]))))

(defn translate-constuctor [{:keys [name arity T]}]
  (macro/with-freshs arity bindings
    [{:name      name
      :instances [{:params (vec (for [b bindings] {:type :var :var b}))
                   :body   {:node   :tuple
                            :values (vec (cons {:node  :const
                                                :value name}
                                               (for [b bindings] {:node :var :var b})))}}]
      :T         T}
     {:name      (str "_Unapply" name)
      :instances [{:params [{:type :var :var "v"}]
                   :body   {:node   :call
                            :target {:node :var :var "_Unapply"}
                            :args   [{:node :var :var "v"}
                                     {:node :const :value name}]}}]
      :T         {:type        :fun
                  :type-params (:type-params T)
                  :params      [(:return T)]
                  :return      {:type :tuple
                                :args (:params T)}}}]))

(defn def-instance-free-vars [instance]
  (apply set/difference (:free-vars (:body instance))
         (map patterns/pattern-bindings (:params instance))))

(defn def-free-vars [def]
  (set (mapcat def-instance-free-vars (:instances def))))

(defn set-free-vars* [ast]
  (if (:free-vars ast)
    ast
    (if (= :defs-group (:node ast))
      (do
        (let [defs' (for [mr-group (:defs ast)
                          :let [group' (for [def mr-group]
                                         (assoc def :free-vars (def-free-vars def)))]]
                      group')]
          (assoc ast
            :defs defs'
            :free-vars (set/difference (set (for [mr-group defs' d mr-group v (:free-vars d)] v))
                                       (set (for [mr-group defs' d mr-group] (:name d)))))))
      (assoc ast :free-vars
                 (case (:node ast)
                   (:otherwise :parallel) (set/union (:free-vars (:left ast))
                                                     (:free-vars (:right ast)))
                   (:sequential :pruning) (set/difference (set/union (:free-vars (:left ast))
                                                                     (:free-vars (:right ast)))
                                                          (patterns/pattern-bindings (:pattern ast)))
                   :conditional (set/union (:free-vars (:var ast))
                                           (:free-vars (:then ast))
                                           (:free-vars (:else ast)))
                   (:field-access :call) (apply set/union
                                                (get-in ast [:target :free-vars])
                                                (map :free-vars (:args ast)))
                   (:tuple :list) (apply set/union (map :free-vars (:values ast)))
                   (:declare-types :has-type) (:free-vars (:expr ast))
                   :record (apply set/union (map (comp :free-vars second) (:pairs ast)))
                   :var #{(:var ast)}
                   :refer (set/difference (:free-vars (:expr ast))
                                          (set (for [[ns symbols] (:namespaces ast)
                                                     s symbols]
                                                 s)))
                   (:const :stop) #{})))))

(defn set-free-vars [ast]
  (utils/ast-postwalk set-free-vars* ast))

(defn divide-defs* [ast]
  (case (:node ast)
    :defs-group (let [m      (utils/index-by :name (first (:defs ast)))
                      edges  (for [d1 (first (:defs ast))
                                   d2 (first (:defs ast))
                                   :when (contains? (:free-vars d1) (:name d2))]
                               [(:name d1) (:name d2)])
                      sorted (reverse (utils/sort-keys-by-vals (utils/dfs-sort edges)))
                      ;; mutual recursive groups. it guarantees that each group depends
                      ;; only on previous groups or group members. set of functions
                      ;; without mutual recursion are always set of one-element lists
                      groups (reverse (utils/graph-components-in-order edges sorted))
                      defs'  (concat
                               (for [mr-group groups]
                                 (map m mr-group))
                               ;; nodes without edges
                               (map (fn [k] [(m k)]) (set/difference (set (keys m)) (set sorted))))]
                  (assoc ast :defs defs'))
    ast))

;; it divides list of defs into list of sublists of defs,
;; where each sublist contains mutually recursive defs
(defn divide-defs [ast]
  (utils/ast-postwalk divide-defs* ast))

(defn accumulate-def-type [T1 T2]
  (utils/assoc-when T1
                    :type-params (:type-params T2)
                    :params (map #(or %1 %2) (:params T1) (:params T2))
                    :return (:return T2)))

;; TODO site imports & type datastructures
(defn translate-declarations [{:keys [decls expr]} k translate-clauses?]
  (letfn [(finalize [[state acc] decls]
            (let [expr' (if (seq decls)
                          {:node  :declarations
                           :decls decls
                           :expr  expr}
                          expr)]
              (case state
                :def {:node :defs-group
                      :defs [(if translate-clauses?
                               (mapv translate-def (vals acc))
                               (vals acc))]
                      :expr expr'}
                :refer {:node       :refer
                        :namespaces acc
                        :expr       expr'}
                :datatype {:node  :declare-types
                           :types (map first acc)
                           :expr  {:node :defs-group
                                   :defs [(->> (for [[_ constructors] acc
                                                     constructor constructors
                                                     x           (translate-constuctor constructor)]
                                                 x)
                                               (vec))]
                                   :expr expr'}}
                :type-alias {:node  :declare-types
                             :types acc
                             :expr  expr'}
                :site {:node        :sites
                       :definitions acc
                       :expr        expr'})))]
    (loop [[state acc :as s] [:init] [d & tail :as decls] decls]
      (if d
        (case (:type d)
          :include (recur s (concat (:decls d) tail))
          :val (if (= :init state)
                 (k {:node    :pruning
                     :pattern (:pattern d)
                     :left    (if (seq tail)
                                {:node  :declarations
                                 :decls tail
                                 :expr  expr}
                                expr)
                     :right   (:expr d)})
                 (finalize s decls))
          :def (if (#{:def :init} state)
                 (let [inst {:params (:params d)
                             :body   (:body d)
                             :guard  (:guard d)
                             :T      (:T d)}
                       def  (get acc (:name d)
                                 {:name      (:name d)
                                  :instances []
                                  :T         (:T d)})]
                   (recur [:def (assoc acc (:name d) (-> def
                                                         (update :instances conj inst)
                                                         (update :T accumulate-def-type (:T inst))))] tail))
                 (finalize s decls))
          :refer (if (#{:refer :init} state)
                   (recur [:refer (conj acc [(:namespace d) (:symbols d)])] tail)
                   (finalize s decls))
          :site (if (#{:site :init} state)
                  (recur [:site (assoc acc (:name d) (:definition d))] tail)
                  (finalize s decls))
          :datatype (if (#{:datatype :init} state)
                      (recur [:datatype (conj acc [[(:name d) {:type         :datatype
                                                               :constructors (:constructors d)
                                                               :type-params  (:type-params d)}]
                                                   (:constructors d)])] tail)
                      (finalize s decls))
          :type-alias (if (#{:type-alias :init} state)
                        (recur [:type-alias (conj acc [(:name d) (:T d)])] tail)
                        (finalize s decls))
          :def-sig (if (#{:def :init} state)
                     (recur [:def (assoc acc (:name d)
                                             {:name      (:name d)
                                              :instances []
                                              :T         (:T d)})]
                            tail)
                     (finalize s decls)))
        (finalize s decls)))))

(defn translate-conditional [{:keys [if then else]}]
  (macro/with-fresh t
    {:node    :pruning
     :pattern {:type :var :var t}
     :right   if
     :left    {:node  :parallel
               :left  {:node    :sequential
                       :left    {:node   :call
                                 :target site-ift
                                 :args   [{:node :var :var t}]}
                       :pattern {:type :wildcard}
                       :right   then}
               :right {:node    :sequential

                       :left    {:node   :call
                                 :target site-iff
                                 :args   [{:node :var :var t}]}
                       :pattern {:type :wildcard}
                       :right   else}}}))

(defn constantize [x]
  {:node  :const
   :value x})

(defn translate-list [vals]
  (reduce (fn [tail v]
            {:node   :call
             :target {:node :var :var ":"}
             :args   [v tail]})
          {:node   :call
           :target {:node :var :var "_Nil"}
           :args   []}
          (reverse vals)))

(declare translate)
(defn translate* [{:keys [deflate? patterns? conditional? clauses? data-structures?] :as options} ast]
  (case (:node ast)
    :declarations (translate-declarations ast (partial translate* options) clauses?)
    :lambda (let [body (:body ast)
                  n    (str "__def_" (utils/sha body))]
              {:node   :defs-group
               :lambda true
               :defs   [[{:name      n
                          :instances [{:guard  (:guard ast)
                                       :params (:params ast)
                                       :body   body}]
                          :T (:T ast)}]]
               :expr   {:node :var
                        :var  n}})
    :list (cond
            data-structures? (translate* options
                                         (translate-list (:values ast)))
            deflate? (macro/as-cursor [c ast] (deflate-values (seq (:values c)) c))
            :else ast)
    :tuple (cond
             data-structures? (translate* options
                                          {:node   :call
                                           :target site-tuple
                                           :args   (:values ast)})
             deflate? (macro/as-cursor [c ast] (deflate-values (seq (:values c)) c))
             :else ast)
    :record (cond
              data-structures? (translate* options
                                           {:node   :call
                                            :target site-record
                                            :args   (vec (mapcat (fn [[f arg]] [(constantize f) arg])
                                                                 (:pairs ast)))})
              deflate? (macro/as-cursor [c ast] (deflate-values (map second (:pairs c)) c))
              :else ast)
    :call (if (= ":=" (get-in ast [:target :var]))
            (translate* options
                        {:node   :call
                         :args   [(second (:args ast))]
                         :target {:node   :field-access
                                  :target (first (:args ast))
                                  :field  "write"}})
            (if deflate?
              (macro/as-cursor [c ast] (deflate-values (concat [(:target c)] (:args c)) c))
              ast))
    :field-access (if deflate?
                    (macro/as-cursor [c ast] (deflate-values [(:target c)] c))
                    ast)
    :conditional (cond
                   conditional? (translate-conditional ast)
                   deflate? (macro/as-cursor [c ast] (deflate-values [(:if c)] c))
                   :else ast)

    :sequential (if (and patterns? (strict-pattern? (:pattern ast)))
                  (macro/with-freshs 2 [source-binding bridge]
                    (let [[source target-f] (convert/convert-pattern (:pattern ast) source-binding)]
                      {:node    :sequential
                       :left    (:left ast)
                       :pattern {:type :var :var source-binding}
                       :right   {:node    :sequential
                                 :pattern {:type :var :var bridge}
                                 :left    source
                                 :right   (target-f bridge (:right ast))}}))
                  ast)

    :pruning (if (and patterns? (strict-pattern? (:pattern ast)))
               (macro/with-freshs 2 [source-binding bridge]
                 (let [[source target-f] (convert/convert-pattern (:pattern ast) source-binding)]
                   {:node    :pruning
                    :right   {:node    :sequential
                              :left    (:right ast)
                              :pattern {:type :var :var source-binding}
                              :right   source}
                    :pattern {:type :var :var bridge}
                    :left    (target-f bridge (:left ast))}))
               ast)

    :dereference (translate* options
                             {:node   :call
                              :args   []
                              :target {:node   :field-access
                                       :target (:target ast)
                                       :field  "read"}})

    :ns (translate* options (:body ast))

    ast))

(defn translate [ast options]
  (utils/ast-prewalk (partial translate* options) ast))

(defn with-sha [ast]
  (utils/ast-postwalk utils/with-sha ast))

(declare analyze-env)

;; Call in tail position if
;; - it is not in left branch of sequential,
;; - not in right branch of pruning,
;; - not in left branch of otherwise
(def ^:dynamic *tail-pos*)

(defn analyze-instance [id instance]
  (let [argument-envs (fn [i p] (map (fn [binding] [binding {:type         :argument
                                                             :position     i
                                                             :instance-sha (:sha (:body instance))
                                                             :id           id}])
                                     (patterns/pattern-bindings p)))]
    (macro/with-envs (into {} (mapcat argument-envs (range) (:params instance)))
      (binding [*tail-pos* {:id id}]
        (update instance :body analyze-env)))))

(defn analyze-def [defs {:keys [name usages sha instances] :as node}]
  (macro/with-envs (into {} (for [{:keys [name usages node sha]} defs]
                              [name {:type   :def
                                     :id     sha
                                     :usages usages}]))
    (assoc node
      :arity (count (:params (first instances)))
      :instances (mapv (partial analyze-instance sha) instances))))

(defn analyze-defs [defs]
  (let [defs' (map #(assoc %
                      :usages (atom 0)
                      :sha (utils/sha %))
                   defs)]
    (mapv (partial analyze-def defs') defs')))

;; TODO check target & arity
(defn check-call! [call])

(defn analyze-env [ast]
  (case (:node ast)
    :pruning (assoc ast :left (macro/with-pattern (:pattern ast) {:type     :pruning
                                                                  :node-sha (:sha (:right ast))}
                                (analyze-env (:left ast)))
                        :right (binding [*tail-pos* nil]
                                 (analyze-env (:right ast))))
    :sequential (assoc ast :right (macro/with-pattern (:pattern ast) {:type     :sequential
                                                                      :node-sha (:sha (:left ast))}
                                    (analyze-env (:right ast)))
                           :left (binding [*tail-pos* nil]
                                   (analyze-env (:left ast))))
    :otherwise (assoc ast :right (analyze-env (:right ast))
                          :left (binding [*tail-pos* nil]
                                  (analyze-env (:left ast))))
    :defs-group (let [defs (map analyze-defs (:defs ast))]
                  (macro/with-envs (into {} (for [{:keys [name sha usages]} (apply concat defs)]
                                              [name {:type   :def
                                                     :id     sha
                                                     :usages usages}]))
                    (assoc ast
                      :defs defs
                      :expr (analyze-env (:expr ast)))))
    :refer (macro/with-envs (into {} (for [[ns symbols] (:namespaces ast)
                                           s symbols]
                                       [s {:type      :refer
                                           :namespace ns}]))
             (update ast :expr analyze-env))
    :sites (macro/with-envs (into {} (for [[var definition] (:definitions ast)]
                                       [var {:type       :site
                                             :source     {:type :custom}
                                             :definition definition}]))
             (analyze-env (:expr ast)))
    :call (let [ast' (-> ast (update :target analyze-env) (assoc :args (mapv analyze-env (:args ast))))]
            (check-call! ast')
            (let [s (get-in ast' [:target :source])]
              (if (and *tail-pos* (= :def (:type s)) (= (:id s) (:id *tail-pos*)))
                (assoc ast'
                  :tail-pos *tail-pos*)
                ast')))
    :var (if-let [source (get vars/*env* (:var ast))]
           (do
             (when (= :def (:type source))
               (swap! (:usages source) inc))
             (assoc ast :source (dissoc source :usages)))
           ;; "Undefined variable" (:var ast)
           (utils/error "Undefined variable" ast
                        :variable (:var ast)))
    (utils/ast-walk analyze-env identity ast)))

(defn remove-unused [ast]
  (let [f (fn [ast]
            (case (:node ast)
              :defs-group
              (let [in-use (for [mr-group (:defs ast)
                                 :let [group' (for [def mr-group
                                                    :let [usages @(:usages def)]
                                                    :when (pos? usages)]
                                                (assoc def :usages usages))]
                                 :when (seq group')]
                             group')]
                (if (not-empty in-use)
                  (assoc ast :defs in-use)
                  (:expr ast)))
              ast))]
    (utils/ast-postwalk f ast)))

(defn analyze
  ([ast] (analyze ast {}))
  ([ast env] (analyze ast env {:deflate?       true
                               :conditional?   false
                               :clauses?       false
                               :patterns?      false
                               :remove-unused? true
                               :typecheck?     false}))
  ([ast env options]
   (binding [vars/*env* env
             *tail-pos* nil]
     (-> ast
         (translate options)
         (with-sha)
         (analyze-env)
         (cond-> (:remove-unused? options) (remove-unused))
         (set-free-vars)
         (divide-defs)
         (cond-> (:typecheck? options) (->
                                         (translate {:clauses?     true
                                                     :patterns?    true
                                                     :conditional? true
                                                     :deflate?     true})
                                         (typecheck/process (utils/map-vals #(get-in % [:source :T]) env)
                                                            (:dependencies options))))
         ))))

(defn flat-namespace [ast]
  (case (:node ast)
    :defs-group (concat (for [mr-group (:defs ast)
                              d        mr-group]
                          [(:name d) d])
                        (flat-namespace (:expr ast)))
    :refer (flat-namespace (:expr ast))
    nil))

(defn analyze-namespace
  [ns env options]
  (-> {:node  :declarations
       :decls (:body ns)
       :expr  {:node :stop}}
      (analyze env (assoc options :remove-unused? false))
      (flat-namespace)
      (->> (into {}))))
