(ns clojure.core.matrix.impl.ndarray
  (:refer-clojure :exclude [vector?])
  (:require [clojure.walk :as w]
            [clojure.core.matrix.impl.default]
            [clojure.core.matrix.impl.ndarray-magic :as magic]
            [clojure.core.matrix.protocols :as mp]
            [clojure.core.matrix.implementations :as imp]
            [clojure.core.matrix.impl.mathsops :as mops]
            [clojure.core.matrix.multimethods :as mm]
            [clojure.core.matrix.utils :refer :all]
            [clojure.core.matrix.impl.ndarray-macro :refer :all]))

;; (error "NDArray loaded!")

(set! *warn-on-reflection* true)
(set! *unchecked-math* true)

;; **NOTE**: this file was generated as follows:
;;
;;     lein marg -d . -f ndarray.html -n "NDArray for core.matrix" \
;;         -D "git version: `git log --pretty=format:'%H' -n 1`" \
;;         src/main/clojure/clojure/core/matrix/impl/ndarray.clj \
;;         src/main/clojure/clojure/core/matrix/impl/ndarray_magic.clj \
;;         src/main/clojure/clojure/core/matrix/impl/ndarray_macro.clj
;;
;; [py1]: http://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html
;; [py2]: http://scipy-lectures.github.io/advanced/advanced_numpy/
;; [mult1]: http://penguin.ewu.edu/~trolfe/MatMult/MatOpt.html
;; [mult2]: https://code.google.com/p/efficient-java-matrix-library/source/browse/trunk/src/org/ejml/alg/dense/mult/MatrixMatrixMult.java
;; [lu]: https://github.com/vitaut/gsl/blob/master/linalg/lu.c
;; [lupy]: https://gist.github.com/si14/3ae62f1dd28703e1a6ca


;; ## Intro
;;
;; This is an implementation of strided N-Dimensional array (or NDArray
;; for short). The underlying structure is similar to NumPy's [[py1]],
;; [[py2]].
;;
;; Default striding scheme is row-major, as in C. It's efficient on most
;; modern processors. However, some functions (like `main-diagonal`) return
;; "views" into original array with different striding scheme and this can
;; affect performance of expensive functions like matrix multiplication.
;; To avoid this penalty please use `clone` before the application, because
;; `clone` always packs data in memory.
;;
;; ## The "magic"
;;
;; For efficiency, Java's primitive arrays should be used. However, there
;; is no easy way to write code that is polymorphic and uses primitives.
;; Therefore, the majority of code in this namespace is autogenerated.
;; Here is how this works (please consult `ndarray-magic` for further
;; details):
;;
;; 1. `magic/init` defines different specializations, defined by a name
;;    (keyword) and a map of names that can be substituted into generated
;;    code;
;; 2. `magic/with-magic` and `magic/extend-types` collect forms that are
;;    passed inside of them, process that forms replacing symbols that
;;    are ending at "#" in forms themselves and metadata to corresponding
;;    values in specialization map (so `^typename# a` will be replaced by
;;    `^NDArray a` in the case of specialization to :object) and store them
;;    in atoms. There is an additional trick: when function is defined in
;;    `with-magic` block, more than one function will be generated; for
;;    function `foo` there will be functions `foo-long`, `foo-float` and so
;;    on, depending on `:fn-suffix` of specialization map. To access such
;;    autogenerated functions from specialized code, the form `foo#t` can by
;;    used (`t` is for `type` here). For example, `(empty-ndarray#t [2 2])`
;;    will be specialized to `(empty-ndarray-double [2 2])`.
;; 3. `magic/spit-code` will emit all code collected on step 2 and
;;    specialized using the map from step 1. The downside of this approach
;;    is that linenumbers in stacktraces are messed up.
;;
;; ## Useful macros
;;
;; There is a couple of macros defined in `ndarray-macro` that are used
;; in a lot of the code here:
;;
;; * `expose-ndarrays` is an anaphoric macro that takes a list of names of
;;   NDArrays and provides properly hinted bindings for NDArray fields. For
;;   example, `(expose-ndarrays [a b])` will bind names `a-shape`, `b-shape`,
;;   `a-offset` and so on in it's body. Some macros like `get-2d*` assume
;;   that they are used inside `expose-ndarray`.
;; * `loop-over` is an anaphoric macro to loop over provided list of
;;   ndarrays in efficient manner. It contains `expose-ndarrays` in itself,
;;   so there is no need to wrap it in this macro. In addition to names bound
;;   by `expose-ndarrays`, `loop-over` will also bind `X-idx` for every
;;   NDArray that was provided. For example, the following code
;;
;;   `(loop-over [a b] (aset a-data a-idx (aget b-data b-idx))`
;;
;;   will copy the contents of b into a regardless of b's or a's striding
;;   schemes. `loop-over` is most efficient when provided NDArray are
;;   "packed", that is, they are strided as by default.
;; * `fold-over` is like `loop-over`, but accepts an accumulator, binds
;;   an additional name `loop-acc` and updates value of accumulator at each
;;   iteration. Here is an example of `fold-over`:
;;
;;   `(fold-over [m] 0 (+ loop-acc (aget m-data m-idx)))`
;;
;; `loop-over` and `fold-over` can be used from outside of NDArray,
;; providing a very efficient way to perform complex element-wise operations.
;; To do specialization in this case, use  `magic/specialize`. An example
;; can be found on `test-ndarray-implementation` namespace.
;;
;; ## Default striding schemes
;;
;; When we are using C-like striding (row-major ordering), the formula for
;; strides is as follows (see [[py2]]):
;;
;; $$shape = (d\_1, d\_2, \dots, d\_N)$$
;; $$strides = (s\_1, s\_2, \dots, s\_N)$$
;; $$s_j = d\_{j+1} d\_{j+2} \dots d\_N$$

(defn c-strides [shape]
  (let [shape-ints (int-array shape)
        n (alength shape-ints)]
    (if-not (pos? n)
      (int-array [1])
      (let [strides (int-array n)]
        (aset strides (dec n) (int 1))
        (c-for [i (int (- n 2)) (>= i 0) (dec i)]
          (aset strides i (* (aget strides (inc i))
                             (aget shape-ints (inc i)))))
        strides))))

;; We can easily check for correctness here using NumPy:
;;
;;     np.empty([4, 3, 2], dtype=np.int8, order="c").strides
;;     # (6, 2, 1)
;;
;; An actual test can be found in namespace test-ndarray-implementation.

;; ## "Magic" initialization
;;
;; Here we define what should be substituted into specialized versions of
;; NDArray. Please read "The 'magic'" section for better explanation of
;; what's going on.

(magic/init
 {:object {:regname :ndarray
           :fn-suffix nil
           :typename 'NDArray
           :array-tag 'objects
           :array-cast 'object-array
           :type-cast 'identity
           :type-object java.lang.Object}
;  :long {:regname :ndarray-long
;         :fn-suffix 'long
;         :typename 'NDArrayLong
;         :array-tag 'longs
;         :array-cast 'long-array
;         :type-cast 'long
;         :type-object Long/TYPE}
;  :float {:regname :ndarray-float
;          :fn-suffix 'float
;          :typename 'NDArrayFloat
;          :array-tag 'floats
;          :array-cast 'float-array
;          :type-cast 'float
;          :type-object Float/TYPE}
  :double {:regname :ndarray-double
           :fn-suffix 'double
           :typename 'NDArrayDouble
           :array-tag 'doubles
           :array-cast 'double-array
           :type-cast 'double
           :type-object Double/TYPE}})

;; ## The structure
;;
;; The structure is identical to NumPy's. Strides are stored explicitly;
;; this allows to perform some operations like transposition or
;; broadcasting to be done on "strides" field alone, avoiding touching
;; the data itself.

(magic/with-magic
  [:long :float :double :object]
  (deftype typename#
      [^array-tag# data
       ^int ndims
       ^ints shape ;; this is compiled to type Object, WAT
       ^ints strides
       ^int offset]))

;; ## Constructors
;;
;; In this section some constructors are provided.

(magic/with-magic
  [:long :float :double :object]
  (defn empty-ndarray
    "Returns an empty NDArray of given shape"
    [shape]
    (let [shape (int-array shape)
          ndims (count shape)
          strides (c-strides shape)
          len (reduce * shape)
          data (array-cast# len)
          offset 0]
      (new typename# data ndims shape strides offset))))

(magic/with-magic
  [:long :float :double]
  (defn empty-ndarray-zeroed
    "Returns an empty NDArray of given shape, guaranteed to be zeroed"
    [shape]
    (let [shape (int-array shape)
          ndims (count shape)
          strides (c-strides shape)
          len (reduce * shape)
          data (array-cast# len)
          offset 0]
      (new typename# data ndims shape strides offset))))

(magic/with-magic
  [:object]
  (defn empty-ndarray-zeroed
    "Returns an empty NDArray of given shape, guaranteed to be zeroed"
    [shape]
    (let [shape (int-array shape)
          ndims (count shape)
          strides (c-strides shape)
          len (reduce * shape)
          data (array-cast# len)
          offset 0
          m (new typename# data ndims shape strides offset)]
      (java.util.Arrays/fill ^objects data
                             (cast java.lang.Object 0.0))
      m)))

(magic/with-magic
  [:long :float :double :object]
  (defn ndarray
    "Returns NDArray with given data, preserving shape of the data"
    [data]
    (let [mtx (empty-ndarray#t (mp/get-shape data))]
      (mp/assign! mtx data)
      mtx)))

(magic/with-magic
  [:long :float :double :object]
  (defn arbitrary-slice
    "Returns an arbitrary slice of provided NDArray along the given
     dimension and slice index (from 0 to the size of NDArray along this
     dimension)"
    [^typename# m dim idx]
    (iae-when-not (> (.ndims m) 0)
      (str "can't get slices on [" (.ndims m) "]-dimensional object"))
    (let [^array-tag# data (.data m)
          ndims (.ndims m)
          ^ints shape (.shape m)
          ^ints strides (.strides m)
          offset (.offset m)
          new-ndims (dec ndims)
          new-shape (abutnth dim shape)
          new-strides (abutnth dim strides)
          new-offset (+ offset (* idx (aget strides dim)))]
      (iae-when-not (< new-offset (alength data))
        "new offset is larger than the array itself")
      (new typename# data new-ndims new-shape new-strides new-offset))))

(magic/with-magic
  [:long :float :double :object]
  (defn row-major-slice
    "Specialized constructor for slicing along major dimension (like rows
     for matrices)"
    [^typename# m idx]
    (arbitrary-slice#t m 0 idx)))

;; TODO: this should be a macro, transpose takes way too much time because
;;       of this function (64ns just for restriding)
(magic/with-magic
  [:long :float :double :object]
  (defn reshape-restride
    "Returns a view of provided NDArray with different shape, strides and
     offset"
    [^typename# m new-ndims ^ints new-shape ^ints new-strides new-offset]
    (let [^array-tag# data (.data m)
          new-ndims (int new-ndims)
          new-offset (int new-offset)]
      (new typename# data new-ndims new-shape new-strides new-offset))))

;; ## Seqable

(magic/with-magic
  [:long :float :double :object]
  (defn row-major-seq
    "Returns a sequence of row-major slices of given NDArray. Always
     returns NDArrays, even on 1d vector (0d NDArrays in this case)"
    [^typename# m]
    (iae-when-not (pos? (.ndims m))
      (str "can't get slices on [" (.ndims m) "]-dimensional object"))
    (let [^ints shape (.shape m)]
      (map (partial row-major-slice#t m) (range (aget shape 0))))))

(magic/with-magic
  [:long :float :double :object]
  (defn row-major-seq-no0d
    "Like row-major-seq but drops NDArray's wrapping on 0d-slices, so
     `(row-major-seq-no0d (ndarray [1 2 3]))` will return sequence
     `1 2 3` instead of a sequence of 0d NDArrays"
    [^typename# m]
    (if (== (.ndims m) 1)
      (map mp/get-0d (row-major-seq#t m))
      (row-major-seq#t m))))

;; ## LU-decomposition-related stuff
;;
;; In this section a few functions are defined that allow NDArray to perform
;; quite efficient LU-decomposition and find inverse and determinant. This
;; functions use mutatation heavily to avoid an extra allocation, so use with
;; caution.
;;
;; *NOTE*: to extend this to :object, clojure.math.numeric-tower is needed

(magic/with-magic
  [:double]
  (defn lu-decompose!
    "LU-decomposition of a matrix into P A = L U. Saves L and U into
     the input matrix as follows: L is a lower triangular part of it,
     with diagonal omitted (they are all equal to 1); U is an upper
     triangular part. P returned as a primitive int permutation array.
     Returns a vector of two values: first is integer (-1)^n, where n is
     a number of permutations, and second is a primitive int permutations
     array.
     This function is translated from GNU linear algebra library, namely
     `gsl_linalg_LU_decomp` (see [[lu]] for example). Python translation that
     was used to implement this can be found at [[lupy]]."
    [^typename# m]
    (expose-ndarrays [m]
      (iae-when-not (== m-ndims 2)
        "lu-decompose! can operate only on matrices")
      (iae-when-not (== (aget m-shape 0) (aget m-shape 1))
        "lu-decompose! can operate only on square matrices")
      (let [n (aget m-shape 0)
            ;; permutations array
            permutations (int-array (range n))
            ;; sign of determinant
            sign (int-array [1])]
        ;; for all columns
        (c-for [j (int 0) (< j (dec n)) (inc j)]
          (let [i-pivot
                (loop [i (inc j)
                       max-i j
                       max (Math/abs (aget-2d* m j j))]
                  (if (< i n)
                    (let [current (Math/abs (aget-2d* m i j))]
                      (if (< max current)
                        (recur (inc i) i current)
                        (recur (inc i) max-i max)))
                    (do (iae-when-not (not (== max 0))
                          "lu-decompose can't decompose singular matrix")
                        max-i)))
                pivot (aget-2d* m i-pivot j)]
            ;; when maximum element is not on diagonal, swap rows, update
            ;; permutations and permutation counter
            (when-not (== i-pivot j)
              (c-for [k (int 0) (< k n) (inc k)]
                (let [swap (aget-2d* m i-pivot k)]
                  (aset-2d* m i-pivot k (aget-2d* m j k))
                  (aset-2d* m j k swap)))
              (let [swap (aget permutations i-pivot)]
                (aset permutations i-pivot j)
                (aset permutations j swap))
              (aset sign 0 (* -1 (aget sign 0))))
            (c-for [i (inc j) (< i n) (inc i)]
              (let [scaled (/ (aget-2d* m i j) pivot)]
                (aset-2d* m i j scaled)
                (c-for [k (inc j) (< k n) (inc k)]
                  (aset-2d* m i k (- (aget-2d* m i k)
                                     (* (aget-2d* m j k)
                                        scaled))))))))
        [(aget sign 0) permutations]))))

(magic/with-magic
  [:double]
  (defn lu-solve!
    "Solves a system of linear equations Ax = b using LU-decomposition.
     lu should be a decomposition of A in a form produced by lu-decompose!,
     permutations should be a primitive int vector of permutations (as from
     lu-decompose!), x should be a primitive vector of right hand sides. After
     an execution of this function x will be replaced with solution vector."
    [^typename# lu ^ints permutations ^array-tag# x]
    (expose-ndarrays [lu]
      (iae-when-not (== lu-ndims 2)
        "lu-solve! can operate only on matrices")
      (iae-when-not (== (aget lu-shape 0) (aget lu-shape 1))
        "lu-solve! can operate only on square matrices")
      (let [n (aget lu-shape 0)]
        ;; Solving Ly = b using forward substitution
        (c-for [i (int 0) (< i n) (inc i)]
          (loop [j (int 0)
                 s (aget x i)]
            (if (< j i)
              (recur (inc j) (- s (* (aget-2d* lu i j)
                                     (aget x j))))
              (aset x i s))))
        ;; Solving Ux = y using backward substitution
        (aset x (dec n) (/ (aget x (dec n))
                               (aget-2d* lu (dec n) (dec n))))
        (c-for [i (- n 2) (>= i 0) (dec i)]
          (loop [j (inc i)
                 s (aget x i)]
            (if (< j n)
              (recur (inc j) (- s (* (aget-2d* lu i j)
                                     (aget x j))))
              (aset x i (/ s (aget-2d* lu i i))))))
        nil))))

(magic/with-magic
  [:double]
  (defn invert
    "Inverts given matrix. Returns new one"
    [^typename# m]
    (expose-ndarrays [m]
      (iae-when-not (== m-ndims 2)
        "invert can operate only on matrices")
      (iae-when-not (== (aget m-shape 0) (aget m-shape 1))
        "invert can operate only on square matrices")
      (let [n (aget m-shape 0)
            ^array-tag# x (array-cast# n)
            ^typename# lu (mp/clone m)
            ^typename# m-inverted (empty-ndarray#t [n n])
            lu-output (lu-decompose!#t lu) ; lu-decompose! mutates lu
            ^ints permutations (second lu-output)]
        (expose-ndarrays [m-inverted]
          (c-for [i (int 0) (< i n) (inc i)]
            (c-for [j (int 0) (< j n) (inc j)]
              (if (== (aget permutations j) i)
                (aset x j (type-cast# 1))
                (aset x j (type-cast# 0))))
            (lu-solve!#t lu permutations x)
            (c-for [j (int 0) (< j n) (inc j)]
              (aset-2d* m-inverted j i (aget x j)))))
        m-inverted))))

(magic/with-magic
  [:double]
  (defn determinant
    "Finds a determinant of a given matrix"
    [^typename# m]
    (expose-ndarrays [m]
      (iae-when-not (== m-ndims 2)
        "invert can operate only on matrices")
      (iae-when-not (== (aget m-shape 0) (aget m-shape 1))
        "invert can operate only on square matrices")
      (let [n (aget m-shape 0)
            ^typename# lu (mp/clone m)
            lu-output (lu-decompose!#t lu) ; lu-decompose! mutates lu
            sign (first lu-output)]
        (expose-ndarrays [lu]
          (loop [i (int 0)
                 det (type-cast# sign)]
            (if (< i n)
              (recur (inc i) (* det (aget-2d* lu i i)))
              det)))))))

(magic/extend-types
  [:long :float :double :object]
  java.lang.Object
    (toString [m]
       (str (mp/persistent-vector-coerce m)))

  clojure.lang.Seqable
    (seq [m]
      (row-major-seq-no0d#t m))

;; This interface is used to inform user that this type is sequential --
;; that is, it has a defined and constant element order

  clojure.lang.Sequential

;; ## Mandatory protocols for all matrix implementations
;;
;; This bunch of protocols is mandatory for all core.matrix implementations.

  mp/PImplementation
    (implementation-key [m] regname#)
    (meta-info [m]
      {:doc "An implementation of strided N-Dimensional array"})
    (new-vector [m length]
      (empty-ndarray-zeroed#t [length]))
    (new-matrix [m rows columns]
      (empty-ndarray-zeroed#t [rows columns]))
    (new-matrix-nd [m shape]
      (empty-ndarray#t shape))
    (construct-matrix [m data]
      (ndarray#t data))
    (supports-dimensionality? [m dims]
      true)

  mp/PDimensionInfo
    (get-shape [m] (vec shape))
    (is-vector? [m] (= 1 ndims))
    (is-scalar? [m] false)
    (dimensionality [m] ndims)
    (dimension-count [m x] (aget shape x))

  mp/PIndexedAccess
    (get-1d [m x]
    ;; TODO: check if this check is really needed
      (iae-when-not (= 1 (.ndims m))
        "can't use get-1d on non-vector")
      (aget data (+ offset (* (aget strides 0) x))))
    (get-2d [m x y]
      (iae-when-not (= 2 (.ndims m))
        "can't use get-2d on non-matrix")
      (let [idx (+ offset
                   (* (aget strides 0) (int x))
                   (* (aget strides 1) (int y)))]
        (aget data idx)))
    (get-nd [m indexes]
      (iae-when-not (= (count indexes) ndims)
        "index count should match dimensionality")
      (let [idxs (int-array indexes)]
        (aget-nd data strides offset idxs)))

;; PIndexedSetting is for non-mutative update of a matrix. Here we emulate
;; "non-mutative" setting by making a mutable copy and mutating it.

  mp/PIndexedSetting
    (set-1d [m row v]
      (let [m-new (mp/clone m)]
        (mp/set-1d! m-new row v)
        m-new))
    (set-2d [m row column v]
      (let [m-new (mp/clone m)]
        (mp/set-2d! m-new row column v)
        m-new))
    (set-nd [m indexes v]
      (let [m-new (mp/clone m)]
        (mp/set-nd! m-new indexes v)
        m-new))
    (is-mutable? [m] true)

;; ## Mandatory protocols for mutable matrix implementations
;;
;; In this section, protocols that help to work with mutable matrices are
;; defined. It is worth noting that in the previous section, namely
;; PIndexedSetting protocol implementation, we used mutative operations,
;; therefore this section is required for previous section to work.

  mp/PIndexedSettingMutable
    (set-1d! [m x v]
      (when-not (== 1 ndims)
        (throw (IllegalArgumentException. "can't use set-1d! on non-vector")))
      (aset data (+ offset x) (type-cast# v)))
    (set-2d! [m x y v]
      (when-not (== 2 ndims)
        (throw (IllegalArgumentException. "can't use set-2d! on non-matrix")))
      (let [idx (+ (* (aget strides 0) (int x))
                   (* (aget strides 1) (int y))
                   offset)]
        (aset data idx (type-cast# v))))
    (set-nd! [m indexes v]
      (when-not (= (count indexes) ndims)
        (throw (IllegalArgumentException.
                "index count should match dimensionality")))
      (let [idxs (int-array indexes)]
        (aset-nd data strides offset idxs (type-cast# v))))

  mp/PMatrixCloning
    (clone [m]
      (let [a (empty-ndarray#t (.shape m))]
        (loop-over [m a]
          (aset a-data a-idx (aget m-data m-idx)))
        a))

;; ## Optional protocols
;;
;; Following protocols are implemented for performance or better behaviour.

  mp/PConversion
    (convert-to-nested-vectors [m]
      (case ndims
        0 (aget data offset)
        1 (let [n (aget shape 0)
                stride (aget strides 0)]
            (loop [idx (int offset)
                   cnt (int 0)
                   res []]
              (if (< cnt n)
                (recur (+ idx stride) (inc cnt) (conj res (aget data idx)))
                res)))
       (mapv mp/convert-to-nested-vectors
             (mp/get-major-slice-seq m))))

  mp/PTypeInfo
    (element-type [m] type-object#)

  mp/PMutableMatrixConstruction
    (mutable-matrix [m] (mp/clone m))

  mp/PZeroDimensionAccess
    (get-0d [m] (aget data offset))
    (set-0d! [m v] (aset data offset (type-cast# v)))

  mp/PSpecialisedConstructors
    (identity-matrix [m n]
      (let [^typename# new-m (empty-ndarray#t [n n])
            ^array-tag# new-m-data (.data new-m)]
        (when (= type-object# java.lang.Object)
          (c-for [i (int 0) (< i (* n n)) (inc i)]
            (aset new-m-data i (type-cast# 0))))
        (c-for [i (int 0) (< i n) (inc i)]
          (aset new-m-data (+ i (* i n)) (type-cast# 1)))
        new-m))
    (diagonal-matrix [m diag]
      (let [prim-diag (array-cast# diag)
            n (alength prim-diag)
            ^typename# new-m (empty-ndarray#t [n n])
            ^array-tag# new-m-data (.data new-m)]
        (when (= type-object# java.lang.Object)
          (c-for [i (int 0) (< i (* n n)) (inc i)]
            (aset new-m-data i (type-cast# 0))))
        (c-for [i (int 0) (< i n) (inc i)]
          (aset new-m-data (int (+ i (* i n)))
                (type-cast# (aget prim-diag i))))
        new-m))

  ;; mp/PCoercion
  ;;   (coerce-param [m param])
  ;; mp/PBroadcast
  ;;   (broadcast [m target-shape])
  ;; mp/PBroadcastLike
  ;;   (broadcast-like [m a])

  mp/PBroadcastCoerce
    (broadcast-coerce [m a]
      (let [^typename# a (if (instance? typename# a) a (mp/coerce-param m a))]
        (mp/broadcast-like m a)))

  ;; mp/PReshaping
  ;;   (reshape [m shape])

  mp/PMatrixSlices
    (get-row [m i]
      (iae-when-not (== ndims 2)
        "get-row is applicable only for matrices")
      (row-major-slice#t m i))
    (get-column [m i]
      (iae-when-not (== ndims 2)
        "get-column is applicable only for matrices")
      (arbitrary-slice#t m 1 i))
    (get-major-slice [m i]
      (row-major-slice#t m i))
    (get-slice [m dimension i]
               ;;get-slice requires to return a scalar for a slice of a 1-dim
               ;;array
               (let [res (arbitrary-slice#t m dimension i)]
                 (if (= 1 ndims)
                   (mp/get-0d res)
                   res)))

  mp/PSubVector
    (subvector [m start length]
      (iae-when-not (== ndims 1)
        "subvector is applicable only for vectors")
      (let [new-shape (int-array 1 (int length))
            new-offset (+ offset (* (aget strides 0) start))]
        (reshape-restride#t m ndims new-shape strides new-offset)))

  mp/PSliceView
   (get-major-slice-view [m i] (row-major-slice#t m i))

  mp/PSliceSeq
    (get-major-slice-seq [m] (seq m))

  ;; mp/PSliceJoin
  ;;   (join [m a])

  ;; TODO: generalize for higher dimensions (think tensor trace)
  ;; TODO: make it work for rectangular matrices
  ;; TODO: clarify docstring about rectangular matrices
  ;; TODO: clarify docstring about higher dimensions
  mp/PMatrixSubComponents
    (main-diagonal [m]
      (let [new-ndims (int 1)
            min-shape (min (aget shape 0) (aget shape 1))
            new-shape (int-array 1 min-shape)
            new-strides (int-array 1 (* (inc (aget shape 1))
                                        (aget strides 1)))]
        (reshape-restride#t m new-ndims new-shape new-strides offset)))

  ;; mp/PAssignment

  ;; TODO: will not work for stride != 1
  mp/PMutableFill
    (fill! [m v]
      (loop-over [m]
        (aset m-data m-idx (type-cast# v))))

  ;; mp/PDoubleArrayOutput

  ;; mp/PMatrixEquality
  ;;   (matrix-equals [a b]
  ;;     (if (identical? a b)
  ;;       true
  ;;       (if-not (instance? typename# b)
  ;;         ;; Coerce second argument to first one
  ;;         (mp/matrix-equals a (mp/coerce-param a b))
  ;;         ;; Fast path, types are same
  ;;         (loop-over-2d [a b] true
  ;;           (if (== (aget a-data a-idx)
  ;;                   (aget b-data b-idx))
  ;;             (continue true)
  ;;             (break false))))))

  mp/PMatrixEquality
    (matrix-equals [a b]
      (if (identical? a b)
        true
        (if-not (instance? typename# b)
          ;; Coerce second argument to first one
          (mp/matrix-equals a (mp/coerce-param a b))
          ;; Fast path, types are same
          (let [^typename# b b
                ^ints shape-b (.shape b)
                ^array-tag# data-b (.data b)
                ^ints strides-b (.strides b)
                offset-b (.offset b)]
            (if (not (java.util.Arrays/equals shape shape-b))
              false
              (case ndims
                0 (== (aget data 0) (aget data-b 0))
                1 (let [step-a (aget strides 0)
                        step-b (aget strides-b 0)
                        end (+ offset (* (aget shape 0) step-a))]
                    (loop [i-a offset
                           i-b offset-b]
                      (if (< i-a end)
                        (if (== (aget data i-a) (aget data-b i-b))
                          (recur (+ i-a step-a) (+ i-b step-b))
                          false)
                        true)))
                2 (let [nrows (aget shape 0)
                        ncols (aget shape 1)
                        step-col-a (aget strides 1)
                        step-row-a (- (aget strides 0)
                                      (* step-col-a ncols))
                        step-col-b (aget strides-b 1)
                        step-row-b (- (aget strides-b 0)
                                      (* step-col-b ncols))]
                    (loop [i-a offset
                           i-b offset-b
                           row-a 0
                           col-a 0]
                      (if (< row-a nrows)
                        (if (< col-a ncols) 
                          (if (== (aget data i-a) (aget data-b i-b))
                            (recur (+ i-a step-col-a) (+ i-b step-col-b)
                                   row-a (inc col-a))
                            false)
                          (recur (+ i-a step-row-a) (+ i-b step-row-b)
                                   (inc row-a) 0))
                        true)))
                ;; N-dimensional case
                (let [end (+ offset
                             (areduce shape i s (int 0)
                                      (+ s (* (aget shape i)
                                              (aget strides i)))))]
                  (loop [idxs (int-array ndims)]
                    (if (== (aget-nd data strides offset idxs)
                            (aget-nd data-b strides-b offset-b idxs))
                      (if (loop [dim (int (dec ndims))]
                            (if (>= dim 0)
                              (if (< (aget idxs dim) (dec (aget shape dim)))
                                (do (aset idxs dim (inc (aget idxs dim)))
                                    true)
                                (do (aset idxs dim (int 0))
                                    (recur (dec dim))))
                              false))
                        (recur idxs)
                        true)
                      false)))))))))

  ;; TODO: optimize on smaller arrays
  ;; TODO: replace stride multiplication with addition, this is faster
  ;; (one can use explicit addition of stride instead of (inc i))
  ;; TODO: implement transposition of argument for faster access
  ;; For algorithm see [[mult1]], for inspiration check out [[mult2]]

  mp/PMatrixMultiply
   (matrix-multiply [a b]
     (if-not (instance? typename# b)
       ;; Coerce second argument to first one
       (mp/matrix-multiply a (mp/coerce-param a b))
       ;; Fast path, types are same
       (let [^typename# b b
             a-ndims ndims
             b-ndims (.ndims b)
             ^ints b-shape (.shape b)]
         (cond
          (== b-ndims 0) (mp/scale a b)
          (and (== a-ndims 1) (== b-ndims 1))
          (mp/inner-product a b)
          (and (== a-ndims 1) (== b-ndims 2))
          (let [b-rows (aget b-shape (int 0))
                b-cols (aget b-shape (int 1))]
            (mp/reshape (mp/matrix-multiply (mp/reshape a [1 b-rows]) b)
                        [b-cols]))
          (and (== a-ndims 2) (== b-ndims 1))
          (let [a-cols (aget shape (int 1))
                a-rows (aget shape (int 0))]
            (mp/reshape (mp/matrix-multiply a (mp/reshape b [a-cols 1]))
                        [a-rows]))
          (and (== a-ndims 2) (== b-ndims 2))
          (let [^typename# c (empty-ndarray-zeroed#t
                              [(aget shape (int 0))
                               (aget (ints (.shape b)) (int 1))])]
            (expose-ndarrays [a b c]
              (let [a-rows (aget a-shape (int 0))
                    a-cols (aget a-shape (int 1))
                    b-rows (aget b-shape (int 0))
                    b-cols (aget b-shape (int 1))]
                (do (iae-when-not (== a-cols b-rows)
                      (str "dimension mismatch: "
                           [a-rows a-cols] "x" [b-rows b-cols]))
                    (c-for [i (int 0) (< i a-rows) (inc i)
                            k (int 0) (< k a-cols) (inc k)]
                      (let [t (aget-2d a-data a-strides a-offset i k)]
                        (c-for [j (int 0) (< j b-cols) (inc j)]
                          (aadd-2d c-data c-strides c-offset i j
                                   (* t (aget-2d b-data b-strides b-offset k j))))))
                    c))))))))
   (element-multiply [a b]
     (if-not (instance? typename# b)
       (if (number? b)
         (mp/scale a b)
         ;; Coerce second argument to first one
         (mp/element-multiply a (mp/coerce-param a b)))
       ;; Fast path, types are same
       (expose-ndarrays [a b]
         (if-not (java.util.Arrays/equals a-shape b-shape)
           (let [[a b] (mp/broadcast-compatible a b)]
                (mp/element-multiply a b))
           (let [c (mp/clone a)]
             (loop-over [b c]
               (aset c-data c-idx
                     (* (aget c-data c-idx)
                        (aget b-data b-idx))))
             c)))))

  ;; mp/PMatrixProducts

  mp/PAddProduct
    (add-product [m a b]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/broadcast-coerce m a))
            ^typename# b (if (instance? typename# b) b
                             (mp/broadcast-coerce m b))]
        (iae-when-not (and (java.util.Arrays/equals (ints (.shape m))
                                                    (ints (.shape a)))
                           (java.util.Arrays/equals (ints (.shape a))
                                                    (ints (.shape b))))
          "add-product operates on arrays of equal shape")
        (let [^typename# c (mp/clone m)]
          (loop-over [a b c]
                     (aset c-data c-idx (+ (aget c-data c-idx)
                                           (* (aget a-data a-idx)
                                              (aget b-data b-idx)))))
          c)))

  mp/PAddProductMutable
    (add-product! [m a b]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/broadcast-coerce m a))
            ^typename# b (if (instance? typename# b) b
                             (mp/broadcast-coerce m b))]
        (iae-when-not (and (java.util.Arrays/equals (ints (.shape m))
                                                    (ints (.shape a)))
                           (java.util.Arrays/equals (ints (.shape a))
                                                    (ints (.shape b))))
          "add-product operates on arrays of equal shape")
        (let [^typename# m m]
          (loop-over [a b m]
                     (aset m-data m-idx (+ (aget m-data m-idx)
                                           (* (aget a-data a-idx)
                                              (aget b-data b-idx)))))
          m)))

  mp/PAddScaledProduct
    (add-scaled-product [m a b factor]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))
            ^typename# b (if (instance? typename# b) b
                             (mp/coerce-param m b))]
        (iae-when-not (and (java.util.Arrays/equals (ints (.shape m))
                                                    (ints (.shape a)))
                           (java.util.Arrays/equals (ints (.shape a))
                                                    (ints (.shape b))))
          "add-scaled-product operates on arrays of equal shape")
        (let [^typename# c (mp/clone m)]
          (expose-ndarrays [a b c]
            (let [a-rows (aget a-shape (int 0))
                  a-cols (aget a-shape (int 1))
                  b-rows (aget b-shape (int 0))
                  b-cols (aget b-shape (int 1))
                  factor (type-cast# factor)]
              (do (c-for [i (int 0) (< i a-rows) (inc i)
                          k (int 0) (< k a-cols) (inc k)]
                    (let [t (* factor
                               (aget-2d a-data a-strides a-offset i k))]
                      (c-for [j (int 0) (< j b-cols) (inc j)]
                        (aadd-2d c-data c-strides (int 0) i j
                                 (* t (aget-2d b-data b-strides b-offset k j))))))
                  c))))))

  mp/PAddScaledProductMutable
    (add-scaled-product! [m a b factor]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))
            ^typename# b (if (instance? typename# b) b
                             (mp/coerce-param m b))]
        (iae-when-not (and (java.util.Arrays/equals (ints (.shape m))
                                                    (ints (.shape a)))
                           (java.util.Arrays/equals (ints (.shape a))
                                                    (ints (.shape b))))
          "add-scaled-product! operates on arrays of equal shape")
        (expose-ndarrays [a b m]
          (let [a-rows (aget a-shape (int 0))
                a-cols (aget a-shape (int 1))
                b-rows (aget b-shape (int 0))
                b-cols (aget b-shape (int 1))
                factor (type-cast# factor)]
            (do (c-for [i (int 0) (< i a-rows) (inc i)
                        k (int 0) (< k a-cols) (inc k)]
                  (let [t (* factor
                             (aget-2d a-data a-strides a-offset i k))]
                    (c-for [j (int 0) (< j b-cols) (inc j)]
                      (aadd-2d m-data m-strides (int 0) i j
                               (* t (aget-2d b-data b-strides b-offset k j))))))
                m)))))

  mp/PAddScaled
    (add-scaled [m a factor]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))]
        (iae-when-not (java.util.Arrays/equals (ints (.shape m))
                                               (ints (.shape a)))
          "add-scaled operates on arrays of equal shape")
        (let [^typename# b (mp/clone m)]
          (loop-over [a b]
            (aset b-data b-idx (* (type-cast# factor)
                                  (aget a-data a-idx))))
          b)))

  mp/PAddScaledMutable
    (add-scaled! [m a factor]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))]
        (iae-when-not (java.util.Arrays/equals (ints (.shape m))
                                               (ints (.shape a)))
          "add-scaled operates on arrays of equal shape")
        (loop-over [m a]
          (aset m-data m-idx (* (type-cast# factor)
                                (aget a-data a-idx))))
        m))

  mp/PMatrixDivide
    (element-divide [m a]
       (let [[m a] (mp/broadcast-compatible m a)
             a (mp/coerce-param m a)
             m (mp/clone m)]
         (loop-over [m a]
           (aset m-data m-idx (type-cast#
                               (/ (aget m-data m-idx)
                                  (type-cast# (aget a-data a-idx))))))
         m))
    (element-divide [m]
      (let [a (mp/clone m)]
        (loop-over [a]
          (aset a-data a-idx (type-cast#
                              (/ (type-cast# 1)
                                 (aget a-data a-idx)))))
       a))

  ;; PMatrixMultiplyMutable

  ;; PVectorTransform

  mp/PMatrixScaling
    (scale [m factor]
      (let [a (mp/clone m)
            factor (type-cast# factor)]
        (loop-over [a]
          (aset a-data a-idx (type-cast#
                              (* (aget a-data a-idx)
                                 factor))))
        a))
    (pre-scale [m factor]
      (let [a (mp/clone m)
            factor (type-cast# factor)]
        (let []
          (loop-over [a]
           (aset a-data a-idx (type-cast#
                               (* factor
                                  (aget a-data a-idx))))))
        a))

  mp/PMatrixMutableScaling
    (scale! [m factor]
      (let [factor (type-cast# factor)]
        (loop-over [m]
          (aset m-data m-idx (type-cast#
                              (* (aget m-data m-idx)
                                 factor)))))
      m)
    (pre-scale! [m factor]
      (let [factor (type-cast# factor)]
        (loop-over [m]
         (aset m-data m-idx (type-cast#
                             (* factor
                                (aget m-data m-idx))))))
      m)

  mp/PMatrixAdd
    (matrix-add [m a]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))]
        (if-not (java.util.Arrays/equals (ints (.shape m)) (ints (.shape a)))
          (let [[m a] (mp/broadcast-compatible m a)]
            (mp/matrix-add m a))
          (let [b (mp/clone m)]
            (loop-over [a b]
              (aset b-data b-idx (+ (aget b-data b-idx)
                                    (aget a-data a-idx))))
            b))))
    (matrix-sub [m a]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))]
        (if-not (java.util.Arrays/equals (ints (.shape m)) (ints (.shape a)))
          (let [[m a] (mp/broadcast-compatible m a)]
            (mp/matrix-sub m a))
          (let [b (mp/clone m)]
            (loop-over [a b]
              (aset b-data b-idx (- (aget b-data b-idx)
                                    (aget a-data a-idx))))
            b))))

  mp/PMatrixAddMutable
    (matrix-add! [m a]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))]
        (if-not (java.util.Arrays/equals (ints (.shape m)) (ints (.shape a)))
          (let [[m a] (mp/broadcast-compatible m a)]
            (mp/matrix-add! m a))
          (do
            (loop-over [a m]
              (aset m-data m-idx (+ (aget m-data m-idx)
                                    (aget a-data a-idx))))
            m))))
    (matrix-sub! [m a]
      (let [^typename# a (if (instance? typename# a) a
                             (mp/coerce-param m a))]
        (if-not (java.util.Arrays/equals (ints (.shape m)) (ints (.shape a)))
          (let [[m a] (mp/broadcast-compatible m a)]
            (mp/matrix-sub! m a))
          (do
            (loop-over [a m]
              (aset m-data m-idx (- (aget m-data m-idx)
                                    (aget a-data a-idx))))
            m))))

  ;; mp/PSubMatrix
  ;; mp/PComputeMatrix

  mp/PTranspose
    (transpose [m]
      (let [new-shape (areverse shape)
            new-strides (areverse strides)]
        (reshape-restride#t m ndims new-shape new-strides offset)))

  ;; mp/PNumerical ;; similar to matrix-equals, needs longjump
  ;; mp/PVectorOps ;; needs fold-over
  ;; mp/PVectorCross
  ;; mp/PVectorDistance
  ;; mp/PVectorView ;; needs "packed" flag to be efficient
  ;; mp/PVectorisable ;; similar to PVectorView
  ;; mp/PMutableVectorOps ;; needs fold-over

  mp/PNegation
    (negate [m]
      (let [a (mp/clone m)]
        (loop-over [a]
          (aset a-data a-idx (* -1 (aget a-data a-idx))))
        a))

  ;; mp/PMatrixRank

  mp/PSummable
    (element-sum [m]
      ;; TODO: needs fold-over support for N-dimensional case
      (if (<= (mp/dimensionality m) 2)
        (fold-over [m] 0
                 (+ loop-acc (aget m-data m-idx)))
        (reduce (fn [acc a] (+ acc (mp/element-sum a))) 0.0 (mp/get-major-slice-seq m))))

  mp/PExponent
    (element-pow [m exp]
      (let [a (mp/clone m)]
        (loop-over [a]
          (aset a-data a-idx (type-cast# (Math/pow (aget a-data a-idx)
                                                   exp))))
        a))

  mp/PSquare
    (square [m]
      (let [a (mp/clone m)]
        (loop-over [a]
          (aset a-data a-idx (* (aget a-data a-idx)
                                (aget a-data a-idx))))
        a))

  ;; mp/PRowOperations ;; use mutable views

  ;; PMathsFunctions/PMathsFunctionsMutable are evaled below

  mp/PElementCount
    (element-count [m]
      (areduce shape i s (int 1)
               (* s (aget shape i))))

  ;; PFunctionalOperators
  ;; PMatrixPredicates ;; needs long-jump
  ;; PGenericValues
  ;; PGenericOperations
  )

;; For common math functions that operate elementwise (like sin or sqrt)
;; we simply generate code that uses `loop-over`.
(eval
  `(magic/extend-types
    [:long :float :double :object]
    mp/PMathsFunctions
    ~@(map
       (fn [[name func]]
         `(~name [~'m]
             (let [~'a (mp/clone ~'m)]
               (loop-over [~'a]
                 (aset ~'a-data ~'a-idx
                       (~'type-cast# (~func (double (aget ~'a-data ~'a-idx))))))
               ~'a)))
       mops/maths-ops)

    mp/PMathsFunctionsMutable
    ~@(map
       (fn [[name func]]
         `(~(symbol (str name "!")) [~'m]
             (loop-over [~'m]
               (aset ~'m-data ~'m-idx
                     (~'type-cast# (~func (double (aget ~'m-data ~'m-idx))))))
             ~'m))
       mops/maths-ops)))

(magic/extend-types
  [:double]
  mp/PMatrixOps
    (trace [m]
      (iae-when-not (== ndims 2)
        "trace operates only on matrices")
      (iae-when-not (== (aget shape 0) (aget shape 1))
        "trace operates only on square matrices")
      (mp/element-sum (mp/main-diagonal m)))
    (determinant [m]
      (iae-when-not (== ndims 2)
        "determinant operates only on matrices")
      (iae-when-not (== (aget shape 0) (aget shape 1))
        "determinant operates only on square matrices")
      (determinant#t m))
    (inverse [m]
      (iae-when-not (== ndims 2)
        "inverse operates only on matrices")
      (iae-when-not (== (aget shape 0) (aget shape 1))
        "inverse operates only on square matrices")
      (invert#t m)))
