(ns tech.smile.classification
  (:require [clojure.reflect :refer [reflect]]
            [tech.smile.utils :as utils]
            [tech.smile.kernels]
            [tech.smile.distance]
            [tech.datatype :as dtype]
            [tech.ml.protocols :as ml-proto]
            [tech.ml.model :as model]
            [tech.ml.details :as ml-details]
            [tech.ml.registry :as registry]
            [camel-snake-kebab.core :refer [->kebab-case]])
  (:import [smile.classification DecisionTree$SplitRule
            NaiveBayes$Model SVM$Multiclass OnlineClassifier SoftClassifier
            Classifier SVM]))



(def package-name "smile.classification")


(def java-classes
  #{
    "AdaBoost"
    "DecisionTree"
    "FLD"
    "GradientTreeBoost"
    "KNN"
    "LDA"
    "LogisticRegression"
    "Maxent"
    "NaiveBayes"
    "NeuralNetwork"
    "PlattScaling"
    "QDA"
    "RandomForest"
    "RBFNetwork"
    "RDA"
    "SVM"
    })


(defn reflect-cls
  [cls-name]
  (reflect (utils/class-name->class package-name
                                    cls-name)))


(def base-attribute-map
  {:probabilities "SoftClassifier"
   :online "OnlineClassifier"})


(defn java-classes->metadata
  []
  (->> java-classes
       (map (fn [cls-name]
              (let [item-name (-> cls-name
                                  ->kebab-case
                                  keyword)
                    reflect-data (reflect-cls cls-name)
                    bases (map str (:bases reflect-data))
                    ]
                [item-name
                 {:class-name cls-name
                  :name item-name
                  :datatypes (utils/method-datatype "predict" reflect-data)
                  :attributes (->> base-attribute-map
                                   (map (fn [[keywd base-cls]]
                                          (when (->> bases
                                                     (filter #(.contains ^String % base-cls))
                                                     seq)
                                            keywd)))
                                   (remove nil?)
                                   set)}])))
       (into {})))


(defn- options->num-classes
  [options]
  (if-let [retval
           (get-in options [:dataset-info :num-classes])]
    retval
    (throw (ex-info "Failed to find num classes"
                    {:dataset-info (:dataset-info options)}))))



(def classifier-metadata
  {:ada-boost {:attributes #{:probabilities :attributes}
               :class-name "AdaBoost"
               :datatypes #{:float64-array}
               :name :ada-boost
               :options [{:name :ntrees
                          :type :int32
                          :default 500}
                         {:name :max-nodes
                          :type :int32
                          :default 6}]}
   :decision-tree {:attributes #{:probabilities :attributes}
                   :class-name "DecisionTree"
                   :datatypes #{:float64-array}
                   :name :decision-tree
                   :options [{:name :max-nodes
                              :type :int32
                              :default 100}
                             {:name :node-size
                              :type :int32
                              :default 1}
                             {:name :split-rule
                              :type :enumeration
                              :class-type DecisionTree$SplitRule
                              :lookup-table {:gini DecisionTree$SplitRule/GINI
                                             :entropy DecisionTree$SplitRule/ENTROPY
                                             :classification-error DecisionTree$SplitRule/CLASSIFICATION_ERROR}
                              :default :gini}]}
   :fld {:attributes #{:projection}
         :class-name "FLD"
         :datatypes #{:float64-array}
         :name :fld
         :options [{:name :L
                    :type :int32
                    :default -1}
                   {:name :tolerance
                    :type :float64
                    :default 1e-4}]}
   :gradient-tree-boost {:attributes #{:probabilities}
                         :class-name "GradientTreeBoost"
                         :datatypes #{:float64-array}
                         :name :gradient-tree-boost
                         :options [{:name :ntrees
                                    :type :int32
                                    :default 500}
                                   {:name :max-nodes
                                    :type :int32
                                    :default 6
                                    :range :>0}
                                   {:name :shrinkage
                                    :type :float64
                                    :default 0.005
                                    :range :>0}
                                   {:name :sampling-fraction
                                    :type :float64
                                    :default 0.7
                                    :range [0.0 1.0]}]}
   :knn {:attributes #{:probabilities :object-data}
         :class-name "KNN"
         :datatypes #{:float64-array}
         :name :knn
         :options [{:name :distance
                    :type :distance
                    :default {:distance-type :euclidean}}
                   {:name :num-clusters
                    :type :int32
                    :default 5}]}
   :lda {:attributes #{:probabilities}
         :class-name "LDA"
         :datatypes #{:float64-array}
         :name :lda
         :options [{:name :prioiri
                    :type :float64-array
                    :default nil}
                   {:name :tolerance
                    :default 1e-4
                    :type :float64}]}
   :logistic-regression {:attributes #{:online :probabilities}
                         :class-name "LogisticRegression"
                         :datatypes #{:float64-array}
                         :name :logistic-regression
                         :options [{:name :lambda
                                    :type :float64
                                    :default 0.0}
                                   {:name :tolerance
                                    :type :float64
                                    :default 1e-5}
                                   {:name :max-iter
                                    :type :int32
                                    :default 500}]}
   ;;Not supported at this time because constructor patter is unique
   :maxent {:attributes #{:probabilities}
            :class-name "Maxent"
            :datatypes #{:float64-array :int32-array}
            :name :maxent}

   :naive-bayes {:attributes #{:online :probabilities}
                 :class-name "NaiveBayes"
                 :datatypes #{:float64-array :sparse}
                 :name :naive-bayes
                 :options [{:name :model
                            :type :enumeration
                            :class-type NaiveBayes$Model
                            :lookup-table {:general NaiveBayes$Model/GENERAL
                                           :multinomial NaiveBayes$Model/MULTINOMIAL
                                           :bernoulli NaiveBayes$Model/BERNOULLI
                                           :polyaurn NaiveBayes$Model/POLYAURN}
                            :default :general}
                           {:name :num-classes
                            :type :int32
                            :default options->num-classes}
                           {:name :input-dimensionality
                            :type :int32
                            :default (fn [options]
                                       (if-let [retval
                                                (get-in options [:datatset-info :values-ecount])]
                                         retval
                                         (throw (ex-info "Failed to find values ecount"
                                                         {:dataset-info (:dataset-info options)}))))}
                           {:name :sigma
                            :type :float64
                            :default 1.0}]}
   :neural-network {:attributes #{:online :probabilities}
                    :class-name "NeuralNetwork"
                    :datatypes #{:float64-array}
                    :name :neural-network}
   :platt-scaling {:attributes #{}
                   :class-name "PlattScaling"
                   :datatypes #{:double}
                   :name :platt-scaling}
   :qda {:attributes #{:probabilities}
         :class-name "QDA"
         :datatypes #{:float64-array}
         :name :qda}
   :random-forest {:attributes #{:probabilities}
                   :class-name "RandomForest"
                   :datatypes #{:float64-array}
                   :name :random-forest}
   :rbf-network {:attributes #{}
                 :class-name "RBFNetwork"
                 :datatypes #{}
                 :name :rbf-network}
   :rda {:attributes #{:probabilities}
         :class-name "RDA"
         :datatypes #{:float64-array}
         :name :rda}
   :svm {:attributes #{:online :probabilities}
         :class-name "SVM"
         :datatypes #{:float64-array}
         :name :svm
         :options [{:name :kernel
                    :type :mercer-kernel
                    :default {:kernel-type :gaussian}}
                   {:name :soft-margin-penalty
                    :type :float64
                    :altname "C"
                    :default 1.0}
                   {:name :num-classes
                    :type :int32
                    :default options->num-classes}
                   {:name :multiclass-strategy
                    :type :enumeration
                    :class-type SVM$Multiclass
                    :lookup-table {:one-vs-one SVM$Multiclass/ONE_VS_ONE
                                   :on-vs-all SVM$Multiclass/ONE_VS_ALL}
                    :default :one-vs-one}]}})


(defn model-type->classification-model
  [model-type]
  (if-let [retval (get classifier-metadata model-type)]
    retval
    (throw (ex-info "Unrecognized model type"
                    {:model-type model-type
                     :available-types (keys classifier-metadata)}))))



(defn- train-online
  "Online systems can train iteratively.  They can handle therefore much larger
  datasets."
  [options entry-metadata coalesced-dataset]
  (let [;;Do basic NN shit to make it work.  Users don't need to specify the
        ;;parts that are dataset specific (input-size) *or* that never change
        ;;(output-size).
        ^OnlineClassifier untrained
        (-> (utils/prepend-data-constructor-arguments entry-metadata options [])
            (utils/construct package-name options))]
    (->> coalesced-dataset
         (map #(.learn untrained ^doubles (:values %)
                       (int (dtype/get-value (:label %) 0))))
         dorun)
    (when (= (:name entry-metadata) :svm)
      (println "training platt scaling")
      (let [^SVM sort-of-trained untrained]
        (.trainPlattScaling sort-of-trained
                            (->> (map :values coalesced-dataset)
                                 object-array)
                            ^ints
                            (->> (map (comp int #(dtype/get-value % 0) :label) coalesced-dataset)
                                 int-array))))
    ;;its trained now
    untrained))


(defn- train-block
  "Train by downloading all the data into a fixed matrix."
  [options entry-metadata coalesced-dataset]
  (let [value-seq (->> coalesced-dataset
                       (map :values))
        [x-data x-datatype] (if (contains? (:attributes entry-metadata)
                                           :object-data)
                              [(object-array value-seq) :object-array]
                              [(into-array value-seq) :float64-array-array])

        n-entries (first (dtype/shape x-data))
        ^ints y-data (first (dtype/copy-raw->item!
                             (map :label coalesced-dataset)
                             (dtype/make-array-of-type :int32 n-entries)
                             0))
        data-constructor-arguments [{:type x-datatype
                                     :default x-data
                                     :name :training-data}
                                    {:type :int32-array
                                     :default y-data
                                     :name :labels}]]
    (-> (utils/prepend-data-constructor-arguments entry-metadata options
                                                  data-constructor-arguments)
        (utils/construct package-name options))))


(defrecord SmileClassification []
  ml-proto/PMLSystem
  (system-name [_] :smile/classification)
  (coalesce-options [system _]
    {:container-type dtype/make-array-of-type
     :datatype :float64})
  (train [system options coalesced-dataset]
    (let [entry-metadata (model-type->classification-model (:model-type options))]
      (-> (if (contains? (:attributes entry-metadata) :online)
            (train-online options entry-metadata coalesced-dataset)
            (train-block options entry-metadata coalesced-dataset))
          model/model->byte-array)))
  (predict [system options trained-model-bytes coalesced-dataset]
    (let [trained-model (model/byte-array->model trained-model-bytes)
          label-map (ml-details/options->label-map options (:label-keys options))
          ordered-labels (->> label-map
                              (sort-by second)
                              (mapv first))]
      (if (instance? SoftClassifier trained-model)
        (let [probabilities (double-array (count ordered-labels))
              ^SoftClassifier trained-model trained-model]
          (->> coalesced-dataset
               (map (fn [{:keys [values]}]
                      (.predict trained-model ^doubles values probabilities)
                      (zipmap ordered-labels probabilities)))))
        (let [^Classifier trained-model trained-model]
          (->> coalesced-dataset
               (map (fn [{:keys [values]}]
                      (let [prediction (.predict trained-model ^doubles values)]
                        {(get ordered-labels prediction) 1.0})))))))))



(def system (constantly (->SmileClassification)))


(registry/register-system (system))
