(ns com.phronemophobic.clip
  (:require [com.phronemophobic.clip.raw :as raw]
            [clojure.java.io :as io]
            [clojure.edn :as edn]
            [clojure.string :as str])
  (:import com.sun.jna.ptr.FloatByReference
           com.sun.jna.Memory
           com.sun.jna.Structure
           com.sun.jna.Pointer
           java.lang.ref.Cleaner)
  (:gen-class))

(raw/import-structs!)
(def ^:dynamic *num-threads*
  "Number of threads to use when creating embeddings."
  1)

(defn ^:private make-clip-image-u8 []
  (let [img* (raw/clip_image_u8_make)
        ptr (.getPointer ^Structure img*)]
    (.register ^Cleaner raw/cleaner img*
               (fn []
                 (raw/clip_image_u8_free ptr)))
    img*))

(defn ^:private make-clip-image-f32 []
  (let [img* (raw/clip_image_f32_make)
        ptr (.getPointer ^Structure img*)]
    (.register ^Cleaner raw/cleaner img*
               (fn []
                 (raw/clip_image_f32_free ptr)))
    img*))

(defn create-context
  "Creates a context using model at `model-path`."
  [model-path]
  (assert (string? model-path))
  (let [ctx (raw/clip_model_load model-path 0)
        ptr (Pointer/nativeValue ctx)]
    (.register ^Cleaner raw/cleaner ctx
               (fn []
                 (raw/clip_free (Pointer. ptr))))
    ctx))

(defn image-embedding
  "Returns an embedding for image at `f` as a float array.

  `f` should be something that can be coerced via `clojure.java.io/as-file`."
  [ctx f]
  (let [f (io/as-file f)
        path (.getCanonicalPath f)

        img* (make-clip-image-u8)
        img-res* (make-clip-image-f32)

        params (raw/clip_get_vision_hparams ctx)
        vec-dim (:projection_dim params)

        _ (when (zero? (raw/clip_image_load_from_file path img*))
            (throw (ex-info "Could not load image."
                            {:ctx ctx
                             :f f}))
            )

        _ (when (zero? (raw/clip_image_preprocess ctx img* img-res*))
            (throw (ex-info "Could not preprocess image."
                            {:ctx ctx
                             :f f})))

        img-vec (Memory.
                 (*
                  ;; 4 bytes per float
                  4
                  ;; vec-dim floats
                  vec-dim))
        _ (when (zero?
                 (raw/clip_image_encode
                  ctx *num-threads* img-res* img-vec 1))
            (throw (ex-info "Could not encode image."
                            {:ctx ctx
                             :f f})))]
    (.getFloatArray img-vec 0 vec-dim)))


(defn text-embedding
  "Returns an embedding for `text` as a float array."
  [ctx text]
  (let [tokens* (clip_tokensByReference.)
        _ (raw/clip_tokenize ctx text tokens*)

        params (raw/clip_get_vision_hparams ctx)
        vec-dim (:projection_dim params)

        vec (Memory.
             (*
              ;; 4 bytes per float
              4
              ;; vec-dim floats
              vec-dim))
        _ (when (zero?
                 (raw/clip_text_encode ctx *num-threads* tokens* vec 1))
            (throw (ex-info "Could not encode text."
                            {:ctx ctx
                             :text text})))]
    (.getFloatArray vec 0 vec-dim)))

(defn cosine-similarity
  "Returns the cosine similarity between two embeddings as a float in [0.0, 1.0].

  The embeddings should be float arrays."
  [^floats emb1 ^floats emb2]
  (let [num (alength emb1)]
    (loop [dot-product (float 0)
           n 0]
      (if (< n num)
        (recur (+
                dot-product
                (* (aget emb1 n)
                   (aget emb2 n)))
               (inc n))
        dot-product))))

(defn distinct-by
  "Returns a lazy sequence of the elements of coll with duplicates removed.
  Returns a stateful transducer when no collection is provided."
  ([keyfn]
   (fn [rf]
     (let [seen (volatile! #{})]
       (fn
         ([] (rf))
         ([result] (rf result))
         ([result input]
          (let [k (keyfn input)]
            (if (contains? @seen k)
              result
              (do (vswap! seen conj k)
                  (rf result input))))))))))

(comment
  (do
    (def model-path
      "models/CLIP-ViT-B-32-laion2B-s34B-b79K_ggml-model-f16.gguf")

    (def ctx (create-context model-path)))

  (def photos
    (into []
          (comp
           (filter #(or (str/ends-with? (.getName %) ".jpg")
                        (str/ends-with? (.getName %) ".png")
                        (str/ends-with? (.getName %) ".jpeg"))))
          (file-seq (io/file "aimages"))))

  (require 'dev)
  (dev/add-local "membrane2")
  (require '[membrane.ui :as ui])
  (def to-split
    (->> photos
         (filter (fn [f]
                   (try
                     (let [[w h] (ui/bounds (ui/image (.getCanonicalPath f)))]
                       (and (= w 2048)
                            (= h 1024)))
                     (catch Exception e
                       false))))))

  (require '[membrane.java2d :as java2d])
  (doseq [f to-split]
    (let [img (ui/image (.getCanonicalPath f))
          left (ui/scissor-view
                [0 0]
                [1024 1024]
                img)
          right (ui/scissor-view
                 [0 0]
                 [1024 1024]
                 (ui/translate (- 1024) 0
                               img))]
      (java2d/save-image (str "aimages/left/" (.getName f))
                         left)
      (java2d/save-image (str "aimages/right/" (.getName f))
                         right)))

  ;; 2048x1024

  (def to-index (clojure.set/difference (set photos) split?))


  (def embeddings
    (time
     (binding [*num-threads* 8]
       (into []
             (comp (map (fn [f]
                          (try
                            [(.getName f)
                             (vec (image-embedding ctx f))]
                            (catch Exception e
                              nil))))
                   (remove nil?))
             to-index))))


  (require 'dev)
  (def embeddings (dev/read-edn "embeddings.edn"))

  (with-open [w (io/writer "embeddings.edn")]
    (dev/write-edn w embeddings)
    #_(doseq [o embeddings]
        (dev/write-edn w o)
        (.write w "\n")))
  (dev/w)

  (dev/add-libs '{com.taoensso/nippy {:mvn/version "3.3.0"}})
  (require '[taoensso.nippy :as nippy])
  (nippy/freeze-to-file "embeddings.nippy" (into []
                                                 (comp (map (fn [[fname emb]]
                                                         [fname (float-array emb)]))
                                                       (distinct-by (fn [[fname emb]]
                                                                      (vec emb))))
                                                 embeddings))

  (text-embedding ctx "hello")
  (image-embedding ctx "aimages/005bce0c-710f-4b3d-8c94-5be8d86585e9.jpg")

  ,)

