(ns moncee.sql
  "SQL/TimescaleDB support for moncee - lazy sequence interface over JDBC"
  (:require [next.jdbc :as jdbc]
            [next.jdbc.result-set :as rs]
            [next.jdbc.prepare :as prepare]
            [next.jdbc.date-time]  ;; auto-convert java.time types
            [clojure.string :as str]
            [taoensso.timbre :as log])
  (:import [java.sql PreparedStatement ResultSet Timestamp]
           [java.time Instant LocalDateTime OffsetDateTime]))

;; Dynamic bindings for default datasource and database type
(defonce ^:dynamic *ds* nil)
(defonce ^:dynamic *dbtype* :postgresql)

(defn set-datasource!
  "Set default datasource"
  ([ds]
   (alter-var-root #'*ds* (constantly ds)))
  ([ds dbtype]
   (alter-var-root #'*ds* (constantly ds))
   (alter-var-root #'*dbtype* (constantly dbtype))))

(defn datasource
  "Create a datasource from options.

   Supported :dbtype values:
   - :postgresql (default, port 5432)
   - :oracle (port 1521, uses service name)

   Examples:
     (datasource {:dbtype :postgresql :host \"localhost\" :dbname \"mydb\"})
     (datasource {:dbtype :oracle :host \"dbhost\" :dbname \"ORCL\" :user \"scott\" :password \"tiger\"})"
  [{:keys [dbtype host port dbname user password]
    :or {dbtype :postgresql host "localhost"}}]
  (let [dbtype (keyword dbtype)
        port (or port (case dbtype :oracle 1521 5432))]
    (alter-var-root #'*dbtype* (constantly dbtype))
    (jdbc/get-datasource
     (case dbtype
       :oracle
       {:dbtype "oracle"
        :host host
        :port port
        :dbname dbname  ;; service name
        :user user
        :password password}
       ;; default: postgresql
       {:dbtype "postgresql"
        :host host
        :port port
        :dbname dbname
        :user user
        :password password}))))

;; SQL operators mapping
(def sql-ops {< "<" > ">" <= "<=" >= ">=" not= "<>" = "="})

(defprotocol SqlOperations
  (-limit [this])
  (-offset [this])
  (-where [this])
  (-select [this])
  (-order-by [this])
  (-build-sql [this])
  (-insert [this docs])
  (-delete [this])
  (-update [this m]))

;; Lazy result set wrapper
(deftype LazyResultSet [^ResultSet rs opts]
  clojure.lang.Seqable
  (seq [this]
    (when (.next rs)
      (cons (rs/datafiable-row rs (:connectable opts) opts)
            (lazy-seq (.seq this)))))

  java.io.Closeable
  (close [this]
    (.close rs)))

(defn- condition->sql
  "Convert a condition pair to SQL fragment and params"
  [[field value]]
  (let [field-name (name field)]
    (cond
      ;; {:field {> 100}} -> "field > ?"
      (map? value)
      (let [[op v] (first value)
            sql-op (get sql-ops op (str op))]
        [(str field-name " " sql-op " ?") [v]])

      ;; {:field nil} -> "field IS NULL"
      (nil? value)
      [(str field-name " IS NULL") []]

      ;; {:field [1 2 3]} -> "field IN (?, ?, ?)"
      (sequential? value)
      [(str field-name " IN (" (str/join ", " (repeat (count value) "?")) ")") (vec value)]

      ;; {:field "value"} -> "field = ?"
      :else
      [(str field-name " = ?") [value]])))

(defn- build-where
  "Build WHERE clause from conditions"
  [conditions]
  (if (empty? conditions)
    ["" []]
    (let [parts (map condition->sql conditions)
          clauses (map first parts)
          params (mapcat second parts)]
      [(str "WHERE " (str/join " AND " clauses)) (vec params)])))

(defn- build-select
  "Build SELECT clause from projection"
  [projection]
  (if (empty? projection)
    "*"
    (str/join ", " (map name (keys (filter (fn [[_ v]] v) projection))))))

(defn- build-order-by
  "Build ORDER BY clause"
  [ordering]
  (if (empty? ordering)
    ""
    (str "ORDER BY "
         (str/join ", "
                   (map (fn [[field dir]]
                          (str (name field) " " (if (= dir 1) "ASC" "DESC")))
                        ordering)))))

(declare sql-context)

(deftype SqlContext [operations ^javax.sql.DataSource datasource table-name]
  clojure.lang.ILookup
  (valAt [this key]
    (condp = key
      :operations operations
      :datasource datasource
      :table table-name
      nil))
  (valAt [this key not-found]
    (or (.valAt this key) not-found))

  clojure.lang.Associative
  (containsKey [this key]
    (contains? #{:operations :datasource :table} key))
  (entryAt [this key]
    (.valAt this key))
  (assoc [this key val]
    (condp = key
      :operations (SqlContext. val datasource table-name)
      :datasource (SqlContext. operations val table-name)
      :table (SqlContext. operations datasource val)
      this))

  clojure.lang.Counted
  (count [this]
    (let [[where-clause params] (build-where (-where this))
          sql (str "SELECT COUNT(*) AS cnt FROM " table-name " " where-clause)]
      (log/trace "COUNT SQL:" sql "params:" params)
      (:cnt (jdbc/execute-one! datasource (into [sql] params)))))

  clojure.lang.Seqable
  (seq [this]
    (let [[sql params] (-build-sql this)]
      (log/trace "SQL:" sql "params:" params)
      ;; Use reducible-query for lazy processing
      (let [conn (jdbc/get-connection datasource)
            stmt (jdbc/prepare conn (into [sql] params)
                               {:fetch-size 1000
                                :result-type :forward-only
                                :concurrency :read-only})
            rs (.executeQuery ^PreparedStatement stmt)]
        ;; Return lazy seq that closes resources when exhausted
        (letfn [(lazy-rows []
                  (if (.next rs)
                    (cons (rs/datafiable-row rs conn {:builder-fn rs/as-unqualified-kebab-maps})
                          (lazy-seq (lazy-rows)))
                    (do
                      (.close rs)
                      (.close stmt)
                      (.close conn)
                      nil)))]
          (lazy-rows)))))

  clojure.lang.ISeq
  (first [this]
    (let [[sql params] (-build-sql (update this :operations conj {:op :limit :v 1}))]
      (log/trace "FIRST SQL:" sql "params:" params)
      (jdbc/execute-one! datasource (into [sql] params)
                         {:builder-fn rs/as-unqualified-kebab-maps})))
  (more [this]
    (update this :operations conj {:op :offset :v 1}))
  (cons [this o]
    this)
  (next [this]
    (seq (.more this)))

  SqlOperations
  (-limit [this]
    (or (last (sort (map :v (filter #(= (:op %) :limit) operations)))) 0))

  (-offset [this]
    (apply max (cons 0 (map :v (filter #(= (:op %) :offset) operations)))))

  (-where [this]
    (->> operations
         (filter #(= (:op %) :where))
         (mapcat :v)))

  (-select [this]
    (->> operations
         (filter #(= (:op %) :select))
         (map :v)
         (reduce merge {})))

  (-order-by [this]
    (->> operations
         (filter #(= (:op %) :order))
         (map :v)
         (reduce merge {})))

  (-build-sql [this]
    (let [select-clause (build-select (-select this))
          [where-clause where-params] (build-where (-where this))
          order-clause (build-order-by (-order-by this))
          limit-val (-limit this)
          offset-val (-offset this)
          ;; Oracle 12c+ uses FETCH FIRST, PostgreSQL uses LIMIT/OFFSET
          sql (str "SELECT " select-clause
                   " FROM " table-name
                   (when (seq where-clause) (str " " where-clause))
                   (when (seq order-clause) (str " " order-clause))
                   (if (= *dbtype* :oracle)
                     ;; Oracle syntax
                     (str (when (> offset-val 0) (str " OFFSET " offset-val " ROWS"))
                          (when (> limit-val 0) (str " FETCH FIRST " limit-val " ROWS ONLY")))
                     ;; PostgreSQL syntax
                     (str (when (> limit-val 0) (str " LIMIT " limit-val))
                          (when (> offset-val 0) (str " OFFSET " offset-val)))))]
      [sql where-params]))

  (-insert [this docs]
    (when (seq docs)
      (let [ks (keys (first docs))
            cols (str/join ", " (map name ks))
            placeholders (str/join ", " (repeat (count ks) "?"))
            sql (str "INSERT INTO " table-name " (" cols ") VALUES (" placeholders ")")]
        (jdbc/execute-batch! datasource sql (map (fn [doc] (map #(get doc %) ks)) docs) {}))))

  (-delete [this]
    (let [[where-clause params] (build-where (-where this))
          sql (str "DELETE FROM " table-name " " where-clause)]
      (log/trace "DELETE SQL:" sql "params:" params)
      (:next.jdbc/update-count
       (jdbc/execute-one! datasource (into [sql] params)))))

  (-update [this m]
    (let [[where-clause where-params] (build-where (-where this))
          set-clause (str/join ", " (map (fn [[k _]] (str (name k) " = ?")) m))
          set-params (vals m)
          sql (str "UPDATE " table-name " SET " set-clause " " where-clause)]
      (log/trace "UPDATE SQL:" sql "params:" (concat set-params where-params))
      (:next.jdbc/update-count
       (jdbc/execute-one! datasource (into [sql] (concat set-params where-params)))))))

;; Constructor
(defn sql-table
  "Create a SQL context for a table"
  ([table-name]
   (sql-table *ds* table-name))
  ([datasource table-name]
   (SqlContext. [] datasource (name table-name))))

;; Query operations (same interface as moncee.core)

(defn restrict
  "Add WHERE conditions"
  [& params-and-source]
  (let [ctx (last params-and-source)
        conditions (->> (butlast params-and-source)
                        (partition 2)
                        (map vec))]
    (update ctx :operations conj {:op :where :v conditions})))

(defn project
  "Select specific fields"
  [& fields-and-source]
  (let [source (last fields-and-source)
        fields (butlast fields-and-source)
        projection (if (and (= 1 (count fields)) (vector? (first fields)))
                     (zipmap (first fields) (repeat true))
                     (into {} (map vec (partition 2 fields))))]
    (update source :operations conj {:op :select :v projection})))

(defn order
  "Add ORDER BY"
  [field direction source]
  (update source :operations conj
          {:op :order
           :v {field (condp = direction :asc 1 :desc -1 direction)}}))

(defn limit
  "Add LIMIT"
  [n source]
  (update source :operations conj {:op :limit :v n}))

(defn skip
  "Add OFFSET"
  [n source]
  (update source :operations conj {:op :offset :v n}))

(defn fetch
  "Fetch first row"
  [source]
  (first source))

(defn query
  "Execute query and return lazy seq"
  [source]
  (seq source))

(defn insert!
  "Insert documents"
  [source & docs]
  (-insert source docs))

(defn delete!
  "Delete matching rows"
  ([source]
   (-delete source))
  ([k v & kvs-and-source]
   (let [source (last kvs-and-source)
         kvs (partition 2 (butlast kvs-and-source))]
     (delete! (reduce (fn [ctx [k# v#]]
                        (restrict k# v# ctx))
                      (restrict k v source)
                      kvs)))))

(defn update!
  "Update matching rows"
  [& ops-and-source]
  (let [source (last ops-and-source)
        updates (into {} (map vec (partition 2 (butlast ops-and-source))))]
    (-update source updates)))

;; Aggregation support

(defn count*
  "Count rows"
  [source]
  (count source))

(defn sum
  "Sum a field"
  [field source]
  (let [[where-clause params] (build-where (-where source))
        sql (str "SELECT SUM(" (name field) ") AS total FROM " (:table source) " " where-clause)]
    (:total (jdbc/execute-one! (:datasource source) (into [sql] params)))))

(defn avg
  "Average a field"
  [field source]
  (let [[where-clause params] (build-where (-where source))
        sql (str "SELECT AVG(" (name field) ") AS average FROM " (:table source) " " where-clause)]
    (:average (jdbc/execute-one! (:datasource source) (into [sql] params)))))

(defn min*
  "Min of a field"
  [field source]
  (let [[where-clause params] (build-where (-where source))
        sql (str "SELECT MIN(" (name field) ") AS minimum FROM " (:table source) " " where-clause)]
    (:minimum (jdbc/execute-one! (:datasource source) (into [sql] params)))))

(defn max*
  "Max of a field"
  [field source]
  (let [[where-clause params] (build-where (-where source))
        sql (str "SELECT MAX(" (name field) ") AS maximum FROM " (:table source) " " where-clause)]
    (:maximum (jdbc/execute-one! (:datasource source) (into [sql] params)))))

;; Raw SQL execution

(defn sql!
  "Execute raw SQL query and return results as vector of maps.

   Examples:
     (sql! \"SELECT * FROM users WHERE id = ?\" 1)
     (sql! \"SELECT * FROM dual\")"
  [sql & params]
  (jdbc/execute! *ds* (into [sql] params)
                 {:builder-fn rs/as-unqualified-kebab-maps}))

(defn sql-one!
  "Execute raw SQL query and return first result as map."
  [sql & params]
  (jdbc/execute-one! *ds* (into [sql] params)
                     {:builder-fn rs/as-unqualified-kebab-maps}))

(defn exec!
  "Execute raw SQL statement (INSERT, UPDATE, DELETE, DDL).
   Returns update count or result."
  [sql & params]
  (jdbc/execute-one! *ds* (into [sql] params)))

(defn tables
  "List tables in current schema.
   Returns vector of table names."
  []
  (case *dbtype*
    :oracle
    (->> (sql! "SELECT table_name FROM user_tables ORDER BY table_name")
         (mapv :table-name))
    ;; PostgreSQL
    (->> (sql! "SELECT tablename FROM pg_tables WHERE schemaname = 'public' ORDER BY tablename")
         (mapv :tablename))))

(defn describe
  "Describe table structure."
  [table-name]
  (case *dbtype*
    :oracle
    (sql! "SELECT column_name, data_type, nullable, data_length, data_precision
           FROM user_tab_columns WHERE table_name = ? ORDER BY column_id"
          (str/upper-case (name table-name)))
    ;; PostgreSQL
    (sql! "SELECT column_name, data_type, is_nullable, character_maximum_length
           FROM information_schema.columns
           WHERE table_name = ? AND table_schema = 'public'
           ORDER BY ordinal_position"
          (name table-name))))

(defn connected?
  "Check if database is connected."
  []
  (try
    (case *dbtype*
      :oracle (sql-one! "SELECT 1 FROM dual")
      (sql-one! "SELECT 1"))
    true
    (catch Exception _ false)))
