(ns cortex.util
  (:refer-clojure :exclude [defonce])
  (:require
    [clojure.core.matrix :as m]
    [clojure.core.matrix.random :as rand-matrix]
    [cortex.core-matrix-backends :as b]
    [clojure.string :as str]
    [clojure.pprint :as pp]
    #?(:cljs [goog.string :refer [format]]))
  #?(:clj (:import [mikera.vectorz Vectorz]
                   [java.io Writer]
                   [java.util Random])))

#?(:clj (do (set! *warn-on-reflection* true)
            (set! *unchecked-math* :warn-on-boxed)))

;;;; Vars
(defmacro defonce
  "Like clojure.core/defonce, but allows docstring."
  {:added "1.0"
   :arglists '([symbol doc-string? init?])}
  [name & args]
  `(let [^clojure.lang.Var v# (ns-resolve '~(ns-name *ns*) '~name)]
     (when (or (nil? v#)
               (not (.hasRoot v#)))
       (def ~name ~@args))))

(defmacro def-
  "Analogue to defn- for def."
  [name & decls]
  (list* `def (with-meta name (assoc (meta name) :private true)) decls))

;;;; Timing

(defmacro ctime*
  "Returns a map where :return is the value of expr and :time is
  the CPU time needed to evaluate it, in milliseconds."
  [expr]
  `(let [thread# (java.lang.management.ManagementFactory/getThreadMXBean)
         start# (.getCurrentThreadCpuTime thread#)
         return# ~expr
         end# (.getCurrentThreadCpuTime thread#)]
     {:return return#
      :time (/ (- end# start#) 1000000.0)}))

(defmacro ctime
  "Returns the CPU time needed to evaluate expr in a user-friendly
  string, in the same format as clojure.core/time."
  [expr]
  `(format "Elapsed time: %f msecs" (:time (ctime* ~expr))))

(defn clamp
  "Constrains x to be between floor and ceil."
  ^double [^double floor ^double x ^double ceil]
  (max floor
       (min x ceil)))

(defn relative-error
  "Calculates the relative error between two values."
  ^double [^double a ^double b]
  (if-not (and (zero? a) (zero? b))
    (/ (Math/abs (- a b))
       (max (Math/abs a) (Math/abs b)))
    0))

(defn avg
  "Calculates the arithmetic mean of a sequence of numbers."
  [& xs]
  (double (/ (double (m/esum xs)) (count xs))))

;;;; Random number generation

#?(:clj
   (do
     (def ^Random RAND-GENERATOR (Random.))

     (defn rand-normal ^double []
       (.nextDouble RAND-GENERATOR))

     (defn rand-gaussian ^double []
       (.nextGaussian RAND-GENERATOR)))
   :cljs
   (do
     (defn rand-gaussian* [mu sigma]
       ;; This function implements the Kinderman-Monahan ratio method:
       ;;  A.J. Kinderman & J.F. Monahan
       ;;  Computer Generation of Random Variables Using the Ratio of Uniform Deviates
       ;;  ACM Transactions on Mathematical Software 3(3) 257-260, 1977
       (let [u1  (rand)
             u2* (rand)
             u2 (- 1. u2*)
             s (* 4 (/ (Math/exp (- 0.5)) (Math/sqrt 2.)))
             z (* s (/ (- u1 0.5) u2))
             zz (+ (* 0.25 z z) (Math/log u2))]
         (if (> zz 0)
           (recur mu sigma)
           (+ mu (* sigma z)))))

     (defn rand-normal []
       (rand))

     (defn rand-gaussian []
       (rand-gaussian* 0 1.0))))

;;;; Sequences

(defn seq-like?
  "Returns true if x ought to be viewed as a sequence. This is a
  subjective function, and is used in sformat to determine whether
  to format arguments as sequences. It accounts for built-in Clojure
  types as well as core.matrix vectors."
  [x]
  (and
    (not (nil? x))
    (not (string? x))
    (or
      (instance? clojure.lang.Seqable x)
      (try
        (do
          (seq x)
          true)
        (catch IllegalArgumentException _
          false)))))

(defn interleave-all
  "Like interleave, but does not stop until all sequences are
  exhausted."
  [& colls]
  (lazy-seq
    (if-let [ss (seq (filter identity (map seq colls)))]
      (concat (map first ss) (apply interleave-all (map rest ss))))))

(defn pad
  "If coll is of length less than n, pads it to length n using pad-seq."
  [n pad-seq coll]
  (if (seq (drop n coll))
    coll
    (take n (concat coll pad-seq))))

;;;; Collections

(defn map-keys
  "Applies f to each of the keys of a map, returning a new map."
  [f map]
  (reduce-kv (fn [m k v]
               (assoc m (f k) v))
             {}
             map))

(defn map-vals
  "Applies f to each of the values of a map, returning a new map."
  [f map]
  (reduce-kv (fn [m k v]
               (assoc m k (f v)))
             {}
             map))

(defn map-entry
  [k v]
  (clojure.lang.MapEntry/create k v))

(defn approx=
  "Determines if the collections are all equal, but allows floating-point
  numbers to differ by a specified relative error. Works with arbitrarily
  deeply nested collections. For collections with no floating-point numbers,
  behaves the same as regular =. Also works if you provide plain
  floating-point numbers instead of collections. Notice that integers will
  be compared exactly, because presumably you will not have rounding errors
  with integers. Thus:

  (approx= 0.1 10 11) => false
  (approx= 0.1 10.0 11.0) => true

  Does not require collections to be of compatible types, i.e. sets can be
  equal to vectors. As a result, works seamlessly with core.matrix vectors
  of different implementations.

  See also approx-diff."
  [error & colls]
  (let [seqs (map (partial tree-seq seq-like? seq)
                  colls)]
    (and (apply = (map count seqs))
         (->> seqs
           (apply map (fn [& items]
                        (cond
                          (every? float? items)
                          (<= (relative-error (apply min items)
                                              (apply max items))
                              ^double error)
                          (every? seq-like? items)
                          true
                          :else
                          (apply = items))))
           (every? identity)))))

(defn approx-diff
  "Recursively traverses the given data structures, which should be
  identical except for corresponding floating-point numbers and the
  concrete types of sequence-like objects.

  Returns a data structure of the same form as the provided ones,
  with floating-point numbers replaced with the relative errors of
  the corresponding sets of numbers.

  See also approx=."
  [& colls]
  (cond
    (every? float? colls)
    (relative-error (apply min colls)
                    (apply max colls))
    (every? list? colls)
    (->> colls
      (apply map (partial apply approx-diff))
      (apply list))
    (every? map-entry? colls)
    (map-entry (apply approx-diff (map key colls))
               (apply approx-diff (map val colls)))
    (every? seq-like? colls)
    (->> colls
      (apply map approx-diff)
      (into (empty (first colls))))
    (apply = colls)
    (first colls)
    :else
    (throw
      (IllegalArgumentException.
        (str "Divergent nodes: " colls)))))

(defn vectorize
  "Recursively turns a data structure into nested vectors. All sequence-like
  types except maps and strings are transformed into vectors, but the structure
  of the data is maintained. Transforms core.matrix vectors into normal Clojure
  vectors.

  (vectorize (list 1 2 {\"key\" #{3 4} (list) (new-vector 3 5)} [6 7]))
  => [1 2 {\"key\" [4 3], [] [5 5 5]} [6 7]]"
  [data]
  (cond
    (map? data)
    (into {}
          (map (fn [[k v]]
                 [(vectorize k)
                  (vectorize v)])
               data))
    (seq-like? data)
    (mapv vectorize data)
    :else
    data))

(defmacro extend-print
  [class str-fn]
  `(do
     (defmethod print-dup ~class
       [obj# ^Writer writer#]
       (.write writer# ^String (~str-fn obj#)))
     (defmethod print-method ~class
       [obj# ^Writer writer#]
       (.write writer# ^String (~str-fn obj#)))))

(defn calc-mean-variance
  [data]
  (let [num-elems (double (m/ecount data))
        elem-sum (double (m/esum data))]
    (if (= num-elems 0.0)
      {:mean 0
       :variance 0}
      (let [mean (/ elem-sum num-elems)
            variance (/ (double (m/ereduce (fn [^double sum val]
                                             (let [temp-val (- mean (double val))]
                                               (+ sum (* temp-val temp-val))))
                                           0.0
                                           data))
                        num-elems)]
        {:mean mean
         :variance variance}))))


(defn ensure-gaussian!
  [data ^double mean ^double variance]
  (let [actual-stats (calc-mean-variance data)
        actual-variance (double (:variance actual-stats))
        actual-mean (double (:mean actual-stats))
        variance-fix (Math/sqrt (double (if (> actual-variance 0.0)
                                          (/ variance actual-variance)
                                          1.0)))
        adjusted-mean (* actual-mean variance-fix)
        mean-fix (- mean adjusted-mean)]
    (doall (m/emap! (fn [^double data-var]
                      (+ (* variance-fix data-var)
                         mean-fix))
                    data))))


(defonce weight-initialization-types
  [:xavier
   :bengio-glorot
   :relu])


(defn weight-initialization-variance
  "http://andyljones.tumblr.com/post/110998971763/an-explanation-of-xavier-initialization"
  [^long n-inputs ^long n-outputs initialization-type]
  (condp = initialization-type
    :xavier (/ 1.0 n-inputs)
    :bengio-glorot (/ 2.0 (+ n-inputs n-outputs))
    :relu (/ 2.0 n-inputs)
    (throw (Exception. (format "%s fails to match any initialization type."
                               initialization-type)))))


(defn weight-matrix
  "Creates a randomised weight matrix.
  Weights are gaussian values 0-centered with variance that is dependent upon
  the type of initialization [xavier, bengio-glorot, relu].
  http://andyljones.tumblr.com/post/110998971763/an-explanation-of-xavier-initialization.
  Initialization defaults to xavier."
  ([^long n-output ^long n-input initialization-type]
   (let [mean 0.0
         variance (weight-initialization-variance n-input n-output initialization-type)]
     ;;Java's gaussian generated does not generate great gaussian values for small
     ;;values of n (mean and variance will be > 20% off).  Even for large-ish (100-1000)
     ;;ones the variance is usually off by around 10%.
     (b/array (vec (repeatedly n-output
                               #(ensure-gaussian! (double-array
                                                   (vec (repeatedly
                                                         n-input
                                                         rand-gaussian)))
                                                  mean variance))))))
  ([^long n-output ^long n-input]
   (if (= 1 n-output n-input)
     (b/array [[0]])
     (weight-matrix n-output n-input :xavier))))


(defn identity-matrix
  "Creates a square identity matrix"
  ([^long n-output]
   (b/array (mapv (fn [^long idx]
                    (let [retval (double-array n-output)]
                      (aset retval idx 1.0)))
                  (range n-output)))))



(defn random-matrix
  "Constructs an array of the given shape with random normally distributed element values"
  ([shape-vector]
   (rand-matrix/sample-normal shape-vector)))

(defn assign-sparse-to-packed!
  [packed-data sparse-data]
  (let [packed-data (if packed-data
                      packed-data
                      (let [elem-count (reduce + (map m/ecount sparse-data))]
                        (b/new-array [elem-count])))]
    (reduce (fn [^long offset next-item]
              (let [item-vec (m/as-vector next-item)
                    item-len (long (m/ecount item-vec))]
                (m/assign! (m/subvector packed-data offset item-len) item-vec)
                (+ offset item-len)))
            0
            sparse-data)
    packed-data))

(defn assign-packed-to-sparse!
  [sparse packed]
  (reduce (fn [^long offset next-item]
            (let [elem-count (long (m/ecount next-item))]
              (m/assign! (m/as-vector next-item)
                         (m/subvector packed offset elem-count))
              (+ offset elem-count)))
          0
          sparse))

(defn zero-sparse!
  [sparse]
  (doseq [item sparse]
    (m/fill! item 0.0)))


(defn get-or-new-array
  "Gets an array from the associative dtata structure item, or returns a new empty array
   of the specified shape"
  [item kywd shape]
  (or (get item kywd)
      (b/new-array shape)))

(defn get-or-array
  "Gets an array from the associative dtata structure item, or returns a new mutable array
   containing a clone of data"
  [item kywd data]
  (or (get item kywd)
      (b/array data)))

(def DEFAULT-TOLERANCE 0.001)
(def DEFAULT-MAX-TESTS 100)

(defn converges?
  "Tests if a sequence of array values converges to a target value, with a given tolerance.
   Returns nil if convergence does not happen, the success value from the sequence if it does."
  ([sequence target]
   (converges? sequence target nil))
  ([sequence target {:keys [tolerance max-tests test-fn hits-needed] :as options}]
   (let [tolerance (or tolerance DEFAULT-TOLERANCE)
         max-tests (long (or max-tests DEFAULT-MAX-TESTS))
         test-fn (or test-fn identity)
         hits-needed (long (or hits-needed 1))]
     (loop [i 0
            hits 0
            sequence (seq sequence)]
       (when (< i max-tests)
         (if-let [v (first sequence)]
           (if (m/equals target (test-fn v) tolerance) ;; equals with tolerance
             (if (>= (inc hits) hits-needed)
               v
               (recur (inc i) (inc hits) (next sequence)))
             (recur (inc i) 0 (next sequence)))))))))

;;;; Time


#?(:clj
   (defmacro error
     "Throws an error with the provided message(s). This is a macro in order to try and ensure the
     stack trace reports the error at the correct source line number."
     ([& messages]
      `(throw (mikera.cljutils.Error. (str ~@messages)))))
   :cljs
   (defn error [& messages]
     (throw (mikera.cljutils.Error. (apply str messages)))))

(defmacro error?
  "Returns true if executing body throws an error, false otherwise."
  ([& body]
   `(try
      ~@body
      false
      (catch Throwable t#
        true))))

;;;; Formatting

(defn parse-long
  "(parse-long x) => (Long/parseLong x)"
  [x]
  (Long/parseLong x))

(defn parse-double
  "(parse-double x) => (Double/parseDouble x)"
  [x]
  (Double/parseDouble x))

(def fmt-string-regex
  "Matches a Java format specifier, see
  https://docs.oracle.com/javase/8/docs/api/java/util/Formatter.html
  for details. Groups match the argument number (if given) and the
  remainder of the format specifier, respectively."
  #"%(?:([1-9][0-9]*)\$)?([-#+ 0,(]*(?:[1-9][0-9]*)?(?:\.[0-9]+)?(?:[bBhHsScCdoxXeEfgGaA%n]|[tT][HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc]))")

(defn sformat
  "Like format, but smarter. If one of the args is a collection,
  the format specifier is mapped over each element of the collection,
  and the results are placed in the formatted string as a vector.
  Also works on nested collections. For some format types, will
  attempt to cast the object before formatting. That is,
  (sformat \"%.1f\" 42) will return \"42.0\" rather than throwing an
  IllegalFormatConversionException.

  Will not necessarily handle malformed format specifiers gracefully,
  but anything legal according to the Javadoc is fair game."
  [fmt & args]
  (let [fmt-strings
        (for [[fmt-string ^long arg-index fmt-specifier]
              (loop [unprocessed (re-seq fmt-string-regex fmt)
                     processed []
                     arg-index 1]
                (if (seq unprocessed)
                  (let [match (first unprocessed)
                        positional? (nth match 1)
                        argless? (#{\% \n} (last (nth match 2)))]
                    (recur (rest unprocessed)
                           (conj processed (update match 1 (if positional?
                                                             parse-long
                                                             (constantly arg-index))))
                           (if (or positional? argless?)
                             arg-index
                             (inc arg-index))))
                  processed))]
          (let [arg-index (dec arg-index)
                arg (nth args arg-index nil)]
            (if (seq-like? arg)
              (str "["
                   (str/join " "
                             (map (fn [item]
                                    (sformat (str "%" fmt-specifier)
                                             item))
                                  arg))
                   "]")
              (format (str "%" fmt-specifier)
                      (case (last fmt-specifier)
                        (\e \E \f \g \G \a \A) (double arg)
                        (\d \o \x \X) (long arg)
                        arg)))))
        splits (str/split fmt fmt-string-regex)]
    (apply str (interleave-all splits fmt-strings))))

(defn sprintf
  "Like printf, but uses sformat instead of format."
  [fmt & args]
  (print (apply sformat fmt args)))

;;;; Neural networks

(defn mse-gradient-fn
  "Returns the MSE error gradient for a given output and target value"
  ([output target]
   (let [result (m/mutable output)]
     (m/sub! result target)
     (m/scale! result 2.0)
     result)))
