(ns vlaaad.reveal.pro.sql.connection
  (:require [clojure.spec.alpha :as s]
            [clojure.string :as str]
            [vlaaad.reveal.stream :as stream])
  (:import [java.sql Connection ResultSet SQLException Statement]
           [clojure.lang ILookup Associative MapEntry IReduceInit]))

(s/def ::table (s/coll-of string? :kind vector?))
(s/def ::column string?)
(s/def ::columns (s/coll-of ::column :kind vector?))
(s/def ::pk ::columns)
(s/def ::src ::columns)
(s/def ::dst ::columns)
(s/def ::ref (s/keys :req-un [::table ::src ::dst]))
(s/def ::refs (s/coll-of ::ref :kind vector?))
(s/def ::fks ::refs)
(s/def ::joins ::refs)
(s/def ::tables (s/map-of ::table (s/keys :req-un [::columns ::pk ::fks ::joins])))
(s/def ::url string?)
(s/def ::catalog string?)
(s/def ::quote string?)
(s/def ::can-limit boolean?)
(s/def ::schema (s/keys :req-un [::url ::quote ::tables ::can-limit]
                        :opt-un [::catalog]))

(s/def ::where string?)
(s/def :vlaaad.reveal.pro.sql.pattern/columns (s/keys :opt-un [::where]))
(s/def :vlaaad.reveal.pro.sql.pattern/joins (s/map-of ::ref :vlaaad.reveal.pro.sql.pattern/node))
(s/def :vlaaad.reveal.pro.sql.pattern/node
  (s/keys :opt-un [:vlaaad.reveal.pro.sql.pattern/columns
                   :vlaaad.reveal.pro.sql.pattern/joins]))

(s/def ::limit pos-int?)
(s/def ::pattern
  (s/tuple (s/keys :req-un [::table ::limit])
           :vlaaad.reveal.pro.sql.pattern/node))

(defn ^Connection connect [conn-fn]
  (conn-fn))

(defn- reducible-result-set [^ResultSet rs]
  (let [accessor (reify
                   ILookup
                   (valAt [this k]
                     (.valAt this k nil))
                   (valAt [_ k not-found]
                     (try
                       (.getObject rs (name k))
                       (catch SQLException _ not-found)))
                   Associative
                   (entryAt [this k]
                     (let [x (.valAt this k ::not-found)]
                       (when-not (= ::not-found x)
                         (MapEntry/create k x)))))]
    (reify IReduceInit
      (reduce [_ f init]
        (loop [acc init]
          (if (.next rs)
            (let [ret (f acc accessor)]
              (if (reduced? ret)
                @ret
                (recur ret)))
            acc))))))

(defn query-schema! [conn-fn]
  (with-open [c (connect conn-fn)]
    (let [cm (.getMetaData c)
          can-limit (boolean (some #{"LIMIT" "limit"} (str/split (.getSQLKeywords cm) #",")))
          catalog (.getCatalog c)
          tables (into []
                       (map #(select-keys % [:TABLE_SCHEM :TABLE_NAME]))
                       (reducible-result-set
                         (.getTables cm catalog nil nil (into-array String ["TABLE"]))))
          columns (into []
                        (comp (mapcat
                                #(reducible-result-set
                                   (.getColumns cm catalog (:TABLE_SCHEM %) (:TABLE_NAME %) nil)))
                              (map #(select-keys % [:TABLE_SCHEM :TABLE_NAME :COLUMN_NAME])))
                        tables)
          pks (into []
                    (comp
                      (mapcat #(reducible-result-set
                                 (.getPrimaryKeys cm catalog (:TABLE_SCHEM %) (:TABLE_NAME %))))
                      (map #(select-keys % [:TABLE_SCHEM :TABLE_NAME :COLUMN_NAME :KEY_SEQ])))
                    tables)
          fks (into []
                    (comp (mapcat #(reducible-result-set
                                     (.getImportedKeys cm catalog (:TABLE_SCHEM %) (:TABLE_NAME %))))
                          (map #(select-keys % [:FKTABLE_SCHEM :FKTABLE_NAME :FK_NAME :FKCOLUMN_NAME
                                                :PKTABLE_SCHEM :PKTABLE_NAME :PKCOLUMN_NAME :KEY_SEQ])))
                    tables)
          table-fn (fn [& keys]
                     (fn [x]
                       (vec (keep x keys))))
          table-name->columns (group-by (table-fn :TABLE_SCHEM :TABLE_NAME) columns)
          table-name->pk (group-by (table-fn :TABLE_SCHEM :TABLE_NAME) pks)
          table-name->fks (group-by (table-fn :FKTABLE_SCHEM :FKTABLE_NAME) fks)
          raw-schema (->> tables
                          (map (table-fn :TABLE_SCHEM :TABLE_NAME))
                          (map
                            (juxt
                              identity
                              (fn [name]
                                {:name name
                                 :columns (->> name table-name->columns (mapv :COLUMN_NAME))
                                 :pk (->> name table-name->pk (sort-by :KEY_SEQ) (mapv :COLUMN_NAME))
                                 :fks (->> name
                                           table-name->fks
                                           (group-by (juxt :PKTABLE_SCHEM :PKTABLE_NAME :FK_NAME))
                                           vals
                                           (mapv
                                             (fn [fk]
                                               (let [sorted (sort-by :KEY_SEQ fk)]
                                                 {:table ((table-fn :PKTABLE_SCHEM :PKTABLE_NAME) (first fk))
                                                  :src (mapv :FKCOLUMN_NAME sorted)
                                                  :dst (mapv :PKCOLUMN_NAME sorted)}))))}))))
          table-name->reverse-fks (->> (for [[fk-table-name {:keys [fks]}] raw-schema
                                             {:keys [table src dst]} fks]
                                         [table {:table fk-table-name
                                                 :src dst
                                                 :dst src}])
                                       (group-by first)
                                       (map (juxt key #(->> % val (mapv second))))
                                       (into {}))]
      (cond->
        {:url (.getURL cm)
         :quote (.getIdentifierQuoteString cm)
         :can-limit can-limit
         :tables (->> raw-schema
                      (map (fn [[k v]]
                             [k (assoc v :joins (into (:fks v) (table-name->reverse-fks k)))]))
                      (into {}))}
        catalog
        (assoc :catalog catalog)))))

(def ^:private get-table-name last)

(defn join-kw [schema src-table join]
  (let [{:keys [src dst] dst-table :table} join
        possible-joins (filter #(= dst-table (:table %))
                               (get-in schema [:tables src-table :joins]))
        include-src (< 1 (count (distinct (map :src possible-joins))))
        include-dst (< 1 (count (distinct (map :dst possible-joins))))]
    (keyword (str (str/join "_" dst-table)
                  (when include-src
                    (str "_from_" (str/join "_and_" src)))
                  (when include-dst
                    (str "_to_" (str/join "_and_" dst)))))))

(defn parse-pattern [schema pattern]
  (let [{:keys [quote]} schema
        escape #(str quote % quote)
        get-table-alias (fn [prefix table]
                          (str prefix "__" (get-table-name table)))
        get-column-alias (fn [prefix column]
                           (str prefix "__" column))
        column-id (fn [prefix table column]
                    (str
                      (escape (get-table-alias prefix table))
                      "."
                      (escape column)))
        alias-column (fn [prefix table column]
                       (str (column-id prefix table column)
                            " AS "
                            (escape (get-column-alias prefix column))))
        alias-table (fn [prefix table]
                      (str (->> table (map escape) (str/join "."))
                           " AS "
                           (escape (get-table-alias prefix table))))
        ret ((fn walk [acc src-prefix [{src-table :table} {:keys [columns joins]}]]
               (let [column-order (into {}
                                        (map-indexed
                                          (fn [i v]
                                            [v i]))
                                        (get-in schema [:tables src-table :columns]))]
                 (reduce
                   (fn [acc [{:keys [src dst] dst-table :table :as join} :as dst-pattern]]
                     (let [dst-prefix (name (gensym "join"))
                           rel-kw (join-kw schema src-table join)
                           child
                           (-> acc
                               (update :on conj
                                       (str/join
                                         " AND "
                                         (map
                                           (fn [src-column dst-column]
                                             (str (column-id src-prefix src-table src-column)
                                                  " = "
                                                  (column-id dst-prefix dst-table dst-column)))
                                           src dst)))
                               (walk dst-prefix dst-pattern))]
                       (-> child
                           (assoc :pull (fn [rs]
                                          (assoc ((:pull acc) rs) rel-kw ((:pull child) rs))))
                           (assoc :columns (conj (:columns acc) {:fn rel-kw
                                                                 :columns (:columns child)})))))
                   (-> acc
                       (update :select into
                               (map #(alias-column src-prefix src-table %)
                                    (or (seq (keys columns))
                                        (get-in schema [:tables src-table :pk]))))
                       (update :where into (for [[col {:keys [where]}] columns
                                                 :when where]
                                             (str (column-id src-prefix src-table col)
                                                  " "
                                                  where)))
                       (update :from conj (alias-table src-prefix src-table))
                       (assoc :pull (let [k->alias
                                          (into
                                            {}
                                            (map (juxt keyword #(get-column-alias src-prefix %)))
                                            (keys columns))]
                                      (fn [rs]
                                        (into {} (map (juxt key #(get rs (val %)))) k->alias))))
                       (assoc :columns (->> columns
                                            (sort-by (comp column-order key))
                                            (mapv (fn [[col {:keys [where]}]]
                                                    (let [kw (keyword col)]
                                                      {:fn kw
                                                       :header
                                                       (apply stream/horizontal
                                                              (concat
                                                                [(stream/stream kw)]
                                                                (when where
                                                                  [stream/separator
                                                                   (stream/raw-string where
                                                                                      {:fill :util})])))
                                                       :sortable false}))))))

                   joins)))
             {:select []
              :from []
              :where []
              :on []}
             "root"
             pattern)]
    {:sql (str
            "SELECT DISTINCT "
            (str/join ", " (:select ret))
            " FROM "
            (str/join " LEFT JOIN "
                      (map
                        (fn [from on]
                          (str from (when on (str " ON " on))))
                        (:from ret)
                        (cons nil (:on ret))))
            (when (seq (:where ret))
              (str " WHERE "
                   (str/join " AND " (:where ret)))))
     :limit (when (:can-limit schema) (get-in pattern [0 :limit]))
     :pull (:pull ret)
     :columns (:columns ret)}))

(defn execute-parsed-pattern! [^Statement stmt {:keys [sql pull limit]}]
  (into []
        (map pull)
        (reducible-result-set
          (.executeQuery stmt (str sql
                                   (when limit
                                     (str " LIMIT " limit)))))))

(defn count-parsed-pattern! [^Statement stmt {:keys [sql]}]
  (first
    (into
      []
      (map :RESULT)
      (-> stmt
          (.executeQuery (str "SELECT COUNT(*) AS RESULT FROM (" sql ") AS subquery"))
          reducible-result-set))))

(defn set-column-clause [pattern path column k v]
  (if (zero? (count path))
    (update-in pattern [1 :columns] (fn [columns]
                                      (cond
                                        (some? v)
                                        (assoc-in columns [column k] v)

                                        (contains? columns column)
                                        (update columns column dissoc k)

                                        :else
                                        columns)))
    (let [[j & js] path]
      (update-in pattern [1 :joins] (fn [m]
                                      (let [e [j (get m j)]
                                            [_ v] (set-column-clause e js column k v)]
                                        (assoc m j v)))))))

(defn switch-column [pattern path column]
  (if (zero? (count path))
    (update-in pattern [1 :columns] (fn [m]
                                      (if (contains? m column)
                                        (dissoc m column)
                                        (assoc m column {}))))
    (let [[k & ks] path]
      (update-in pattern [1 :joins] (fn [m]
                                      (let [e [k (get m k)]
                                            [_ v] (switch-column e ks column)]
                                        (if (and (empty? (:columns v))
                                                 (empty? (:joins v)))
                                          (dissoc m k)
                                          (assoc m k v))))))))

(defn switch-join [pattern path columns]
  (if (= 1 (count path))
    (let [join (first path)]
      (update-in pattern [1 :joins] (fn [m]
                                      (if (contains? m join)
                                        (dissoc m join)
                                        (assoc m join {:columns (zipmap columns
                                                                        (repeat {}))})))))
    (let [[j & js] path]
      (update-in pattern [1 :joins] (fn [m]
                                      (let [e [j (get m j)]
                                            [_ v] (switch-join e js columns)]
                                        (if (and (empty? (:columns v))
                                                 (empty? (:joins v)))
                                          (dissoc m j)
                                          (assoc m j v))))))))