(ns instacheck.util
  (:require [clojure.set :refer [union]]
            [clojure.walk :refer [postwalk]]
            [clojure.string :as string]))


(defn tree-matches
  "Return seq of pred? matches for any node in the tree."
  [pred? tree]
  (let [branch? (some-fn map? sequential?)
        children (fn [n] (if (map? n) (vals n) (seq n)))
        all-nodes (tree-seq branch? children tree)]
    (seq (filter pred? all-nodes))))

(defn tree-deps
  "Takes a structure like {:a tree-a :b tree-b :c tree-c} and returns
  a map like {:a #{:b :c} :b #{:c} :c #{}} which means that :a appears
  in tree-b and tree-c, :b appears in tree-c, but :c does not appear
  in tree-a or tree-b."
  [trees]
  (apply merge-with
         union
         (for [k1 (keys trees)
               [k2 t] trees]
           (if (tree-matches #(= k1 %) t)
             {k2 #{k1}}
             {k2 #{}}))))

(defn remove-key
  "Walk a tree removing every key/value where key match k"
  [tree k]
  (postwalk #(if (and (vector? %) (= k (first %))) nil %) tree))


(comment

(def ttree {:a [1 2 [:b] {:foo [:c :c]}]
            :b {:bar {:baz [:qux :c]}}
            :c {:foo {:bar [:baz :qux []]}}})

(tree-matches #(= :c %) (:a ttree))
;=>(:c :c)

(tree-deps ttree)
;=>{:a #{:b :c} :b #{:c} :c #{}}

)

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(defn flatten-text*
  "Take a tree (sequences hierarchy) and flattens it all the way to
  a single sequence of numbers and strings. Empty values are removed."
  [tree]
  (lazy-seq
    (cond
      (or (number? tree) (string? tree))  (list tree)
      (empty? tree)                       (list)
      :else                               (mapcat flatten-text* tree))))

(defn flatten-text
  "Take a tree (sequences hierarchy) and flattens it all the way to
  a single string (optionally separated by sep). Empty values are
  removed."
  [tree & [sep]]
  (string/replace
    (apply str (if sep
                 (interpose sep (flatten-text* tree))
                 (flatten-text* tree)))
    #" +" " "))

(comment

(flatten-text ["foo" "" [[nil "bar"] "baz" ["qux"]]])
;=>"foobarbazqux"

(flatten-text [" "])
;=>"foobar bazqux"

(flatten-text ["foo" "" [[nil "bar"] "baz" ["qux"]]] " ")
;=>"foo bar baz qux"

)

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; weight utilities

;;(defn prune-zero-weights
;;  "Remove weight paths with weight value of zero."
;;  [weights]
;;  (into {} (filter (comp pos? val) weights)))

(defn filter-alts
  "Remove paths for weights that are not alternations. Only
  alternations (gen/frequency) are currently affected by the weights
  so remove everything else."
  [weights]
  (into {} (filter #(-> % key reverse second (= :alt)) weights)))

;;(defn weight-leaves
;;  "Prune off intermediate weight paths (e.g. return only leaf node
;;  weights)"
;;  [weights]
;;  (let [keyfn (fn [[k v]] (string/join "-" k))
;;        sorted-kvs (sort-by keyfn weights) ]
;;    (into
;;      {}
;;      (reduce (fn [res [k' v' :as kv']]
;;                (let [[k v :as kv] (last res)]
;;                  (if (= k (take (count k) k'))
;;                    (conj (vec (butlast res)) kv')
;;                    (conj res kv'))))
;;              []
;;              sorted-kvs))))

(comment

(prune-zero-weights {[:a :b :c] 0
                     [:a :b] 3
                     [:a :b :c :d] 0})
;=>{[:a :b] 3}

(filter-alts {[:abc :alt 1] 7
              [] 8
              [:def :alt 2 :star] 9
              [:def :alt 2 :cat 2 :alt 3] 9})
;=>{[:abc :alt 1] 7, [:def :alt 2 :cat 2 :alt 3] 9}


(weight-leaves {[:b :a] 3
                [:a] 4
                [:b] 1
                [:a :b] 6
                [:a :c] 7
                [:a :b :c] 8
                [:a :c :b] 2
                [:a :c :d] 0})
;=>{[:a :b :c] 8, [:a :c :b] 2, [:a :c :d] 0, [:b :a] 3}

)

(defn weight-reducer-sqrt
  "Returns the amount to reduce a weight that will result in the final
  weight being the sqrt of the original weight. A weight of 1 should
  return 1 and 0 should return 0."
  [weight]
  (- weight (Math/round (Math/sqrt (- weight 1)))))

(defn weight-reducer-half
  "Returns the amount to reduce a weight that will result in the final
  weight being the half of the original weight. A weight of 1 should
  return 1 and 0 should return 0."
  [weight]
  (- weight (Math/round (float (/ (- weight 1) 2)))))

(comment

(loop [a [] w 100]
  (if (= (last a) 0) a (recur (conj a w) (- w (weight-reducer-sqrt w)))))
;=>[100 10 3 1 0]

(loop [a [] w 100]
  (if (= (last a) 0) a (recur (conj a w) (- w (weight-reducer-half w)))))
;=>[100 50 25 12 6 3 1 0]

)


(defn group-by-alts
  [weights]
  (let [just-alts (filter-alts weights)]
    (group-by (comp vec butlast key) just-alts)))

(defn group-alts
  [weights]
  (let [grouped (group-by-alts weights)]
    (into {} (for [[alt choices] grouped]
               [alt (into {} (for [choice choices]
                               [(last (first choice)) (last choice)]))]))))

(comment

(group-alts {[:foo :a :alt 0] 6
             [:foo :a :alt 1] 7
             [:foo :b :alt 0] 2
             [:foo :b :alt 1] 3
             [:foo :b :alt 2] 4})
;=>{[:foo :a :alt] {0 6, 1 7}, [:foo :b :alt] {0 2, 1 3, 2 4}}

)

(defn get-grammar-node
  "Get the a grammar node for the given path in grammar. Nil is
  returned if the path does not exist in the grammar, the tag type
  along the path don't match, or the numeric parser index is
  out-of-bounds."
  [grammar path]
  (loop [g (get grammar (first path))
         p (rest path)]
    (let [[t1 t2 & ts] p]
      (cond
        (empty? p)
        g

        (or (nil? g) (not= (:tag g) t1))
        nil

        (and (number? t2) (:parsers g) (> (count (:parsers g)) t2))
        (recur (nth (:parsers g) t2) ts)

        (:parser g)
        (recur (:parser g) (rest p))

        :else
        nil))))

(defn filter-weight-nodes
  "Return a weight map where the grammar nodes (for each weight path)
  fulfill pred? predicate. For example, to return weights that only
  refer to string or regex terminals use the predicate:
    #(#{:string :regex} (:tag %))"
  [grammar weights pred?]
  (into {} (filter #(pred? (get-grammar-node grammar (key %))) weights)))

(defn- weights-ratios
  [grammar change-weights]
  (let [palts (filter-alts change-weights)
        literals-fn (fn [node] (#{:string :regex} (:tag node)))
        pliterals (filter-weight-nodes grammar palts literals-fn)
        ptotal (reduce + (vals pliterals))]
    (for [[k v] pliterals] [k (/ v ptotal)])))

(defn weights-reduce
  "Takes a grammar, an starting weight map, a map weights to change
  (change-weights), and a per weight reducer function. Returns a new
  weight map with weights that are reduced based on change-weights and
  reducer-fn. If there is a single weight in the change-weights map
  then that weight in the original will be reduced by (reducer-fn
  original-weight). If there is more than one weight in
  change-weights, then the reduce value will be proportional to the
  ratio of the change-weight value to the sum of all the change-weight
  values."
  [grammar weights change-weights reducer-fn]
  (let [pratios (weights-ratios grammar change-weights)]
    ;;pratios
    (reduce (fn [weights [path ratio]]
              (let [w (get weights path)]
                (assoc weights path (int (- w (* ratio (reducer-fn w)))))))
            weights pratios)))

;; TODO: check if there are any alt groups that have a 0 total and
;; raise an exception. Eventually, for 0 alt groups, this should
;; search for alts earlier in tree and reduce those to zero instead of
;; failing.
(defn check-weight-alts
  [])

(comment

(loop [a [] w 100]
  (prn :w w)
  (if (> w 0)
    (recur (conj a w) (- w (weight-reducer-sqrt w)))
    (conj a 0)))
;=>[100 10 3 1 0]

(weights-reduce {[:a :b] 100 [:c :d :e] 10 [:f :g] 25}
                [[:a :b] [:c :d :e]]
                weight-reducer-sqrt)
;=>{[:a :b] 10, [:c :d :e] 3, [:f :g] 25}

)

