(ns multi-atom.aws.dynamo
  "A dynamo backend providing a multi-atom implementation
   - keys are always encoded as strings
   - maps and vectors are stored natively
   - byte arrays are stored natively as binary blobs
   - numerics are encoded as strings
   - strings are quoted
   - sets/lists are encoded as strings"
  (:require [multi-atom.core :as atom]
            [taoensso.faraday :as far]
            [clojure.edn :as edn]
            [clojure.walk :as walk])
  (:import (com.amazonaws.services.dynamodbv2.model ConditionalCheckFailedException)
           (clojure.lang Keyword IPersistentMap IPersistentVector Symbol)))


(defmulti write-value* class)

(derive IPersistentVector ::native)
(derive (Class/forName "[B") ::native)


(defmethod write-value* :default
  [x]
  (when-some [x x]
    (pr-str x)))

(defmethod write-value* ::native
  [x]
  x)

(deftype Key [x])

(defmethod write-value* Key
  [^Key k]
  (pr-str (.x k)))

;;maps are only natively supported if all keys are `named`
(defmethod write-value* IPersistentMap
  [x]
  (reduce-kv #(assoc %1 (if (or (string? %2) (instance? clojure.lang.Named %2))
                          %2
                          (Key. %2)) %3) {} x))

(defmulti read-value* class)

(defmethod read-value* :default
  [x]
  x)

(defmethod read-value* Keyword
  [x]
  (let [ns (namespace x)]
    (if ns
      (edn/read-string (str ns "/" (name x)))
      (edn/read-string (name x)))))

(defmethod read-value* String
  [x]
  (edn/read-string x))

(defn write-value
  "Encodes the given value for dynamo"
  [x]
  (walk/prewalk write-value* x))

(defn read-value
  "Reads an encoded dynamo value"
  [x]
  (walk/postwalk read-value* x))

(defn read-key
  "Reads an encoded dynamo key"
  [x]
  (if (and (vector? x) (= (first x) ::key))
    (second x)
    x))

(defn native-key
  [x]
  (if (and (coll? x) (empty? x))
    [::key x]
    x))

(defn write-key
  "Keys must be encoded differently, as regardless
  of the type, they have to be encoded as strings.

  This function will allow you to write a non-string
  key if one wishes."
  [x]
  (pr-str
    (native-key x)))

;;hack all the things
(def ^:dynamic *dynamo-list-hack* false)

(alter-var-root #'far/attr-multi-vs
                (fn [f]
                  (fn [x]
                    (if *dynamo-list-hack*
                      (mapv far/clj-item->db-item x)
                      (f x)))))

(defn fixed-batch-write-item
  "To get around an issue with batch-write-item and lists (namely they aren't supported properly)"
  [client-opts request]
  (binding [*dynamo-list-hack* true]
    (let [r (far/batch-write-item client-opts request)]
      (if (seq (:unprocessed r))
        (throw (Exception.
                 (format "Could not successfully write all items to dynamo. failed to write %s items"
                         (count (:unprocessed r)))))
        r))))

(defrecord TableClient [client-opts table key-name])

(defn table-client
  "A table client represents a 'connection' via faraday
  to the given table using the given key-name (hash key column)"
  [client-opts table key-name]
  (map->TableClient {:client-opts client-opts
                     :table       table
                     :key-name    key-name}))

(defn prepare-value
  "Makes sure the value is a map, makes sure if there is an existing
  value under the encoded`key-name` it is preserved. Should be called
  before `write-value`"
  ([value key-name]
   (if (map? value)
     (let [key-name key-name]
       (if (contains? value key-name)
         (-> value
             (dissoc key-name)
             (assoc ::key (get value key-name)))
         value))
     {::value value})))

(defn add-key
  "Writes the (encoded) key to the value"
  [prepared-value key-name key]
  (assoc prepared-value
    (write-key key-name) (write-key key)))

(defn read-item
  "Takes a value that has been unencoded and returns the original value
  from the prepared map"
  [read-value key-name]
  (when read-value
    (if (contains? read-value ::value)
      (::value read-value)
      (cond-> (dissoc read-value (native-key key-name) ::key ::version)
              (contains? read-value ::key) (assoc key-name (get read-value ::key))))))

(defn find-raw-value
  "Performs a consistent faraday lookup against `key`"
  [table-client key]
  (let [{:keys [client-opts table key-name]} table-client]
    (far/get-item client-opts table {(write-key key-name) (write-key key)} {:consistent? true})))

(defn find-raw-item
  "Finds the item at `key`, will unencode the value but will not
  do any further pruning."
  [table-client key]
  (-> (find-raw-value table-client key)
      read-value))

(defn find-item
  "Finds the item stored under `key`. Performs a consistent lookup"
  [table-client key]
  (let [{:keys [key-name]} table-client]
    (-> (find-raw-item table-client key)
        (read-item key-name))))

(defn delete-item!
  "Deletes the item stored under `key`"
  [table-client key]
  (let [{:keys [client-opts table key-name]} table-client]
    (far/delete-item client-opts table {(write-key key-name) (write-key key)})))

(defn batch-result-mapping
  [results table key-name]
  (let [read (map read-value (get results table))
        grouped (group-by #(get % key-name) read)]
    (-> (reduce-kv #(assoc! %1 (read-key %2) (read-item (first %3) key-name)) (transient {}) grouped)
        persistent!)))

(defn find-item-mapping
  "Lookups the given coll of `keys` and returns a map
  from key to item. If an item isn't found, the key will be missing from the
  result.
  Each item will be read consistently."
  [table-client keys]
  (when (seq keys)
    (let [{:keys [client-opts table key-name]} table-client
          results (far/batch-get-item client-opts {table {:prim-kvs    {(write-key key-name) (mapv write-key keys)}
                                                          :consistent? true}})]
      (batch-result-mapping results table key-name))))

(defn find-items
  "Looks up a the items given by keys.
  Returns a seq of items, keys that could not be found will be omitted.

  Each item will be read consistently."
  [table-client keys]
  (let [rank (into {} (map-indexed (fn [i key] (vector key i)) keys))]
    (->> (sort-by (comp rank key) (seq (find-item-mapping table-client keys)))
         (map val))))

(defn find-inconsistent-item
  "Eventually consistent version of `find-item`"
  [table-client key]
  (let [{:keys [client-opts table key-name]} table-client]
    (-> (far/get-item client-opts table {(write-key key-name) (write-key key)} {:consistent? false})
        read-value
        (read-item key-name))))

(defn find-inconsistent-item-mapping
  "Eventually consistent version of `find-item-mapping`"
  [table-client keys]
  (when (seq keys)
    (let [{:keys [client-opts table key-name]} table-client
          results (far/batch-get-item client-opts {table {:prim-kvs    {(write-key key-name) (mapv write-key keys)}
                                                          :consistent? false}})]
      (batch-result-mapping results table key-name))))

(defn find-inconsistent-items
  "Eventually consistent version of `find-items`"
  [table-client keys]
  (let [rank (into {} (map-indexed (fn [i key] (vector key i)) keys))]
    (->> (sort-by (comp rank key) (seq (find-inconsistent-item-mapping table-client keys)))
         (map val))))

(def ^:dynamic *default-read-throughput*
  "The default amount of read throughput for tables created via `create-table!`"
  8)
(def ^:dynamic *default-write-throughput*
  "The default amount of write throughput for tables created via `create-table!`"
  8)

(defn create-table!
  "Creates a table suitable for storing data according to the encoding scheme."
  ([table-client]
   (create-table! table-client *default-read-throughput* *default-write-throughput*))
  ([table-client read-throughput write-throughput]
   (let [{:keys [client-opts table key-name]} table-client]
     (far/create-table client-opts table
                       [(write-key key-name) :s]
                       {:throughput {:read read-throughput :write write-throughput}}))))
(defn ensure-table!
  "Creates a table suitable for storing data according to the encoding scheme,
  unless it already exists."
  ([table-client]
   (ensure-table! table-client *default-read-throughput* *default-write-throughput*))
  ([table-client read-throughput write-throughput]
   (let [{:keys [client-opts table key-name]} table-client]
     (far/ensure-table client-opts table
                       [(write-key key-name) :s]
                       {:throughput {:read read-throughput :write write-throughput}}))))

(defn put-item!
  "Stores the value under `key`"
  ([table-client key value]
   (let [{:keys [client-opts table key-name]} table-client]
     (far/put-item client-opts table (-> (prepare-value value key-name)
                                         write-value
                                         (add-key key-name key))))))

(defn put-items!
  "Stores each value in `kvs` under its corresponding key.
  `kvs` should be a map or a seq of key value pairs."
  ([table-client kvs]
   (let [{:keys [client-opts table key-name]} table-client
         chunked (partition-all 25 kvs)]
     (doseq [chunk chunked]
       (fixed-batch-write-item client-opts {table {:put (mapv #(-> (prepare-value (second %) key-name)
                                                                   write-value
                                                                   (add-key key-name (first %)))
                                                              chunk)}})))))

(defn cas-put-item!
  "Overwrites the value under `key` only if
   - The version provided matches the previous version of the item
   - There is not data currently under `key`"
  [table-client key value version]
  (let [{:keys [client-opts table key-name]} table-client
        version-key (write-key ::version)]
    (try
      (far/put-item client-opts table (-> (prepare-value value key-name)
                                          write-value
                                          (add-key key-name key)
                                          (assoc version-key (write-value (if version (inc version) 0))))
                    {:expected {version-key (if version [:eq (write-value version)] :not-exists)}})
      true
      (catch ConditionalCheckFailedException e
        false))))

(def ^:dynamic *default-cas-sleep-ms*
  "When we hit contention we will wait this long before attempting the CAS operation
  again by default."
  500)

(defn swap-item!*
  ([table-client key f sleep-ms timeout-ms timeout-val]
   (if (and timeout-ms (<= 0 timeout-ms))
     timeout-val
     (let [{:keys [key-name]} table-client
           value (find-raw-item table-client key)
           version (::version value)
           item (f (read-item value key-name))]
       (if (cas-put-item! table-client key item version)
         item
         (do (when sleep-ms (Thread/sleep sleep-ms))
             (recur table-client key f sleep-ms (when timeout-ms (- timeout-ms sleep-ms)) timeout-val)))))))

(defn swap-item!
  "Applies the function `f` and any `args` to the value currently under `key` storing
  the result. Returns the result."
  ([table-client key f]
   (swap-item!* table-client key f
                (:cas-sleep-ms table-client)
                (:cas-timeout table-client)
                (:cas-timeout-val table-client)))
  ([table-client key f & args]
   (swap-item! table-client key #(apply f % args))))

(defrecord DynamoMultiAtom [table-client cas-sleep-ms cas-timeout-ms  cas-timeout-val]
  atom/IMultiDeref
  (-deref-at [this key not-found]
    (let [raw-item (find-raw-item table-client key)]
      (if (some? raw-item)
        (read-item raw-item (:key-name table-client))
        not-found)))
  atom/IMultiAtom
  (-swap-at! [this key f]
    (swap-item!* table-client key f cas-sleep-ms cas-timeout-ms cas-timeout-val))
  atom/IDeleteCell
  (delete-at! [this key]
    (delete-item! table-client key)))

(defn multi-atom
  "Creates a multi-atom, where each cell will be an item in the
   table given by the client."
  ([table-client]
   (multi-atom table-client *default-cas-sleep-ms*))
  ([table-client cas-sleep-ms]
   (multi-atom table-client cas-sleep-ms nil nil))
  ([table-client cas-sleep-ms cas-timeout-ms cas-timeout-val]
   (map->DynamoMultiAtom {:table-client    table-client
                          :cas-sleep-ms    cas-sleep-ms
                          :cas-timeout-ms  cas-timeout-ms
                          :cas-timeout-val cas-timeout-val})))

