(ns iris
  (:require [clojure.java.io :as io]
            [taoensso.nippy :as nippy]
            [scicloj.ml.core :as ml]
            [scicloj.ml.metamorph :as mm]
            [tablecloth.api :as tc]
            [tech.v3.dataset :as ds]))

(def numerical-cols ["sepal_length", "sepal_width", "petal_length", "petal_width"])

(def target "label")
(def labels ["Iris-versicolor" "Iris-setosa" "Iris-virginica"])

(def dataset (-> "iris.csv"
                 (ds/->dataset)))


(def splits (-> (tc/split->seq dataset
                               :holdout
                               {:seed        123
                                :ratio       [0.7 0.3]
                                :split-names [:train :test]})
                (first)))

(def train-data (:train splits))
(def test-data (:test splits))

(def model-type :smile.classification/logistic-regression)

(def pipeline (ml/pipeline
               (mm/min-max-scale numerical-cols {})
               (mm/categorical->number [target])
               (mm/set-inference-target target)
               {:metamorph/id :model}
               (mm/model {:model-type model-type})))

(defn train []
  (-> (pipeline {:metamorph/data train-data
                 :metamorph/mode :fit})
      (dissoc :metamorph/data)
      (update-in [:model :model-data] dissoc :smile-df-used)))

(defn predict [dataset ctx]
  (-> ctx
      (assoc :metamorph/data (tc/add-column dataset target [nil] :cycle)
             :metamorph/mode :transform)
      (pipeline)
      :metamorph/data
      (ds/drop-columns [target])))

(defn load-ctx [path]
  (-> path
      (io/file)
      (nippy/thaw-from-file)))


;; RUN THIS FIRST

(comment
  (def train-ctx (train))
  (keys train-ctx)
  ;(:metamorph/mode #uuid "1f84a66a-351e-4884-b28a-5f032bb3136b" :model)
  (nippy/freeze-to-file "iris.nippy" train-ctx))


(comment
  (def train-ctx (load-ctx "iris.nippy"))
  (keys train-ctx)
  ;(:metamorph/mode #uuid "1f84a66a-351e-4884-b28a-5f032bb3136b" :model)
  (def ctx-for-predict
       (merge train-ctx
              {:metamorph/data test-data
               :metamorph/mode :transform}))
  (keys ctx-for-predict)
   (ml/transform-pipe test-data (pipeline) train-ctx))
  
