(ns orcl.utils
  (:require [clojure.string :as str]
            [orcl.utils.hashify :as hashify]
    #?(:clj
            [orcl.utils.macro :as macro]))
  #?(:cljs (:require-macros [orcl.utils.macro :as macro])))



(defn sha
  [x]
  (cond
    (map? x) (if-let [sha (:sha x)] x (sha (sort-by first x)))
    (sequential? x) (sha (str/join "" (map sha x)))
    (string? x) (hashify/hashify x)
    (keyword? x) (sha (name x))
    (symbol? x) (sha (name x))
    :else (sha (str x))))

(defn with-sha [node]
  (assoc node :sha (sha node)))

(defn map-vals
  "Build map k -> (f v) for [k v] in map, preserving the initial type"
  [f m]
  (cond
    (sorted? m)
    (reduce-kv (fn [out-m k v] (assoc out-m k (f v))) (sorted-map) m)
    (map? m)
    (persistent! (reduce-kv (fn [out-m k v] (assoc! out-m k (f v))) (transient {}) m))
    :else
    (macro/for-map [[k v] m] k (f v))))

(defn index-by
  "Returns a map of the elements of `coll` keyed by the result of `f` on each
   element.  The value at each key will be a single element (in contrast to
   `clojure.core/group-by`).  Therefore `f` should generally return an unique
   key for every element - otherwise elements get discarded."
  [f coll]
  (persistent! (reduce #(assoc! %1 (f %2) %2) (transient {}) coll)))

(defn assoc-when
  "Like assoc but only assocs when value is truthy"
  [m & kvs]
  (assert (even? (count kvs)))
  (into (or m {})
        (for [[k v] (partition 2 kvs)
              :when v]
          [k v])))

(defn update-in-when
  "Like update-in but returns m unchanged if key-seq is not present."
  [m key-seq f & args]
  (let [found (get-in m key-seq ::none)]
    (if-not (identical? ::none found)
      (assoc-in m key-seq (apply f found args))
      m)))

(defn def-walk [f def]
  (assoc def :instances (mapv #(update % :body f) (:instances def))))

;; Like clojure.walk but ast specific
(defn ast-walk
  [inner outer ast]
  (case (:node ast)
    (:pruning :sequential :parallel :otherwise) (outer (-> ast (update :left inner) (update :right inner)))
    :conditional (outer (-> ast (update :if inner) (update :then inner) (update :else inner)))
    :lambda (outer (-> ast (update :body inner) (update-in-when [:guard] inner)))
    :declarations (outer (-> ast
                             (update :expr inner)
                             (assoc :decls (mapv #(case (:type %)
                                                    :def (update % :body inner)
                                                    %)
                                                 (:decls ast)))))
    :defs-group (outer (-> ast
                           (update :expr inner)
                           (assoc :defs (doall (for [mr-group (:defs ast)]
                                                 (mapv (partial def-walk inner) mr-group))))))
    (:sites :refer :has-type :declare-types) (outer (update ast :expr inner))
    :call (outer (-> ast (update :target inner) (assoc :args (mapv inner (:args ast)))))
    (:list :tuple) (outer (assoc ast :values (mapv inner (:values ast))))
    :record (outer (assoc ast :pairs (mapv (fn [[k v]] [k (inner v)]) (:pairs ast))))
    (:field-access :dereference) (outer (update ast :target inner))
    (:stop :const :var) (outer ast)))

(defn ast-postwalk
  [f form]
  (ast-walk (partial ast-postwalk f) f form))

(defn ast-prewalk
  [f form]
  (ast-walk (partial ast-prewalk f) identity (f form)))

(defn todo-exception [] (throw #?(:clj (RuntimeException.) :cljs (js/Error.))))

(defn error [msg node & {:as info}]
  (throw (ex-info msg (merge {:orcl/error-pos (:pos node)
                              :orcl/error     msg}
                             info))))

(defn mk-graph [edges backward]
  (->> edges
       (group-by (if backward second first))
       (map-vals #(map (if backward first second) %))))

(defn dfs-sort [edges]
  (let [graph (mk-graph edges false)
        search (fn search [[node & xs] counter visited finishes]
                 (cond
                   (nil? node)
                   [counter visited finishes]

                   (not (contains? visited node))
                   (let [[counter' visited' finishes'] (search (get graph node) (inc counter) (conj visited node) finishes)
                         counter'' (inc counter')]
                     (search xs counter'' visited' (assoc finishes' node counter'')))

                   :else
                   (search xs counter visited finishes)))]
    (let [[_ _ finishes] (search (keys graph) 0 #{} {})]
      finishes)))

(defn sort-keys-by-vals [m]
  (map first (sort-by second m)))

(defn graph-components-in-order
  [edges nodes]
  (let [graph (mk-graph edges true)
        visited (atom #{})]
    (letfn [(nodes-seq [[n & xs]]
              (cond
                (nil? n) nil

                (contains? @visited n)
                (nodes-seq xs)

                :else
                (do (swap! visited conj n)
                    (cons n (concat (nodes-seq (get graph n))
                                    (nodes-seq xs))))))]
      (keep #(nodes-seq [%]) nodes))))