(ns tech.smile.kernels
  (:require [clojure.reflect :refer [reflect]]
            [camel-snake-kebab.core :refer [->kebab-case]]
            [tech.datatype :as dtype]
            [tech.smile.utils :as utils])
  (:import [smile.math.kernel MercerKernel]
           [java.lang.reflect Constructor]))


(def package-name "smile.math.kernel")

(def kernel-class-names
  #{"BinarySparseGaussianKernel"
    "BinarySparseHyperbolicTangentKernel"
    "BinarySparseLaplacianKernel"
    "BinarySparseLinearKernel"
    "BinarySparsePolynomialKernel"
    "BinarySparseThinPlateSplineKernel"
    "GaussianKernel"
    "HellingerKernel"
    "HyperbolicTangentKernel"
    "LaplacianKernel"
    "LinearKernel"
    "PearsonKernel"
    "PolynomialKernel"
    "SparseGaussianKernel"
    "SparseHyperbolicTangentKernel"
    "SparseLaplacianKernel"
    "SparseLinearKernel"
    "SparsePolynomialKernel"
    "SparseThinPlateSplineKernel"
    "ThinPlateSplineKernel"
    })


(def kernel-metadata
  {:gaussian {:options [{:name :sigma
                         :type :float64
                         :range :>0
                         :default 2.0}]}

   :hyperbolic-tangent {:options [{:name :scale
                                   :type :float64
                                   :range :>0
                                   :default 1.0}
                                  {:name :offset
                                   :type :float64
                                   :default 0.0}]}

   :laplacian {:options [{:name :sigma
                          :type :float64
                          :range :>0
                          :default 2.0}]}

   :linear {}

   :polynomial {:options [{:name :degree
                           :type :int32
                           :range :>1
                           :default 2}
                          {:name :scale
                           :type :float64
                           :range :>1
                           :default 1.0}
                          {:name :offset
                           :type :float64
                           :default 0.0}]}

   :thin-plate-spline {:options [{:name :sigma
                                  :type :float64
                                  :range :>0
                                  :default 2.0}]}

   :pearson {:options [{:name :omega
                        :type :float64
                        :default 1.0}
                       {:name :sigma
                        :type :float64
                        :default 1.0}]}
   :hellinger {}})


(defn- reflect-kernel
  [kernel-name]
  (reflect (Class/forName (str package-name "." kernel-name))))


(def sparsity-type-map
  {:double<> :dense
   :int<> :sparse-binary
   :smile.math.SparseArray :sparse})


(defn- sparsity-type
  [reflect-data]
  (->> reflect-data
       :members
       (filter #(= "k" (name (:name %))))
       (mapcat :parameter-types)
       (map (comp keyword name))
       (remove #(= :java.lang.Object %))
       (map sparsity-type-map)
       first))


(def kernels
  (->> kernel-class-names
       (mapv (fn [nm]
               (let [reflect-info (reflect-kernel nm)
                     kernel-name (keyword (->kebab-case nm))]
                {:name kernel-name
                 :class-name nm
                 :type (->> (keys kernel-metadata)
                            (filter #(.contains (name kernel-name) (name %)))
                            first)
                 :sparsity-type (sparsity-type reflect-info)})))
       (group-by :type)))


(defmulti find-kernel
  (fn [kernel-type sparsity-type]
    [kernel-type sparsity-type]))


(defmethod find-kernel :default
  [kernel-type sparsity-type]
  (if-let [retval (->> (get kernels kernel-type)
                       (filter #(= sparsity-type (:sparsity-type %)))
                       first)]
    retval
    (throw (ex-info "Failed to find kernel"
                    {:kernel-type kernel-type
                     :sparsity-type sparsity-type
                     :available (->> kernels
                                     (map (fn [[k v]]
                                            [k (->> v
                                                    (map :sparsity-type)
                                                    set)])))}))))

(defn dense-kernels
  []
  (->> kernels
       (filterv #(contains? (:sparsity-type %) :dense))))

(defn sparse-binary-kernels
  []
  (->> kernels
       (filterv #(contains? (:sparsity-type %) :sparse-binary))))

(defn sparse-kernels
  []
  (->> kernels
       (filterv #(contains? (:sparsity-type %) :sparse))))


(defn construct
  ^MercerKernel [kernel-type sparsity-type options]
  (utils/construct (merge (get kernel-metadata kernel-type)
                          (find-kernel kernel-type sparsity-type))
                   package-name
                   options))


(defmethod utils/option->class-type :mercer-kernel
  [& args]
  MercerKernel)


(defmethod utils/option-value->value :mercer-kernel
  [option option-value]
  (construct (:kernel-type option-value)
             (or (:sparsity-type option-value) :dense)
             option-value))
