(ns thi.ng.ndarray.macros)

(defn make-symbols
  [id n] (mapv #(symbol (str id %)) (range n)))

(defn pair-fn
  [f coll]
  (let [coll (->> coll
                  (partition-all 2)
                  (map #(if (< 1 (count %)) (cons f %) (first %))))]
    (if (> (count coll) 2) (recur f coll) coll)))

(defn make-indexer
  [dim strides p]
  `(+ ~@(->> (range dim)
             (map #(list '* (symbol (strides %)) (list p %)))
             (concat ['_offset])
             (pair-fn '+))))

(defn make-indexer-syms
  [dim strides psyms]
  `(+ ~@(->> (range dim)
             (map #(list '* (symbol (strides %)) (symbol (psyms %))))
             (concat ['_offset])
             (pair-fn '+))))

(defn with-bounds-check
  [dim psyms shapes err & body]
  `(if (and
        ~@(mapcat
           #(list (list '>= (symbol (psyms %)) 0)
                  (list '<  (symbol (psyms %)) (symbol (shapes %))))
           (range dim)))
     (do ~@body)
     (throw (new ~err (str "Invalid index: " (pr-str [~@psyms]))))))

(defmacro def-ndarray
  [dim err]
  (let [type-name (symbol (str "NDArray" dim))
        raw-name  (symbol (str "make-raw-ndarray" dim))
        strides   (make-symbols "_stride" dim)
        shapes    (make-symbols "_shape" dim)
        asyms     (make-symbols "a" dim)
        bsyms     (make-symbols "b" dim)
        psyms     (make-symbols "p" dim)
        [->st ->sh ->a ->b ->p] (map #(comp symbol %) [strides shapes asyms bsyms psyms])
        [c d p o] (repeatedly gensym)
        idx       (make-indexer dim strides p)
        idx-syms  (make-indexer-syms dim strides psyms)
        rdim      (range dim)]
    `(do
       (deftype ~type-name [~'_get ~'_set ~'_data ~'_offset ~@strides ~@shapes]
         ~'PNDArray
         (~'data
           [_#] ~'_data)
         (~'data-seq
           [_#]
           (for [~@(mapcat #(vector (->a %) (list 'range (->sh %))) rdim)]
             (~'get-at-unsafe _# [~@asyms])))
         (~'dimension
           [_#] ~dim)
         (~'stride
           [_#] [~@strides])
         (~'shape
           [_#] [~@shapes])
         (~'offset
           [_#] ~'_offset)
         (~'index-at
           [_# ~p] ~idx)
         (~'get-at
           [_# [~@psyms]]
           ~(with-bounds-check dim psyms shapes err
              (list '_get '_data idx-syms)))
         (~'get-at-unsafe
           [_# ~p] (~'_get ~'_data ~idx))
         (~'set-at
           [_# [~@psyms] ~c]
           ~(with-bounds-check dim psyms shapes err
              (list '_set '_data ('int idx-syms) c))
           _#)
         (~'set-at-unsafe
           [_# ~p v#] (~'_set ~'_data (int ~idx) v#) _#)
         (~'hi
           [_# [~@psyms]]
           (new ~type-name ~'_get ~'_set ~'_data ~'_offset ~@strides
                ~@(map
                   #(list 'if (list 'neg? (->p %))
                          (->sh %)
                          (->p %))
                   rdim)))
         (~'lo
           [_# [~@psyms]]
           (let [~@(mapcat
                    #(list
                      [(->a %) (->b %)]
                      (list 'if (list 'pos? (->p %))
                            [(list '- (->sh %) (->p %))
                             (list '* (->st %) (->p %))]
                            [(->sh %) 0]))
                    rdim)
                 ~o (+ ~@(->> rdim (map #(->b %)) (concat ['_offset]) (pair-fn '+)))]
             (new ~type-name ~'_get ~'_set ~'_data ~o ~@strides ~@asyms)))
         (~'transpose
           [_# [~@psyms]]
           (let [~@(mapcat #(list (->p %) (list 'if (->p %) (->p %) %)) rdim)
                 ~c [~@strides]
                 ~d [~@shapes]]
             (new ~type-name ~'_get ~'_set ~'_data ~'_offset
                  ~@(map #(list c (->p %)) rdim)
                  ~@(map #(list d (->p %)) rdim))))
         (~'step
           [_# [~@psyms]]
           (let [~o ~'_offset
                 ~@(mapcat
                    #(let [stride' (list '* (->st %) (->p %))]
                       (list
                        [(->a %) (->b %) o]
                        (list 'if (list 'number? (->p %))
                              (list 'if (list 'neg? (->p %))
                                    [(list 'Math/ceil (list '/ (list '- (->sh %)) (->p %)))
                                     stride'
                                     (list '+ o (list '* (->st %) (list 'dec (->sh %))))]
                                    [(list 'Math/ceil (list '/ (->sh %) (->p %)))
                                     stride'
                                     o])
                              [(->sh %) (->st %) o])))
                    rdim)]
             (new ~type-name ~'_get ~'_set ~'_data ~o ~@bsyms ~@asyms)))
         (~'pick
           [_# [~@psyms]]
           (let [~o ~'_offset, ~c [], ~d []
                 ~@(mapcat
                    #(list
                      [c d o]
                      (list 'if (list 'and (list 'number? (->p %)) (list '>= (->p %) 0))
                            [c d (list '+ o (list '* (->st %) (->p %)))]
                            [(list 'conj c (->sh %)) (list 'conj d (->st %)) o]))
                    rdim)
                 ]
             ((@~'ctor-registry (count ~c)) ~'_get ~'_set ~'_data ~o ~d ~c)))
         ~'Object
         (~'toString
           [_#]
           (pr-str
            {:data ~'_data :length (* ~@shapes) :offset ~'_offset
             :shape [~@shapes] :stride [~@strides]})))

       (defn ~(with-meta raw-name {:export true})
         [get# set# data# o# [~@strides] [~@shapes]]
         (new ~type-name get# set# data# o# ~@strides ~@shapes))

       (swap! ~'ctor-registry assoc ~dim ~raw-name))))
