(ns materia.middleware-set
  (:require [clojure.tools.namespace.dependency :as dep]
            [ring.util.request :as req]
            [ronda.routing :as ronda])
  (:import [clojure.tools.namespace.dependency
            DependencyGraphUpdate DependencyGraph]))

(defprotocol Wrappable
  (wrap [this h]))

(defprotocol Identifiable
  (get-key [this]))

(extend-protocol Wrappable
  clojure.lang.AFn
  (wrap [this h]
    (this h)))

(extend-protocol Identifiable
  clojure.lang.AFn
  (get-key [this]
    this)
  clojure.lang.Keyword
  (get-key [this]
    this))

(defrecord Middleware [tag f args]
  Wrappable
  (wrap [this h]
    (apply f h args))

  Identifiable
  (get-key [this]
    tag))

(defrecord DisabledMiddleware [tag f args]
  Wrappable
  (wrap [this h]
    h)

  Identifiable
  (get-key [this]
    tag))

(defrecord ConditionalMiddleware [pred-fn tag f args]
  Wrappable
  (wrap [this h]
    (let [wrapped (apply f h args)]
      (fn conditional-handler [req]
        (if (pred-fn req)
          (wrapped req)
          (h req)))))

  Identifiable
  (get-key [this]
    tag))

(defprotocol MiddlewareSet
  (add [this middleware] [this middleware deps])
  (get-all [this]))

(defrecord Middlewares [m dep-graph]
  DependencyGraph
  (immediate-dependencies [this node]
    (dep/immediate-dependencies dep-graph node))
  (immediate-dependents [this node]
    (dep/immediate-dependents dep-graph node))
  (transitive-dependencies [this node]
    (dep/transitive-dependencies dep-graph node))
  (transitive-dependents [this node]
    (dep/transitive-dependents dep-graph node))
  (nodes [this]
    (dep/nodes dep-graph))

  DependencyGraphUpdate
  (depend [this node dep]
    (Middlewares. m (dep/depend dep-graph (get-key node) (get-key dep))))
  (remove-edge [this node dep]
    (Middlewares. m (dep/remove-edge dep-graph (get-key node) (get-key dep))))
  (remove-all [this node]
    (Middlewares. m (dep/remove-all dep-graph (get-key node))))
  (remove-node [this node]
    (Middlewares. m (dep/remove-node dep-graph (get-key node))))

  Wrappable
  (wrap [this h]
    (reduce (fn [acc middleware]
              (if (satisfies? Wrappable middleware)
                (wrap middleware acc)
                acc))
            h
            (reverse (get-all this))))

  MiddlewareSet
  (add [this middleware]
    (add this middleware []))
  (add [this middleware deps]
    (let [key  (get-key middleware)
          deps (conj deps ::root)] ; Should apply also independent middlewares
      (Middlewares. (assoc m key middleware)
                    (reduce #(dep/depend %1 key %2) dep-graph deps))))
  (get-all [this]
    (filter some? (map #(get m %) (dep/topo-sort dep-graph)))))

(defn middlewares []
  (->Middlewares {} (dep/graph)))

(defn middleware [tag f & args]
  (->Middleware tag f args))

(defn disabled-middleware [tag f & args]
  (->DisabledMiddleware tag f args))

(defn conditional-middleware [pred-fn tag f & args]
  (->ConditionalMiddleware pred-fn tag f args))

(defn routed-middleware [re-or-endpoint tag f & args]
  {:pre [(or (instance? java.util.regex.Pattern re-or-endpoint)
             (keyword? re-or-endpoint))]}
  (apply conditional-middleware
         (cond
           (instance? java.util.regex.Pattern re-or-endpoint)
           #(when-let [s (req/path-info %)]
              (re-matches re-or-endpoint s))
           (keyword?  re-or-endpoint)
           #(= (ronda/endpoint %) re-or-endpoint))
         tag f args))

(defn add-middleware
  ([bag f]
   (add bag f))
  ([bag f deps]
   (add bag f deps))
  ([bag key f deps]
   (add bag (middleware key f) deps))
  ([bag key f deps & args]
   (add bag (apply middleware key f args) deps)))

(defn append [bag m]
  (add bag m (dep/nodes bag)))

(defn prepend [bag m]
  (reduce (fn [acc k]
            (dep/depend acc (get-key k) m))
          (add bag m)
          (get-all bag)))

(defn after [bag anchor m]
  (-> bag
      (add m)
      (dep/depend m anchor)))

(defn before [bag anchor m]
  (-> bag
      (add m)
      (dep/depend anchor m)))

(defn between [bag from to m]
  (-> bag
      (add m)
      (dep/depend m from)
      (dep/depend to m)))

(defn add-linearly [bag ms & [init-deps]]
  (->> ms
       (mapv (comp vector get-key))
       (cons init-deps)
       (map vector ms)
       (reduce (fn [acc [m deps]]
                 (add acc m deps))
               bag)))

(defn disable [m]
  (if (instance? DisabledMiddleware m)
    m
    (with-meta (map->DisabledMiddleware {:tag (get-key m)}) {::disabled m})))

(defn enable [m]
  (if-let [origin (::disabled (meta m))]
    origin
    (map->Middleware m)))

(defn disable-middlewares [ms tags]
  (reduce (fn [ms' tag] (update-in ms' [:m tag] disable)) ms tags))

(defn enable-middlewares [ms tags]
  (reduce (fn [ms' tag] (update-in ms' [:m tag] enable)) ms tags))
