(ns ncm
  (:require
    [clojure.string :as str]
    [clojure.java.io :refer [reader]]
    [clojure.core.async :refer [go]]
    [clojure.pprint :refer [pprint]]
    [clj-time.local  :as l]
    [prism.util :as util]
    [prism.sampling :refer [uniform->cum-uniform samples]]
    [prism.nn.rnn :as rnn]
    [prism.nn.encoder-decoder :as ed]))

(defn convert-rare-word-to-unk
  [wc word]
  (if (get wc word) word "<unk>"))

(defn word->feature
  [em word]
  (or (get em word) #{"<unk>"}))

(defn cnvs->utt-pairs
  [wc cnvs-line]
  (let [utts (->> (str/split cnvs-line #"<eos>")
                  (mapv (fn [utt]
                          (->> (str/split utt #" ")
                               (remove #(or (re-find #" |　" %) (= "" %)))
                               (map #(convert-rare-word-to-unk wc %))))))]
    (if (> (count utts) 1)
      (loop [utts utts,
             acc []]
        (if-let [next-utt (second utts)]
          (let [utt (first utts)]
            (recur (rest utts)
                   (conj acc
                         {:encoder-x (vec utt)
                          :decoder-x (->> (cons "<go>" next-utt) vec)
                          :decoder-y (conj (vec next-utt) "<eos>")})))
          acc))
      :skip)))

(defn add-negatives
  [utt-pair negative negatives]
  (let [{:keys [decoder-y]} utt-pair]
    (when (not= (* (count decoder-y) negative) (count negatives))
      (throw (Exception. "Invalid negative count")))
    (assoc utt-pair
      :decoder-y
      (->> decoder-y
           (map-indexed
             (fn [i pos]
               {:pos (set [pos]) :neg (->> negatives (drop (* i negative)) (take negative) set)}))
           vec))))

(defn train-ncm!
  [model train-path & [option]]
  (let [{:keys [interval-ms workers negative initial-learning-rate min-learning-rate
                skip-lines snapshot model-path]
         :or {interval-ms 60000 ;; 1 minutes
              workers 4
              negative 5
              initial-learning-rate 0.01
              min-learning-rate 0.01
              skip-lines 0
              snapshot 60 ;  1 hour when interval-ms is set 60000
              }} option
        all-lines-num (with-open [r (reader train-path)] (count (line-seq r)))
        {:keys [wc em input-type]} model
        tmp-loss (atom 0)
        cache-size (* interval-ms 10)
        interval-counter (atom 0)
        progress-counter (atom 0)
        snapshot-num (atom 1)
        done? (atom false)
        negative-dist (atom nil)]
    (with-open [r (reader train-path)]
      (print (str "skipping " skip-lines " lines ..."))
      (loop [skip skip-lines]
        (when (> skip 0)
          (.readLine r)
          (swap! progress-counter inc)
          (recur (dec skip))))
      (println "done")
      (dotimes [w workers]
        (go (loop []
              (if-let [line (.readLine r)]
                (let [progress (/ @progress-counter all-lines-num)
                      learning-rate initial-learning-rate;(max (- initial-learning-rate (* initial-learning-rate progress)) min-learning-rate)
                      utt-pairs (cnvs->utt-pairs wc line)]
                  (swap! interval-counter inc)
                  (swap! progress-counter inc)
                  (if (= :skip utt-pairs)
                    (recur)
                    (do (->> utt-pairs
                             (mapv (fn [utt-pair]
                                     (let [{:keys [encoder-x decoder-x decoder-y] :as training} utt-pair]
                                       (try
                                         ; when encoder-decoder-model
                                         (let [encoder-x (->> encoder-x (mapv #(word->feature em %)))
                                               decoder-x (vec (cons #{"<go>"} (->> (rest decoder-x) (mapv #(word->feature em %)))))
                                               forward (ed/forward model encoder-x decoder-x decoder-y)
                                               {:keys [param-loss loss]} (ed/bptt model forward decoder-y)
                                               loss-no-skipped (->> loss (remove empty?))
                                               loss-seq (->> loss-no-skipped
                                                             (mapv #(->> % ; by 1 target and some negatives
                                                                         (map (fn [[_ v]] (Math/abs v)))
                                                                         (apply +))))]
                                           (swap! tmp-loss #(+ %1 (/ (reduce + loss-seq) (count loss-seq))));; loss per rnn-step
                                           (ed/update-model! model param-loss learning-rate))
                                         (catch Exception e
                                           (do
                                             ;; debug purpose
                                             (println "error has occured")
                                             (println "line\n" line)
                                             (println "words")
                                             (pprint training)
                                             (clojure.stacktrace/print-stack-trace e)
                                             (Thread/sleep 60000)))))))
                             doall)
                      (recur))))
                (reset! done? true)))))
      (loop [loop-counter 0]
        (when-not @done?
          (println (str (util/progress-format @progress-counter all-lines-num @interval-counter interval-ms "lines/s") ", loss: " (float (/ @tmp-loss (inc @interval-counter))))); loss per 1 word, and avoiding zero divide
          (reset! tmp-loss 0)
          (reset! interval-counter 0)
          (when (and model-path (not (zero? snapshot)) (not (zero? loop-counter)) (zero? (rem loop-counter snapshot)))
            (let [spath (str model-path "-SNAPSHOT-" @snapshot-num)]
              (println (str "saving " spath))
              (util/save-model model spath)
              (swap! snapshot-num inc)))
          (Thread/sleep interval-ms)
          (recur (inc loop-counter))))
      (println "finished learning")))
  model)


(defn encoder-decoder-ncm
  [wc em em-size encoder-hidden-size decoder-hidden-size rnn-type]
  (let [wc-set (-> (set (keys wc)))]
    (-> (ed/init-model {:input-items #{"<go>" "<unk>"}
                        :input-size em-size
                        :encoder-hidden-size encoder-hidden-size
                        :decoder-hidden-size decoder-hidden-size
                        :output-type :multi-class-classification
                        :output-items (conj wc-set "<eos>")
                        :rnn-type rnn-type})
        (assoc
          :wc wc
          :em em
          :ncm-type :encoder-decoder))))

(defn make-ncm
  [training-path embedding-path export-path em-size encoder-hidden-size decoder-hidden-size rnn-type option]
  (let [_(println "making word list...")
        wc (util/make-wc training-path option)
        _(println "done")
        em (util/load-model embedding-path)
        model (encoder-decoder-ncm wc em em-size encoder-hidden-size decoder-hidden-size rnn-type)]
    (train-ncm! model training-path (assoc option :model-path export-path))
    (print (str "Saving NCM model as " export-path " ... "))
    (util/save-model model export-path)
    (println "Done")
    model))


(defn build-reply
  [ncm-model input-utt]
  (let [{:keys [wc em]} ncm-model
        available-words (-> wc (dissoc "<unk>") keys set (conj "<eos>"))
        converted-utt (->> input-utt (mapv #(word->feature em %)))]
    (loop [coll (conj converted-utt #{"<go>"}),
           acc-word [],
           acc-prob, [],
           previous-word nil,
           l 0]
      (if (or (= previous-word "<eos>") (> l 20))
        {:reply acc-word :probs acc-prob :utt input-utt :unk (->> input-utt (remove #(get em %)) vec)}
        (let [padding (vec (repeat (dec (count coll)) :skip))
              word-prob (:output (:activation (last (rnn/forward ncm-model coll (conj padding available-words)))))
              [next-word prob] (first (sort-by second > (dissoc word-prob "<unk>")))]
          (recur (conj coll (word->feature em next-word))
                 (conj acc-word next-word)
                 (conj acc-prob prob)
                 next-word
                 (inc l)))))))


(defn sort-reply-list-by-plausibility
  [ncm-model input-utt reply-list]
  (let [{:keys [wc em]} ncm-model
        context (conj (->> input-utt (mapv #(word->feature em %))) #{"<go>"})]
    (->> reply-list
         (mapv (fn [reply]
                 (loop [context context
                        rest-reply reply
                        repf (->> reply (mapv #(word->feature em %)))
                        acc []]
                   (if-let [r (first rest-reply)]
                     (let [padding (vec (repeat (dec (count context)) :skip))
                           word-prob (-> (rnn/forward ncm-model context (conj padding (set [r])))
                                         last
                                         :activation
                                         :output
                                         (get r))]
                       (recur (conj context (first repf))
                              (rest rest-reply)
                              (rest repf)
                              (conj acc word-prob)))
                     {:reply reply :probs acc}))))
         (mapv (fn [{:keys [reply probs]}]
                 (let [probs (->> probs (mapv #(Math/log %)))]
                   {:reply reply
                    :log-probs probs
                    :utt-prob (Math/exp (apply + probs))})))
         (sort-by :utt-prob >))))


(defn load-model
  [path]
  (util/load-model path))

