(ns dictionary.approximate)


(defn- zero-indizes [m]
  (let [cnt (count (first m))]
    (loop [i 0
           acc (transient [])]
      (if (< i (* cnt cnt))
        (let [r (quot i cnt)
              c (mod i cnt)]
          (if (zero? (get-in m [r c]))
            (recur (inc i) (conj! acc i))
            (recur (inc i) acc)))
        (persistent! acc)))))

(defn- find-path [r2c c2r fringe]
  (when (seq fringe)
    (let [[n rows-visited cols-visited path] (first fringe)
          row? (= (count rows-visited) (count cols-visited))
          rest-fringe (subvec fringe 1)
          reachable (if row?
                      (remove cols-visited (get r2c n))
                      (remove rows-visited (get c2r n)))
          continue? (seq reachable)
          npath (when continue?
                  (conj path n))
          [nrv ncv] (when continue?
                      (if row?
                        [(conj rows-visited n) cols-visited]
                        [rows-visited (conj cols-visited n)]))
          add-fringe (map #(vector % nrv ncv npath) reachable)]
      (if row?
        (recur r2c c2r (reduce conj rest-fringe add-fringe))
        (if (some #(== -1 %) reachable)
          npath
          (recur r2c c2r (reduce conj rest-fringe add-fringe)))))))

(defn- continuous-augment-path [r2c c2r assignments]
  (let [start-fringe (->> (keys r2c)
                          (remove assignments)
                          (mapv #(vector % #{} #{} [])))
        p (find-path r2c c2r start-fringe)
        [nr2c nc2r] (->> (map vector p (rest p) (range))
                         (reduce (fn [[r1 c1] [s t i]]
                                   (if (even? i)
                                     [(update-in r1 [s] #(remove #{t} %))
                                      (assoc c1 t [s])]
                                     [(update-in r1 [t] conj s)
                                      c1]))
                                 [r2c c2r]))
        na (into assignments (map #(subvec p % (+ % 2)) (range 0 (dec (count p)) 2)))]
    (if p
      (recur nr2c nc2r na)
      assignments)))


(defn- find-max-zero-assignment [idz cnt]
  (let [idc-ps (map (juxt #(quot % cnt) #(mod % cnt)) idz)
        r2c (->> (group-by first idc-ps)
                 (map (fn [[k v]]
                        [k (mapv second v)]))
                 (into {}))
        c2r (->> (group-by second idc-ps)
                 (map (fn [[k v]]
                        [k (conj (mapv first v) -1)]))
                 (into {}))
        assignments (continuous-augment-path r2c c2r {})]
    assignments))

(defn- mark-columns [idz cnt marked-rows]
  (->> idz
       (filter #(contains? marked-rows (quot % cnt)))
       (map #(mod % cnt))
       set))

(defn- mark-rows [assigned-cells cnt marked-cols]
  (->> (filter #(contains? marked-cols (mod % cnt)) assigned-cells)
       (map #(quot % cnt))
       set))

(defn- cover-zeros [idz assignments cnt]
  (let [mza (for [[r c] assignments]
              (+ c (* cnt r)))
        marked-rows (set (remove assignments (range cnt)))
        union #(reduce conj %1 %2)
        [mr mc] (loop [mr marked-rows
                       mc #{}
                       last-rcnt 0
                       last-ccnt 0]
                  (if (and (== (count mc) last-ccnt)
                           (== (count mr) last-rcnt))
                    [mr mc]
                    (let [mcn (mark-columns idz cnt mr)
                          mrn (mark-rows mza cnt mcn)]
                      (recur (union mr mrn) (union mc mcn) (count mr) (count mc)))))]
    [(set (remove mr (range cnt))) mc]))

(defn- subtract-min-from-row [row]
  {:pre [(every? number? row)]}
  (let [m (reduce min (get row 0) (subvec row 1))]
    (if (pos? m)
      (mapv #(- % m) row)
      row)))

(defn- subtract-min-from-cols [m]
  {:pre [(every? vector? m)]}
  (let [r (count m)
        c (count (first m))
        col-mins (for [i (range c)]
                   (reduce min (map #(get % i) m)))]
    (mapv (fn [row]
            (mapv - row col-mins)) m)))

(defn- prepare-matrix [m]
  {:pre [(every? (every-pred vector? #(every? number? %)) m)]}
  (->> m
       (mapv subtract-min-from-row)
       (subtract-min-from-cols)))

(defn- minimize-step [m]
  (let [cnt (count (first m))
        idz (zero-indizes m)
        assignments (find-max-zero-assignment idz cnt)]
    (if (< (count assignments) cnt)
      (let [[cvd-rows cvd-cols] (cover-zeros idz assignments cnt)
            uncovered-idc (for [i (range cnt)
                                j (range cnt)
                                :when (not (or (cvd-rows i)
                                               (cvd-cols j)))]
                            [i j])
            new-min (reduce min (map #(get-in m %) uncovered-idc))
            m0 (reduce #(update-in %1 %2 - new-min) m uncovered-idc)
            m1 (->> (for [r cvd-rows
                          c cvd-cols]
                      [r c])
                    (reduce #(update-in %1 %2 + new-min) m0))]
        (recur m1))
      assignments)))


(defn minimize [m]
  {:pre [(= (count m) (count (first m)))]
   :post [(= (count m)
             (count (:assignments %)))
          (= (range (count m)) (sort (vals (:assignments %))))]}
  (let [pm (prepare-matrix m)
        mza (minimize-step pm)
        cnt (count (first m))]
    {:cost (reduce + (map #(get-in m %) mza))
     :assignments mza}))

(defn minimize-rectangle [r]
  (let [rs (count r)
        cs (count (first r))
        mx-value (reduce max (map #(reduce max %) r))
        m (if (< rs cs)
            (reduce conj r (repeat (- cs rs) (vec (repeat cs mx-value))))
            (if (== rs cs)
              r
              (mapv #(reduce conj % (repeat (- rs cs) mx-value)) r)))]
    (minimize m)))