(ns shannon.core
  (:require [shannon.ngrams.en :as en]
            [shannon.compatibility :refer [abs arr->num bits->bytearray bytearray->bits
                                           char->int decompose-date fInfinity fNaN get-float-parts
                                           int->char is-boolean? is-date? is-float? is-infinite?
                                           is-nan? num->arr pow recompose-date round roundd]]
            [shannon.distributions :refer [benford-distribution cdf constant-distribution
                                           custom-distribution geometric-distribution
                                           inverse-cdf next-higher next-lower uniform-distribution
                                           zipf-distribution]]
                                        )

        
  (:require-macros [shannon.macros :as m])

       
                                                                                
                                                                                         
                                                                                                
                                
                               
                                                                              )

                                      

;;;

(def ^:private coding-precision 48)
(def ^:private whole   (m/pow2 48))
(def ^:private half    (/ whole 2.))
(def ^:private quarter (/ whole 4.))

;;;

(defprotocol EmitsBits
  (read! [o])
  (source [o]))

(defprotocol ConsumesBits
  (write! [o bit])
  (target [o]))

(defprotocol DistributionCoder
  (encode-symbol [c distribution sym])
  (decode-symbol [c distribution]))

(defprotocol BitCoder
  (with-bits! [c bits])
  (finalize! [c]))

;; (deftype SeqBitSource [input]
;;   EmitsBits
;;   (read! [_]
;;     (let [b (first @input)]
;;       (swap! input rest)
;;       b))
;;   (source [_] @input))

;; (deftype ConjBitSink [output]
;;   ConsumesBits
;;   (write! [_ bit]
;;     (swap! output conj bit))
;;   (target [_] @output)

  ;; BitCoder
  ;; (finalize! [_] output))

     
                                        
           
            
        
                          
                           
              
                                             

           
                             

     
                                        
              
                                             
                                               

          
                                   

           
                              

      
(deftype JSArrayBitSource [input]
  EmitsBits
  (read! [_] (js/shift input))
  (source [_] input))

      
(deftype JSArrayBitSink [output]
  ConsumesBits
  (write! [_ bit] (js/push output bit))
  (target [_] output)

  BitCoder
  (finalize! [_] output))

(defprotocol AsBits
  (bit-source [o])
  (bit-sink [o]))

     
                       
             
                                                                

              
                                                              

            
                                                                
                                                              

                     
                                                                     

                     
                                                                   

                                                
                                                          

;; (extend-protocol AsBits
;;   #+clj clojure.lang.PersistentVector
;;   #+cljs cljs.core.PersistentVector
;;   (bit-source [o] (SeqBitSource. (atom o)))
;;   (bit-sink [o] (ConjBitSink. (atom o))))

(defn default-bit-sink []
                                           
         (bit-sink (JSArrayBitSink. (array))))

;;;

(defn- update-interval
                                                         
         [[^number la ^number lb] [^number cx ^number dx]]
  (if (not-any? nil? [cx dx])
    (let [w (- lb la)]
      [(+ la (roundd (* w (/ cx whole))))
       (+ la (roundd (* w (/ dx whole))))])
    [la lb]))

(defn- rescale-interval
                               
         [[^number a ^number b]]
  (let [a (roundd (* whole a))
        b (roundd (* whole b))]
    (cond
     (zero? a) [a (dec b)]
     (== b whole) [(inc a) b]
     :else [(inc a) (dec b)])))

(defn- rescaled-inverse-cdf
                                    
         [distribution ^number znorm]
  (let [z (roundd (* whole znorm))]
    (loop [i (inverse-cdf distribution znorm)]
      (let [[a b :as interval] (rescale-interval (cdf distribution i))]
        (cond
         (< z (dec a)) (recur (next-lower distribution i))
         (> z (inc b)) (recur (next-higher distribution i))
         :else [i interval])))))

(defn- emit-bits
                                 
         [bits ^boolean i ^number s]
  (let [r (not i)]
    (write! bits i)
    (loop [j s]
      (if (pos? j)
        (do (write! bits r)
            (recur (dec j))))))
  bits)

(deftype ArithmeticCoder                                                   
                                [^number a ^number b ^number s ^number z bits]
  DistributionCoder
  (encode-symbol [_ distribution sym]
    (let [[a b] (->> (cdf distribution sym)
                     (rescale-interval)
                     (update-interval [a b]))]
      (loop [a a b b s s bits bits]
        (cond
         (< b half)
         (recur (* 2. a) (* 2. b) 0
                (emit-bits bits false s))

         (> a half)
         (recur (* 2. (- a half)) (* 2. (- b half)) 0
                (emit-bits bits true s))

         :else
         (loop [a a b b s s]
           (if (and (> a quarter)
                    (< b (- whole quarter)))
             (recur (* 2. (- a quarter))
                    (* 2. (- b quarter)) (inc s))
             (ArithmeticCoder. a b s z bits)))))))

  (decode-symbol [_ distribution]
    (letfn [(update-z [z bits] (if (read! bits) (inc z) z))]
      (let [znorm (/ (- z a) (- b a))
            [sym cdf-interval] (rescaled-inverse-cdf distribution znorm)
            [a b] (update-interval [a b] cdf-interval)]
        (loop [a a b b z z]
          (cond
           (< b half)
           (recur (* 2. a) (* 2. b)
                  (double (update-z (* 2. z) bits)))

           (> a half)
           (recur (* 2. (- a half)) (* 2. (- b half))
                  (double (update-z (* 2. (- z half)) bits)))

           :else
           (loop [a a b b z z]
             (if (and (> a quarter) (< b (- whole quarter)))
               (recur (* 2. (- a quarter)) (* 2. (- b quarter))
                      (double (update-z (* 2. (- z quarter)) bits)))
               [sym (ArithmeticCoder. a b s z bits)])))))))

  BitCoder
  (with-bits! [_ newbits]
    (loop [i (dec coding-precision) z 0.]
      (if (>= i 0)
        (recur (dec i) (+ z (if (read! newbits) (pow 2 i) 0.)))
        (ArithmeticCoder. a b s z newbits))))

  (finalize! [_]
    (let [s (inc s)]
      (if (<= a quarter)
        (emit-bits bits false s)
        (emit-bits bits true s))
      (finalize! bits)
      bits))

                 
                                                                   )

(deftype StatefulCoder [coder]
  DistributionCoder
  (encode-symbol [_ distribution sym]
    (swap! coder (fn [coder]
                   (encode-symbol coder distribution sym))))
  (decode-symbol [_ distribution]
    (let [sym (atom nil)]
      (swap! coder (fn [coder]
                     (let [[s c] (decode-symbol coder distribution)]
                       (reset! sym s)
                       c)))
      @sym))

  BitCoder
  (finalize! [_] (finalize! @coder))

                 
                                   )

(defn- empty-coder
  ([] (ArithmeticCoder. 0. whole 0 0. nil))
  ([bit-io] (ArithmeticCoder. 0. whole 0 0. bit-io)))

(defn input-stateful-coder [source]
  (StatefulCoder. (atom (with-bits! (empty-coder) source))))

(defn output-stateful-coder
  ([] (StatefulCoder. (atom (empty-coder (default-bit-sink)))))
  ([sink] (StatefulCoder. (atom (empty-coder sink)))))

;;;

(defprotocol Codeable
  (encode [o scoder] [o scoder sym] [o scoder sym typehint])
  (decode [o scoder]))

(deftype AtomCoder [distribution]
  Codeable
  (encode [_ scoder sym]
    (encode-symbol scoder distribution sym))
  (decode [_ scoder]
    (decode-symbol scoder distribution)))

(defn uniform [N] (AtomCoder. (uniform-distribution N)))
(defn zipf [N] (AtomCoder. (zipf-distribution N)))
(defn geometric [mu] (AtomCoder. (geometric-distribution mu)))
(defn benford [base digits] (AtomCoder. (benford-distribution base digits)))
(defn custom [options] (AtomCoder. (custom-distribution options)))
(defn constant [value] (AtomCoder. (constant-distribution value)))

(deftype FixedArrayCoder [coders]
  Codeable
  (encode [_ scoder syms]
    (loop [c coders, s syms]
      (if c (do (encode (first c) scoder (first s))
                (recur (next c) (next s)))
          nil)))
  (decode [_ scoder]
    (map (fn [c] (decode c scoder)) coders)))

(defn fixed-array [coders] (FixedArrayCoder. coders))

;; coders
(def uint32coder (zipf (m/pow2 32)))

(def ^:private uniform-uint32coder (uniform (m/pow2 32)))
(def ^:private lencoder (custom {1 0.6, 2 0.4}))
(def ^:private short64coder uint32coder)
(def ^:private long64coder (fixed-array [uniform-uint32coder uint32coder]))

(deftype UInt64Coder []
  Codeable
  (encode [_ scoder sym]
    (let [arr (num->arr (round sym))
          len (if (zero? (nth arr 1)) 1 2)]
      (encode lencoder scoder len)
      (if (> len 1)
        (encode long64coder scoder arr)
        (encode short64coder scoder (nth arr 0)))))
  (decode [_ scoder]
    (let [len (decode lencoder scoder)
          arr (if (> len 1)
                (decode long64coder scoder)
                [(decode short64coder scoder) 0])]
      (arr->num arr))))

(def uint64coder (UInt64Coder.))

(deftype TransformCoder [coder encodefn decodefn]
  Codeable
  (encode [_ scoder sym] (encode coder scoder (encodefn sym)))
  (decode [_ scoder] (decodefn (decode coder scoder))))

(defn transform [coder encodefn decodefn]
  (TransformCoder. coder encodefn decodefn))

(defn offset-zipf [low mid high]
  (let [total-range (inc (- high low))
        high-range (inc (- high mid))
        low-range (- mid low)
        low-coder (transform (zipf low-range) #(dec (- mid %)) #(- mid (inc %)))
        high-coder (transform (zipf high-range) #(- % mid) #(+ % mid))]

    (m/switch {:low  {:pr (/ low-range total-range)
                      :test #(< % mid)
                      :coder low-coder}
               :else {:pr (/ high-range total-range)
                      :coder high-coder}})))

(def int32coder (offset-zipf (- (m/pow2 31)) 0 (dec (m/pow2 31))))

(def ^:private signcoder (uniform 2))

(deftype SignCoder [numcoder]
  Codeable
  (encode [_ scoder sym]
    (let [sign (if (neg? sym) 1 0)]
      (encode signcoder scoder sign)
      (encode numcoder scoder (abs sym))))
  (decode [_ scoder]
    (let [sign (decode signcoder scoder)
          num (decode numcoder scoder)]
      (if (pos? sign)
        (- num)
        num))))

(defn signed [coder] (SignCoder. coder))

(def int32coder (signed uint32coder))
(def int64coder (signed uint64coder))

(defn split-float64
                   
         [^number x]
  (let [[m e] (get-float-parts x)
        [l32 h21] (num->arr m)]
    [(+ e 51) h21 l32]))

(defn join-float64
                                     
         [^number e ^number h21 ^number l32]
  (let [m (arr->num [l32 h21])]
    (if (== e -1023)
      (* m (m/pow2 -1074))
      (+ (pow 2 e) (* (pow 2 (- e 52)) m)))))

(def ^:private small-pr (/ 1 (m/pow2 32)))
(def ^:private exp (offset-zipf -1023 0 1024))
(def ^:private unif-h21 (uniform (m/pow2 21)))
(def ^:private unif-l32 (uniform (m/pow2 32)))
(def ^:private mcoder (fixed-array [unif-h21 unif-l32]))

(deftype UDoubleCoder []
  Codeable
  (encode [_ scoder sym]
    (cond
     (is-nan? sym)
     (do (encode exp scoder 1024)
         (encode mcoder scoder [0 1]))

     (is-infinite? sym)
     (do (encode exp scoder 1024)
         (encode mcoder scoder [0 0]))

     :else
     (let [[e h21 l32] (split-float64 sym)]
       (encode exp scoder e)
       (encode mcoder scoder [h21 l32]))))

  (decode [_ scoder]
    (let [e (decode exp scoder)
          [h21 l32] (decode mcoder scoder)]
      (if (== e 1024)
        (if (zero? l32) fInfinity fNaN)
        (join-float64 e h21 l32)))))

(def doublecoder (signed (UDoubleCoder.)))

(def booleancoder (custom {true 0.5, false 0.5}))

(defn one-indexed [coder] (transform coder dec inc))

(def ^:private year (offset-zipf 0 2015 9999))
(def ^:private month (one-indexed (uniform 12)))
(def ^:private days-in-month {:normal [31 28 31 30 31 30 31 31 30 31 30 31]
                              :leap   [31 29 31 30 31 30 31 31 30 31 30 31]})
(def ^:private days (apply conj {}
                           (map (fn [d]
                                  [d (one-indexed (uniform d))])
                                [28 29 30 31])))
(def ^:private hours (uniform 24))
(def ^:private minutes (uniform 60))
(def ^:private seconds-59 (uniform 60))
(def ^:private seconds-60 (uniform 61))
(def ^:private milliseconds (uniform 1000))

(defn- leap-year? [year]
  (and (zero? (bit-and year 3))
       (or (not (zero? (mod year 25)))
           (zero? (bit-and year 15)))))

(defn- day-coder [year month]
  (let [num-days
        (nth (get days-in-month
                  (if (leap-year? year) :leap :normal))
             (dec month))]
    (get days num-days)))

(defn- sec-coder [mins]
  (if (== mins 59) seconds-60 seconds-59))

(deftype UTCDateCoder []
  Codeable
  (encode [_ scoder sym]
    (let [[yr mo d hr mn s ms] (decompose-date sym)]
      (encode year scoder yr)
      (encode month scoder mo)
      (encode (day-coder yr mo) scoder d)
      (encode hours scoder hr)
      (encode minutes scoder mn)
      (encode (sec-coder mn) scoder s)
      (encode milliseconds scoder ms)))
  (decode [_ scoder]
    (let [yr (decode year scoder)
          mo (decode month scoder)
          d (decode (day-coder yr mo) scoder)
          hr (decode hours scoder)
          mn (decode minutes scoder)
          s (decode (sec-coder mn) scoder)
          ms (decode milliseconds scoder)]
      (recompose-date yr mo d hr mn s ms))))

(def date-coder (UTCDateCoder.))

(deftype VariableArrayCoder [coder countcoder]
  Codeable
  (encode [_ scoder sym]
    (encode countcoder scoder (count sym))
    (doseq [s sym] (encode coder scoder s)))
  (decode [_ scoder]
    (let [cnt (decode countcoder scoder)]
      (repeatedly cnt #(decode coder scoder)))))

(defn variable-array
  ([coder]
     (VariableArrayCoder. coder uint64coder))
  ([coder countcoder]
     (VariableArrayCoder. coder countcoder)))

(deftype SparseArray [maxcnt valuecoder]
  Codeable
  (encode [_ scoder sym]
    (doseq [i (range maxcnt)]
      (if (contains? sym i)
        (do (encode booleancoder scoder true)
            (encode valuecoder scoder (get sym i)))
        (encode booleancoder scoder false))))
  (decode [_ scoder]
    (loop [i 0, sarr {}]
      (if (< i maxcnt)
        (if (decode booleancoder scoder)
          (recur (inc i) (assoc sarr i (decode valuecoder scoder)))
          (recur (inc i) sarr))
        sarr))))

(defn sparse-array [maxcnt valuecoder]
  (SparseArray. maxcnt valuecoder))

(declare default-coder)

(deftype PolymorphicCoder [type-coder]
  Codeable
  (encode [_ scoder sym]
    (encode sym scoder))
  (decode [_ scoder]
    (decode default-coder scoder)))

(def polymorphic (PolymorphicCoder. default-coder))

(defn- type-directory->coder [directory]
  (let [dist (zipmap (keys directory)
                     (map :pr (vals directory)))
        flag (custom dist)]
    (reify Codeable
      (encode [_ scoder sym typehint]
        (encode flag scoder typehint)
        (encode (get-in directory [typehint :coder]) scoder sym))
      (encode [_ scoder sym]
        (let [specs (drop-while (fn [[k v]] (not ((v :test) sym)))
                                directory)
              [k spec] (first specs)]
          (encode flag scoder k)
          (encode (spec :coder) scoder sym)))
      (decode [_ scoder]
        (let [k (decode flag scoder)]
          (decode (get-in directory [k :coder]) scoder))))))

(defn- get-types-coder [type-directory types-coder]
  (or @types-coder
      (reset! types-coder (type-directory->coder @type-directory))))

(deftype RegisteredTypesCoder [type-directory types-coder]
  Codeable
  (encode [_ scoder sym typehint]
    (encode (get-types-coder type-directory types-coder) scoder sym typehint))
  (encode [_ scoder sym]
    (encode (get-types-coder type-directory types-coder) scoder sym))
  (decode [_ scoder]
    (decode (get-types-coder type-directory types-coder) scoder)))

(defn register-type! [registry name
                      & {:keys [coder test pr] :or {pr 1}}]
  {:pre [(not (nil? (and registry name coder test)))]}
  (swap! (.-type-directory registry) assoc name {:pr pr
                                                :coder coder
                                                :test test} )
  (reset! (.-types-coder registry) nil))

(defn unregister-type! [registry name]
  (swap! (.-type-directory registry) dissoc name)
  (reset! (.-types-coder registry) nil))

(def default-coder (RegisteredTypesCoder. (atom nil) (atom nil)))

(def byte-coder (uniform 256))
(def bytes-coder (variable-array byte-coder))

(def unicode-char-coder
  (transform (uniform (m/pow2 16)) char->int int->char))

(defn string-coder [character-coder]
  (transform (variable-array character-coder) #(map str %) #(apply str %)))

(def unicode-coder (string-coder unicode-char-coder))

(defn unigram-coder [unigrams model-pr]
  (let [ucoder (custom unigrams)]
    (m/switch {:modeled {:pr model-pr
                         :coder ucoder
                         :test #(contains? unigrams %)}
               :else    {:pr (- 1 model-pr)
                         :coder unicode-char-coder}})))

(defn language-coder [unigrams model-pr]
  (string-coder (unigram-coder unigrams model-pr)))

(def english-coder (language-coder en/unigrams 0.98))

(defn- average-codepoint [s]
  (let [sample (take 20 s)]
    (/ (reduce + (map char->int sample)) (max 1 (count sample)))))

(def adaptive-string-coder
  (m/switch {:english {:pr 0.8
                       :coder english-coder
                       :test #(< (average-codepoint %) (m/pow2 7))}
             :else    {:pr 0.2
                       :coder unicode-coder}}))

(def symbol-coder (transform english-coder str symbol))
(def keyword-coder (transform english-coder #(subs (str %) 1) keyword))
(def nil-coder (constant nil))

(def coll-coder (variable-array polymorphic))

(def list-coder (transform (variable-array polymorphic)
                           identity #(into () (reverse %))))

(def vec-coder (transform (variable-array polymorphic)
                          identity vec))

(def map-coder (transform (variable-array (fixed-array [polymorphic polymorphic]))
                          identity #(into {} (map vec %))))

(def set-coder (transform (variable-array polymorphic)
                          identity set))

(def atom-pr 1.)
(def coll-pr 1.)

;; Standard atoms
(m/register-polymorphic-type! :nil, nil, :pr atom-pr, :coder nil-coder, :test nil?)

(m/register-polymorphic-type! :boolean,
                                                     
                                     boolean,
                              :pr atom-pr, :coder booleancoder, :test is-boolean?)

(m/register-polymorphic-type! :symbol,
                                                       
                                     cljs.core.Symbol,
                              :pr atom-pr, :coder symbol-coder, :test symbol?)

(m/register-polymorphic-type! :date,
                                                                           
                                     [goog.date.Date goog.date.DateTime js/Date],
                              :pr atom-pr, :coder date-coder, :test is-date?)

(m/register-polymorphic-type! :string,
                                                    ,
                                     string,
                              :pr atom-pr, :coder adaptive-string-coder, :test string?)

(m/register-polymorphic-type! :keyword,
                                                        ,
                                     cljs.core.Keyword,
                              :pr atom-pr, :coder keyword-coder, :test keyword?)

     
   
                                                                                                         
                                                                                                         

      
(do ;; JS can't directly dispatch on integer/float, so these must be handled manually
  (register-type! default-coder, :integer, :pr atom-pr, :coder int64coder,  :test integer?)
  (register-type! default-coder, :float,   :pr atom-pr, :coder doublecoder, :test is-float?)
  (extend-protocol Codeable
    number
    (encode [o s] (encode default-coder s o (if (integer? o) :integer :float)))
    (decode [_ s] (decode default-coder s))))

;; Standard collections
(m/register-polymorphic-type! :list,
                                                                 
                                                                           
                                                          ,
                                     [cljs.core.List,
                                      cljs.core.EmptyList,
                                      cljs.core.LazySeq],
                              :pr coll-pr, :coder list-coder, :test list?)

(m/register-polymorphic-type! :vector,
                                                                 ,
                                     cljs.core.PersistentVector,
                              :pr coll-pr, :coder vec-coder, :test vector?)

(m/register-polymorphic-type! :map,
                                                                     
                                                                    
                                                                                     
                                                                    ,
                                     [cljs.core.PersistentArrayMap,
                                      cljs.core.PersistentHashMap,
                                      cljs.core.TransientHashMap,
                                      cljs.core.PersistentTreeMap],
                              :pr coll-pr, :coder map-coder, :test map?)

(m/register-polymorphic-type! :set,
                                                                    
                                                                                     
                                                                    ,
                                     [cljs.core.PersistentHashSet,
                                      cljs.core.TransientHashSet,
                                      cljs.core.PersistentTreeSet],
                              :pr coll-pr, :coder set-coder, :test set?)

(defn- with-output [f]
  (                       let [sc (output-stateful-coder)]
    (f sc)
                                                
           (target (finalize! sc))))

(defn- with-input [f source]
  (                       let [sc (input-stateful-coder (bit-source source))]
    (f sc)))

(defn compress
  ([o] (with-output #(encode o %)))
  ([o coder] (with-output #(encode coder % o))))

(defn decompress
  ([source] (with-input #(decode default-coder %) source))
  ([source coder] (with-input #(decode coder %) source)))

;;;;;;;;;;;; This file autogenerated from src/cljx/shannon/core.cljx
