(ns co.multiply.scoped.impl
  (:import
    [clojure.lang Counted IDeref IEditableCollection ITransientAssociative ITransientCollection Indexed Var$Unbound]
    [java.lang Runtime$Version]
    [java.util Map]))


(def ^:private use-scoped-value
  (and (not= (System/getProperty "co.multiply.scoped.force-fallback") "true")
    (>= (Runtime$Version/.feature (Runtime/version)) 25)))


(defmacro create-carrier
  "Instantiates a new carrier object.

   On JDK 25+, this is a ScopedValue instance. On older JDKs, this is a ThreadLocal."
  []
  (if use-scoped-value
    `(java.lang.ScopedValue/newInstance)
    `(java.lang.ThreadLocal/withInitial (constantly {}))))


(defonce ^{:doc "The carrier holding the current scope map."}
  carrier (create-carrier))


(defmacro current-scope
  "Returns the current scope map, or an empty map if no scope is active."
  []
  (if use-scoped-value
    `(java.lang.ScopedValue/.orElse carrier {})
    `(java.lang.ThreadLocal/.get carrier)))


(defn get-scoped-var
  "Retrieve a scoped value, falling back to the var's root binding.

   If the var is in the current scope, returns the scoped value.
   If not in scope, returns the var's current value.
   If the var is unbound and not in scope, throws IllegalStateException.

   This is the runtime implementation for the `ask` macro."
  [v]
  (let [scope (current-scope)
        value (Map/.getOrDefault scope v ::not-found)]
    (if (identical? ::not-found value)
      ;; No value in the given scope; attempt to use default value.
      (let [value (IDeref/.deref v)]
        (if (instance? Var$Unbound value)
          (throw (IllegalStateException. (str "Unbound: " v)))
          value))
      ;; A value is available in the scope.
      value)))


(defn- convert-var
  "Convert even-indexed elements (keys) to var references at compile time."
  [idx sym]
  (if (even? idx)
    (or (resolve sym) (throw (IllegalArgumentException. (str "Cannot resolve: " sym))))
    sym))


(defn resolve-bindings
  "Resolve symbols to vars in a bindings vector at compile time.

   Transforms `[sym1 val1 sym2 val2 ...]` into `[#'sym1 val1 #'sym2 val2 ...]`."
  [bindings]
  (into [] (map-indexed convert-var) bindings))


(defn extend-scope
  "Extend a scope map with new bindings.

   Takes a scope map and a flat vector of `[var1 val1 var2 val2 ...]` and returns
   a new scope map with these bindings added."
  [scope bindings]
  (let [binding-count (Counted/.count bindings)]
    (loop [scope (IEditableCollection/.asTransient scope)
           k-idx (unchecked-int 0)]
      (if (< k-idx binding-count)
        (let [v-idx (unchecked-inc-int k-idx)]
          (recur
            (ITransientAssociative/.assoc scope (Indexed/.nth bindings k-idx) (Indexed/.nth bindings v-idx))
            (unchecked-inc-int v-idx)))
        (ITransientCollection/.persistent scope)))))


(defn extend-current-scope
  "Merge new bindings into the current scope.

   Takes a flat vector of `[var1 val1 var2 val2 ...]` and returns a new scope map
   with these bindings added to (or overriding) the current scope."
  [bindings]
  (extend-scope (current-scope) bindings))



(defmacro with-scope
  "Execute body with a pre-built scope map. Returns the value of body."
  [scope & body]
  (if use-scoped-value
    `(-> (java.lang.ScopedValue/where carrier ~scope)
       (java.lang.ScopedValue$Carrier/.call
         (fn scope-call# [] ~@body)))
    `(let [prev# (java.lang.ThreadLocal/.get carrier)]
       (try (java.lang.ThreadLocal/.set carrier ~scope)
         ~@body
         (finally
           (java.lang.ThreadLocal/.set carrier prev#))))))
