(ns thi.ng.ndarray.macros)

(defn- type-hinted
  [type x] (if type (with-meta x {:tag (name type)}) x))

(defn- make-symbols
  [id n] (mapv #(symbol (str id %)) (range n)))

(defn- pair-fn
  [f coll]
  (let [coll (->> coll
                  (partition-all 2)
                  (map #(if (< 1 (count %)) (cons f %) (first %))))]
    (if (> (count coll) 2) (recur f coll) coll)))

(defn- make-indexer
  [dim ->st p]
  `(+ ~@(->> (range dim)
             (map #(list '* (->st %) `(int (~p ~%))))
             (cons '_offset)
             (pair-fn '+))))

(defn- make-indexer-syms
  [dim ->st ->p]
  `(+ ~@(->> (range dim)
             (map #(list '* (->st %) `(int ~(->p %))))
             (cons '_offset)
             (pair-fn '+))))

(defn- with-bounds-check
  [dim psyms shapes clj? & body]
  `(if (and
        ~@(mapcat
           #(list `(>= ~(symbol (psyms %)) 0)
                  `(< ~(symbol (psyms %)) ~(symbol (shapes %))))
           (range dim)))
     (do ~@body)
     (throw
      (new ~(if clj? 'IndexOutOfBoundsException 'js/Error)
           (str "Invalid index: " (pr-str [~@psyms]))))))

(defn- inject-clj-protos
  [clj? get data ->a ->sh idx rdim]
  (if clj?
    (list
     'clojure.lang.Seqable
     `(~'seq
       [_#]
       (for [~@(mapcat #(vector (->a %) `(range ~(->sh %))) rdim)]
         (~get ~data ~idx))))
    (list
     'ISeqable
     `(~'-seq
       [_#]
       (for [~@(mapcat #(vector (->a %) `(range ~(->sh %))) rdim)]
         (~get ~data ~idx))))))

(defn- do-cast
  [cast body]
  (if cast `(~cast ~body) body))

(defmacro def-ndarray
  [dim cast type-hint type-id data-ctor get set & [clj?]]
  (let [type-name (symbol (str "NDArray" dim (name type-id)))
        raw-name  (symbol (str "make-raw-ndarray" dim "-" (name type-id)))
        strides   (make-symbols "_stride" dim)
        shapes    (make-symbols "_shape" dim)
        asyms     (make-symbols "a" dim)
        bsyms     (make-symbols "b" dim)
        psyms     (make-symbols "p" dim)
        [->st ->sh ->a ->b ->p] (map #(comp symbol %) [strides shapes asyms bsyms psyms])
        [c d f p o] (repeatedly gensym)
        idx       (make-indexer dim ->st p)
        idx-syms  (make-indexer-syms dim ->st ->p)
        data      (type-hinted type-hint '_data)
        rdim      (range dim)]
    `(do
       (deftype ~type-name
           [~data ~'_offset ~@strides ~@shapes]
         ~@(inject-clj-protos clj? get data ->a ->sh (make-indexer-syms dim ->st ->a) rdim)
         ~'PNDArray
         (~'data
           [_#] ~data)
         (~'data-type
           [_#] ~type-id)
         (~'dimension
           [_#] ~dim)
         (~'stride
           [_#] [~@strides])
         (~'shape
           [_#] [~@shapes])
         (~'offset
           [_#] ~'_offset)
         (~'index-at
           [_# ~@psyms] ~idx-syms)
         (~'index-pos
           [_# ~p]
           (let [~p (int ~p)
                 ~c (- ~p ~'_offset)
                 ~@(drop-last
                    2 (mapcat
                       #(list
                         (->a %) `(int (/ ~c ~(->st %)))
                         c `(- ~c (* ~(->a %) ~(->st %))))
                       rdim))]
             [~@asyms]))
         (~'index-seq
           [_#]
           (for [~@(mapcat #(vector (->a %) `(range ~(->sh %))) rdim)]
             ~(make-indexer-syms dim ->st ->a)))
         (~'position-seq
           [_#]
           (for [~@(mapcat #(vector (->a %) `(range ~(->sh %))) rdim)]
             [~@asyms]))
         (~'get-at
           [_# ~@psyms]
           (~get ~data ~idx-syms))
         (~'get-at-safe
           [_# ~@psyms]
           ~(with-bounds-check dim psyms shapes clj?
              `(~get ~data ~idx-syms)))
         (~'get-at-index
           [_# i#]
           (~get ~data (int i#)))
         (~'set-at
           [_# ~@psyms ~c]
           (~set ~data ~idx-syms ~(do-cast cast c)) _#)
         (~'set-at-safe
           [_# ~@psyms ~c]
           ~(with-bounds-check dim psyms shapes clj?
              `(~set ~data ~idx-syms ~(do-cast cast c)))
           _#)
         (~'set-at-index
           [_# i# ~c]
           (~set ~data (int i#) ~(do-cast cast c)) _#)
         (~'update-at
           [_# ~@psyms ~f]
           (let [~c ~idx-syms]
             (~set ~data ~c ~(do-cast cast `(~f ~@psyms (~get ~data ~c)))))
           _#)
         (~'update-at-safe
           [_# ~@psyms ~f]
           ~(with-bounds-check dim psyms shapes clj?
              `(let [~c ~idx-syms]
                 (~set ~data ~c ~(do-cast cast `(~f ~@psyms (~get ~data ~c))))))
           _#)
         (~'update-at-index
           [_# ~c ~f]
           (~set ~data ~c
                 ~(do-cast cast `(~f ~c (~get ~data ~c))))
           _#)
         (~'truncate-h
           [_# ~@psyms]
           (new ~type-name ~data ~'_offset ~@strides
                ~@(map
                   #(list 'if `(neg? ~(->p %)) (->sh %) `(int ~(->p %)))
                   rdim)))
         (~'truncate-l
           [_# ~@psyms]
           (let [~@(mapcat
                    #(list
                      [(->a %) (->b %)]
                      `(if (pos? ~(->p %))
                         [(- ~(->sh %) (int ~(->p %)))
                          (* ~(->st %) (int ~(->p %)))]
                         [~(->sh %) 0]))
                    rdim)
                 ~o (+ ~@(->> rdim (map ->b) (cons '_offset) (pair-fn '+)))]
             (new ~type-name ~data ~o ~@strides ~@asyms)))
         (~'transpose
           [_# ~@psyms]
           (let [~@(mapcat #(list (->p %) `(if ~(->p %) (int ~(->p %)) ~%)) rdim)
                 ~c [~@strides]
                 ~d [~@shapes]]
             (new ~type-name ~data ~'_offset
                  ~@(map #(list c (->p %)) rdim)
                  ~@(map #(list d (->p %)) rdim))))
         (~'step
           [_# ~@psyms]
           (let [~o ~'_offset
                 ~@(mapcat
                    #(let [stride' `(* ~(->st %) (int ~(->p %)))]
                       (list
                        [(->a %) (->b %) o]
                        `(if (number? ~(->p %))
                           (if (neg? ~(->p %))
                             [~(list 'int (list 'Math/ceil `(/ (- ~(->sh %)) (int ~(->p %)))))
                              ~stride'
                              (+ ~o (* ~(->st %) (dec ~(->sh %))))]
                             [~(list 'int (list 'Math/ceil `(/ ~(->sh %) (int ~(->p %)))))
                              ~stride'
                              ~o])
                           [~(->sh %) ~(->st %) ~o])))
                    rdim)]
             (new ~type-name ~data ~o ~@bsyms ~@asyms)))
         (~'pick
           [_# ~@psyms]
           (let [~o ~'_offset, ~c [], ~d []
                 ~@(mapcat
                    #(list
                      [c d o]
                      `(if (and (number? ~(->p %)) (>= ~(->p %) 0))
                         [~c ~d (+ ~o (* ~(->st %) (int ~(->p %))))]
                         [(conj ~c ~(->sh %)) (conj ~d ~(->st %)) ~o]))
                    rdim)
                 cnt# (count ~c)]
             (if (pos? cnt#)
               ((get-in @~'ctor-registry [cnt# ~type-id :ctor]) ~data ~o ~d ~c)
               (~get ~data ~o))))
         ~'Object
         (~'toString
           [_#]
           (pr-str
            {:data ~data :type ~type-id
             :size (* ~@shapes) :total (count ~data) :offset ~'_offset
             :shape [~@shapes] :stride [~@strides]})))

       (defn ~(with-meta raw-name {:export true})
         [data# o# [~@strides] [~@shapes]]
         (new ~type-name data# o# ~@strides ~@shapes))

       (swap!
        ~'thi.ng.ndarray.core/ctor-registry
        assoc-in [~dim ~type-id]
        {:ctor ~raw-name
         :data-ctor ~data-ctor}))))
