(ns tech.smile.regression
  (:require [tech.ml.model :as model]
            [tech.ml.dataset :as dataset]
            [tech.ml.protocols :as ml-proto]
            [tech.ml.registry :as registry]
            ;;Kernels have to be loaded
            [tech.smile.kernels]
            [tech.datatype :as dtype]
            [tech.smile.utils :as utils]
            [clojure.reflect :refer [reflect]])
  (:import [smile.regression
            GradientTreeBoost$Loss
            NeuralNetwork$ActivationFunction
            Regression
            OnlineRegression]))


(def package-name "smile.regression")


(def regression-class-names
  #{
    "ElasticNet"
    "GaussianProcessRegression"
    "GradientTreeBoost"
    "LASSO"
    "NeuralNetwork"
    "OLS"
    "RandomForest"
    "RBFNetwork"
    "RegressionTree"
    "RidgeRegression"
    "RLS"
    "SVR"
    })


(def regression-metadata
  {:elastic-net {:options [{:name :lambda1
                            :type :float64
                            :default 0.1
                            :range :>0}
                           {:name :lambda2
                            :type :float64
                            :default 0.1
                            :range :>0}
                           {:name :tolerance
                            :type :float64
                            :default 1e-4
                            :range :>0}
                           {:name :max-iterations
                            :type :int32
                            :default 1000
                            :range :>0}]
                 :sparsity-types #{:dense :sparse :sparse-binary}
                 :class-name "ElasticNet"}

   :gaussian-process {:options [{:name :kernel
                                  :type :mercer-kernel
                                 :default {:kernel-type :gaussian}}
                                {:name :lambda
                                 :type :float64
                                 :range :>0
                                 :default 2}]
                      :sparsity-types #{:dense :sparse :sparse-binary}
                      :class-name "GaussianProcessRegression"
                      :flags #{:object-data}}

   :gaussian-process-regressors
   {:options [{:name :inducing-samples
               :type :input-array}
              {:name :kernel
               :type :mercer-kernel
               :default {:kernel-type :gaussian}}
              {:name :lambda
               :type :float64
               :range :>0
               :default 2}]
    :sparsity-types #{:dense :sparse :sparse-binary}
    :class-name "GaussianProcessRegression"
    :flags #{:object-data}}

   :gaussian-process-nystrom
   {:options [{:name :inducing-samples
               :type :input-array}
              {:name :kernel
               :type :mercer-kernel
               :default {:kernel-type :gaussian}}
              {:name :lambda
               :type :float64
               :range :>0
               :default 2}
              {:name :nystrom-marker
               :type :boolean
               :default true}]
    :sparsity-types #{:dense :sparse :sparse-binary}
    :class-name "GaussianProcessRegression"
    :flags #{:object-data}}

   :gradient-tree-boost
   {:options [{:name :loss
               :type :enumeration
               :class-type GradientTreeBoost$Loss
               :lookup-table {:least-squares GradientTreeBoost$Loss/LeastSquares
                              :least-absolute-deviation GradientTreeBoost$Loss/LeastAbsoluteDeviation
                              :huber GradientTreeBoost$Loss/Huber}
               :default :least-squares}
              {:name :n-trees
               :type :int32
               :default 500
               :range :>0}
              {:name :max-nodes
               :type :int32
               :default 6
               :range :>0}
              {:name :shrinkage
               :type :float64
               :default 0.005
               :range :>0}
              {:name :sampling-fraction
               :type :float64
               :default 0.7
               :range [0.0 1.0]}]
    :sparsity-types #{:dense}
    :class-name "GradientTreeBoost"}

   :lasso
   {:options [{:name :lambda
               :type :float64
               :default 1.0
               :range :>0}
              {:name :tolerance
               :type :float64
               :default 1e-4
               :range :>0}
              {:name :max-iterations
               :type :int32
               :default 1000
               :range :>0}]
    :class-name "LASSO"
    :sparsity-types #{:dense}}

   :neural-network
   {:options [{:name :activation-function
               :type :enumeration
               :class-type NeuralNetwork$ActivationFunction
               :lookup-table {:logistic-sigmoid NeuralNetwork$ActivationFunction/LOGISTIC_SIGMOID
                              :tanh NeuralNetwork$ActivationFunction/TANH}
               :default :logistic-sigmoid}
              {:name :momentum
               :type :float64
               :default 1e-4
               :range :>0}
              {:name :weight-decay
               :type :float64
               :default 0.9}
              {:name :layer-sizes
               :type :int32-array
               :default (int-array [100])}
              {:name :learning-rate
               :default 0.1
               :type :float64
               :setter "setLearningRate"}]
    :class-name "NeuralNetwork"
    :sparsity-types #{:dense}
    :flags #{:online}}

   :ordinary-least-squares
   {:options [{:name :svd?
               :type :boolean
               :default false}]
    :class-name "OLS"
    :sparsity-types #{:dense}}

   :recursive-least-squares
   {:options [{:name :forgetting-factor
               :default 1.0
               :type :float64}]
    :sparse-types #{:dense}
    :class-name "RLS"}

   :support-vector
   {:options [{:name :loss-function-error-threshold
               :default 0.1
               :type :float64
               :altname "eps"}
              {:name :soft-margin-penalty
               :default 1.0
               :type :float64
               :altname "C"}
              {:name :tolerance
               :default 1e-3
               :type :float64
               :altname "tol"}]
    :class-name "SVR"
    :sparsity-types #{:dense :sparse :sparse-binary}
    :flags #{:object-data}}})


(def marker-interfaces
  {:regression "Regression"
   :online-regression "OnlineRegression"})


(defmulti model-type->regression-model
  (fn [model-type]
    model-type))


(defmethod model-type->regression-model :default
  [model-type]
  (if-let [retval (get regression-metadata model-type)]
    retval
    (throw (ex-info "Failed to get regression type"
                    {:model-type model-type}))))


(defmethod model-type->regression-model :regression
  [model-type]
  (get regression-metadata :lasso))


(defn reflect-regression
  [cls-name]
  (reflect (Class/forName (str package-name "." cls-name))))



(defrecord SmileRegression []
  ml-proto/PMLSystem
  (system-name [_] :smile/regression)
  (coalesce-options [system]
    {:container-type dtype/make-array-of-type
     :datatype :float64})
  (train [system options coalesced-dataset]
    (let [entry-metadata (model-type->regression-model (:model-type options))
          trained-model
          (if (contains? (:flags entry-metadata) :online)
            (let [options-with-setters (->> (:options entry-metadata)
                                            (filter :setter))
                  ;;Do basic NN shit to make it work.  Users don't need to specify the parts that are dataset
                  ;;specific (input-size) *or* that never change (output-size).
                  options (if (= (:class-name entry-metadata) "NeuralNetwork")
                            (let [input-size (dtype/ecount (:values (first coalesced-dataset)))
                                  option-val (utils/get-option-value entry-metadata :layer-sizes options)
                                  real-val (->> (concat [input-size]
                                                        (vec option-val)
                                                        [1])
                                                (int-array))]
                              (println (vec real-val))
                              (assoc options :layer-sizes real-val))
                            options)
                  ^OnlineRegression untrained (utils/construct entry-metadata
                                                               package-name
                                                               options)]
              (->> coalesced-dataset
                   (map #(.learn untrained ^doubles (:values %) (double (dtype/get-value (:label %) 0))))
                   dorun)
              untrained)
            (let [value-seq (->> coalesced-dataset
                                 (map :values))
                  [x-data x-datatype] (if (contains? (:flags entry-metadata) :object-data)
                                        [(object-array value-seq) :object-array]
                                        [(into-array value-seq) :float64-array-array])

                  n-entries (dtype/ecount x-data)
                  ^doubles y-data (first (dtype/copy-raw->item!
                                          (map :label coalesced-dataset)
                                          (dtype/make-array-of-type :float64 n-entries)
                                          0))
                  entry-options (concat [{:type x-datatype
                                          :default x-data
                                          :name :training-data}
                                         {:type :float64-array
                                          :default y-data
                                          :name :labels}]
                                        (get entry-metadata :options))]
              (utils/construct (assoc entry-metadata :options entry-options)
                               package-name
                               options)))]
      (model/model->byte-array trained-model)))
  (predict [system trained-model-bytes coalesced-dataset]
    (let [^Regression trained-model (model/byte-array->model trained-model-bytes)]
      (->> coalesced-dataset
           (map #(double (.predict trained-model ^doubles (:values %))))
           (into-array)))))


(def system
  (memoize
   (fn []
     (->SmileRegression))))


(registry/register-system (system))
