(ns scicloj.ml.smile.try)

(ns assignment.decision-tree
  (:require
    [assignment.eda :refer [concrete-data]]
    [calc-metric.patch]
    [fastmath.stats :as stats]
    [scicloj.kindly.v4.kind :as kind]
    [scicloj.ml.core :as ml]
    [scicloj.ml.dataset :as ds]
    [scicloj.ml.metamorph :as mm]))

;; # Smile with Clojure
;; ## Define regressors and response
(def response :concrete-compressive-strength)
(def regressors
  (ds/column-names concrete-data (complement #{response})))

(ds/info concrete-data)

;; ## Setup Pipelines
(def pipeline-fn
  (ml/pipeline
    (mm/set-inference-target response)))

;; ### Generic pipeline function
(defn create-model-pipeline
  [model-type params]
  (ml/pipeline
    pipeline-fn
    {:metamorph/id :model}
    (mm/model (merge {:model-type model-type} params))))

;; #### Gradient tree context
(defn gradient-tree-pipe-fn
  [params]
  (create-model-pipeline :smile.regression/gradient-tree-boost params))

;; #### Decision tree context
(defn random-forest-pipe-fn
  [params]
  (create-model-pipeline :smile.regression/random-forest params))

;; ## Pipeline Functions
;; ### Evaluate pipeline
(defn evaluate-pipe [pipe data]
  (ml/evaluate-pipelines
    pipe
    data
    stats/omega-sq
    :accuracy
    {:other-metrices                   [{:name :mae :metric-fn ml/mae}
                                        {:name :rmse :metric-fn ml/rmse}]
     :return-best-pipeline-only        false
     :return-best-crossvalidation-only true}))

;; ### Generate hyperparameters for models
;No hyperparameters with :smile.regression/gradient-tree-boost or :smile.regression/random-forest
(ml/hyperparameters :smile.classification/decision-tree)
(defn generate-hyperparams [model-type]
  (case model-type
    :gradient-tree (take 60 (ml/sobol-gridsearch {:trees       1 ;(ml/linear 1 1000 10 :int32)
                                                  :loss        (ml/categorical [:least-absolute-deviation :least-squares])
                                                  :max-depth   (ml/linear 10 50 20 :int32)
                                                  :max-nodes   (ml/linear 10 1000 30 :int32)
                                                  :node-size   (ml/linear 0 20 20 :int32)
                                                  :shrinkage   0.01 ;(ml/linear 0 1)
                                                  :sample-rate (ml/linear 0.1 1 10)}))
    :random-forest (take 30 (ml/sobol-gridsearch {:trees       (ml/linear 1 1000 10 :int32)
                                                  :max-depth   (ml/linear 10 50 20 :int32)
                                                  :max-nodes   (ml/linear 10 1000 30 :int32)
                                                  :node-size   (ml/linear 0 20 20 :int32)
                                                  :sample-rate (ml/linear 0.1 1 10)}))))

;; ### Evaluate a single model
(defn evaluate-model [dataset split-fn model-type model-fn]
  (let [data-split (split-fn dataset)
        pipelines (cond
                    (= model-type :best-subset)
                    (model-fn dataset response regressors)
                    :else (map model-fn (generate-hyperparams model-type)))]
    (evaluate-pipe pipelines data-split)))

;; ### Split functions
(defn train-test [dataset]
  (ds/split->seq dataset :bootstrap {:seed 123 :repeats 20}))

(defn train-val [dataset]
  (let [ds-split (train-test dataset)]
    (ds/split->seq (:train (first ds-split)) :kfold {:seed 123 :k 5})))

;; ### Define model types and corresponding functions as a vector of vectors
(def model-type-fns
  {:gradient-tree gradient-tree-pipe-fn
   :random-forest random-forest-pipe-fn})


;; ### Evaluate models for a dataset
(defn evaluate-models [dataset split-fn]
  (mapv (fn [[model-type model-fn]]
          (evaluate-model dataset split-fn model-type model-fn))
        model-type-fns))

;; ### Evaluate separately
(def tree-models (evaluate-models concrete-data train-val))

;; ## Extract Useable Models
(defn best-models [eval]
  (->> eval
       flatten
       (map
         #(hash-map :summary (ml/thaw-model (get-in % [:fit-ctx :model]))
                    :fit-ctx (:fit-ctx %)
                    :timing-fit (:timing-fit %)
                    :metric ((comp :metric :test-transform) %)
                    :other-metrices ((comp :other-metrices :test-transform) %)
                    :other-metric-1 ((comp :metric first) ((comp :other-metrices :test-transform) %))
                    :other-metric-2 ((comp :metric second) ((comp :other-metrices :test-transform) %))
                    :params ((comp :options :model :fit-ctx) %)
                    :pipe-fn (:pipe-fn %)))
       (sort-by :metric)))

(def best-val-gradient-tree
  (-> (first tree-models)
      best-models
      reverse))

(-> best-val-gradient-tree first :summary)
(-> best-val-gradient-tree first :metric)
(-> best-val-gradient-tree first :other-metrices)
(-> best-val-gradient-tree first :params)
