(ns orcl.naive
  (:require [orcl.naive.impl :as impl]
            [orcl.naive.vars :as vars]
            [orcl.naive.lib]
    #?(:cljs [cljs.js])
            [orcl.utils :as utils]
            [orcl.naive.lib :as lib])

  (:refer-clojure :exclude [compile]))

(defn pattern-bindings [p]
  (case (:type p)
    :var [(symbol (:var p))]
    :wildcard []
    :const []
    :record (mapcat (comp pattern-bindings second) (:pairs p))
    (:list :tuple) (mapcat pattern-bindings (:patterns p))
    :as (cons (symbol (:alias p))
              (pattern-bindings (:pattern p)))
    :cons (concat (pattern-bindings (:head p)) (pattern-bindings (:tail p)))))

(defn lenient-binding [p]
  (case (:type p)
    :var (symbol (:var p))
    :wildcard '_))

(defn sequential [pattern left right]
  (if (impl/strict-pattern? pattern)
    `(impl/sequential
       (impl/sequential ~left (fn [x#] (impl/call lib/extract-pattern [~pattern x#])))
       (fn [[~@(pattern-bindings pattern)]] ~right))
    `(impl/sequential ~left (fn [~(lenient-binding pattern)] ~right))))

(defn pruning [pattern left right]
  (if (impl/strict-pattern? pattern)
    `(impl/pruning ~pattern
                   (fn [[~@(pattern-bindings pattern)]] ~left)
                   (impl/sequential ~right (fn [x#] (impl/call lib/check-pattern [~pattern x#]))))
    `(impl/pruning ~pattern (fn [[~(lenient-binding pattern)]] ~left) ~right)))

(defn parallel [left right]
  `(impl/parallel ~left ~right))

(defn otherwise [left right]
  `(impl/otherwise ~left ~right))

(defn conditional [x then else]
  `(impl/parallel (impl/sequential (impl/call lib/Ift [~x]) (fn [v#] ~then))
                 (impl/sequential (impl/call lib/Iff [~x]) (fn [_#] ~else))))



(defn defs-group [defs expr]
  `(letfn [~@defs]
     ~expr))

(declare compile)

(defn compile-instance [{:keys [params body]}]
  {:params params
   :body   `(fn [~@(map (comp vec pattern-bindings) params)] ~(compile body))})

(defn compile-def [{:keys [name arity instances]}]
  (let [bindings (repeatedly arity gensym)]
    `(~(symbol name) [~@bindings]
       ((impl/function [~@(map compile-instance instances)]) ~@bindings))))

(defn compile-prelude-site [s]
  (case (:definition s)
    `(get @vars/prelude ~(:definition s))))

(defn compile-import-site [s]
  (utils/todo-exception))

(defn compile-primitive [ast]
  (case (:node ast)
    :const (:value ast)
    :var (let [s (:source ast)]
           (case (:type s)
             :site (if (= :prelude (:type (:source s)))
                     (compile-prelude-site s)
                     (compile-import-site s))
             (:def :sequential :pruning :argument) (symbol (:var ast))))))

(defn compile [ast]
  (case (:node ast)
    :pruning (pruning (:pattern ast) (compile (:left ast)) (compile (:right ast)))
    :sequential (sequential (:pattern ast) (compile (:left ast)) (compile (:right ast)))
    :otherwise (otherwise (compile (:left ast)) (compile (:right ast)))
    :parallel (parallel (compile (:left ast)) (compile (:right ast)))
    :conditional (conditional (compile-primitive (:if ast)) (compile (:then ast)) (compile (:else ast)))
    :defs-group (defs-group (map compile-def (:defs ast)) (compile (:expr ast)))
    :call `(impl/call ~(compile-primitive (:target ast)) [~@(map compile-primitive (:args ast))])
    :tuple `(impl/call lib/make-tuple [~@(map compile-primitive (:values ast))])
    :list `(impl/call lib/make-list [~@(map compile-primitive (:values ast))])
    :record `(impl/call (lib/make-record [~@(map first (:pairs ast))]) [~@(map (comp compile-primitive second) (:pairs ast))])
    :field-access `(impl/call lib/field-access [~(compile-primitive (:target ast)) ~(:field ast)])
    :stop `impl/stop
    :const `(impl/constant ~(compile-primitive ast))
    :var `(impl/call lib/Let [~(compile-primitive ast)])))

#?(:cljs
   (defn cljs-eval [program]
         (prn "---EVAL" program)
         (cljs.js/eval (cljs.js/empty-state)
                       program
                       {:eval cljs.js/js-eval
                        ;:ns 'cljs.user
                        :def-emits-var true}
                       (fn [res] (:value res)))))

(defn make-res [values prev-coeffects]
  (let [killed (keep (fn [[k channel]] (when-not (impl/open? @channel) k)) prev-coeffects)]
    {:values           @values
     :coeffects        (map (fn [[k {:keys [definition]}]] [k definition]) @impl/*coeffects*)
     :killed-coeffects killed
     :state            {:values    values
                        :coeffects (into {} (concat (map (fn [[k {:keys [channel]}]] [k channel]) @impl/*coeffects*)
                                                    (apply dissoc prev-coeffects killed)))}}))

(defn run [program]
   (let [values (atom [])]
     (binding [impl/*coeffects* (atom {})]
       (let [res (#?(:clj eval :cljs cljs-eval) program)]
         (trampoline res (impl/channel (fn [_ v] (swap! values conj v)) (fn [])))
         (impl/execution-loop)
         (make-res values {})))))

(defn unblock [{:keys [values coeffects]} coeffect value]
  (reset! values [])
  (binding [impl/*coeffects* (atom {})]
    (impl/write @(get coeffects coeffect) value)
    (impl/execution-loop)
    (make-res values (dissoc coeffects coeffect))))

