(ns freebie.interpreter
  (:require [clojure.core.match :refer [match]]
            [clojure.set :as set]
            [clojure.string :as str]
            [freebie.types :refer [eval-instr]]
            [schema.core :as s]))

(defn get-state-from-call [output current-state]
  (if (and (::return (meta output))
           (vector? output))
    (first output)
    current-state))

(defn get-result-from-call [output]
  (if (and (::return (meta output))
           (vector? output))
    (second output)
    output))

(defn validate-interpreter-is-total
  [env interpreter-fns]
  (let [defined-functions0 (set (mapv first interpreter-fns))
        all-functions     (set (keys (get-in env [:interpreter-domain :domains-fn-meta])))
        has-default?      (contains? defined-functions0 'default)
        defined-functions (set/select #(not= 'default %) defined-functions0)

        missing-functions (set/difference all-functions defined-functions)
        unknown-functions (set/difference defined-functions all-functions)]
    (cond

      (and (empty? missing-functions)
           has-default?)
      (throw
       (Exception.
        (str "Interpreter `" (:interpreter-name env)
             "` has unreachable `default` declaration")))

      ;;
      (and (not (empty? unknown-functions)))
      (throw
       (Exception.
        (str "Declarations on interpreter `" (:interpreter-name env)
             "` are not defined on domain `" (:interpreter-domain-name env) "`: "
             (str/join ", " unknown-functions) "; available declarations are: "
             (str/join ", " all-functions))))

      ;;
      (and (not (empty? missing-functions))
           (not has-default?))
      (throw
       (Exception.
        (str "Declarations from domain `" (:interpreter-domain-name env)
             "` are missing on interpreter `" (:interpreter-name env) "`: "
             (str/join ", " missing-functions)))))))


(defn is-next-interpreter-called?
  [{:keys [interpreter-next-interpreter-sym] :as env} fn-body]
  (cond
    (and (list? fn-body)
         (= (first fn-body) interpreter-next-interpreter-sym))
    true

    (list? fn-body)
    (some #(is-next-interpreter-called? env %) fn-body)

    :else false))

(defn mk-interpreter-domain-fn-cond
  [{:keys [interpreter-domain
           interpreter-name
           interpreter-args-sym
           interpreter-state-sym
           interpreter-recur-always-needed
           interpreter-free-value-sym
           interpreter-next-interpreter-sym] :as env}
   [fn-name fn-args & fn-body]]


  (when-not (symbol? fn-name)
    (throw
     (Exception.
      (str "Definition on interpreter `" interpreter-name
           "` must start with a symbol; found instead: " fn-name))))

  (when-not (vector? fn-args)
    (throw
     (Exception.
      (str "Second position of definition `" fn-name
           "` on interpreter `" interpreter-name
           "` must be a vector; found instead: " fn-name))))

  (let [{:keys [fn-pascal-name fn-functor-arg]
         domain-fn-args :fn-args}
        (get-in interpreter-domain [:domains-fn-meta fn-name])

        next-state-sym  (gensym (str interpreter-name "-next-state"))
        return-val-sym  (gensym (str interpreter-name "-return-val"))
        new-options-sym (gensym (str interpreter-name "-new-options"))
        new-program-sym (gensym (str interpreter-name "-new-program"))
        user-output-sym (gensym (str interpreter-name "-user-output"))


        _
        (when-not (= (count fn-args) (count domain-fn-args))
          (throw
           (Exception.
            (str "Definition `" fn-name "` on interpreter `" interpreter-name "` expects "
                 (count domain-fn-args) " argument(s), got " (count fn-args) " instead."))))


        let-bindings
        (vec (apply
              concat
              (for [[arg domain-arg] (map vector fn-args domain-fn-args)]
                [arg (list (keyword domain-arg)
                           interpreter-free-value-sym)])))]



    `(
      ;; query
      '~fn-pascal-name

      ;; body
      ~(if (and (not interpreter-recur-always-needed)
                (is-next-interpreter-called? env fn-body))
         ;; then
         (if (nil? ~interpreter-next-interpreter-sym)
           (throw
            (Exception.
             (str "`" ~interpreter-next-interpreter-sym "` is being explicitly used"
                  " in interpreter, but its value is nil. "
                  " Please invoke `" ~interpreter-name "` with a `:next-interpreter` argument.")))
           `(do fn-body))

         ;; else
         `(let [~user-output-sym
                (let ~let-bindings ~@fn-body)

                ~next-state-sym
                (get-state-from-call ~user-output-sym ~interpreter-state-sym)

                ~return-val-sym
                (get-result-from-call ~user-output-sym)

                ~new-program-sym
                ~(if (= fn-functor-arg 'next)
                   `(:next ~interpreter-free-value-sym)
                   `((:f-next ~interpreter-free-value-sym)
                     ~return-val-sym))

                ~new-options-sym
                (merge ~interpreter-args-sym
                       {:state ~next-state-sym})]

            (if (nil? ~interpreter-next-interpreter-sym)
              (recur ~new-program-sym ~new-options-sym)
              (~interpreter-next-interpreter-sym ~new-program-sym ~new-options-sym)))))))

(defn mk-interpreter-default-fn-cond
  [{:keys [interpreter-name
           interpreter-program-sym] :as env}
   [fn-name fn-args & fn-body]]
  (match [fn-args]
         [[program-sym]]
         `((let [~program-sym ~interpreter-program-sym]
             ~@fn-body))
         :else
         (throw
          (Exception.
           (str "`default` declaration on interpreter `" interpreter-name "`"
                " must have a program argument. (e.g. `(default [program])`)")))))

(defn build-interpreter
  [caller domain-name interpreter-name interpreter-ctx interpreter-fns]

  (when-not (symbol? interpreter-name)
    (throw
     (Exception.
      (str "First argument of `" caller "` must be a symbol; found instead: " interpreter-name))))

  (when-not (symbol? domain-name)
    (throw
     (Exception.
      (str "Second argument of `" caller "` must be a symbol; found instead: " domain-name))))

  (when-not (vector? interpreter-ctx)
    (throw
     (Exception.
      (str "Third argument of `" caller "` must be a vector; found instead: " interpreter-ctx))))

  (let [domain (eval domain-name)

        env0  (match [interpreter-ctx]

                     [[state-name next-interpreter-name]]
                     {:interpreter-state-sym state-name
                      :interpreter-next-interpreter-sym next-interpreter-name
                      :interpreter-return-state true
                      :interpreter-recur-always-needed false
                      :interpreter-ctx [state-name next-interpreter-name]
                      }

                     [[state-name]]
                     {:interpreter-state-sym state-name
                      :interpreter-return-state true
                      :interpreter-ctx [state-name]
                      }

                     [[]]
                     {})

        env
        (merge {:interpreter-domain               domain
                :interpreter-domain-name          domain-name
                :interpreter-name                 interpreter-name
                :interpreter-return-state         false
                :interpreter-recur-always-needed  true
                :interpreter-args-sym             (gensym (str interpreter-name "-args"))
                :interpreter-state-sym            (gensym (str interpreter-name "-state"))
                :interpreter-program-sym          (gensym (str interpreter-name "-program"))
                :interpreter-next-interpreter-sym (gensym (str interpreter-name "-next-interpreter"))
                :interpreter-free-value-sym       (gensym (str interpreter-name "-free-value"))
                :interpreter-free-type-sym        (gensym (str interpreter-name "-free-type"))
                }
               env0)]

    (validate-interpreter-is-total env interpreter-fns)

    `(fn ~interpreter-name
       ([~(:interpreter-program-sym env)]
        (~interpreter-name ~(:interpreter-program-sym env) {}))
       ([~(:interpreter-program-sym env)
         ~(:interpreter-args-sym env)]
        (let [~(:interpreter-state-sym env)
              (:state ~(:interpreter-args-sym env))

              ~(:interpreter-program-sym env)
              (eval-instr ~(:interpreter-program-sym env))

              ~(:interpreter-next-interpreter-sym env)
              (:next-interpreter ~(:interpreter-args-sym env))]

          ~(when (>= (count (:interpreter-ctx env)) 1)
             `(if (nil? ~(:interpreter-state-sym env))
                (throw
                 (Exception.
                  (str "Interpreter `" '~interpreter-name
                       "` requires a `:state` argument to run.")))))


          (condp = (:type ~(:interpreter-program-sym env))
            :pure
            ~(if (:interpreter-return-state env)
               `{:state ~(:interpreter-state-sym env)
                 :result (:functor-value ~(:interpreter-program-sym env))}
               `(:functor-value ~(:interpreter-program-sym env)))

            :free
            (let [~(:interpreter-free-value-sym env)
                  (:functor-value ~(:interpreter-program-sym env))

                  ~(:interpreter-free-type-sym env)
                  (-> ~(:interpreter-free-value-sym env)
                      type
                      .getName
                      (str/split #"\.")
                      last
                      symbol)]
              (condp = ~(:interpreter-free-type-sym env)
                ~@(apply
                   concat
                   (mapv (fn [interpreter-fn]
                           (if (= (first interpreter-fn) 'default)
                             (mk-interpreter-default-fn-cond env interpreter-fn)
                             (mk-interpreter-domain-fn-cond env interpreter-fn)))
                         interpreter-fns))))

            ;; else
            (throw
             (Exception.
              (str "Interpreter `" '~interpreter-name
                   "` received an invalid program argument: "
                   (pr-str ~(:interpreter-program-sym env))))))))
       )))


(defn return-with-state
  ([new-st]
   (return-with-state new-st nil))
  ([new-st result]
   (with-meta [new-st result] {::return true})))

(defmacro definterpreter
  [interpreter-name domain-name interpreter-ctx & interpreter-fns]
  `(def ~interpreter-name
     ~(build-interpreter "definterpreter" domain-name interpreter-name interpreter-ctx interpreter-fns)))

(defmacro interpreter*
  [interpreter-name domain-name interpreter-ctx & interpreter-fns]
  (build-interpreter "interpreter*" domain-name interpreter-name interpreter-ctx interpreter-fns
   ) )

(comment
  ;; An Example
  (defdomain console-domain
    (read-input [] :- s/Str)
    (print-output [output :- s/Str]))

  (def read-and-print
    (mdo
     msg <- (read-input)
     (print-output msg)))

  (definterpreter simple-console
    console-domain
    [state next-interpreter]

    (print-output [output]
                  (return-with-state (conj state :printed)
                                     (println output)))
    (default [program]
      (println "failing> " program))))
