(ns com.kurogitsune.ccgjp
	(:require [clojure.core.match :refer [match]]
						[serializable.fn :as s]))

(defn is-compatible [a b]
	(let [result (match [a b]
		[[an (spec-a :guard set?)] [bn (spec-b :guard set?)]] (and (is-compatible an bn) (clojure.set/subset? spec-a spec-b))
		[[an (spec-a :guard set?)] bn] (is-compatible an bn)
		[(va :guard sequential?) (vb :guard sequential?)] (every? (fn [x] (is-compatible (first x) (second x))) (map vector va vb))
		[_ (_ :guard (fn [x] (= a b)))] true
		[_ (_ :guard (fn [x] (and (map? a) (contains? a "T"))))] true
		[_ (_ :guard (fn [x] (and (map? b) (contains? b "T"))))] true
		[_ _] false
		)]
		result))

(defn extract-t [typed a]
	(match [typed a]
		[(vt :guard sequential?) (va :guard sequential?)] (apply merge (filter some? (map (fn [x] (extract-t (first x) (second x))) (map vector vt va))))
		[et ea] (if (and (map? et) (contains? et "T")) (let [] {et ea}))))

(defn count-t [typed]
	(match [typed]
		[(v :guard (fn [x] (and (sequential? x) (not (empty? x)))))] (apply max (map (fn [x] (count-t x)) v))
		[e] (if (and (map? e) (contains? e "T")) (get e "T") 0)))

(defn replaced-t [t n]
	(match [t]
		[{"T" c}] (if (<= c n) {"T" (+ n c)} t)
		[(m :guard map?)] (into {} (map (fn [x] [(first x) (replaced-t (second x) n)]) m))
		[(v :guard sequential?)] (vec (map (fn [x] (replaced-t x n)) v))
		[_] t
	))

(defn complete-t [typed t]
	(let [t2 (replaced-t t (count-t typed))] ;; 代入先が被らないようにカウントアップ
		(match [typed]
			[(v :guard sequential?)] (vec (map (fn [x] (complete-t x t2)) v))
			[e] (if (and (map? e) (contains? e "T") (some? (get t2 e))) (get t2 e) e))))

(defn type-lifted-normal [a] (let [t {"T" (+ 1 (count-t a))}] [t "/" [t "\\" a]]))
(defn type-lifted-reversed [a] (let [t {"T" (+ 1 (count-t a))}] [t "\\" [t "/" a]]))
															
(defn is-apply-func [s]
	(match [s]
		[[[p1 [a "/" b] f1] [p2 c f2]]] (is-compatible c b) ;; 関数適用規則
		[[[p1 a f1] [p2 [b "\\" c] f2]]] (is-compatible a c)
		:else false))

(defn is-merge-func [s]
	(match [s]
		[[[p1 [a "/" b] f1] [p2 args f2]]] ;; 一般化関数合成規則
		(and 
			(let [seps (take-nth 2 (rest args))] (apply = "/" seps))
			(is-compatible b (first args)))
		[[[p1 args f1] [p2 [a "\\" b] f2]]]
		(and 
			(let [seps (take-nth 2 (rest args))] (apply = "\\" seps))
			(is-compatible b (first args)))
		:else false))

(defn is-cross-func [s]
	(match [s]
		[[[p1 [[a "/" b] "\\" c] f1] [p2 [d "\\" e] f2]]] (and (is-compatible b d) (is-compatible c e)) ;; 関数交差置換規則
		[[[p1 [a "/" b] f1] [p2 [[c "\\" d] "/" e] f2]]] (and (is-compatible a d) (is-compatible b e))
		:else false))

(defn mix-normal [a] 
	(let [t {"T" (+ 1 (count-t a))}] 
		[[t "/" (get a 2)] "/" [t "/" (get a 0)]]))
(defn mix-reversed [a] 
	(let [t {"T" (+ 1 (count-t a))}] 
		[[t "\\" (get a 2)] "\\" [t "\\" (get a 0)]]))

(defn is-mix-func [s]
	(match [s]
		[[[p1 [a "/" args] f1] p2]]
		(and 
			(let [seps (take-nth 2 (rest args))] (apply = "\\" seps))
			(is-compatible a (first args)))
		[[[p1 [a "\\" args] f1] p2]]
		(and 
			(let [seps (take-nth 2 (rest args))] (apply = "/" seps))
			(is-compatible a (first args)))
		[[p1 [p2 [b "/" args] f2]]] 
		(and 
			(let [seps (take-nth 2 (rest args))] (apply = "\\" seps))
			(is-compatible b (first args)))
		[[p1 [p2 [b "\\" args] f2]]] 
		(and 
			(let [seps (take-nth 2 (rest args))] (apply = "/" seps))
			(is-compatible b (first args)))
		:else false
	))

(def max-depth 2)
(def max-t 2)

(defn is-not-type-complex [depth a] (and (< depth max-depth) (< (count-t a) max-t)))

(defn n-args [f]
  (-> f class .getDeclaredMethods first .getParameterTypes alength))

(defn applied [f1 f2]
	(let []
		(if (and (some? f1) (fn? f1)) (f1 f2) "Failed")))

(defn combinations [tokens-set]
	(reduce 
		(fn [l r] 
			(mapcat identity 
				(map (fn [x] (map (fn [y] (conj (vec x) y)) r)) l))) (map vector (first tokens-set)) (rest tokens-set)))

(defn combine
	([s] (combine s "" 0))
	([s span depth] (combine s span depth true))
	([s span depth all-t?]
	 	(let [some-zero-in-combined (fn [lz] #{(not-empty (vec (some (fn [combined] (some (fn [p] (if (and (= 0 (count-t p))) p)) combined)) lz)))})
					all-t-in-combined (fn [lz] (map vec (reduce clojure.set/union lz)))]
			(match [s]
				[([[p1 [a "/" b] f1] [p2 c f2]] :guard is-apply-func)] ;; 関数適用規則
				(let []
					#{[(str p1 span p2) (complete-t a (extract-t b c)) (applied f1 f2)]}) 
				[([[p1 a f1] [p2 [b "\\" c] f2]] :guard is-apply-func)]
				(let [] 
					#{[(str p1 span p2) (complete-t b (extract-t c a)) (applied f2 f1)]})
				[([[p1 [a "/" b] f1] [p2 args f2]] :guard is-merge-func)] ;; 一般化関数合成規則
				#{[(str p1 span p2) 
					(let [e (extract-t b (first args))] 
						(vec (concat [(complete-t a e) "/"] (vec (map (fn [x] (complete-t x e)) (rest (rest args))))))) (s/fn [x] (if (and (fn? f1) (fn? f2)) (f1 (f2 x))))]}
				[([[p1 args f1] [p2 [a "\\" b] f2]] :guard is-merge-func)]
				#{[(str p1 span p2) 
					(let [e (extract-t b (first args))] 
						(vec (concat [(complete-t a e) "\\"] (vec (map (fn [x] (complete-t x e)) (rest (rest args))))))) (s/fn [x] (if (and (fn? f1) (fn? f2)) (f2 (f1 x))))]}
				[([[p1 [[a "/" b] "\\" c] f1] [p2 [d "\\" e] f2]] :guard is-cross-func)] ;; 関数交差置換規則
				#{[(str p1 span p2) [a "/" e] (s/fn [x] (if (and (fn? f1) (fn? f2)) ((f1 x) (f2 x))))]}
				[([[p1 [a "/" b] f1] [p2 [[c "\\" d] "/" e] f2]] :guard is-cross-func)]
				#{[(str p1 span p2) [c "/" b] (s/fn [x] (if (and (fn? f1) (fn? f2)) ((f2 x) (f1 x))))]}
				[(ps :guard 
					(fn [x] (let [rels (take-nth 2 (rest x))] 
						(and 
							(> (count rels) 0)
							(every? (fn [r] (= (second r) "CONJ")) rels) 
							(apply = (map first rels))))))] 
				#{[(clojure.string/join " " (map first ps)) (second (first ps)) (applied (get (second ps) 2) (map (fn [x] (get x 2)) (take-nth 2 ps)))]}
				[(full :guard empty?)] nil
				[(full :guard (fn [x] (and (every? sequential? x) (> (count x) 2))))]
				(let [partial-combined (map (fn [x] (combine [(first x) (second x)] span depth false)) (map vector full (rest full)))
							re-full (map-indexed (fn [i x] (combinations (concat (map vector (take i full)) [(into [] x)] (map vector (drop (+ i 2) full))))) partial-combined)
							partial-combined-t (lazy-seq (map (fn [x] (combine [(first x) (second x)] span depth true)) (map vector full (rest full))))
							re-full-t (lazy-seq (map-indexed (fn [i x] (combinations (concat (map vector (take i full)) [(into [] x)] (map vector (drop (+ i 2) full))))) partial-combined-t))
							result (vec (reduce clojure.set/union (mapcat vec (map (fn [x] (map (fn [p] (combine p span depth false)) x)) re-full))))]
					(if (some? (not-empty result)) result (vec (reduce clojure.set/union (mapcat vec (map (fn [x] (map (fn [p] (combine p span depth true)) x)) re-full-t))))))
					

				;; 一致しない場合、型繰り上げ規則を適用
				[[[p1 (a :guard (partial is-not-type-complex depth)) f1] [p2 (b :guard (partial is-not-type-complex depth)) f2]]] 
				(let [aln (type-lifted-normal a) 
							alr (type-lifted-reversed a) 
							bln (type-lifted-normal b) 
							blr (type-lifted-reversed b)
							ft1 (s/fn [f] (if (fn? f) (f f1)))
							ft2 (s/fn [f] (if (fn? f) (f f2)))
							fm1 (s/fn [g] (s/fn [x] (if (and (fn? g) (fn? f1)) (g (f1 x)))))
							fm2 (s/fn [g] (s/fn [x] (if (and (fn? g) (fn? f2)) (g (f2 x)))))]
					(defn some-in-complex [filter-func]
						(filter-func (lazy-seq [
							(combine [[p1 aln ft1] [p2 b f2]] span (+ depth 1) all-t?)
							(combine [[p1 alr ft1] [p2 b f2]] span (+ depth 1) all-t?)
							(combine [[p1 a f1] [p2 bln ft2]] span (+ depth 1) all-t?)
							(combine [[p1 a f1] [p2 blr ft2]] span (+ depth 1) all-t?)
							(match [a]
								[[ap "/" args-a]] 
									(let [an (mix-normal a)] 
										(filter-func (lazy-seq [
											(combine [[p1 an fm1] [p2 b f2]] span (+ depth 1) all-t?)
											(combine [[p1 an fm1] [p2 bln ft2]] span (+ depth 1) all-t?)
											(combine [[p1 an fm1] [p2 blr ft2]] span (+ depth 1) all-t?)])))
								[[ap "\\" args-a]] 
									(let [ar (mix-reversed a)] 
										(filter-func (lazy-seq [
											(combine [[p1 ar fm1] [p2 b f2]] span (+ depth 1))
											(combine [[p1 ar fm1] [p2 bln ft2]] span (+ depth 1) all-t?)
											(combine [[p1 ar fm1] [p2 blr ft2]] span (+ depth 1) all-t?)])))
								[_] nil
							)
							(match [b]
								[[bp "/" args-b]] 
									(let [bn (mix-normal b)] 
										(filter-func (lazy-seq [
											(combine [[p1 a f1] [p2 bn fm2]] span (+ depth 1) all-t?)
											(combine [[p1 aln ft1] [p2 bn fm2]] span (+ depth 1) all-t?)
											(combine [[p1 alr ft1] [p2 bn fm2]] span (+ depth 1) all-t?)])))
								[[bp "\\" args-b]] 
									(let [br (mix-reversed b)]
										(filter-func (lazy-seq [
											(combine [[p1 a f1] [p2 br fm2]] span (+ depth 1) all-t?)
											(combine [[p1 aln ft1] [p2 br fm2]] span (+ depth 1) all-t?)
											(combine [[p1 alr ft1] [p2 br fm2]] span (+ depth 1) all-t?)])))
								[_] nil
							)
							]))
					)
					(if all-t? (some-in-complex all-t-in-combined) (some-in-complex some-zero-in-combined))
					)
				[_] nil
			))
	))

