(ns faiss.index
  (:import (faiss FaissNative FaissException)
           (com.sun.jna.ptr PointerByReference)
           (com.sun.jna Memory)))

(defrecord Index [pointer metric d])

(defn- exec [f]
  (let [code (f)]
    (if (neg? code)
      (let [msg (or (FaissNative/faiss_get_last_error)
                    (str "Failure code received: " code))]
        (throw (FaissException. msg)))
      code)))

(defn- index->ptr [index]
  @(:pointer index))

(defn create! [^Long d ^String description metric-type]
  {:pre [(pos? d)
         (#{:l2 :cosine} metric-type)
         (string? description)]}
  (let [p-idx  (PointerByReference.)
        metric (case metric-type
                 :l2 1
                 :cosine 0)]
    (exec #(FaissNative/faiss_index_factory p-idx (int d) description metric))
    (->Index (atom (.getValue p-idx)) metric-type d)))

(defn dim [^Index index]
  (:d index))

(defn metric [^Index index]
  (:metric index))

(defn release! [^Index index]
  (swap! (:pointer index) #(some-> % (FaissNative/faiss_Index_free))))

(defn trained? [^Index index]
  (pos? (FaissNative/faiss_Index_is_trained (index->ptr index))))

(defn add! [^Index index vectors]
  {:pre [(sequential? vectors)
         (every? #(= (dim index) (count %)) vectors)]}
  (let [d (dim index)
        n (count vectors)
        x (Memory. (* 4 d n))]
    (doseq [[i v] (map-indexed vector vectors)]
      (.write x (* i d 4) (float-array (map float v)) 0 d))
    (exec #(FaissNative/faiss_Index_add (index->ptr index) (long n) x))
    (count vectors)))

(defn add-with-ids! [^Index index vectors]
  {:pre [(sequential? vectors)
         (every? #(= (dim index) (count (:data %))) vectors)
         (every? #(pos? (:id %)) vectors)]}
  (let [d   (dim index)
        n   (count vectors)
        x   (Memory. (* 4 d n))
        ids (Memory. (* 8 n))]
    (doseq [[i {:keys [id data]}] (map-indexed vector vectors)]
      (.write x (* i d 4) (float-array (map float data)) 0 d)
      (.write ids (* i 8) (long-array [id]) 0 1))
    (exec #(FaissNative/faiss_Index_add_with_ids (index->ptr index) (long n) x ids))
    (count vectors)))

(defn search [^Index index ^Long k queries]
  {:pre [(sequential? queries)
         (every? #(= (dim index) (count %)) queries)]}
  (let [d      (dim index)
        n      (count queries)
        x      (Memory. (* 4 d n))
        dists  (Memory. (* 4 k n))
        labels (Memory. (* 8 k n))]
    (doseq [[i v] (map-indexed vector queries)]
      (.write x (* i d 4) (float-array (map float v)) 0 d))
    (exec #(FaissNative/faiss_Index_search (index->ptr index) (long n) x (long k) dists labels))
    (let [ds (map #(.getFloat dists (* 4 %)) (range (* n k)))
          ls (map #(.getLong labels (* 8 %)) (range (* n k)))]
      (->> (map (fn [d l] {:distance d :label l}) ds ls)
           (partition k k)
           (mapv #(vec (take-while (comp not neg? :label) %)))))))

(defn search-one [^Index index ^Long k query]
  (first (search index k [query])))



(comment
  (def i (create! 128 "Flat" :l2))

  (def x [{:id 1234 :data (map (comp float rand-int) (range 128))}
          {:id 1231 :data (map (comp float rand-int) (range 128))}
          {:id 1235 :data (map (comp float rand-int) (range 128))}
          {:id 1232 :data (map (comp float rand-int) (range 128))}
          {:id 1235 :data (map (comp float rand-int) (range 128))}
          {:id 1231 :data (map (comp float rand-int) (range 128))}])

  (FaissNative/faiss_Index_search)
  (add! i (map :data x))
  (add-with-ids! i x)

  (def q [(map (comp float rand-int) (range 128))
          (map (comp float rand-int) (range 128))])

  (search-one i 20 (first q))


  @(:pointer x)

  (trained? x)

  (release! x)

  #_(let [n         (count queries)
          x         (Memory. (* size-of-float d n))
          distances (Memory. (* size-of-float k))
          labels    (Memory. (* size-of-long k))]
      (doseq [[idx query-vals] (map-indexed vector queries)]
        (assert (= d (count query-vals)))
        (.write x (* idx d size-of-float) (float-array (map float query-vals)) 0 (count query-vals)))
      (call-native! #(FaissWrapper/searchFromIndex ptr (long k) n x distances labels))
      (let [res-distances (float-array k)
            res-labels    (long-array k)]
        (.read distances 0 res-distances 0 k)
        (.read labels 0 res-labels 0 k)
        (->> (map vector res-labels res-distances)
             (remove (comp neg? first))
             (map (fn [[id dist]] {:id id :distance dist}))
             (vec))))

  '-)
