(ns fastmath.optimization.bo
  (:require [fastmath.gp :as gp]
            [fastmath.random :as r]
            [fastmath.kernel :as k]
            [fastmath.core :as m]
            [fastmath.optimization :as opt]))

(set! *warn-on-reflection* true)
(set! *unchecked-math* :warn-on-boxed)
(m/use-primitive-operators)

(defmulti ^:private utility-function :method)

(defmethod utility-function :default [m]
  (utility-function (assoc m :method :ucb)))

(defmethod utility-function :ucb
  [{:keys [^double kappa gp]
    :or {kappa 2.576}}]
  (fn [& x]
    (let [[^double mean ^double stddev] (gp/predict gp x true)]
      (+ mean (* kappa stddev)))))

(defmethod utility-function :ei
  [{:keys [^double y-max gp ^double xi]
    :or {xi 0.001}}]
  (let [d (+ y-max xi)]
    (fn [& x]
      (let [[^double mean ^double stddev] (gp/predict gp x true)
            diff (- mean d)
            z (/ diff stddev)]
        (+ (* diff ^double (r/cdf r/default-normal z))
           (* stddev ^double (r/pdf r/default-normal z)))))))

(defmethod utility-function :poi
  [{:keys [^double y-max gp ^double xi]
    :or {xi 0.001}}]
  (let [d (+ y-max xi)]
    (fn [& x]
      (let [[^double mean ^double stddev] (gp/predict gp x true)]
        (r/cdf r/default-normal (/ (- mean d) stddev))))))

(defmethod utility-function :pi [m]
  (utility-function (assoc m :method :poi)))

(defmethod utility-function :ei-pi
  [{:keys [^double kappa] :as m
    :or {kappa 1.0}}]
  (let [ei (utility-function (assoc m :method :ei))
        pi (utility-function (assoc m :method :poi))]
    (fn [& x]
      (let [^double vei (apply ei x)
            ^double vpi (apply pi x)]
        (+ vpi (* kappa vei))))))

;; https://www.ijcai.org/Proceedings/2020/0316.pdf
(defmethod utility-function :rgp-ucb
  [{:keys [^double scale ^long t] :as m
    :or {scale 1.0 t 3}}]
  (let [shape (max 0.01 (/ (m/log (/ (inc (* t t)) m/SQRT2PI))
                           (m/log (inc scale))))
        gamma (r/distribution :gamma {:shape shape scale scale})
        beta (m/sqrt (r/drandom gamma))]
    (utility-function (assoc m :kappa beta :method :ucb))))

(utility-function {:method :rgp-ucb
                   :scale 1.0
                   :t 20})


(def gp (gp/gaussian-process [[1 1] [2 2] [3 3] [9 -1] [5 -2]] [-2 -1 1 -1 2] {:noise 0.001
                                                                               :normalize? true
                                                                               :kernel (k/kernel :mattern-52 1.48796)
                                                                               :kscale 0.87396}))
(gp/predict gp '(4.62100181160906 -3.5))

(opt/scan-and-maximize :bfgs (utility-function {:method :ucb

                                                :gp gp}) {:bounds [[0 10] [-5 5]]})
;; => [(4.862100181160906 -2.349996775981009) 0.557154163177696]
;; => [(4.6001465276639495 -3.531503947537969) 4.333240536082748]
;; => [(4.981048745997669 -2.0285429857231874) 0.4899829648551593]
;; => [(4.750949198693047 -2.818882138909999) 0.19133431735709724]
;; => [(4.567607972440291 -3.6893979394097896) 4.945040161917608]
;; => [(4.600145774519866 -3.5315036892286065) 4.333240536082746]

(time (opt/scan-and-maximize :bfgs (fn [^double k ^double s]
                                     ;; (println k)
                                     (let [gp (gp/gaussian-process [[1 1] [2 2] [3 3] [9 -1] [5 -2]] [-2 -1 1 -1 2]
                                                                   {:noise 0.0001
                                                                    :normalize? true
                                                                    :kernel (k/kernel :mattern-52 k)
                                                                    :kscale s})]
                                       (gp/L gp))) {:bounds [[0.001 20] [0.001 20]]
                                                    :initial [3.0]}))

(gp/L gp)


(gp/predict gp [2 22] true)
(gp/predict-all gp [[1 1] [2 2]] true)

