(ns scicloj.ml.try
  (:require [clojure.java.io :as io]
            [taoensso.nippy :as nippy]
            [scicloj.ml.core :as ml]
            [scicloj.ml.metamorph :as mm]
            [tablecloth.api :as tc]
            [tech.v3.dataset :as ds]))

(def numerical-cols ["sepal_length", "sepal_width", "petal_length", "petal_width"])

(def target "species")
;(def labels ["Iris-versicolor" "Iris-setosa" "Iris-virginica"])

(def dataset (-> "https://raw.githubusercontent.com/techascent/tech.ml/master/test/data/iris.csv"
                 (ds/->dataset)))


(def splits (-> (tc/split->seq dataset
                               :holdout
                               {:seed        123
                                :ratio       [0.7 0.3]
                                :split-names [:train :test]})
                (first)))

(def train-data (:train splits))
(def test-data (:test splits))

(def model-type :smile.classification/logistic-regression)

(defn make-pipeline []
  (ml/pipeline
   {:metamorph/id :min-max-scale}
   (mm/min-max-scale numerical-cols {})
   (mm/categorical->number [target])
   (mm/set-inference-target target)
   {:metamorph/id :model}
   (mm/model {:model-type model-type})))

(def pipeline (make-pipeline))

(defn train []
  (-> (pipeline {:metamorph/data train-data
                 :metamorph/mode :fit})))
      ;(dissoc :metamorph/data)
      ;(update-in [:model :model-data] dissoc :smile-df-used)))

(defn predict [dataset ctx]
  (-> ctx
      (assoc :metamorph/data (tc/add-column dataset target [nil] :cycle)
             :metamorph/mode :transform)
      (pipeline)
      :metamorph/data
      (ds/drop-columns [target])))

(def trained-ctx (train))
(def bytes-of-model (-> trained-ctx :model :model-data :model-as-bytes))
      
; store persistent 'bytes-of-model'      

;; RUN THIS FIRST



;; THEN RESTART REPL BEFORE RUNNING THE NEXT FORM

(def predict-pipeline (make-pipeline))


(def prediction
  (predict-pipeline 
   {
    :metamorph/mode :transform
    :metamorph/data (tc/add-column train-data target [nil] :cycle)
    :min-max-scale
    {:fit-minmax-xform
     {:min -0.5,
      :max 0.5,
      :column-data
      {"sepal_length" {:min 4.3, :max 7.9},
       "sepal_width" {:min 2.2, :max 4.4},
       "petal_length" {:min 1.0, :max 6.9},
       "petal_width" {:min 0.1, :max 2.5}}}},
    :model {:model-data {:model-as-bytes bytes-of-model}
            :options {:model-type :smile.classification/logistic-regression},
            
            :feature-columns ["sepal_length" "sepal_width" "petal_length" "petal_width"],
            :target-columns ["species"],
            :target-categorical-maps
            {"species"
             {:lookup-table {"versicolor" 0, "setosa" 1, "virginica" 2}, 
              :src-column "species", :result-datatype :float64}}}
    
    }))

(-> prediction :metamorph/data (get "species") frequencies) 
