(ns circle-util.core
  (:refer-clojure :exclude [bean remove filter])
  (:require [clojure.core.typed :as t :refer (AnyInteger Option)]
            [clojure.string :as str]
            [clojure.tools.logging :refer (infof)]
            [circle-util.type.ann])
  (:import (clojure.lang IRef
                         Named)
           java.security.MessageDigest
           (java.io InputStream
                    ByteArrayInputStream
                    PushbackReader
                    StringReader)
           java.net.URL))

(t/warn-on-unannotated-vars)

(t/ann atom-set! [(t/Atom1 t/Any) t/Any -> t/Any])
(defn atom-set! [a v]
  (swap! a (constantly v)))

(t/ann printfln [String t/Any * -> nil])
(defn printfln [fmt & args]
  (apply printf fmt args)
  (newline))

(t/ann prnlog [String t/Any * -> nil])
(defn prnlog [fmt & args]
  (let [s (apply format fmt args)]
    (println s)
    (infof s)))

(t/ann ^:no-check apply-map [(t/IFn [t/Any -> t/Any]) t/Any * -> t/Any])
(defn apply-map
  "Takes a fn and any number of arguments. Applies the arguments like
  apply, except that the last argument is converted into keyword
  pairs, for functions that keyword arguments.

  (apply foo :a :b {:c 1 :d 2}) => (foo :a :b :c 1 :d 2)"
  [f & args*]
  (let [normal-args (butlast args*)
        m (last args*)]
    (when m
      (assert (map? m) "last argument must be a map"))
    (apply f (concat normal-args (apply concat (seq m))))))

(t/ann in? (t/All [x] [(t/Seqable x) x -> (Option Boolean)]))
(defn in?
  "True if val is contained in seq. O(n)"
  [seq val]
  (some #(= val %) seq))

(t/ann ^:no-check seq1 (t/All [x] [(t/Seqable x) -> (t/Seqable x)]))
(defn seq1
  "Converts a normal seq, with chunking behavior, to one-at-a-time. See http://blog.fogus.me/2010/01/22/de-chunkifying-sequences-in-clojure/"
  [#^clojure.lang.ISeq s]
  (reify clojure.lang.ISeq
    (first [_] (.first s))
    (more [_] (seq1 (.more s)))
    (next [_] (let [sn (.next s)] (and sn (seq1 sn))))
    (seq [_] (let [ss (.seq s)] (and ss (seq1 ss))))
    (count [_] (.count s))
    (cons [_ o] (.cons s o))
    (empty [_] (.empty s))
    (equiv [_ o] (.equiv s o))))

(t/ann ^:no-check apply-if [t/Any (t/IFn [t/Any -> t/Any]) t/Any * -> t/Any])
(defn apply-if
  "if test (apply f arg args) else arg"
  [test f arg & args]
  (if test
    (apply f arg args)
    arg))

(defn assoc-when-in
  "assoc-in, but only if the nested sequence of keys already exists!"
  [m ks v & args]
  (let [sentinel (Object.)]
    (if-not (identical? sentinel (get-in m ks sentinel))
      (apply assoc-in m ks v args)
      m)))

(defn update-when-in
  "update-in, but only if the nested sequence of keys already exists!"
  [m ks f & args]
  (let [sentinel (Object.)]
    (if-not (identical? sentinel (get-in m ks sentinel))
      (apply update-in m ks f args)
      m)))

(t/ann byte-array-to-hex-string [(Array byte) -> String])
(defn byte-array-to-hex-string [ba]
  (-> ba
      (org.apache.commons.codec.binary.Hex/encodeHex)
      (String.)))

(t/ann ^:no-check digest [String (t/U (Array byte) String) -> String])
(defn digest
  "Returns the digest. Takes a byte array or string"
  [algo s]
  (let [ba (if (string? s)
             (.getBytes ^String s "UTF-8")
             s)]
    (-> (MessageDigest/getInstance algo)
        (.digest ba)
        (byte-array-to-hex-string))))

(t/ann sha1 [String -> String])
(defn sha1
  "Returns the SHA1. Takes a byte array or string"
  [s]
  (digest "SHA-1" s))

(t/ann md5 [String -> String])
(defn md5
  "Returns the SHA1. Takes a byte array or string"
  [s]
  (digest "MD5" s))

(defn secure-rand
  "SecureRandom equivalent of clojure.core/rand"
  ([] (.nextDouble (java.security.SecureRandom/getInstance "SHA1PRNG")))
  ([n] (* n (secure-rand))))

(defn secure-rand-int
  "SecureRandom equivalent of clojure.core/rand-int"
  [n]
  (int (secure-rand n)))

(defn- ->interval-number
  "Return the number of the interval that index belongs to.

  Intervals is a list of the end-index (exclusive) of intervals of a list.
  Index is an index into the same list.

  Intervals must be non-empty and each value must be a positive integer."
  [intervals index]
  {:pre [(seq intervals)
         (every? pos? intervals)]}
  ;; counts the numbers of intervals we need to skip to reach the interval that
  ;; contains the index
  (count (take-while #(<= % index) intervals)))

(defn weighted-rand-value
  "Passed a map of value to weight, returns a random value taking weights
  into account.
  E.g.
  (weighted-rand-value {:red 1, :blue 3, :yellow 10}) is most likely to return
  :yellow.

  Weights must be positive integers

  Inspired by http://stackoverflow.com/a/14467227"
  [weights]
  {:pre [(every? pos? (vals weights))]}
  (let [;; Assume we have a list of each value repeated `weight` times and
        ;; flattened. Intervals is the list of end-indices (exclusive) for each
        ;; range of values: (1 6 16) for the docstring example
        intervals (reductions + (vals weights))
        index (rand-int (last intervals))
        selected-interval (->interval-number intervals index)]
    (nth (keys weights) selected-interval)))

(defmacro defn-once
  "Defs a function of no arguments that will execute body only once,
  the first time the function is called. On all future calls, returns
  the cached return value from the first run.

 ex:
  (defn-once foo
     (bar)
     (time/now))"
  [name & body]
  `(do
     (let [d# (delay (do ~@body))]
       (defonce ~name (fn []
                        (deref d#))))))

(defmacro defn-partial
  "Defines a new fn that is a partial of another fn. e.g.
  (def foo (partial bar arg1)). Copies the docs and arglist over from the existing fn, but removes args from the arglist metadata for each arg passed to defn-partial"
  [name f & part-args]
  (let [arg-count (count part-args)
        old-meta (meta (resolve f))
        new-meta (-> old-meta
                     (update-in [:arglists] vec)
                     (update-in [:arglists 0] #(vec (drop arg-count %)))
                     (update-in [:arglists] list*))]
    `(do
       (def ~name (fn [& new-args#]
                    (apply ~f (concat ~@part-args new-args#))))
       (alter-meta! (var ~name) (constantly (quote ~new-meta))))))

(t/ann ->int (t/IFn [String -> AnyInteger]
                    [String AnyInteger -> AnyInteger]))
(defn ->int
  "Converts a string representation of an int to a real int.
  If base is provided, parse the number in base N"
  ([int-str base]
     (Integer/parseInt int-str (int base)))
  ([int-str]
     (->int int-str 10)))

(t/ann -?>int [String -> (t/Option t/Int)])
(defn -?>int
  "->int, but returns nil (instead of throwing) if the conversion fails."
  [int-str]
  (try
    (->int int-str)
    (catch NumberFormatException _
      nil)))

(defn ->int-list
  "Converts a string int-list into a list of actual ints. int-list can include ranges (0-3)
  and commas, e.g. '1', '1,2', '1-3,5'..."
  [int-list-str]
  (->> (str/split int-list-str #",")
       (map #(str/split % #"-" 2))
       (mapcat #(if (= 1 (count %))
                  (list (->int (first %)))
                  (range (->int (first %))
                         (inc (->int (second %))))))))

(t/ann ->double [String -> Double])
(defn ->double
  [double-str]
  (Double/parseDouble double-str))

(t/ann -?>double [String -> (t/U Double nil)])
(defn -?>double
  "->double, but returns nil (instead of throwing) if the conversion fails."
  [double-str]
  (try
    (->double double-str)
    (catch NumberFormatException _
      nil)))

(t/ann ->long (t/IFn [String -> AnyInteger]
                     [String AnyInteger -> AnyInteger]))
(defn ->long
  "Converts a string representation of a long to a real long. If base is provided, parse the number in base N"
  ([long-str base]
     (Long/parseLong long-str (int base)))
  ([long-str]
     (->long long-str 10)))

(t/ann -?>long [String -> (t/Option AnyInteger)])
(defn -?>long
  "->long, but returns nil (instead of throwing) if the conversion fails."
  [long-str]
  (try
    (->long long-str)
    (catch Exception _
      nil)))

(t/ann ^:no-check to-name (t/IFn [(t/U Named String) -> String]
                                 [t/Any -> t/Any]))
(defn to-name
  "Returns x, or (name x) if supported"
  [x]
  (try (name x)
       (catch Exception _ x)))


(t/ann log2 [(t/U double long Double AnyInteger) -> Number])
(defn log2 [v]
  (/ (Math/log (double v)) (Math/log (double 2))))

(t/defalias WatchFn (t/All [x] [t/Any (t/Ref1 x) x x -> x]))

(t/ann get-watches (t/All [x] [(t/Ref1 x) -> (t/Map t/Any t/Any)]))
;; should be -> (t/Map t/Any WatchFn), but can't prove it
(defn get-watches
  "Returns the watchers on a ref."
  [ref]
  (.getWatches ^IRef ref))

(defn read-string-all
  "Takes a string of clojure code. Reads all of it, returning a seq of one or more forms returned by read"
  [clj-str]
  (let [reader (PushbackReader. (StringReader. clj-str))
        result (read reader)]
    (lazy-cat [result]
              (when result
                [(read reader)]))))

(defn name-kw
  "Given a keyword, return its whole name.
  (name :foo/bar) => bar
  (name-kw :foo/bar) => foo/bar"
  [k]
  (-> k str (subs 1)))

(t/ann fixed-point (t/All [x] [[x -> x] -> [x -> x]]))
(defn fixed-point
  "takes a fn of one argument, returns a fn that recurs if the new result is different from the old. Only use on fns that converge!"
  [f]
  (t/fn [arg :- x]
    (t/loop [v :- x, arg
             iterations :- Long, 100000] ;; number of iterations before giving up
      (let [new-result (f v)]
        (if (= new-result v)
          new-result
          (if (pos? iterations)
            (recur new-result (dec iterations))
            (throw (Exception. "no fixed point is found in so many iterations.  Is function convergent?"))))))))

(defn bounding
  "returns n, making sure it is between floor and ceil"
  [n floor ceil]
  (max (min n ceil) floor))

(defn third [l]
  (nth l 2))

(defn ->url [s]
  (URL. s))

(defn ungroup-by
  "given a map that was created by group-by, undo it, returning a seq of vals"
  [m]
  (apply concat (vals m)))

(defn stacktrace->string
  "given an exception, returns the result of printStackTrace as a string"
  [^Throwable e]
  (let [sw (java.io.StringWriter.)
        pw (java.io.PrintWriter. sw)]
    (.printStackTrace e pw)
    (.toString sw)))

(t/ann string->stream [String -> InputStream])
(defn string->stream
  "Takes a string, returns an input stream that will return the string"
  [x]
  (-> x
      (str)
      (.getBytes)
      (ByteArrayInputStream.)))

(defmacro when-seq
  "bindings => xs ys
Same as (let [xs ys] (when (seq xs) body))"
  [bindings & body]
  (let [[xs ys] bindings]
    `(let [~xs ~ys]
       (when (seq ~xs)
         ~@body))))

; like bean but returns a map that implements iterator
(defn bean [& args]
  ; need to call seq
  (into {} (seq (apply clojure.core/bean args))))

(defn remove
  "Returns a new, eager, collection of the same type as coll consisting of the
  items in coll for which (pred item) returns false. Pred must be free from
  side-effects."
  [pred coll]
  (into (empty coll) (clojure.core/remove pred coll)))

(defn filter
  "Returns a new, eager, collection of the same type as coll consisting of the
  items in coll for which (pred item) returns true. Pred must be free from
  side-effects."
  [pred coll]
  (into (empty coll) (clojure.core/filter pred coll)))
