(ns webnf.base.unroll
  (:require [clojure.tools.logging :as log]))

(defmacro defunrolled
  "A macro to define unrolled functions. Arities should return
   implementation body. In addition to arities, this macro takes a
   couple of options:

  - :min          -- first arity, corresponds to minimum number of args. default 0
  - :unroll       -- last non-vararg arity, unrolled from definition
  - :more-arities -- vararg arity, to generate a body for a vararg catch-all arity
                     this takes a single collection argument with the varargs, but,
                     if that vararg is a vector destructuring, the intermediate apply can be optimized away
   Example:
    (defunrolled comp*
      :unroll 3
      :more-arities ([[f g h & args]] (apply comp* (comp* f g h)
                                         (for [pa# (partition-all 3 args)]
                                           (apply comp* pa#))))
      ([] identity)
      ([f] f)
      ([f & fs] (let [arg (gensym \"arg-\")]
              `(fn* [~arg] ~(reduce #(list %2 %1) arg (reverse (cons f fs)))))))"
  {:arglists '([self & [{:keys [min unroll more-arities] :or {min 0 unroll 8}} & macro-arities]]
               [self & macro-arities])}
  [self & flags+arities]
  (let [{:keys [min unroll more-arities doc gen-fn]
         :or {min 0 unroll 8}}
        (loop [flags {} [f v :as fas] flags+arities]
          (cond
            (string? f)
            (recur (assoc flags :doc f) (next fas))
            (keyword? f)
            (recur (assoc flags f v) (nnext fas))
            :else
            (assoc flags :gen-fn (eval (cons 'fn* fas)))))
        fixed (vec (repeatedly min #(gensym "a-")))
        vars (repeatedly unroll #(gensym "v-"))]
    `(defn ~self ~@(when doc [doc])
       ~@(for [n-v (range (inc unroll))
               :let [args (into fixed (take n-v vars))]]
           (list args (apply gen-fn args)))
       ~(when-let [[[ma-arg :as ma-args] & ma-body] more-arities]
          (assert (= 1 (count ma-args)))
          (if (and (vector? ma-arg)
                   (= (count ma-arg) (+ 2 unroll))
                   (= '& (-> ma-arg butlast last)))
            (do
              (assert (every? symbol? ma-arg))
              (cons ma-arg ma-body))
            (let [vararg (gensym "va-")]
              (when (vector? ma-arg)
                (log/warn self :more-arities "destructuring needs to take" :unroll (list '= unroll)
                          "args + a rest arg. (has" 
                          " Falling back on generic vararg repackaging"))
              (list (-> fixed
                        (into vars)
                        (into ['& vararg]))
                    `(let [~ma-arg (list* ~@fixed ~@vars ~vararg)]
                       ~@ma-body))))))))

(comment
 (defunrolled comp*
   :unroll 3
   :more-arities ([[f g h & args]] (apply comp* (comp* f g h)
                                          (for [pa# (partition-all 3 args)]
                                            (apply comp* pa#))))
   ([] identity)
   ([f] f)
   ([f & fs] (let [arg (gensym "arg-")]
               `(fn* [~arg] ~(reduce #(list %2 %1) arg (reverse (cons f fs)))))))

 ((comp* inc dec inc dec inc) 5))
