(ns tablecloth.transfrom.projection
  (:require [tablecloth.api.utils :refer [column-names]]
            [tablecloth.api.dataset :refer [rows dataset columns]]
            [tablecloth.api.columns :refer [select-columns drop-columns add-or-replace-columns]]
            [tablecloth.api.group-by :refer [grouped? ungroup process-group-data]]            
            [fastmath.kernel :as k])
  (:import [smile.projection PCA ProbabilisticPCA KPCA GHA RandomProjection Projection]
           [smile.math.kernel MercerKernel]))

(set! *warn-on-reflection* true)

(defn- pca
  ([rows target-dims] (pca rows target-dims false))
  ([rows ^long target-dims cor?]
   (let [^PCA model (if cor? (PCA/cor rows) (PCA/fit rows))]
     (.setProjection model target-dims)
     model)))

(defn- pca-prob
  [rows target-dims]
  (ProbabilisticPCA/fit rows target-dims))

(defn- kpca
  [rows target-dims kernel kernel-params threshold]
  (let [k (cond
            (instance? MercerKernel kernel) kernel
            (fn? kernel) (k/smile-mercer kernel)
            :else (k/smile-mercer (apply k/kernel kernel kernel-params)))]
    (KPCA/fit rows k target-dims threshold)))

(defn- gha
  [rows target-dims learning-rate decay]
  (let [^GHA model (GHA. (count (first rows)) target-dims learning-rate)]
    (doseq [row rows]
      (.setLearningRate model (* decay (.getLearningRate model)))
      (.update model row))
    model))

(defn- random
  [rows target-dims]
  (let [cnt (count (first rows))]
    (RandomProjection/of cnt target-dims)))

(defn- build-model
  [rows algorithm target-dims {:keys [kernel kernel-params
                                      threshold learning-rate decay]
                               :or {kernel (k/kernel :gaussian)
                                    threshold 0.0001
                                    learning-rate 0.0001
                                    decay 0.995}}]
  (case algorithm
    :pca-cov (pca rows target-dims)
    :pca-cor (pca rows target-dims true)
    :pca-prob (pca-prob rows target-dims)
    :kpca (kpca rows target-dims kernel kernel-params threshold)
    :gha (gha rows target-dims learning-rate decay)
    :random (random rows target-dims)
    (pca rows target-dims)))

(defn- rows->array
  [ds names]
  (-> ds
      (select-columns names)
      (rows :as-double-arrays)))

(defn- array->ds
  [arr target-columns]
  (->> arr
       (map (partial zipmap target-columns))
       (dataset)))

(defn process-reduction
  [ds algorithm target-dims cnames target-columns {:keys [drop-columns? model]
                                                   :or {drop-columns? true}
                                                   :as opts}]
  (let [rows (rows->array ds cnames)
        ^Projection model (or model (build-model rows algorithm target-dims opts))
        ds-res (array->ds (.project model #^"[[D" rows) target-columns)]
    (-> (if drop-columns? (drop-columns ds cnames) ds)
        (add-or-replace-columns (columns ds-res :as-map)))))


(defn reduce-dimensions
  ([ds target-dims] (reduce-dimensions ds :pca target-dims))
  ([ds algorithm target-dims] (reduce-dimensions ds :type/numerical algorithm target-dims))
  ([ds columns-selector algorithm target-dims] (reduce-dimensions ds columns-selector algorithm target-dims {}))
  ([ds columns-selector algorithm target-dims {:keys [prefix parallel? model common-model?]
                                               :or {common-model? true
                                                    parallel? false}
                                               :as opts}]
   (let [cnames (column-names ds columns-selector)
         target-columns (map #(str (or prefix
                                       (name algorithm)) "-" %) (range))]
     (if (grouped? ds)
       (let [opts (assoc opts :model (or model (if common-model?
                                                 (build-model (-> (ungroup ds)
                                                                  (rows->array cnames)) algorithm target-dims opts)
                                                 model)))]
         (process-group-data ds (fn [ds]
                                  (process-reduction ds algorithm target-dims cnames target-columns opts)) parallel?))
       (process-reduction ds algorithm target-dims cnames target-columns opts)))))



(def iris (dataset "data/iris.csv"))

#_(stats/variance ((reduce-dimensions iris :type/numerical :pca 4) "pca-3"))

(let [a [0.023835092973449445
         0.07820950004291945
         0.24267074792863372
         4.228241706034866]
      s (reduce + a)]
  (map #(/ % s) a))

;; => data/iris.csv [150 5]:
;;    | Species |       pca-0 |       pca-1 |       pca-2 |       pca-3 |
;;    |---------|-------------|-------------|-------------|-------------|
;;    |  setosa | -2.68412563 | -0.31939725 |  0.02791483 |  0.00226244 |
;;    |  setosa | -2.71414169 |  0.17700123 |  0.21046427 |  0.09902655 |
;;    |  setosa | -2.88899057 |  0.14494943 | -0.01790026 |  0.01996839 |
;;    |  setosa | -2.74534286 |  0.31829898 | -0.03155937 | -0.07557582 |
;;    |  setosa | -2.72871654 | -0.32675451 | -0.09007924 | -0.06125859 |
;;    |  setosa | -2.28085963 | -0.74133045 | -0.16867766 | -0.02420086 |
;;    |  setosa | -2.82053775 |  0.08946138 | -0.25789216 | -0.04814311 |
;;    |  setosa | -2.62614497 | -0.16338496 |  0.02187932 | -0.04529787 |
;;    |  setosa | -2.88638273 |  0.57831175 | -0.02075957 | -0.02674474 |
;;    |  setosa | -2.67275580 |  0.11377425 |  0.19763272 | -0.05629540 |
;;    |  setosa | -2.50694709 | -0.64506890 |  0.07531801 | -0.01501992 |
;;    |  setosa | -2.61275523 | -0.01472994 | -0.10215026 | -0.15637921 |
;;    |  setosa | -2.78610927 |  0.23511200 |  0.20684443 | -0.00788791 |
;;    |  setosa | -3.22380374 |  0.51139459 | -0.06129967 | -0.02167981 |
;;    |  setosa | -2.64475039 | -1.17876464 |  0.15162752 |  0.15920972 |
;;    |  setosa | -2.38603903 | -1.33806233 | -0.27777690 |  0.00655155 |
;;    |  setosa | -2.62352788 | -0.81067951 | -0.13818323 |  0.16773474 |
;;    |  setosa | -2.64829671 | -0.31184914 | -0.02666832 |  0.07762818 |
;;    |  setosa | -2.19982032 | -0.87283904 |  0.12030552 |  0.02705187 |
;;    |  setosa | -2.58798640 | -0.51356031 | -0.21366517 | -0.06627265 |
;;    |  setosa | -2.31025622 | -0.39134594 |  0.23944404 | -0.01507079 |
;;    |  setosa | -2.54370523 | -0.43299606 | -0.20845723 |  0.04106540 |
;;    |  setosa | -3.21593942 | -0.13346807 | -0.29239675 |  0.00448213 |
;;    |  setosa | -2.30273318 | -0.09870885 | -0.03912326 |  0.14835259 |
;;    |  setosa | -2.35575405 |  0.03728186 | -0.12502108 | -0.30033090 |
