(ns org.soulspace.qclojure.domain.math.clojure.complex-linear-algebra
  "Linear algebra operations for complex numbers in Clojure.
   
   This namespace provides implementations of common linear algebra operations
   for complex numbers represented in a structure-of-arrays (SoA) format.
   It includes matrix addition, multiplication, inversion, eigen-decomposition,
   and other utilities necessary for quantum computing applications.
   
   Complex Number Representation:
   - Complex numbers are represented as maps with :real and :imag keys.
   - Vectors and matrices of complex numbers are represented as maps with
     :real and :imag keys containing vectors or matrices of real numbers.
   
   Linear Algebra Operations:
   - Matrix addition, subtraction, scaling, multiplication
   - Matrix-vector products
   - Kronecker products
   - Transpose and conjugate transpose
   - Inner products
   - Hermitian checks
   - Solving linear systems
   - Matrix inversion
   - Spectral norm computation
   - Eigen-decomposition for Hermitian matrices"
  (:require
   [clojure.math :as math]))

;;;
;;; Configuration and utilities
;;;
(def ^:const ^double default-tolerance 1.0e-12)

;;;
;;; Complex number utilities and predicates
;;;
(defn complex-scalar? [x]
  (and (map? x) (contains? x :real) (contains? x :imag) (number? (:real x)) (number? (:imag x))))

(defn complex-vector? [v]
  (and (map? v) (contains? v :real) (contains? v :imag)
       (vector? (:real v)) (vector? (:imag v)) (= (count (:real v)) (count (:imag v)))))

(defn complex-matrix? [m]
  (and (map? m) (contains? m :real) (contains? m :imag)
       (vector? (:real m)) (vector? (:imag m))
       (= (count (:real m)) (count (:imag m)))
       (every? vector? (:real m)) (every? vector? (:imag m))
       (= (map count (:real m)) (map count (:imag m)))))

(defn make-complex [r i] {:real (double r) :imag (double i)})

(defn ensure-complex-scalar [x]
  (cond (complex-scalar? x) x
        (number? x) (make-complex x 0.0)
        :else (throw (ex-info "Unsupported scalar for complex coercion" {:value x}))))

(defn ensure-complex-vector [v]
  (cond (complex-vector? v) v
        (vector? v) {:real (mapv double v) :imag (mapv (constantly 0.0) v)}
        :else (throw (ex-info "Unsupported vector for complex coercion" {:value v}))))

(defn ensure-complex-matrix [m]
  (cond (complex-matrix? m) m
        (vector? m) {:real (mapv (fn [row] (mapv double row)) m)
                     :imag (mapv (fn [row] (mapv (constantly 0.0) row)) m)}
        :else (throw (ex-info "Unsupported matrix for complex coercion" {:value m}))))


;;;
;;; Matrix algebra helper functions
;;;
(defn identity-matrix [n]
  (vec (for [i (range n)] (vec (for [j (range n)] (double (if (= i j) 1.0 0.0)))))))

(defn matrix-shape [A]
  (if (complex-matrix? A)
    [(count (:real A)) (count (first (:real A)))]
    [(count A) (count (first A))]))

(defn close-matrices? [A B tol]
  (let [[r c] (matrix-shape A)]
    (every? true?
            (for [i (range r) j (range c)]
              (< (abs (double (- (if (complex-matrix? A)
                                   (get-in (:real A) [i j])
                                   (get-in A [i j]))
                                 (if (complex-matrix? B)
                                   (get-in (:real B) [i j])
                                   (get-in B [i j]))))) tol)))))

;;
;; Real matrix operations
;;
(defn real-add [A B]
  (mapv (fn [ra rb] (mapv #(+ (double %1) (double %2)) ra rb)) A B))

(defn real-sub [A B]
  (mapv (fn [ra rb] (mapv #(- (double %1) (double %2)) ra rb)) A B))

(defn real-scale [A a]
  (let [a (double a)] (mapv (fn [row] (mapv #(* a (double %)) row)) A)))

(defn real-mul [A B]
  (let [m (count A) n (count (first B))
        bt (apply map vector B)]
    (vec (for [i (range m)]
           (vec (for [j (range n)]
                  (reduce + (map * (nth A i) (nth bt j)))))))))

(defn real-transpose [A] (vec (apply mapv vector A)))

(defn real-kronecker [A B]
  (let [[ar ac] (matrix-shape A) [br bc] (matrix-shape B)]
    (vec (for [i (range ar)
               bi (range br)]
           (vec (for [j (range ac)
                      bj (range bc)]
                  (* (double (get-in A [i j])) (double (get-in B [bi bj])))))))))


;;
;; Complex (SoA) matrix operations
;;
(defn matrix-add [A B]
  {:real (real-add (:real A) (:real B))
   :imag (real-add (:imag A) (:imag B))})

(defn matrix-subtract [A B]
  {:real (real-sub (:real A) (:real B))
   :imag (real-sub (:imag A) (:imag B))})

(defn matrix-scale [A a]
  (if (complex-scalar? a) ; scalar may be complex
    (let [{ar :real ai :imag} a
          Ar (:real A) Ai (:imag A)
          re (mapv (fn [xr xi]
                     (mapv (fn [x y] (- (* ar x) (* ai y))) xr xi)) Ar Ai)
          im (mapv (fn [xr xi]
                     (mapv (fn [x y] (+ (* ar y) (* ai x))) xr xi)) Ar Ai)]
      {:real re :imag im})
    (let [a-real (if (complex-scalar? a) (:real a) (double a))]
      {:real (real-scale (:real A) a-real)
       :imag (real-scale (:imag A) a-real)})))

(defn matrix-multiply [A B]
  (let [Ar (:real A) Ai (:imag A) Br (:real B) Bi (:imag B)
        AC (real-mul Ar Br)
        BD (real-mul Ai Bi)
        AD (real-mul Ar Bi)
        BC (real-mul Ai Br)]
    {:real (real-sub AC BD)
     :imag (real-add AD BC)}))

(defn matrix-vector-product [A x]
  (let [Ar (:real A) Ai (:imag A) xr (:real x) xi (:imag x)
        mul-r (real-mul Ar (mapv vector xr)) ; treat vector as col matrix
        mul-i (real-mul Ai (mapv vector xi))
        mul-r2 (real-mul Ar (mapv vector xi))
        mul-i2 (real-mul Ai (mapv vector xr))]
    {:real (mapv #(- %1 %2) (mapv first mul-r) (mapv first mul-i))
     :imag (mapv #(+ %1 %2) (mapv first mul-r2) (mapv first mul-i2))}))

(defn hadamard-product [A B]
  (let [Ar (:real A) Ai (:imag A) Br (:real B) Bi (:imag B)
        Cr (mapv (fn [ra ia rb ib]
                   (mapv (fn [a i b j] (- (* a b) (* i j))) ra ia rb ib)) Ar Ai Br Bi)
        Ci (mapv (fn [ra ia rb ib]
                   (mapv (fn [a i b j] (+ (* a j) (* i b))) ra ia rb ib)) Ar Ai Br Bi)]
    {:real Cr :imag Ci}))

(defn kronecker-product [A B]
  (let [Ar (:real A) Ai (:imag A) Br (:real B) Bi (:imag B)
        RR (real-kronecker Ar Br)
        II (real-kronecker Ai Bi)
        RI (real-kronecker Ar Bi)
        IR (real-kronecker Ai Br)]
    {:real (real-sub RR II)
     :imag (real-add RI IR)}))

(defn transpose [A]
  {:real (real-transpose (:real A))
   :imag (real-transpose (:imag A))})

(defn conjugate-transpose [A]
  {:real (real-transpose (:real A))
   :imag (real-scale (real-transpose (:imag A)) -1.0)})

(defn inner-product [x y]
  (let [xr (:real x) xi (:imag x) yr (:real y) yi (:imag y)
        re (reduce + (map (fn [a b c d] (+ (* a b) (* c d))) xr yr xi yi))
        im (reduce + (map (fn [a b c d] (- (* a d) (* c b))) xr yr xi yi))]
    (make-complex re im)))

(defn hermitian? [A tol]
  (let [Ar (:real A) Ai (:imag A) n (count Ar)]
    (and (= n (count (first Ar))) ; square
         (every? true?
                 (for [i (range n) j (range i n)]
                   (let [aij-r (get-in Ar [i j]) aij-i (get-in Ai [i j])
                         aji-r (get-in Ar [j i]) aji-i (get-in Ai [j i])]
                     (and (< (abs (double (- aij-r aji-r))) tol)
                          (< (abs (double (+ aij-i aji-i))) tol))))))))

;; Complex Gaussian elimination helpers
;;
;; These implement partial pivot Gaussian elimination for a single RHS vector
;; and Gauss-Jordan inversion for complex matrices represented in SoA form.
;; They are intentionally straightforward (no blocking / BLAS) and target
;; small to medium matrix sizes.

(defn forward-elimination
  "Forward elimination (partial pivot) for complex A x = b.
  Ar/Ai: matrix parts, br/bi: RHS parts. Returns [Ar Ai br bi] in row echelon form."
  [Ar Ai br bi]
  (let [n (count Ar)]
    (loop [k 0 Ar Ar Ai Ai br br bi bi]
      (if (= k n)
        [Ar Ai br bi]
        (let [pivot (apply max-key #(math/hypot (get-in Ar [% k]) (get-in Ai [% k])) (range k n))
              swap-row (fn [M] (if (not= pivot k) (-> M (assoc k (M pivot)) (assoc pivot (M k))) M))
              Ar (swap-row Ar) Ai (swap-row Ai)
              br (if (not= pivot k) (-> br (assoc k (br pivot)) (assoc pivot (br k))) br)
              bi (if (not= pivot k) (-> bi (assoc k (bi pivot)) (assoc pivot (bi k))) bi)
              akk-r (get-in Ar [k k]) akk-i (get-in Ai [k k])
              denom (+ (* akk-r akk-r) (* akk-i akk-i))]
          (when (zero? denom) (throw (ex-info "Singular complex matrix (zero pivot)" {:k k})))
          (let [rowr (Ar k) rowi (Ai k)
                Ar (assoc Ar k (mapv (fn [ar ai] (/ (+ (* ar akk-r) (* ai akk-i)) denom)) rowr rowi))
                Ai (assoc Ai k (mapv (fn [ar ai] (/ (- (* ai akk-r) (* ar akk-i)) denom)) rowr rowi))
                brk (br k) bik (bi k)
                nr (/ (+ (* brk akk-r) (* bik akk-i)) denom)
                ni (/ (- (* bik akk-r) (* brk akk-i)) denom)
                br (assoc br k nr) bi (assoc bi k ni)
                [Ar Ai br bi] (loop [i (inc k) Ar Ar Ai Ai br br bi bi]
                                (if (= i n)
                                  [Ar Ai br bi]
                                  (let [f-r (get-in Ar [i k]) f-i (get-in Ai [i k])]
                                    (if (and (zero? f-r) (zero? f-i))
                                      (recur (inc i) Ar Ai br bi)
                                      (let [rowk-r (Ar k) rowk-i (Ai k)
                                            Ar-rowi (vec (map-indexed (fn [j arik]
                                                                        (- arik (- (* f-r (rowk-r j)) (* f-i (rowk-i j))))) (Ar i)))
                                            Ai-rowi (vec (map-indexed (fn [j aiik]
                                                                        (- aiik (+ (* f-r (rowk-i j)) (* f-i (rowk-r j))))) (Ai i)))
                                            br (assoc br i (- (br i) (- (* f-r (br k)) (* f-i (bi k)))))
                                            bi (assoc bi i (- (bi i) (+ (* f-r (bi k)) (* f-i (br k)))))
                                            Ar (assoc Ar i Ar-rowi) Ai (assoc Ai i Ai-rowi)]
                                        (recur (inc i) Ar Ai br bi))))))]
            (recur (inc k) Ar Ai br bi)))))))

(defn back-substitution
  "Back substitution on upper-triangular complex system."
  [Ar Ai br bi]
  (let [n (count Ar) xr (double-array n) xi (double-array n)]
    (loop [i (dec n)]
      (when (>= i 0)
        (let [sumr (reduce + (for [j (range (inc i) n)] (- (* (get-in Ar [i j]) (aget xr j)) (* (get-in Ai [i j]) (aget xi j)))))
              sumi (reduce + (for [j (range (inc i) n)] (- (* (get-in Ar [i j]) (aget xi j)) (* (get-in Ai [i j]) (aget xr j)))))
              arii (get-in Ar [i i]) aiii (get-in Ai [i i])
              den (+ (* arii arii) (* aiii aiii))
              nr (- (br i) sumr) ni (- (bi i) sumi)
              xr-i (/ (+ (* nr arii) (* ni aiii)) den)
              xi-i (/ (- (* ni arii) (* nr aiii)) den)]
          (aset xr i xr-i) (aset xi i xi-i) (recur (dec i)))))
    {:real (vec xr) :imag (vec xi)}))

(defn solve-linear
  "Solve complex linear system A x = b and return complex vector representation."
  [A b]
  (let [Ar (mapv vec (:real A)) Ai (mapv vec (:imag A))
        b* (if (complex-vector? b) b (ensure-complex-vector b))
        br (vec (:real b*)) bi (vec (:imag b*))
        [Ar Ai br bi] (forward-elimination Ar Ai br bi)]
    (back-substitution Ar Ai br bi)))

(defn inverse
  "Inverse of complex matrix via Gauss-Jordan."
  [A]
  (let [Ar (mapv vec (:real A)) Ai (mapv vec (:imag A))
        n (count Ar) Ir (identity-matrix n) Ii (vec (repeat n (vec (repeat n 0.0))))]
    (loop [k 0 Ar Ar Ai Ai Ir Ir Ii Ii]
      (if (= k n)
        {:real Ir :imag Ii}
        (let [pivot (apply max-key #(math/hypot (get-in Ar [% k]) (get-in Ai [% k])) (range k n))
              swap-row (fn [M] (if (not= pivot k) (-> M (assoc k (M pivot)) (assoc pivot (M k))) M))
              Ar (swap-row Ar) Ai (swap-row Ai) Ir (swap-row Ir) Ii (swap-row Ii)
              akk-r (get-in Ar [k k]) akk-i (get-in Ai [k k])
              denom (+ (* akk-r akk-r) (* akk-i akk-i))]
          (when (zero? denom) (throw (ex-info "Singular complex matrix (inverse)" {:k k})))
          (let [rowr (Ar k) rowi (Ai k) ir-row (Ir k) ii-row (Ii k)
                Ar (assoc Ar k (mapv (fn [ar ai] (/ (+ (* ar akk-r) (* ai akk-i)) denom)) rowr rowi))
                Ai (assoc Ai k (mapv (fn [ar ai] (/ (- (* ai akk-r) (* ar akk-i)) denom)) rowr rowi))
                Ir (assoc Ir k (mapv (fn [ar ai] (/ (+ (* ar akk-r) (* ai akk-i)) denom)) ir-row ii-row))
                Ii (assoc Ii k (mapv (fn [ar ai] (/ (- (* ai akk-r) (* ar akk-i)) denom)) ir-row ii-row))
                [Ar Ai Ir Ii] (loop [i 0 Ar Ar Ai Ai Ir Ir Ii Ii]
                                (if (= i n)
                                  [Ar Ai Ir Ii]
                                  (if (= i k)
                                    (recur (inc i) Ar Ai Ir Ii)
                                    (let [f-r (get-in Ar [i k]) f-i (get-in Ai [i k])]
                                      (if (and (zero? f-r) (zero? f-i))
                                        (recur (inc i) Ar Ai Ir Ii)
                                        (let [rowk-r (Ar k) rowk-i (Ai k)
                                              irk-row (Ir k) iik-row (Ii k)
                                              Ar-rowi (vec (map-indexed (fn [j arik]
                                                                          (let [rkj (nth rowk-r j) ikj (nth rowk-i j)]
                                                                            (- arik (- (* f-r rkj) (* f-i ikj))))) (Ar i)))
                                              Ai-rowi (vec (map-indexed (fn [j aiik]
                                                                          (let [rkj (nth rowk-r j) ikj (nth rowk-i j)]
                                                                            (- aiik (+ (* f-r ikj) (* f-i rkj))))) (Ai i)))
                                              Ir-rowi (vec (map-indexed (fn [j irik]
                                                                          (let [irkj (nth irk-row j) iikj (nth iik-row j)]
                                                                            (- irik (- (* f-r irkj) (* f-i iikj))))) (Ir i)))
                                              Ii-rowi (vec (map-indexed (fn [j iik]
                                                                          (let [irkj (nth irk-row j) iikj (nth iik-row j)]
                                                                            (- iik (+ (* f-r iikj) (* f-i irkj))))) (Ii i)))
                                              Ar (assoc Ar i Ar-rowi) Ai (assoc Ai i Ai-rowi)
                                              Ir (assoc Ir i Ir-rowi) Ii (assoc Ii i Ii-rowi)]
                                          (recur (inc i) Ar Ai Ir Ii)))))))]
            (recur (inc k) Ar Ai Ir Ii)))))))

(defn spectral-norm
  "Compute spectral norm of complex matrix A via power iteration on A^H A with Rayleigh quotient convergence." [A]
  (let [n (count (:real A))
        x0 {:real (vec (repeat n (/ 1.0 (math/sqrt n)))) :imag (vec (repeat n 0.0))}
        tol 1e-12
        max-it 200
        Ah (conjugate-transpose A)]
    (loop [k 0 x x0 lambda-prev nil]
      (let [Ax (matrix-vector-product A x)
            AhAx (matrix-vector-product Ah Ax)
            lambda (let [num-r (reduce + (map (fn [a b c d] (+ (* a b) (* c d))) (:real x) (:real AhAx) (:imag x) (:imag AhAx)))]
                     (double num-r))
            nr (math/sqrt (reduce + (map (fn [a b] (+ (* a a) (* b b))) (:real AhAx) (:imag AhAx))))
            x' {:real (mapv #(/ % nr) (:real AhAx)) :imag (mapv #(/ % nr) (:imag AhAx))}
            conv? (and lambda-prev (< (abs (- lambda lambda-prev)) (* tol (max 1.0 (abs lambda)))))]
        (if (or (>= k max-it) conv?)
          (math/sqrt (max 0.0 lambda))
          (recur (inc k) x' lambda))))))

;; Jacobi eigen-decomposition (shared helper)
;;
;; The routine is intentionally simple (no pivot strategies beyond largest
;; off-diagonal, no blocking) and targets small/medium matrices typical for
;; algorithmic construction and test cases. Heavy-duty performance should be
;; delegated to a native/optimized backend.

(defn jacobi-symmetric
  "Compute eigen-decomposition of a real symmetric matrix via classical
  Jacobi rotations.

  Parameters:
  - A      real symmetric square matrix (vector of row vectors)
  - eps    convergence tolerance on largest off-diagonal absolute value
  - max-it maximum number of sweeps (rotation applications)

  Returns map:
  {:eigenvalues [...unsorted...] :vectors V :iterations k}
  where V is an orthogonal matrix whose columns are the (unnormalized but
  numerically unit) eigenvectors corresponding to the returned eigenvalues.

  NOTE:
  * Input matrix is copied; original is left untouched.
  * Off-diagonal search is O(n^2) per iteration – acceptable for small n.
  * Sorting of eigenpairs is intentionally left to callers so they can
    perform domain-specific post-processing (e.g. duplicate collapse in
    complex Hermitian embedding)."
  [A eps max-it]
  (let [[n m] (matrix-shape A)]
    (when (not= n m) (throw (ex-info "jacobi-symmetric requires square matrix" {:shape [n m]})))
    (if (zero? n)
      {:eigenvalues [] :vectors [] :iterations 0}
      (let [A0 (mapv vec A)
            V0 (identity-matrix n)]
        (loop [iter 0 M A0 V V0]
          (if (>= iter max-it)
            {:eigenvalues (mapv #(get-in M [% %]) (range n))
             :vectors (vec (apply mapv vector V))
             :iterations iter}
            (let [[p q val] (reduce (fn [[bp bq bv] [i j]]
                                      (let [aij (abs (double (get-in M [i j])))]
                                        (if (> aij bv) [i j aij] [bp bq bv])))
                                    [0 0 0.0]
                                    (for [i (range n) j (range (inc i) n)] [i j]))]
              (if (< val eps)
                {:eigenvalues (mapv #(get-in M [% %]) (range n))
                 :vectors (vec (apply mapv vector V))
                 :iterations iter}
                (let [app (get-in M [p p]) aqq (get-in M [q q]) apq (get-in M [p q])
                      tau (/ (- aqq app) (* 2.0 apq))
                      t (let [s (if (neg? tau) -1.0 1.0)] (/ s (+ (abs tau) (math/sqrt (+ 1.0 (* tau tau))))))
                      c (/ 1.0 (math/sqrt (+ 1.0 (* t t))))
                      s (* t c)
                      rotate-row (fn [M r]
                                   (let [rp (get-in M [r p]) rq (get-in M [r q])]
                                     (-> M
                                         (assoc-in [r p] (- (* c rp) (* s rq)))
                                         (assoc-in [r q] (+ (* s rp) (* c rq))))))
                      M1 (reduce rotate-row M (range n))
                      rotate-col (fn [M r]
                                   (let [pr (get-in M [p r]) qr (get-in M [q r])]
                                     (-> M
                                         (assoc-in [p r] (- (* c pr) (* s qr)))
                                         (assoc-in [q r] (+ (* s pr) (* c qr))))))
                      M2 (reduce rotate-col M1 (range n))
                      apq' (get-in M2 [p q]) app' (get-in M2 [p p]) aqq' (get-in M2 [q q])
                      M3 (-> M2
                             (assoc-in [p p] (- app' (* t apq')))
                             (assoc-in [q q] (+ aqq' (* t apq')))
                             (assoc-in [p q] 0.0)
                             (assoc-in [q p] 0.0))
                      update-V (fn [V r]
                                 (let [vrp (get-in V [r p]) vrq (get-in V [r q])]
                                   (-> V
                                       (assoc-in [r p] (- (* c vrp) (* s vrq)))
                                       (assoc-in [r q] (+ (* s vrp) (* c vrq))))))
                      V1 (reduce update-V V (range n))]
                  (recur (inc iter) M3 V1))))))))))

;; Complex eigenvector phase normalization
;;
;; Eigenvectors of Hermitian matrices are defined up to a global complex phase.
;; For deterministic downstream processing (e.g. comparison in tests, registry
;; lookups) we canonicalize that phase so that the first component with
;; magnitude > tol has zero imaginary part and non-negative real part.
(defn normalize-complex-phase
  "Normalize global phase of complex vector v (SoA map) so the first
  non-negligible component becomes real and non-negative.

  Parameters:
  - v   {:real [...], :imag [...]} (assumed already L2-normalized or close)
  - eps magnitude threshold to select the reference component.

  Returns new complex vector map with adjusted :real/:imag.

  If all components are (near) zero the vector is returned unchanged."
  [v eps]
  (let [xr (:real v) xi (:imag v)
        n (count xr)
        ;; find reference index
        idx (first (for [i (range n)
                         :let [a (double (nth xr i)) b (double (nth xi i))
                               mag2 (+ (* a a) (* b b))]
                         :when (> mag2 (* eps eps))]
                     i))]
    (if (nil? idx)
      v
      (let [a (double (nth xr idx))
            b (double (nth xi idx))
            ;; Compute phase of reference component a+ib = r e^{i phi}
            phi (math/atan2 b a)
            c (math/cos phi)
            s (math/sin phi)
            ;; Multiply whole vector by e^{-i phi}. For each component ar+i ai:
            ;; (ar + i ai)(cos phi - i sin phi) = (ar c + ai s) + i (ai c - ar s)
            xr' (mapv (fn [ar ai] (+ (* ar c) (* ai s))) xr xi)
            xi' (mapv (fn [ar ai] (- (* ai c) (* ar s))) xr xi)
            ;; Ensure reference component real and non-negative (flip sign if needed)
            ref (nth xr' idx)
            sign (if (neg? ref) -1.0 1.0)
            xr'' (if (= sign 1.0) xr' (mapv #(* sign %) xr'))
            xi'' (if (= sign 1.0) xi' (mapv #(* sign %) xi'))]
        {:real xr'' :imag xi''}))))

(defn outer-product
  "Compute outer product of complex vectors x and y.
  Returns a complex matrix in SoA form."
  [x y]
  (let [xr (:real x) xi (:imag x)
        yr (:real y) yi (:imag y)
        ;; (x y^H)_{ij} = (xr_i + i xi_i)(yr_j - i yi_j)
        real (vec (for [i (range (count xr))]
                    (vec (for [j (range (count yr))]
                           (+ (* (nth xr i) (nth yr j)) (* (nth xi i) (nth yi j)))))))
        imag (vec (for [i (range (count xr))]
                    (vec (for [j (range (count yr))]
                           (- (* (nth xi i) (nth yr j)) (* (nth xr i) (nth yi j)))))))]
    {:real real :imag imag}))

(defn trace
  "Compute trace of complex matrix A (SoA form).
  Returns a complex scalar."
  [A]
  (let [n (count (:real A))]
    (make-complex (reduce + (map (fn [i] (get-in (:real A) [i i])) (range n)))
                  (reduce + (map (fn [i] (get-in (:imag A) [i i])) (range n))))))

(defn norm2
  "Compute L2 norm of complex vector x (SoA form).
  Returns a non-negative real number."
  [x]
  ;; <x|x> for a valid inner product should be real non-negative; imaginary part ≈ 0.
  (let [ip (inner-product x x)
        re (double (:real ip))
        re (if (neg? re) (abs re) re)]
    (math/sqrt re)))

(defn diagonal?
  "Check if complex matrix A is diagonal within given tolerance.
  Returns true if diagonal, false otherwise."
  ([A] (diagonal? A default-tolerance))
  ([A tol]
   (let [[n m] (matrix-shape A)]
     (if (not= n m)
       false
       (let [Ar (:real A) Ai (:imag A)]
         (every? true?
                 (for [i (range n) j (range n) :when (not= i j)]
                   (and (< (abs (double (get-in Ar [i j]))) tol)
                        (< (abs (double (get-in Ai [i j]))) tol)))))))))

(defn unitary?
  "Check if complex matrix U is unitary (U^H U = I).
  Returns true if unitary, false otherwise.
  Optionally accepts a tolerance for numerical stability."
  ([U] (unitary? U default-tolerance))
  ([U eps]
   (let [[n m] (matrix-shape U)]
     (if (not= n m)
       false
       (let [Uh (if (complex-matrix? U) (conjugate-transpose U) (real-transpose U))
             P (if (complex-matrix? U) (matrix-multiply Uh U) (real-mul Uh U))
             I (identity-matrix n)]
         (if (complex-matrix? P)
           (close-matrices? (:real P) I eps)
           (close-matrices? P I eps)))))))

(defn eigen-hermitian
  "Compute eigen-decomposition of a complex Hermitian matrix A.
  Returns a map with eigenvalues and eigenvectors in SoA form.
  Optionally accepts a tolerance for numerical stability."
  ([A] (eigen-hermitian A default-tolerance))
  ([A eps]
   (let [X (:real A) Y (:imag A)
         [n m] (matrix-shape X)
         _ (when (not= n m) (throw (ex-info "eigen-hermitian requires square matrix" {:shape [n m]})))
         ;; Real embedding M = [[X -Y][Y X]]
         top (vec (for [i (range n)] (vec (concat (nth X i) (map #(- %) (nth Y i))))))
         bottom (vec (for [i (range n)] (vec (concat (nth Y i) (nth X i)))))
         M (vec (concat top bottom))
         N (* 2 n)
         tol (* eps (inc N))
         max-it (* 10 N N)
         {:keys [eigenvalues vectors]} (jacobi-symmetric M tol max-it)
         build-complex (fn [col-idx]
                         (let [w (nth vectors col-idx) ; length 2n
                               x (subvec w 0 n)
                               y (subvec w n N)
                               nrm (math/sqrt (reduce + (map (fn [a b] (+ (* a a) (* b b))) x y)))
                               nrm (if (pos? nrm) nrm 1.0)
                               x' (mapv #(/ % nrm) x)
                               y' (mapv #(/ % nrm) y)
                               v {:real x' :imag y'}]
                           (normalize-complex-phase v 1e-14)))
         ;; For Hermitian matrix real embedding, eigenvalues appear in pairs
         ;; Sort and take unique values by removing duplicates  
         unique-eigenvals (vec (distinct eigenvalues))
         sorted-unique (sort unique-eigenvals)
         ;; For each unique eigenvalue, find its first occurrence in the original list
         evals (vec (take n sorted-unique))
         evect-indices (mapv (fn [eval-val]
                               (first (keep-indexed (fn [idx val]
                                                      (when (< (abs (- val eval-val)) 1e-12)
                                                        idx))
                                                    eigenvalues)))
                             evals)
         evects (mapv (fn [idx] (build-complex idx)) evect-indices)]
     {:eigenvalues (mapv double evals) :eigenvectors (vec evects)})))

(defn positive-semidefinite?
  "Check if complex matrix A is positive semidefinite.
  Returns true if A is PSD, false otherwise.
  Optionally accepts a tolerance for numerical stability."
  ([A] (positive-semidefinite? A default-tolerance))
  ([A eps]
   (if (hermitian? A eps)
     (let [{:keys [eigenvalues]} (eigen-hermitian A)]
       (every? #(>= % (- eps)) eigenvalues))
     false)))

(defn svd
  "Compute singular value decomposition of a complex matrix A.
  Returns a map with singular values and left/right singular vectors in SoA form.
  Optionally accepts a tolerance for numerical stability."
  ([A] (svd A default-tolerance))
  ([A eps]
   (let [Ar (:real A) Ai (:imag A)
         m (count Ar) n (count (first Ar))
         ;; Build AᴴA = (conj-transpose A) * A
         Ah {:real (real-transpose Ar) :imag (real-scale (real-transpose Ai) -1.0)}
         AhA (matrix-multiply Ah {:real Ar :imag Ai})
         {:keys [eigenvalues eigenvectors]} (eigen-hermitian AhA)
         ;; eigenvalues ascending per contract; reverse for descending singular values
         pairs (reverse (map vector eigenvalues eigenvectors))
         ;; Process singular triplets
         sv-pairs (map (fn [[λ v]] [(Math/sqrt (Math/max 0.0 (double λ))) v]) pairs)
         ;; Filter numerical noise ordering & produce descending order by σ
         sv-pairs (sort-by (fn [[s _]] (- s)) sv-pairs)
         k (min m n)
         sv-pairs (take k sv-pairs)
         singular-values (mapv first sv-pairs)
         V-cols (map second sv-pairs)
         compute-u (fn [sigma v]
                     (if (> sigma eps)
                       (let [u (matrix-vector-product {:real Ar :imag Ai} v)
                             norm-sigma sigma
                             ur (:real u) ui (:imag u)
                             u-norm (Math/sqrt (reduce + (map (fn [a b] (+ (* a a) (* b b))) ur ui)))
                             ;; divide by sigma (not u-norm) per definition: u = Av / σ
                             scale (/ 1.0 norm-sigma)]
                         {:real (mapv #(* scale %) ur)
                          :imag (mapv #(* scale %) ui)})
                       ;; placeholder zero vector; will be replaced in orthonormal completion
                       {:real (vec (repeat m 0.0)) :imag (vec (repeat m 0.0))}))
         U-cols (map (fn [[s v]] (compute-u s v)) sv-pairs)
         ;; Gram–Schmidt for complex vectors
         inner-c (fn [x y] (inner-product x y))
         sub-c (fn [x y]
                 {:real (mapv - (:real x) (:real y))
                  :imag (mapv - (:imag x) (:imag y))})
         scale-c (fn [x alpha]
                   {:real (mapv #(* alpha %) (:real x))
                    :imag (mapv #(* alpha %) (:imag x))})
         mult-cv (fn [c v]
                   (let [ar (:real c) ai (:imag c) vr (:real v) vi (:imag v)]
                     {:real (mapv (fn [r i] (- (* ar r) (* ai i))) vr vi)
                      :imag (mapv (fn [r i] (+ (* ar i) (* ai r))) vr vi)}))
         norm-c (fn [x]
                  (Math/sqrt (reduce + (map (fn [a b] (+ (* a a) (* b b))) (:real x) (:imag x)))))
         normalize-c (fn [x]
                       (let [nrm (norm-c x)]
                         (if (pos? nrm)
                           (scale-c x (/ 1.0 nrm))
                           x)))
         orthonormalize (fn [cols]
                          (reduce (fn [acc v]
                                    (let [v1 (reduce (fn [vv u]
                                                       (let [ip (inner-c u vv)]
                                                         (sub-c vv (mult-cv ip u))))
                                                     v acc)
                                          v-n (normalize-c v1)]
                                      (conj acc v-n))) [] cols))
         U-cols (orthonormalize U-cols) ; re-orthonormalize in case of numerical issues
         ;; Orthonormal completion for rank deficiency
         rank (count (filter #(> % eps) singular-values))
         complete-basis (fn [existing dim]
                          (loop [basis existing i 0]
                            (if (= (count basis) dim)
                              basis
                              (if (>= i dim)
                                basis
                                (let [e {:real (vec (for [k (range dim)] (if (= k i) 1.0 0.0)))
                                         :imag (vec (repeat dim 0.0))}
                                      v1 (reduce (fn [vv u]
                                                   (let [ip (inner-c u vv)]
                                                     (sub-c vv (mult-cv ip u)))) e basis)
                                      nrm (norm-c v1)]
                                  (if (> nrm (* 10 eps))
                                    (recur (conj basis (scale-c v1 (/ 1.0 nrm))) (inc i))
                                    (recur basis (inc i))))))))
         U-full (complete-basis U-cols m)
         V-full (complete-basis V-cols n)
         Vh {:real (real-transpose (mapv :real V-full))
             :imag (real-scale (real-transpose (mapv :imag V-full)) -1.0)}
         U-mat {:real (vec (apply map vector (map :real U-full)))
                :imag (vec (apply map vector (map :imag U-full)))}]
     {:U U-mat :S singular-values :Vt Vh :V† Vh})))

(defn lu-decomposition
  "Compute LU decomposition of a complex matrix A with partial pivoting.
  Returns a map with permutation vector P, lower triangular L and upper triangular U matrices in SoA form."
  ([A] (lu-decomposition A default-tolerance))
  ([A eps]
   ;; Complex LU with partial pivoting
   (let [Ar (mapv vec (:real A)) Ai (mapv vec (:imag A))
         n (count Ar)
         P0 (vec (range n))
         Lr (vec (repeat n (vec (repeat n 0.0))))
         Li (vec (repeat n (vec (repeat n 0.0))))]
     (loop [k 0 Ar Ar Ai Ai Lr Lr Li Li P P0]
       (if (= k n)
         {:P P :L {:real (vec (map-indexed (fn [i row] (assoc row i 1.0)) Lr)) :imag Li}
          :U {:real Ar :imag Ai}}
         (let [pivot-row (->> (range k n) (apply max-key (fn [r]
                                                           (let [pr (get-in Ar [r k]) pi (get-in Ai [r k])]
                                                             (+ (* pr pr) (* pi pi))))))
               swap-row (fn [M] (if (not= pivot-row k) (-> M (assoc k (M pivot-row)) (assoc pivot-row (M k))) M))
               Ar (swap-row Ar) Ai (swap-row Ai)
               Lr (if (not= pivot-row k)
                    (assoc Lr pivot-row (Lr k) k (Lr pivot-row)) Lr)
               Li (if (not= pivot-row k)
                    (assoc Li pivot-row (Li k) k (Li pivot-row)) Li)
               P (if (not= pivot-row k) (-> P (assoc k (P pivot-row)) (assoc pivot-row (P k))) P)
               pr (get-in Ar [k k]) pi (get-in Ai [k k])
               denom (+ (* pr pr) (* pi pi))]
           (when (zero? denom) (throw (ex-info "Singular complex matrix in LU" {:k k})))
           (let [step (fn [[Ar Ai Lr Li] i]
                        (let [ur (get-in Ar [i k]) ui (get-in Ai [i k])
                              fr (/ (+ (* ur pr) (* ui pi)) denom)
                              fi (/ (- (* ui pr) (* ur pi)) denom)
                              ;; store in L (below diag)
                              Lr (assoc-in Lr [i k] fr)
                              Li (assoc-in Li [i k] fi)
                              update-row (fn [Ar Ai j]
                                           (let [akr (get-in Ar [k j]) aki (get-in Ai [k j])
                                                 aij-r (get-in Ar [i j]) aij-i (get-in Ai [i j])
                                                 prod-r (- (* fr akr) (* fi aki))
                                                 prod-i (+ (* fr aki) (* fi akr))]
                                             [(assoc-in Ar [i j] (- aij-r prod-r))
                                              (assoc-in Ai [i j] (- aij-i prod-i))]))
                              [Ar Ai] (reduce (fn [[Ar Ai] j] (update-row Ar Ai j)) [Ar Ai] (range k n))]
                          [Ar Ai Lr Li]))
                 [Ar Ai Lr Li] (reduce step [Ar Ai Lr Li] (range (inc k) n))]
             (recur (inc k) Ar Ai Lr Li P))))))))

(defn qr-decomposition
  "Compute QR decomposition of a complex matrix A.
  Returns a map with orthogonal matrix Q and upper triangular matrix R in SoA form."
  [A]
  (let [Ar (:real A) Ai (:imag A)
        [m n] (matrix-shape A)
        get-col (fn [Mr Mi j] {:real (mapv #(get-in Mr [% j]) (range m)) :imag (mapv #(get-in Mi [% j]) (range m))})]
    (loop [j 0 Q [] Rr (vec (for [_ (range n)] (vec (repeat n 0.0)))) Ri (vec (for [_ (range n)] (vec (repeat n 0.0))))]
      (if (= j n)
        {:Q {:real (vec (apply mapv vector (map :real Q))) :imag (vec (apply mapv vector (map :imag Q)))}
         :R {:real (vec (mapv vec Rr)) :imag (vec (mapv vec Ri))}}
        (let [v0 (get-col Ar Ai j)
              [v Rr Ri] (reduce (fn [[vv Rr Ri] i]
                                  (let [qi (nth Q i)
                                        ;; qr/qi-norm2 omitted (not needed explicitly)
                                        ;; Use complex inner product (qi^H vv)
                                        pr (reduce + (map (fn [a b c d] (+ (* a b) (* c d))) (:real qi) (:real vv) (:imag qi) (:imag vv)))
                                        pi (reduce + (map (fn [a b c d] (- (* a d) (* c b))) (:real qi) (:real vv) (:imag qi) (:imag vv)))
                                        ;; vv = vv - qi * (pr + i pi)
                                        vr' (mapv (fn [vr qjr qji]
                                                    (- vr (- (* pr qjr) (* pi qji)))) (:real vv) (:real qi) (:imag qi))
                                        vi' (mapv (fn [vi qjr qji]
                                                    (- vi (+ (* pr qji) (* pi qjr)))) (:imag vv) (:real qi) (:imag qi))
                                        Rr (assoc-in Rr [i j] pr) Ri (assoc-in Ri [i j] pi)]
                                    [{:real vr' :imag vi'} Rr Ri]))
                                [v0 Rr Ri] (range j))
              rjj (Math/sqrt (reduce + (map #(+ (* % %) 0.0) (:real v))))
              qj (if (pos? rjj) {:real (mapv #(/ % rjj) (:real v)) :imag (mapv #(/ % rjj) (:imag v))}
                     {:real (vec (repeat m 0.0)) :imag (vec (repeat m 0.0))})
              Rr (assoc-in Rr [j j] rjj)
              Q (conj Q qj)]
          (recur (inc j) Q Rr Ri))))))

(defn wilkinson-shift
  "Compute Wilkinson shift for a 2x2 complex matrix.
  The shift is chosen to accelerate convergence of the QR algorithm."
  [A n]
  (let [a11 {:real (get-in (:real A) [(- n 2) (- n 2)]) :imag (get-in (:imag A) [(- n 2) (- n 2)])}
        a12 {:real (get-in (:real A) [(- n 2) (- n 1)]) :imag (get-in (:imag A) [(- n 2) (- n 1)])}
        a21 {:real (get-in (:real A) [(- n 1) (- n 2)]) :imag (get-in (:imag A) [(- n 1) (- n 2)])}
        a22 {:real (get-in (:real A) [(- n 1) (- n 1)]) :imag (get-in (:imag A) [(- n 1) (- n 1)])}
        ;; Trace = a11 + a22
        trace-real (+ (:real a11) (:real a22))
        trace-imag (+ (:imag a11) (:imag a22))
        ;; Determinant = a11*a22 - a12*a21
        det-real (- (- (* (:real a11) (:real a22)) (* (:imag a11) (:imag a22)))
                    (- (* (:real a12) (:real a21)) (* (:imag a12) (:imag a21))))
        det-imag (- (+ (* (:real a11) (:imag a22)) (* (:imag a11) (:real a22)))
                    (+ (* (:real a12) (:imag a21)) (* (:imag a12) (:real a21))))
        ;; Discriminant = trace^2 - 4*det
        trace2-real (- (* trace-real trace-real) (* trace-imag trace-imag))
        trace2-imag (* 2 trace-real trace-imag)
        disc-real (- trace2-real (* 4 det-real))
        disc-imag (- trace2-imag (* 4 det-imag))
        ;; Square root of discriminant
        disc-mag (Math/sqrt (+ (* disc-real disc-real) (* disc-imag disc-imag)))
        sqrt-real (Math/sqrt (/ (+ disc-mag disc-real) 2))
        sqrt-imag (if (>= disc-imag 0)
                    (Math/sqrt (/ (- disc-mag disc-real) 2))
                    (- (Math/sqrt (/ (- disc-mag disc-real) 2))))
        ;; Choose eigenvalue closer to a22
        lambda1-real (/ (+ trace-real sqrt-real) 2)
        lambda1-imag (/ (+ trace-imag sqrt-imag) 2)
        lambda2-real (/ (- trace-real sqrt-real) 2)
        lambda2-imag (/ (- trace-imag sqrt-imag) 2)
        ;; Distance to a22
        dist1 (+ (* (- lambda1-real (:real a22)) (- lambda1-real (:real a22)))
                 (* (- lambda1-imag (:imag a22)) (- lambda1-imag (:imag a22))))
        dist2 (+ (* (- lambda2-real (:real a22)) (- lambda2-real (:real a22)))
                 (* (- lambda2-imag (:imag a22)) (- lambda2-imag (:imag a22))))]
    (if (< dist1 dist2)
      {:real lambda1-real :imag lambda1-imag}
      {:real lambda2-real :imag lambda2-imag})))

(defn matrix-subtract-scalar
  "Subtract a scalar from the diagonal of a matrix A - sI."
  [A s]
  (let [[n m] (matrix-shape A)]
    {:real (mapv (fn [i]
                   (mapv (fn [j]
                           (if (= i j)
                             (- (get-in (:real A) [i j]) (:real s))
                             (get-in (:real A) [i j])))
                         (range m)))
                 (range n))
     :imag (mapv (fn [i]
                   (mapv (fn [j]
                           (if (= i j)
                             (- (get-in (:imag A) [i j]) (:imag s))
                             (get-in (:imag A) [i j])))
                         (range m)))
                 (range n))}))

(defn qr-eigenvalues
  "Compute eigenvalues of a matrix using QR algorithm with Wilkinson shifts."
  [A max-iterations tolerance]
  (let [[n m] (matrix-shape A)]
    (when (not= n m) (throw (ex-info "qr-eigenvalues requires square matrix" {:shape [n m]})))
    ;; For 2x2 matrices, use direct closed-form solution
    (if (= n 2)
      (let [a11-real (get-in (:real A) [0 0])
            a11-imag (get-in (:imag A) [0 0])
            a12-real (get-in (:real A) [0 1])
            a12-imag (get-in (:imag A) [0 1])
            a21-real (get-in (:real A) [1 0])
            a21-imag (get-in (:imag A) [1 0])
            a22-real (get-in (:real A) [1 1])
            a22-imag (get-in (:imag A) [1 1])
            ;; Trace = a11 + a22
            trace-real (+ a11-real a22-real)
            trace-imag (+ a11-imag a22-imag)
            ;; Determinant = a11*a22 - a12*a21
            det-real (- (- (* a11-real a22-real) (* a11-imag a22-imag))
                        (- (* a12-real a21-real) (* a12-imag a21-imag)))
            det-imag (- (+ (* a11-real a22-imag) (* a11-imag a22-real))
                        (+ (* a12-real a21-imag) (* a12-imag a21-real)))
            ;; Discriminant = trace² - 4*det
            trace2-real (- (* trace-real trace-real) (* trace-imag trace-imag))
            trace2-imag (* 2 trace-real trace-imag)
            disc-real (- trace2-real (* 4 det-real))
            disc-imag (- trace2-imag (* 4 det-imag))
            ;; Square root of discriminant
            disc-mag (Math/sqrt (+ (* disc-real disc-real) (* disc-imag disc-imag)))
            sqrt-real (Math/sqrt (/ (+ disc-mag disc-real) 2))
            sqrt-imag (if (>= disc-imag 0)
                        (Math/sqrt (/ (- disc-mag disc-real) 2))
                        (- (Math/sqrt (/ (- disc-mag disc-real) 2))))
            ;; Two eigenvalues: (trace ± sqrt(discriminant)) / 2
            lambda1-real (/ (+ trace-real sqrt-real) 2)
            lambda1-imag (/ (+ trace-imag sqrt-imag) 2)
            lambda2-real (/ (- trace-real sqrt-real) 2)
            lambda2-imag (/ (- trace-imag sqrt-imag) 2)]
        [{:real lambda1-real :imag lambda1-imag}
         {:real lambda2-real :imag lambda2-imag}])
      ;; For larger matrices, use iterative QR algorithm
      (loop [Ak A
             iteration 0
             active-size n
             eigenvals []]
        (if (or (>= iteration max-iterations) (<= active-size 1))
          ;; Base case: extract remaining eigenvalues
          (let [final-eigenvals (if (= active-size 1)
                                  (conj eigenvals {:real (get-in (:real Ak) [0 0])
                                                   :imag (get-in (:imag Ak) [0 0])})
                                  eigenvals)]
            final-eigenvals)
          ;; Check for deflation: look for small off-diagonal elements
          (let [can-deflate? (and (> active-size 1)
                                  (let [off-diag-real (get-in (:real Ak) [(dec active-size) (- active-size 2)])
                                        off-diag-imag (get-in (:imag Ak) [(dec active-size) (- active-size 2)])]
                                    (< (+ (* off-diag-real off-diag-real) (* off-diag-imag off-diag-imag))
                                       (* tolerance tolerance))))
                next-eigenval (when can-deflate?
                                {:real (get-in (:real Ak) [(dec active-size) (dec active-size)])
                                 :imag (get-in (:imag Ak) [(dec active-size) (dec active-size)])})
                new-active-size (if can-deflate? (dec active-size) active-size)
                new-eigenvals (if can-deflate? (conj eigenvals next-eigenval) eigenvals)]
            (if can-deflate?
              ;; Continue with deflated matrix
              (recur Ak iteration new-active-size new-eigenvals)
              ;; Apply QR step with shift
              (let [shift (if (>= active-size 2)
                            (wilkinson-shift Ak active-size)
                            {:real 0.0 :imag 0.0})
                    shifted-A (matrix-subtract-scalar Ak shift)
                    {:keys [Q R]} (qr-decomposition shifted-A)
                    RQ (matrix-multiply R Q)
                    ;; Add shift back: RQ + sI
                    shift-matrix {:real (mapv (fn [i] (mapv (fn [j] (if (= i j) (:real shift) 0.0))
                                                            (range (count (:real Ak)))))
                                              (range (count (:real Ak))))
                                  :imag (mapv (fn [i] (mapv (fn [j] (if (= i j) (:imag shift) 0.0))
                                                            (range (count (:real Ak)))))
                                              (range (count (:real Ak))))}
                    restored-A (matrix-add RQ shift-matrix)]
                (recur restored-A (inc iteration) active-size new-eigenvals)))))))))

(defn eigen-general
  "Compute eigen-decomposition of a complex square matrix A.
  Returns a map with eigenvalues and eigenvectors in SoA form.
  Optionally accepts a tolerance for numerical stability."
  ([A] (eigen-general A default-tolerance))
  ([A eps]
   (let [[n m] (matrix-shape A)]
     (when (not= n m) (throw (ex-info "eigen-general requires square matrix" {:shape [n m]})))
     (if (hermitian? A eps)
       ;; If the matrix is Hermitian, use the specialized Hermitian method
       (eigen-hermitian A eps)
       ;; For general complex matrices, use QR algorithm
       (let [eigenvals (qr-eigenvalues A 1000 eps)
             ;; Sort eigenvalues by real part, then imaginary part
             sorted-eigenvals (sort-by (fn [ev] [(:real ev) (:imag ev)]) eigenvals)
             ;; For now, return identity eigenvectors as placeholder
             ;; TODO: Implement proper eigenvector computation via inverse iteration
             identity-vecs (mapv (fn [i]
                                   {:real (mapv #(if (= % i) 1.0 0.0) (range n))
                                    :imag (mapv #(if (= % i) 0.0 0.0) (range n))})
                                 (range n))]
         {:eigenvalues sorted-eigenvals
          :eigenvectors identity-vecs})))))

(defn cholesky-decomposition
  "Compute Cholesky decomposition of a complex Hermitian positive definite matrix A.
  Returns a map with lower triangular matrix L in SoA form."
  [A]
  (let [Ar (:real A) Ai (:imag A) n (count Ar)]
    (loop [i 0 Lr (vec (repeat n (vec (repeat n 0.0)))) Li (vec (repeat n (vec (repeat n 0.0))))]
      (if (= i n)
        {:L {:real Lr :imag Li}}
        (let [[Lr Li]
              (loop [j 0 Lr Lr Li Li]
                (if (> j i) [Lr Li]
                    (let [sum-r (reduce + (for [k (range j)]
                                            (let [lrik (get-in Lr [i k]) liik (get-in Li [i k]) lrjk (get-in Lr [j k]) lijK (get-in Li [j k])]
                                              (- (* lrik lrjk) (* liik lijK)))))
                          sum-i (reduce + (for [k (range j)]
                                            (let [lrik (get-in Lr [i k]) liik (get-in Li [i k]) lrjk (get-in Lr [j k]) lijK (get-in Li [j k])]
                                              (+ (* lrik lijK) (* liik lrjk)))))
                          arij (get-in Ar [i j]) aiij (get-in Ai [i j])
                          diff-r (- arij sum-r) diff-i (- aiij sum-i)]
                      (if (= i j)
                        (do
                          (when (> (abs diff-i) 1e-12) (throw (ex-info "Hermitian diag not real" {:i i :imag diff-i})))
                          (when (neg? diff-r) (throw (ex-info "Matrix not HPD" {:i i :value diff-r})))
                          (let [val (Math/sqrt diff-r)]
                            (recur (inc j) (assoc-in Lr [i i] val) Li)))
                        (let [lrjj (get-in Lr [j j])
                              fr (/ diff-r lrjj) fi (/ diff-i lrjj)
                              Lr (assoc-in Lr [i j] fr)
                              Li (assoc-in Li [i j] fi)]
                          (recur (inc j) Lr Li))))))]
          (recur (inc i) Lr Li))))))

;; Helper functions for real matrix operations
(defn- matrix-add-real
  "Add two real matrices"
  [A B]
  (vec (for [i (range (count A))]
         (vec (for [j (range (count (first A)))]
                (+ (get-in A [i j]) (get-in B [i j])))))))

(defn- matrix-scale-real
  "Scale a real matrix by scalar s"
  [A s]
  (vec (for [i (range (count A))]
         (vec (for [j (range (count (first A)))]
                (* s (get-in A [i j])))))))

(defn- matrix-multiply-real
  "Multiply two real matrices"
  [A B]
  (let [n (count A) m (count (first B)) p (count B)]
    (vec (for [i (range n)]
           (vec (for [j (range m)]
                  (reduce + (for [k (range p)]
                              (* (get-in A [i k]) (get-in B [k j]))))))))))

(defn- matrix-inverse-real
  "Compute inverse of a real matrix using Gaussian elimination"
  [A]
  (let [n (count A)
        ;; Create augmented matrix [A | I]
        aug (vec (for [i (range n)]
                   (vec (concat (nth A i)
                                (for [j (range n)] (if (= i j) 1.0 0.0))))))]
    ;; Gaussian elimination with partial pivoting
    (loop [k 0 aug aug]
      (if (>= k n)
        ;; Extract the right half (the inverse)
        (vec (for [i (range n)]
               (vec (subvec (nth aug i) n))))
        ;; Find pivot
        (let [pivot-row (+ k (apply max-key #(abs (get-in aug [% k])) (range k n)))
              aug (if (= pivot-row k) aug
                      ;; Swap rows
                      (assoc aug k (nth aug pivot-row) pivot-row (nth aug k)))
              pivot (get-in aug [k k])]
          (if (< (abs pivot) 1e-12)
            (throw (ex-info "Matrix is singular" {}))
            ;; Eliminate column k
            (let [aug (assoc-in aug [k] (vec (map #(/ % pivot) (nth aug k))))
                  aug (loop [i 0 aug aug]
                        (if (>= i n) aug
                            (if (= i k) (recur (inc i) aug)
                                (let [factor (get-in aug [i k])
                                      new-row (vec (map - (nth aug i)
                                                        (map #(* factor %) (nth aug k))))]
                                  (recur (inc i) (assoc aug i new-row))))))]
              (recur (inc k) aug))))))))

(defn- real-matrix-exp
  "Compute matrix exponential of a real matrix using scaling and squaring with Padé approximation.
  Input: A real matrix as vector of vectors [[...][...]]
  Output: A real matrix as vector of vectors"
  [A]
  (let [n (count A)]
    (cond
      ;; 1x1 case
      (= n 1)
      [[(math/exp (get-in A [0 0]))]]

      ;; Small matrices: use Taylor series
      (<= n 2)
      (let [max-norm (apply max (for [i (range n) j (range n)] (abs (get-in A [i j]))))]
        (if (< max-norm 0.1)
          ;; Taylor series: exp(A) ≈ I + A + A²/2 + A³/6 + ...
          (let [I (vec (for [i (range n)] (vec (for [j (range n)] (if (= i j) 1.0 0.0)))))
                A2 (matrix-multiply-real A A)
                A3 (matrix-multiply-real A2 A)]
            (matrix-add-real
             (matrix-add-real
              (matrix-add-real I A)
              (matrix-scale-real A2 0.5))
             (matrix-scale-real A3 (/ 1.0 6.0))))
          ;; Scale down, compute exp, then square back up
          (let [s (math/ceil (/ (math/log (/ max-norm 0.1)) (math/log 2)))
                A-scaled (matrix-scale-real A (/ 1.0 (math/pow 2 s)))
                exp-scaled (real-matrix-exp A-scaled)]
            ;; Square s times
            (loop [k 0 result exp-scaled]
              (if (>= k s) result
                  (recur (inc k) (matrix-multiply-real result result)))))))

      ;; Larger matrices: use scaling and squaring
      :else
      (let [max-norm (apply max (for [i (range n) j (range n)] (abs (get-in A [i j]))))
            s (max 0 (math/ceil (/ (math/log (/ max-norm 1.0)) (math/log 2))))
            A-scaled (if (> s 0) (matrix-scale-real A (/ 1.0 (math/pow 2 s))) A)
            ;; Padé(6,6) approximation for exp(A-scaled)
            I (vec (for [i (range n)] (vec (for [j (range n)] (if (= i j) 1.0 0.0)))))
            A2 (matrix-multiply-real A-scaled A-scaled)
            A3 (matrix-multiply-real A2 A-scaled)
            A4 (matrix-multiply-real A2 A2)
            A5 (matrix-multiply-real A3 A2)
            A6 (matrix-multiply-real A3 A3)
            ;; Numerator: I + A + A²/2 + A³/6 + A⁴/24 + A⁵/120 + A⁶/720
            num (matrix-add-real I
                                 (matrix-add-real A-scaled
                                                  (matrix-add-real (matrix-scale-real A2 0.5)
                                                                   (matrix-add-real (matrix-scale-real A3 (/ 1.0 6.0))
                                                                                    (matrix-add-real (matrix-scale-real A4 (/ 1.0 24.0))
                                                                                                     (matrix-add-real (matrix-scale-real A5 (/ 1.0 120.0))
                                                                                                                      (matrix-scale-real A6 (/ 1.0 720.0))))))))
            ;; Denominator: I - A + A²/2 - A³/6 + A⁴/24 - A⁵/120 + A⁶/720
            den (matrix-add-real I
                                 (matrix-add-real (matrix-scale-real A-scaled -1.0)
                                                  (matrix-add-real (matrix-scale-real A2 0.5)
                                                                   (matrix-add-real (matrix-scale-real A3 (/ -1.0 6.0))
                                                                                    (matrix-add-real (matrix-scale-real A4 (/ 1.0 24.0))
                                                                                                     (matrix-add-real (matrix-scale-real A5 (/ -1.0 120.0))
                                                                                                                      (matrix-scale-real A6 (/ 1.0 720.0))))))))
            ;; Solve den * X = num for X (i.e., X = den^(-1) * num)
            exp-scaled (try
                         (let [den-inv (matrix-inverse-real den)]
                           (matrix-multiply-real den-inv num))
                         (catch Exception _
                           ;; Fallback to Taylor series if inverse fails
                           (matrix-add-real I (matrix-add-real A-scaled (matrix-scale-real A2 0.5)))))]
        ;; Square s times to get exp(A)
        (loop [k 0 result exp-scaled]
          (if (>= k s) result
              (recur (inc k) (matrix-multiply-real result result))))))))

(defn matrix-exp
  "Compute matrix exponential of a complex matrix A.
  Returns a complex matrix in SoA form."
  [A]
  (let [X (:real A) Y (:imag A)
        n (count X)]
    (cond
      ;; 1x1 complex scalar
      (and (= 1 n) (= 1 (count (first X))))
      (let [a (double (get-in X [0 0])) b (double (get-in Y [0 0])) ea (math/exp a)]
        {:real [[(* ea (math/cos b))]]
         :imag [[(* ea (math/sin b))]]})
      ;; Diagonal complex matrix: exp acts element-wise
      (every? true? (for [i (range n) j (range n) :when (not= i j)]
                      (and (zero? (double (get-in X [i j]))) (zero? (double (get-in Y [i j]))))))
      (let [R (vec (for [i (range n)]
                     (vec (for [j (range n)]
                            (if (= i j)
                              (let [a (double (get-in X [i i])) b (double (get-in Y [i i])) ea (math/exp a)]
                                (* ea (math/cos b)))
                              0.0)))))
            I (vec (for [i (range n)]
                     (vec (for [j (range n)]
                            (if (= i j)
                              (let [a (double (get-in X [i i])) b (double (get-in Y [i i])) ea (math/exp a)]
                                (* ea (math/sin b)))
                              0.0)))))]
        {:real R :imag I})
      :else
      ;; General complex: embed A = X + iY into real block [[X -Y][Y X]] and exponentiate
      (let [Z (vec (for [i (range n)] (vec (concat (nth X i) (map #(- (double %)) (nth Y i))))))
            Z2 (vec (for [i (range n)] (vec (concat (nth Y i) (nth X i)))))
            B (into [] (concat Z Z2))
            ;; Use real matrix exponential for the embedded real matrix
            E (real-matrix-exp B)
            Er (vec (for [i (range n)] (subvec (E i) 0 n)))
            Ei (vec (for [i (range n)] (subvec (E (+ i n)) 0 n)))]
        {:real Er :imag Ei}))))

(defn matrix-log
  "Compute matrix logarithm of a complex matrix A.
  Returns a complex matrix in SoA form."
  [A]
  ;; Matrix logarithm implementation.
  ;; supports Hermitian positive definite complex matrix.
  (let [[n _] (matrix-shape A)]
    (when (zero? n) {:real [] :imag []})
    (when (not (hermitian? A 1e-10)) (throw (ex-info "matrix-log implemented only for Hermitian positive definite complex matrices" {:matrix A})))
    ;; Use eigen-decomposition A = V Λ V^H, then log(A) = V log(Λ) V^H
    (let [{:keys [eigenvalues eigenvectors]} (eigen-hermitian A)]
      ;; Compute complex logarithm for each eigenvalue
      (let [log-L (map (fn [l]
                         (cond
                           (> l 0.0) (Math/log l)
                           (= l 0.0) (throw (ex-info "matrix-log cannot handle zero eigenvalues" {:lambda l}))
                           (< l 0.0) {:real (Math/log (Math/abs l)) :imag Math/PI}))
                       eigenvalues)
            ;; Build log(A) = sum_i (log λ_i) v_i v_i^H
            {:keys [real imag]}
            (reduce (fn [{:keys [real imag]} [lv v]]
                      (let [vr (:real v) vi (:imag v)
                            ;; Get real and imaginary parts of log eigenvalue
                            [lv-real lv-imag] (if (number? lv)
                                                [lv 0.0]
                                                [(:real lv) (:imag lv)])
                            ;; outer hermitian rank-1: lv * v v^H
                            ;; (a + bi)(c + di) = (ac - bd) + (ad + bc)i
                            ;; So real part: lv-real * (vr_i * vr_j + vi_i * vi_j) - lv-imag * (vi_i * vr_j - vr_i * vi_j)
                            ;; And imag part: lv-real * (vi_i * vr_j - vr_i * vi_j) + lv-imag * (vr_i * vr_j + vi_i * vi_j)
                            add-r (vec (for [i (range n)]
                                         (vec (for [j (range n)]
                                                (let [vv-real (+ (* (vr i) (vr j)) (* (vi i) (vi j)))
                                                      vv-imag (- (* (vi i) (vr j)) (* (vr i) (vi j)))]
                                                  (- (* lv-real vv-real) (* lv-imag vv-imag)))))))
                            add-i (vec (for [i (range n)]
                                         (vec (for [j (range n)]
                                                (let [vv-real (+ (* (vr i) (vr j)) (* (vi i) (vi j)))
                                                      vv-imag (- (* (vi i) (vr j)) (* (vr i) (vi j)))]
                                                  (+ (* lv-real vv-imag) (* lv-imag vv-real)))))))]
                        {:real (real-add real add-r)
                         :imag (real-add imag add-i)}))
                    {:real (vec (repeat n (vec (repeat n 0.0))))
                     :imag (vec (repeat n (vec (repeat n 0.0))))}
                    (map vector log-L eigenvectors))]
        {:real real :imag imag}))))

(defn matrix-sqrt
  "Compute matrix square root of a complex matrix A.
  Returns a complex matrix in SoA form."
  [A]
  ;; Matrix square root via eigendecomposition for general matrices.
  (let [[n n2] (matrix-shape A)]
    (when (not= n n2) (throw (ex-info "matrix-sqrt requires square" {:shape [n n2]})))
    (when (not (hermitian? A 1e-10)) (throw (ex-info "Matrix square root requires Hermitian matrix" {:matrix A :reason :not-hermitian})))

    ;; Use eigendecomposition: √A = V * √Λ * V^(-1)
    (let [{:keys [eigenvalues eigenvectors]} (eigen-hermitian A)
          ;; Compute square root of eigenvalues (allowing complex results)
          sqrt-eigenvals (mapv (fn [l]
                                 (cond
                                   (> l 0.0) (Math/sqrt l)
                                   (= l 0.0) 0.0
                                   (< l 0.0) {:real 0.0 :imag (Math/sqrt (Math/abs l))}))
                               eigenvalues)

          ;; Build sqrt(A) = sum_i (sqrt λ_i) v_i v_i^H
          {:keys [real imag]}
          (reduce (fn [{:keys [real imag]} [sqrt-lv v]]
                    (let [vr (:real v) vi (:imag v)
                          ;; Get real and imaginary parts of sqrt eigenvalue
                          [sqrt-lv-real sqrt-lv-imag] (if (number? sqrt-lv)
                                                        [sqrt-lv 0.0]
                                                        [(:real sqrt-lv) (:imag sqrt-lv)])
                          ;; Compute sqrt-lv * v v^H
                          add-r (vec (for [i (range n)]
                                       (vec (for [j (range n)]
                                              (let [vv-real (+ (* (vr i) (vr j)) (* (vi i) (vi j)))
                                                    vv-imag (- (* (vi i) (vr j)) (* (vr i) (vi j)))]
                                                (- (* sqrt-lv-real vv-real) (* sqrt-lv-imag vv-imag)))))))
                          add-i (vec (for [i (range n)]
                                       (vec (for [j (range n)]
                                              (let [vv-real (+ (* (vr i) (vr j)) (* (vi i) (vi j)))
                                                    vv-imag (- (* (vi i) (vr j)) (* (vr i) (vi j)))]
                                                (+ (* sqrt-lv-real vv-imag) (* sqrt-lv-imag vv-real)))))))]
                      {:real (real-add real add-r)
                       :imag (real-add imag add-i)}))
                  {:real (vec (repeat n (vec (repeat n 0.0))))
                   :imag (vec (repeat n (vec (repeat n 0.0))))}
                  (map vector sqrt-eigenvals eigenvectors))]
      {:real real :imag imag})))

(defn condition-number
  "Compute condition number of a complex matrix A.
  Returns a positive real number or Double/POSITIVE_INFINITY if A is singular."
  [A]
  (let [normA (spectral-norm A)
        invA (try (inverse A)
                  (catch Exception _ nil))]
    (if invA
      (let [normInv (spectral-norm invA)]
        (* normA normInv))
      Double/POSITIVE_INFINITY)))

(comment
  (eigen-hermitian {:real [[1.0 2.0]
                           [2.0 -1.0]]
                    :imag  [[0.0 0.0]
                            [0.0 0.0]]})
  (eigen-general {:real [[1.0 2.0]
                         [2.0 -1.0]]
                  :imag  [[0.0 0.0]
                          [0.0 0.0]]})

  (eigen-hermitian {:real [[3.0 1.0]
                           [1.0 3.0]]
                    :imag  [[0.0 0.0]
                            [0.0 0.0]]})
  (eigen-general {:real [[3.0 1.0]
                         [1.0 3.0]]
                  :imag  [[0.0 0.0]
                          [0.0 0.0]]})

  (eigen-hermitian {:real [[2.0 0.0]
                           [0.0 2.0]]
                    :imag  [[0.0 0.0]
                            [0.0 0.0]]})
  (eigen-general {:real [[2.0 0.0]
                         [0.0 2.0]]
                  :imag  [[0.0 0.0]
                          [0.0 0.0]]})


  ;
  )