(ns hara.lib.opencl.base
  (:require [uncomplicate.clojurecl.core :as cl]
            [uncomplicate.clojurecl.info :as info]
            [uncomplicate.commons.core :as common]
            [hara.core.base.primitive :as primitive]
            [hara.function :as fn :refer [defexecutive definvoke]]
            [hara.protocol.function :as protocol.function]
            [hara.protocol.opencl :as protocol.opencl]
            [clojure.java.io :as io]
            [hara.io.binary.buffer :as buffer])
  (:import (java.nio Buffer)))

;; extend to all buffers
(extend-type java.nio.Buffer
  uncomplicate.clojurecl.internal.protocols/Mem
  (ptr [this]
    (org.jocl.Pointer/toBuffer this))
  (size [this]
    (unchecked-multiply-int (buffer/buffer-primitive this :bytes)
                            (.capacity this))))

(declare opencl-invoke)

(defn opencl-display
  "outputs a representation of the opencl object
 
   (opencl-display (-hello-cstr-))
   => {:name \"hello_kernel\"
       :arglist [{:type :byte-array
                  :name \"msg\"
                  :output true
                  :flags [:write-only]}]}"
  {:added "3.0"}
  [opencl]
  (->> (select-keys opencl [:name :arglist :file])
       (into {})))

(defexecutive OpenCL
  "type definition for the opencl object"
  {:added "3.0"}
  [name code path arglist worksize options context queue kernel runtime]
  {:type    defrecord
   :tag     "opencl"
   :invoke  opencl-invoke
   :display opencl-display})

(defn default-context
  "creates the default context
 
   (default-context)"
  {:added "3.0"}
  ([]
   (-> (cl/platforms)
       first
       cl/devices
       default-context))
  ([devs]
   (cl/context devs)))

(defn select-gpu
  "selects the gpu
 
   (select-gpu)"
  {:added "3.0"}
  ([]
   (-> (cl/platforms)
       (first)
       (cl/devices)
       (select-gpu)))
  ([devs]
   (->> devs
        (filter (fn [id] (-> (common/info id)
                             :device-type
                             (= :gpu))))
        (first))))

(defn opencl-arglist
  "fills out flags depending on input"
  {:added "3.0"}
  [arglist]
  (->> arglist
       (mapv (fn [{:keys [flags input output] :as arg}]
               (let [flags (or flags
                               (if (and input output) [:read-write])
                               (if input [:read-only])
                               (if output[:write-only]))]
                 (cond-> arg
                   flags (assoc :flags flags)))))))

(defn opencl
  "constructor for the opencl object
   
   (def -mult-float-code-
     \"__kernel void mult(__global float * input,
                         float f){
       int gid = get_global_id(0);
       input[gid] = f * input[gid];
     }\")
   
   (def -mult-float-
     (opencl {:name \"mult\"
              :code  -mult-float-code-
              :worksize (fn [{:strs [src]}]
                         [[(opencl-count src)]])
              :arglist [{:name \"src\" :type :float-array, :input true, :output true}
                        {:name \"f\"   :type :float}]}))"
  {:added "3.0"}
  [{:keys [name code resource path arglist worksize options] :as m}]
  (let [devs (-> (cl/platforms)
                 first
                 cl/devices)
        context (default-context devs)
        device  (or (select-gpu devs)
                    (first devs))
        queue   (cl/command-queue-1 context device)
        code    (or code 
                    (if path (slurp path))
                    (if resource (slurp (io/resource resource))))
        program (-> (cl/program-with-source context [code])
                    (cl/build-program! options nil))
        kernel  (cl/kernel program name)
        arglist (opencl-arglist arglist)]
    (-> m
        (assoc :name name
               :arglist arglist
               :code code
               :program program
               :context context
               :device device
               :queue queue
               :kernel kernel
               :runtime (atom {}))
        (map->OpenCL))))

(definvoke invoke-intern-opencl
  "method body for `:opencl` definvoke form
 
   (invoke-intern-opencl '-mult-float-
                         '{:name \"mult\"
                           :code  -mult-float-code-
                           :worksize (fn [{:strs [src]}]
                                       [[(opencl-count src)]])
                           :arglist [{:name \"src\" :type :float-array, :input true, :output true}
                                     {:name \"f\"   :type :float}]})
   
   (definvoke -mult-float-
     [:opencl {:name \"mult\"
               :code  -mult-float-code-
               :worksize (fn [{:strs [src]}]
                           [[(opencl-count src)]])
               :arglist [{:name \"src\" :type :float-array, :input true, :output true}
                         {:name \"f\"   :type :float}]}])
   
   (vec (-mult-float- (float-array [1.1 20.1 300.1 4000.1])
                      10.1))
   => (map float [11.110001 203.01001 3031.0103 40401.01])"
  {:added "3.0"}
  [:method {:multi protocol.function/-invoke-intern
            :val :opencl}]
  ([name config]
   (invoke-intern-opencl :opencl name config nil))
  ([_ name {:keys [base] :as config} _]
   (let [body  `(delay (opencl ~config))
         arglists (list (mapv (comp clojure.core/name :name) (:arglist config)))]
     `(doto (def ~name ~body)
        (alter-meta! assoc :arglists (quote ~arglists))))))

(def +opencl-builtins+
  {:byte         {:val unchecked-byte :wrap byte-array}
   :short        {:val short  :wrap short-array}
   :int          {:val int    :wrap int-array}
   :long         {:val long   :wrap long-array}
   :float        {:val float  :wrap float-array}
   :double       {:val double :wrap double-array}
   :byte-array   {:component :byte}
   :short-array  {:component :short}
   :int-array    {:component :int}
   :long-array   {:component :long}
   :float-array  {:component :float}
   :double-array {:component :double}})

(defn opencl-count
  "count for both buffers and arrays
 
   (opencl-count (byte-array 10))
   => 10
 
   (opencl-count (buffer/byte-buffer 10))
   => 10"
  {:added "3.0"}
  [obj]
  (cond (instance? Buffer obj)
        (.capacity ^Buffer obj)

        :else
        (count obj)))

(defn opencl-create-input
  "creates the buffers and inputs
 
   (def -buff- (opencl-create-input {:type :float-array, :flags [:read-write]}
                                    (float-array [1 2 3 4])
                                   (:context -mult-float-)))"
  {:added "3.0"}
  [param arg context]
  (if-let [opts (+opencl-builtins+ (:type param))]
    (if-let [comp (:component opts)]
      (apply cl/cl-buffer context (* (primitive/primitive-type comp :bytes) (opencl-count arg)) (:flags param))
      ((:wrap opts) [((:val opts) arg)]))
    (protocol.opencl/-opencl-create-input param arg context)))

(defn opencl-write-input
  "writes the input to the buffer
 
   (opencl-write-input {:type :float-array, :flags [:read-write]}
                       (float-array [1 2 3 4])
                      -buff-
                       (:queue -mult-float-))"
  {:added "3.0"}
  [param arg buffer queue]
  (if (-> param :input)
    (if-let [opts (+opencl-builtins+ (:type param))]
      (if (:component opts)
        (cl/enq-write! queue buffer arg))
      (protocol.opencl/-opencl-write-input param arg buffer queue))))

(defn opencl-init
  "creates a runtime record for opencl invoke
 
   (def -hello- (-hello-cstr-))
   
   (opencl-init -hello- [(byte-array 16)])
   => (contains-in {:buffers anything
                    :args vector?
                    :margs {\"msg\" bytes?},
                    :output \"msg\"})"
  {:added "3.0"}
  [{:keys [kernel context queue arglist] :as opencl} args]
  (let [inputs (map #(opencl-create-input %1 %2 context)
                    arglist
                    args)
        _ (apply cl/set-args! kernel inputs)
        _ (mapv #(opencl-write-input %1 %2 %3 queue)
                arglist
                args
                inputs)]
    {:buffers inputs
     :args    args
     :margs   (zipmap (map :name arglist) args)
     :output  (:name (first (filter (comp :output) arglist)))}))

(defn opencl-read-output
  "reads out the buffer information to outputs"
  {:added "3.0"}
  [param arg buffer queue]
  (if (-> param :output)
    (if-let [opts (+opencl-builtins+ (:type param))]
      (if (:component opts)
        (cl/enq-read! queue buffer arg))
      (protocol.opencl/-opencl-read-output param arg buffer queue))))

(defn opencl-invoke
  "invokes a kernel with given arguments
 
   (String. (opencl-invoke -hello-
                           (byte-array 16)))
   => \"Hello Kernel!!!!\""
  {:added "3.0"}
  [{:keys [kernel worksize queue arglist runtime] :as exec} & args]
  (let [_         (if-not (= (count arglist)
                             (count args))
                    (throw (ex-info "Wrong number of arguments" {:require (count arglist)
                                                                 :input   (count args)})))
        runtime   (opencl-init exec args)                  ;; (time/time "create buffers" )
        dims      (if (fn? worksize)
                    (worksize (:margs runtime))
                    worksize)
        ws        (apply cl/work-size dims)
        _         (do (cl/enq-kernel! queue kernel ws)
                      (mapv #(opencl-read-output %1 %2 %3 queue)
                            arglist
                            args
                            (:buffers runtime)))
        _         (doseq [b (:buffers runtime)]             ;; (time/time "release buffers" )
                    (common/release b))]        
    (get (:margs runtime) (:output runtime))))
