(ns nicolasoury.distributions.core
  (:import java.lang.Math java.util.Arrays
	   [clojure.lang Seqable IEditableCollection ITransientCollection Counted IPersistentCollection]  )
    (:gen-class)) 

(comment "An atomic distribution is a limited faster distribution.
          It is atomic in the sense it does not try to draw in
          any distribution it can contains.
          Then a loop can draw as long as possible.")

(set! *warn-on-reflection* true)

(definterface AtomicDistribution
  (^double atomicMass [])
  (^Object atomicDraw [^double r] )
  )

(definterface Distribution
  (^double mass [])
  (draw [extra])) ; extra is threaded through all the call, in case someone needs it


(deftype TransientVersion [])

(definterface Transientable
  (asTransient [tv])
  (is_transient [])
  (is_my_transient [tv]))

(defn fresh-transient-version [] (TransientVersion.))




(definterface InPlaceUpdatableDistribution
  (insert_BANG_ [^Object x ^double m])
  (insert_BANG_ [^Object x ^int hashCode  ^double m]) 
  (delete_BANG_ [^Object x])
  (delete_BANG_ [^Object x ^int hashCode ])
  )


(defmacro obsolete-double [] (double -1537.51))



(defmacro linear-threshold
  "the threshold of linear distribution"
  []
  `(int 8))

(defmacro  bucket-size
  "the threshold of linear distribution"
  []
  `(int 6))

(defmacro random []
   `(Math/random))


(defmacro random-within [x]
   `(* (random) (.mass ^Distribution x)))



(def mask-size-num 5)
(def block-size-num (bit-shift-left 1 mask-size-num))
(def mask-num (- block-size-num 1))

(defmacro mask-size [] `(int ~mask-size-num))
(defmacro block-size [] `(int ~block-size-num))
(defmacro masked-value
  "Show the masked value of int i starting at bit p. p is better as an int."
  [i p]
  `(bit-and (bit-shift-right ~i ~p) (int ~mask-num)))

(defmacro valid-level? [p]
  `(< ~p  (int 32)))

(defmacro next-level [p]
  `(+ ~p (mask-size)))

 
(declare linear-distribution)

(comment a wrapper for persistent distribution.
	 It's goal is to make sure the distribution is wrapped in
	 a final field, so the transient is up to date when published
         )

(declare atomic-distribution-draw)

(deftype SumDistribution [^AtomicDistribution dist]
  AtomicDistribution
  (^double atomicMass [x] (.atomicMass ^AtomicDistribution dist))
  (atomicDraw [x r]  dist)
  Transientable
  (asTransient [x tv]   (.asTransient ^Transientable dist  tv))
  (is_transient [x] false)
  (is_my_transient [x tv] false)
  IEditableCollection
  (asTransient [x] (.asTransient ^Transientable dist (fresh-transient-version)))
  Counted
  (count [x] (count ^Counted dist))
  IPersistentCollection
  (equiv [x y] (identical? x y))
  Seqable
  (seq [x] (seq dist))
  Distribution
  (mass [x]  (.atomicMass dist))
  (draw [x extra]  (.draw ^Distribution (atomic-distribution-draw dist) extra))
)

(def persistent-disribution-class SumDistribution)
(declare singleton-distribution)

(deftype HashNodeDistribution
  [^TransientVersion transientVersion
   ^int bit-depth
   ^{:unsynchronized-mutable true} ^int  my_count
   ^{:unsynchronized-mutable true} ^double mass
     ^doubles masses ^"[Ljava.lang.Object;" elements]
  AtomicDistribution
  (^double atomicMass [x]
	   (if (= mass (obsolete-double))
             (loop [i (int 0) partial_mass (double 0.0)]
	       (if (= i (block-size))
		 (do 
		   (set! mass partial_mass)
		   partial_mass)
		 (let [m (aget masses i)]
		   (if (= m (obsolete-double))
		     (let [new-m (double (.atomicMass ^AtomicDistribution (aget elements i)))]
		       (aset masses i new-m)
		       (when (== new-m (double 0)) (aset elements i nil))
		       (recur (+ i (int 1)) (+ partial_mass new-m)))
		     (recur (+ i (int 1)) (+ partial_mass m))))))
	       mass))
  (atomicDraw [x r]
	      (loop [i (int 0) cumulated-masses (aget masses i)]
		(if (< r cumulated-masses)
		  (aget elements i)
		  (recur (+ i (int 1)) (+ cumulated-masses (aget masses (+ i (int 1))) )))))
  Transientable
  (asTransient [x tv]
	        (HashNodeDistribution.
		tv bit-depth my_count (obsolete-double)
		(aclone masses)
		(aclone elements)))
  (is_transient [x] (not (nil? transientVersion)))
  (is_my_transient [x tv] (identical? tv transientVersion))
  Counted
  (count [x] (if  (= my_count (int -1))
	       (loop [c (int 0) i (int 0)]
		 (if (= i (block-size))
		   (do
		     (set! my_count c)
		     c)
		   (let [elt (aget elements i)]
		     (cond
		      (nil? elt) (recur c (+ i (int 1)))
		      (instance? AtomicDistribution elt)
		      (let [c1  (count elt)]
			(when (zero? c1) (aset elements i nil)) 
			(recur (+ c c1) (+ i (int 1)) ))
		      true (recur (+ c (int 1)) (+ i (int 1))))))) 
	       my_count))
  ITransientCollection
  (persistent [x]  (SumDistribution. x))
  InPlaceUpdatableDistribution
  (insert! [this ^Object x ^int hashCode ^double m]
           (set! mass (obsolete-double))
	   (let [i (masked-value hashCode bit-depth)
		     elt (aget elements i)]
		 (cond
		  (nil? elt) (do
			       (aset elements i (singleton-distribution x m (next-level bit-depth)))
			       (aset masses i m)
			       (if (< m (double 0.0)) (throw (new Exception (str "Negative element in distribution  "  m)))) 
                               (when (>= my_count (int 0)) (set! my_count (+ my_count (int 1)))))
		  (instance? Transientable elt)
		     (if (.is_my_transient ^Transientable elt transientVersion)
		       (do
                         (aset elements i (.insert! ^InPlaceUpdatableDistribution elt x hashCode m))
                         (set! my_count (int -1))
			 (aset masses i (obsolete-double)))
		       (let [^InPlaceUpdatableDistribution elt (.asTransient ^Transientable elt transientVersion)]
			 (aset elements i (.insert! elt x hashCode m))
			 (aset masses i (obsolete-double))
                         (set! my_count (int -1))))))
	   this)
  (insert! [this ^Object x ^double m] (.insert! this x (hash x) m))
  (delete! [this ^Object x ^int hashCode]
	   (set! mass (obsolete-double)) ;; masses are false
	   (let [i (masked-value hashCode bit-depth)
		     elt (aget elements i)]
		 (cond
		  (nil? elt) (throw (new Exception "Cannot delete a non-present element"))			   
		  (instance? InPlaceUpdatableDistribution elt)
		     (if (.is_my_transient ^Transientable elt transientVersion)
		       (do
                         (aset elements i (.delete! ^InPlaceUpdatableDistribution elt x hashCode))
                          (when (>= my_count (int 0)) (set! my_count (- my_count (int 1))))
			 (aset masses i (obsolete-double)))
		       (let [^InPlaceUpdatableDistribution elt (.asTransient ^Transientable elt transientVersion)]
			 (aset masses i (obsolete-double))
                         (when (>= my_count (int 0))(set! my_count (- my_count (int 1))))
			 (aset elements i (.delete! elt x hashCode)))))
	         this))
  (delete! [this ^Object x] (.delete! this x (hash x)))
  Seqable
  (seq [_]
       (apply concat
	      (map #(cond
		     (nil? %1) ()
		     (instance? AtomicDistribution %1) (seq %1))  elements)))
	   
  ) 


(deftype LinearDistribution
  [^TransientVersion transientVersion
   ^{:unsynchronized-mutable true} ^int  my_count
   ^{:unsynchronized-mutable true}  ^double mass
     ^doubles masses ^"[Ljava.lang.Object;" elements ^int bit-depth]
  AtomicDistribution
  (^double atomicMass [x]
	   (if (= mass (obsolete-double))
             (loop [i (int 0) partial_mass (double 0.0)]
	       (if (= i my_count)
		 (do 
		  (set! mass partial_mass)
		  partial_mass)
		 (recur (+ i (int 1)) (+ partial_mass (aget masses i)))))
	       mass))
  (atomicDraw [x r]
	      (loop [i (int 0) cumulated-masses (aget masses (int 0))]
		(if (< r cumulated-masses)
		  (aget elements i)
		  (recur (+ i (int 1)) (+ cumulated-masses (aget masses (+ i (int 1))) )))))
  Seqable
  (seq [_]
       (map vector (take my_count elements) (take my_count masses)))
  Transientable
  (asTransient [x tv]
	        (LinearDistribution.
		tv my_count (obsolete-double)
		(aclone masses)
		(aclone elements) bit-depth))
  (is_transient [x] (not (nil? transientVersion)))
  (is_my_transient [x tv] (identical? tv transientVersion))
    ITransientCollection
  (persistent [x]  (SumDistribution. x))
  Counted
  (count [x] my_count)
  InPlaceUpdatableDistribution
  (insert! [this ^Object x ^int hashCode ^double m] (.insert! this x m))
  (insert! [this ^Object x ^double m]
	   (if (= my_count  (alength  elements))
	     (let [^InPlaceUpdatableDistribution dist
		  (if (valid-level? bit-depth)
		       (HashNodeDistribution. 
			      transientVersion
			       bit-depth
			      0
			      0.0
			      (double-array (block-size) 0.0)
			      (make-array Object (block-size)))
		    (LinearDistribution. transientVersion 0 0.0
					(double-array (* (int 2) (alength elements)) 0.0)  (make-array Object  (* (int 2) (alength elements))) bit-depth))]
	       (loop [i (int 0)]
      		 (if (= i my_count)
		   (.insert! dist x m)
		   (do
		     (.insert! dist (aget elements i) (aget masses i))
		     (recur (+ i (int 1)))))))
	     (do
	       (set! mass (obsolete-double))
	       (loop [i (int 0)]
		 (if (= i my_count)
		   (do
                     (if (< m (double 0.0)) (throw (new Exception (str "insert a negative weight " m " in Linear distribution: " (doall (seq this)) " for object " x ))))
		     (aset masses my_count m)
		     (aset elements my_count x)
		     (set! my_count (+ my_count (int 1)))
		     this)
		   (let [elt (aget elements i)]
		     (if (= x elt)
		       (let [new_m  (+ (aget masses i) m)]
			 (if (zero? new_m)
			   (.delete! this x)
			   (do
			     (aset masses i new_m)
			     this)))
		       (recur (+ i (int 1))))))))))
  
  (delete! [this ^Object x ^int hashCode] (.delete! this x))
  (delete! [this ^Object x]
           (set! mass (obsolete-double))
	   (loop [i (int 0)]
	     (let [elt (aget elements i)]
	       (if (= x elt)
		 (let [last-index  (- my_count (int 1))
		       last-elt  (aget elements last-index)
		       last-mass (aget masses last-index)]
		   (aset elements i last-elt)  
		   (aset masses i last-mass)
		   (aset elements last-index nil)
		   (aset masses last-index 0.0)
		   (set! my_count (- my_count (int 1)))
		   this) 
		  (recur (+ i (int 1)))))))
  IPersistentCollection
	   ) 

(defn linear-distribution [tv count mass masses elements depth]
  (LinearDistribution. tv count mass masses elements depth))

(deftype EmptyDistribution []
    AtomicDistribution
  (^double atomicMass [x] 0.0)
  (atomicDraw [x r]
	     (throw (new Exception "Cannot draw from an empty distribution"))) 
  Distribution
  (^ double mass [x] 0.)
  (draw [x extra] (throw (new Exception "Cannot draw from an empty distribution")))
  Seqable
   (seq [x] ())
   Transientable
   (is_transient [x] false)
   (is_my_transient [x tv] false)
   (asTransient [x tv]
	        (LinearDistribution.
		tv  0 (obsolete-double)
		(double-array (linear-threshold) 0.0)
		(make-array Object (linear-threshold)) 0))

  IEditableCollection
  (asTransient [x] (.asTransient x (fresh-transient-version)))
  Counted
  (count [x] 0)
  
  )




(def empty-distribution  (EmptyDistribution.))



(deftype SingletonDistribution [^double mass elt ^int bit-depth]
    AtomicDistribution
  (^double atomicMass [x] mass)
  (atomicDraw [x r] elt) 
   Seqable
   (seq [x] (seq [[elt mass]]))
   Transientable
   (is_transient [x] false)
   (is_my_transient [x tv] false)
   (asTransient [x tv]
		(let [dist
		    (LinearDistribution.
		     tv  0 (obsolete-double)
		     (double-array (bucket-size) 0.0)
		     (make-array Object (bucket-size)) bit-depth)]
		  (.insert! dist elt mass)))
  IEditableCollection
  (asTransient [x] (.asTransient x (fresh-transient-version)))
  Counted
  (count [x] 1)
  )

(defn singleton-distribution [elt ^double mass ^int bd]
  (SingletonDistribution. mass elt bd)
  )



(defn atomic-distribution-draw
  "Draws in an atomic distribution. If the atomic distribution contains other atomic distribution, then their mass
   in the former have to be their total-mass and draw draws into them"
  [^AtomicDistribution dist]
  (loop [^AtomicDistribution dist dist r (double (* (.atomicMass dist) (random)))]
    (let [^AtomicDistribution res (.atomicDraw dist r) ]
      (if (instance? HashNodeDistribution res)
	(recur res (* (double (.atomicMass res)) (double (random))) )
	(.atomicDraw res (* (double (.atomicMass res)) (double (random))) )))))

(defn insert! [^InPlaceUpdatableDistribution x a b]
  (if (= (double 0.0) (double b)) x
     (.insert! x a b)))

(defn delete! [^InPlaceUpdatableDistribution x a]
  (.delete! x a))

(defn insert
  [x a b]
   (if (= (double 0.0) (double b)) x
      (persistent! (insert! (transient x) a b)) ))

(defn delete [x a ]
     (persistent! (delete (transient x) a)) )

(defmethod print-method SumDistribution [x y]
	   (print-method (seq x) y))

(defn mass-atomic-distribution [^AtomicDistribution dist]
  (.atomicMass dist))

(defn plus [big small]
  (loop [sources (seq small) ^InPlaceUpdatableDistribution target (transient big)]
      (if (empty? sources) (persistent! target)
	  (let [[key mass] (first sources)]
	    (recur (rest sources) (.insert! target key (double mass)))))))


(defn minus [big small]
  (loop [sources (seq small) ^InPlaceUpdatableDistribution target (transient big)]
      (if (empty? sources) (persistent! target)
	  (let [[key mass] (first sources)]
	    (recur (rest sources) (.insert! target key (- (double mass))))))))



(defn distribution-from-seqable [l]
   (persistent! (reduce #(insert! %1 (first %2) (second %2)) (transient empty-distribution) (seq l))))

;;;;; Generic distribution functions



(defn mass [^Distribution x]
   (if (nil? x)
     0.0
     (.mass x)))


(defn draw
  ([^Distribution x]
     (.draw x nil))
  ([^Distribution x extra]
     (.draw x extra)))



(defn empty-distribution? [^Distribution x]
  (or (nil? x)  (= (.mass x) 0.0)))

;;;; Return distributions are used to return a constant result.
;;;; They may have an inside distribution in which they draw, and they totally ignore the result

(deftype ReturnDistribution [^double m return-value ^Distribution dist]
  Distribution
  (draw [x extra]
        (when dist (.draw dist extra))
        return-value)
  (mass [x] m))

(defn return-distribution
  "(return-distribution return-value m) : a distribution of mass m always drawing return-value"
  [return-value m]
  (ReturnDistribution. m return-value nil))
                  

(defn draw-and-return-distribution
  "(draw-and-return-distribution return-value dist) : a distribution corrsponding
   to (fmap (const return-value) dist)"
  [return-value ^Distribution dist]
  (and dist
       (ReturnDistribution. (.mass dist) return-value dist)))

;;;;; mapping distribution

(deftype MapDistribution [^Distribution dist f]
  Distribution
  (mass [x] (.mass dist))
  (draw [x extra] (f (.draw dist extra))))


(defn map-distribution [f dist]
  (and dist 
       (MapDistribution. dist f)))


;;;; Scaling distribution
;; Scaling distribution are made to gives another mass to a distribution.



(deftype ScaleDistribution [^Distribution dist ^double m]
    Distribution
    (mass [x] m)
    (draw [x extra] (.draw dist extra)))

(defn scale-distribution [^Distribution dist  fact]
  (and dist (ScaleDistribution.  dist (* (double fact) (double (.mass dist))))))

(defn distribution-with-mass [^Distribution dist m]
  (and dist (ScaleDistribution.  dist (double m))))



(defn normalise [x]
  (distribution-with-mass x 1.))


;;;; product distribution
;;; draw in  a product of independant distributions


(deftype ProductDistribution [^double m
                               dists
                              ^int n
                              ^boolean need-result?]; do we need the result or not.
                                        ;Why allocate an array if we don't?
  Distribution
  (mass [x] m)
  (draw [x extra]
        (let [^"[Ljava.lang.Object;" results (if need-result? (make-array Object n)  nil)]
          (loop [i (int 0)]
            (if (>= i n)
              results
              (let [res (.draw ^Distribution (get dists i) extra)]
                (when results (aset results i res))
                (recur (+ i (int 1)))))))))


(defn product-distribution [dists]
  (loop [m  (double 1) ds (seq dists)]
    (if ds
      (recur (* m (.mass ^Distribution (first ds))) (next ds))
      (ProductDistribution. m dists (count dists) true))))

(defn forgetful-product-distribution [dists]
  (loop [m  (double 1) ds (seq dists)]
    (if ds
      (recur (* m (.mass ^Distribution (first ds))) (next ds))
      (ProductDistribution. m dists (count dists) false))))


;;;;;;;;;;;;;;;;;;; Keyed distribution
;;;; A distribution can be keyed in order to indicate some extra data to be passed
;;;;  to the extra field of the draw.
;;; The key must be sufficient to identify the origin of the distribution.
;;; It is used for hashing and equality

;; a protocol for draw helper that accepts keys
(defprotocol KeyedDrawHelper
  (give-key [extra key] "Gives the key to the Draw helper")
  (give-key-index-out-of [extra key ^int index ^int num] "Gives the key and the index")
  (give-keynum-out-of [extra key num out-of])
  )
  
(deftype KeyedDistribution [^double m ^Distribution dist key]
  Distribution
  (draw [x extra]
	(give-key extra key)
	(.draw dist extra))
  (mass [x] m)
  Object
  (equals [x y]
	 (or (identical? x y)
	     (and
	       (= KeyedDistribution (type y))
	       (= key (.key ^KeyedDistribution y))
	       )))
  (hashCode [x] (.hashCode key)))


(defn keyed-distribution [^Distribution dist key]
 (KeyedDistribution. (.mass dist) dist key))

;; a distribution for multiple instances of the same distribution.
;; with respect to ScaleDistribution, allows to give a numbered key
(deftype MultipleDistribution [^double m ^Distribution dist ^int num key]
  Distribution
  (mass [x] m)
  (draw [x extra]
	(when extra
	  (let [i (int (*  (double num) (random)))]
	    (give-key-index-out-of extra key i num)))
	(draw dist extra))
  Object
  (equals [x y]
	 (or (identical? x y)
	     (and
	       (= MultipleDistribution (type y))
	       (= key (.key ^MultipleDistribution y))
	       (= num (.num ^MultipleDistribution y))
	       )))
  (hashCode [x] (.hashCode [key num])))


(defn multiple-distribution
  "Returns a distribution sumsof multiple copies of the same distribution.
   The distribution and the num of copies are arguments, a last argument can be
   the key to give to a KeyedDrawHelper" 
 ([^Distribution dist ^int num]
    (MultipleDistribution. (.mass dist) dist num nil))
 ([^Distribution dist ^int num key]
    (MultipleDistribution. (.mass dist) dist num key)))
	
	
