(ns liblinear.core
  (:use [clj-utils.io :only (with-temp-file)])
  (:import (java.io File))
  (:import (de.bwaldvogel.liblinear Linear))
  (:import (de.bwaldvogel.liblinear Model))
  (:import (de.bwaldvogel.liblinear Parameter))
  (:import (de.bwaldvogel.liblinear Problem))
  (:import (de.bwaldvogel.liblinear SolverType))
  (:import (de.bwaldvogel.liblinear FeatureNode)))

(defn prepared-fv
  "素性ベクトルfvに対して、以下の2つの処理を行なう。
  (1) 素性番号は1からしか受け付けないので、素性番号をインクリメントする
  (2) 素性ベクトルはソート済みしか受け付けないので、素性ベクトルをソートする"
  [fv]
  (->> fv
       (sort-by first)
       (mapv (fn [[k v]] (vector (inc k) v)))))

(defn examples2libliner-format-str
  "以下のようなvectorをliblinear形式の文字列に変換する関数
  [[label1 [[k1 v1] [k2 v2]]]
   [label2 [[k1 v3] [k3 v4]]]]"
  [examples]
  (->> examples
       (mapv (fn [[label fv]]
               (str label " "
                    (->> fv
                         (sort-by first)
                         (mapv #(apply str (interpose ":" %)))
                         (interpose " ")
                         (apply str)))))
       (interpose "\n")
       (apply str)))

(defn make-problem [training-examples]
  (with-temp-file [tmp-file-obj]
    (->> training-examples
         (mapv (fn [[l fv]] (vector l (prepared-fv fv))))
         (examples2libliner-format-str)
         (spit tmp-file-obj))
    (Problem/readFromFile tmp-file-obj -1)))

(defn #^Model make-SVM [#^Parameter param training-examples]
  (Linear/disableDebugOutput)
  (Linear/train (make-problem training-examples) param))

(defn do-cross-validation
  "k-foldのクロスバリデーションを行ない、training-examplesに対する
  予測値のベクトルを返す関数。ハイパーパラメータの最適化時などに使うことを想定"
  [#^Parameter param training-examples k]
  (let [prob (make-problem training-examples)
        target (double-array (count training-examples))]
    (Linear/disableDebugOutput)
    (Linear/crossValidation prob param k target)
    (vec target)))

(defn predict
  "SVRで予測値自体(実数)を使いたい時に使用する関数"
  [#^Model model fv]
  (->> (prepared-fv fv)
       (mapv (fn [[k v]] (new FeatureNode k v)))
       (into-array FeatureNode)
       (Linear/predict model)))

(defn fast-predict
  "fvは自前で用意する代わりにオーバーヘッドを減らすpredict関数"
  [#^Model model fv]
  (Linear/predict model fv))

(defn predict-probability
  "SVRやロジステック回帰のように予測値自体(実数)を使いたい時に使用する関数"
  [#^Model model fv]
  (let [d (double-array (.getNrClass model))
        fv (->> (prepared-fv fv)
                (mapv (fn [[k v]] (new FeatureNode k v)))
                (into-array FeatureNode))]
    (Linear/predictProbability model fv d)
    (vec d)))

(defn fast-predict-probability
  "fvは自前で用意する代わりにオーバーヘッドを減らすpredict-probability関数"
  [#^Model model fv]
  (let [d (double-array (.getNrClass model))]
    (Linear/predictProbability model fv d)
    (vec d)))

(defn predict-class-probability [#^Model model fv c]
  (let [probs (predict-probability model fv)
        labels (vec (. model getLabels))
        pairs (mapv vector labels probs)]
    (->> pairs
         (filter (fn [[l p]] (= l c)))
         (first)
         (second))))

(defn fast-predict-class-probability [#^Model model fv c]
  (let [probs (fast-predict-probability model fv)
        labels (vec (. model getLabels))
        pairs (mapv vector labels probs)]
    (->> pairs
         (filter (fn [[l p]] (= l c)))
         (first)
         (second))))

(defn classify [#^Model model fv]
  (int (predict model fv)))

(defn fast-classify [#^Model model fv]
  (int (fast-predict model fv)))

(defn save-model [#^Model model filename]
  (.save model (new File filename)))

(defn #^Model load-model [filename]
  (Model/load (new File filename)))



(let [model (make-SVM
               (new Parameter
                    SolverType/L2R_LR;L2R_LR_DUAL; 

                           ; SolverType/L2R_L2LOSS_SVC_DUAL
                    0.1 0.1
                                        ;0.01 0.1
                    )
                      [[1 [[3 -1] [20 1]]]
                       [3 [[19 -1] [28 1]]]
                       [4 [[13 -1] [23 1]]]
                       [6 [[12 -1] [27 1]]]
                       [-2 [[1 -1] [2 1] [300 1] [400 1]]]
                       [-2 [[1 -1] [2 1] [300 1] [400 1]]]])

      d (double-array (.getNrClass model))
      fv [[1 -1] [2 1] [30 1] [40 1]]
      ]
    [
   (macroexpand-1 '(predict-probability model fv))
   ; (predict-class-probability model fv -200)
   ])




(comment
  (let [model (make-SVM
               (new Parameter
                    SolverType/L2R_LR;L2R_LR_DUAL; 

                           ; SolverType/L2R_L2LOSS_SVC_DUAL
                    0.1 0.1
                                        ;0.01 0.1
                    )
                      [[1 [[3 -1] [20 1]]]
                       [3 [[19 -1] [28 1]]]
                       [4 [[13 -1] [23 1]]]
                       [6 [[12 -1] [27 1]]]
                       [-2 [[1 -1] [2 1] [300 1] [400 1]]]
                       [-2 [[1 -1] [2 1] [300 1] [400 1]]]])

      d (double-array (.getNrClass model))

      d (double-array (.getNrClass model))

      fv [[1 -1] [2 1] [30 1] [40 1]]
                                        ; [[12 -1] [27 1]]
      _ (save-model model "/tmp/hoge.bin")
      _ (println (load-model "/tmp/hoge.bin"))
      ]
    [(Linear/predictProbability
      ; Linear/predictValues
      ;Linear/predict
      model
    (->> fv
         (prepared-fv)
         (mapv (fn [[k v]] (new FeatureNode k v)))
         (into-array FeatureNode))
    d
    )
;   (predict model fv)

 ;  model
   (vec d)
   (vec (. model getLabels))

   (predict-probability model fv)
   (predict-class-probability model fv -200)
   ])

)