(ns farbetter.freedomdb.schemas
  (:require
   [clojure.set :refer [union]]
   [schema.core :as s :include-macros true]))

(declare Expression FieldName JoinExpressionOrClause
         OrderByDirection)

;;;;;;;;; Helper fns ;;;;;;;;;

(defn op-matches? [op-type]
  #(nil? (s/check op-type (first %))))

(defn valid-order-by? [expr]
  (and (sequential? expr)
       (even? (count expr))
       (loop [[field-name direction & rest] expr
              result false]
         (if (nil? direction)
           result
           (if (and (nil? (s/check FieldName field-name))
                    (nil? (s/check OrderByDirection direction)))
             (recur rest true)
             false)))))

;;;;;;;;; Frontend Schemas ;;;;;;;;;

(def Value (s/conditional
            integer? s/Int
            string? s/Str
            keyword? s/Keyword
            true? s/Bool
            false? s/Bool
            (constantly true) s/Any))
(def FieldType (s/enum :int4 :int8 :num :str1000 :kw :bool :any))
(def TableName s/Keyword)
(def TableNameOrTableNameSeq (s/if sequential?
                               [TableName]
                               TableName))
(def FieldName s/Keyword)
(def NotOperator (s/eq :not))
(def EqualityOperator (s/eq :=))
(def InequalityOperator (s/enum :< :> :<= :>=))
(def CombinationOperator (s/enum :and :or))
(def JoinType (s/enum :inner :left :right :full
                      :left-outer :right-outer :full-outer))
(def JoinExpression
  [(s/one EqualityOperator :=)
   (s/one TableName :left-table)
   (s/one TableName :right-table)])
(def CombinationJoin
  [(s/one CombinationOperator :combination-operator)
   (s/recursive #'JoinExpressionOrClause)])
(def NotJoin
  (s/pair NotOperator :not-operator
          (s/recursive #'JoinExpressionOrClause) :clause))
(def JoinExpressionOrClause
  (s/conditional
   (op-matches? EqualityOperator) JoinExpression
   (op-matches? CombinationOperator) CombinationJoin
   (op-matches? NotOperator) NotJoin))
(def JoinMap
  {(s/required-key :on) JoinExpression
   (s/required-key :type) JoinType})
(def JoinClause (s/conditional
                 map? JoinMap
                 sequential? JoinExpressionOrClause
                 ;;sequential? JoinExpression
                 ))
(def ValueMap {FieldName Value})
(def AsKeyword (s/eq :as))
(def AggregateFnKeyword (s/enum :sum :count :min :max))
(def Aggregate [(s/one AggregateFnKeyword :agg-fn)
                (s/one FieldName :field-name)
                (s/one AsKeyword :as-keyword)
                (s/one FieldName :name)])
(def FieldAttrsMap
  {(s/required-key :type) FieldType
   (s/optional-key :indexed) s/Bool
   (s/optional-key :default) Value})
(def ModifyFieldAttrsMap
  {(s/optional-key :type) FieldType
   (s/optional-key :indexed) s/Bool
   (s/optional-key :default) Value})
(def FieldsMap {FieldName FieldAttrsMap})
(def RowId s/Num)
(def RowIdsOrAll (s/if keyword? (s/eq :all) #{RowId}))
(def FieldNamesOrAll (s/if keyword? (s/eq :all) #{FieldName}))
(def OrderByDirection (s/enum :asc :desc))
(def OrderByClause (s/pred valid-order-by?))
(def BinaryOperator (s/if #(= := %)
                      EqualityOperator
                      InequalityOperator))
(def ScanFnOperator (s/enum :scan-fn))
(def ScanFn (s/=> s/Bool & [Value]))
(def NotExpression (s/pair NotOperator :not-operator
                           (s/recursive #'Expression) :expression))
(def BinaryExpression
  [(s/one BinaryOperator :binary-operator)
   (s/one FieldName :field-name)
   (s/one Value :argument)])
(def CombinationExpression
  [(s/one CombinationOperator
          :combination-operator)
   (s/recursive #'Expression)])
(def ScanFnExpression [(s/one ScanFnOperator :scan-fn-operator)
                       (s/one ScanFn :scan-function)
                       FieldName])
(def Expression
  (s/conditional
   (op-matches? NotOperator) NotExpression
   (op-matches? BinaryOperator) BinaryExpression
   (op-matches? CombinationOperator) CombinationExpression
   (op-matches? ScanFnOperator) ScanFnExpression))
(def ExpressionOrValueMap (s/if map?
                            ValueMap
                            Expression))
(def SelectQuery
  {(s/optional-key :fields) (s/maybe
                             (s/if sequential?
                               [FieldName]
                               FieldName))
   (s/optional-key :join) JoinClause
   (s/optional-key :where) ExpressionOrValueMap
   (s/optional-key :aggregate) Aggregate
   (s/optional-key :order-by) OrderByClause})
(def SelectOneReturn (s/maybe
                      (s/conditional
                       map? ValueMap
                       sequential? [Value]
                       :else Value)))
(def SelectReturn (s/maybe [SelectOneReturn]))
(def UpdateQuery
  {(s/required-key :set) ValueMap
   (s/optional-key :where) ExpressionOrValueMap})
(def TypeMap {FieldName FieldType})
(def DBTableMetadata
  {(s/required-key :table-name) TableName
   (s/required-key :all-fields) #{FieldName}
   (s/required-key :indexed-fields) #{FieldName}
   (s/required-key :type-map) TypeMap
   (s/required-key :defaults) {FieldName Value}})
(def DB (s/pred associative?))
(def RowStoreType (s/enum :mem :durable-rs-mem))

;;;;;;;;; Backend Schemas ;;;;;;;;;

(def RowIdValueMapPair [(s/one RowId :row-id)
                        (s/one ValueMap :value-map)])
(def GetAllRowsOutputStyle (s/enum :row-ids-only
                                   :value-maps-only
                                   :row-ids-and-value-maps))
(def GetAllRowsReturn [(s/conditional
                        sequential? RowIdValueMapPair
                        number? RowId
                        map? ValueMap)])

;;;;;;;;;;;;;;;;;;;; Transit Schemas ;;;;;;;;;;;;;;;;;;;;

(def Tag s/Str)
(def Type s/Any)
(def Handler s/Any)
(def TransitWriteHandlersMap {Type Handler})
(def TransitReadHandlersMap {Tag Handler})
