package co.multiply.pathling;

import clojure.lang.*;
import java.util.ArrayList;
import java.util.Iterator;

/**
 * High-performance scanner for Pathling using Java 21+ pattern matching.
 *
 * Uses pattern switch for type dispatch instead of protocol dispatch,
 * and keyIterator for maps to avoid MapEntry allocation.
 */
public final class Scanner {
    private Scanner() {} // Prevent instantiation

    /**
     * Scan a data structure for values matching the predicate.
     *
     * @param obj         the data structure to scan
     * @param matches     ArrayList to accumulate matching values (mutated)
     * @param pred        predicate function (Clojure IFn)
     * @param includeKeys if true, also match map keys
     * @return navigation structure, or null if no matches
     */
    public static Object pathWhen(Object obj, ArrayList<Object> matches, IFn pred, boolean includeKeys) {
        return switch (obj) {
            case null -> pathScalar(null, matches, pred);
            case PersistentStructMap m -> pathMapStruct(m, matches, pred, includeKeys);
            case IPersistentMap m -> pathMap(m, matches, pred, includeKeys);
            case IPersistentVector v -> pathVector(v, matches, pred, includeKeys);
            case IPersistentSet s -> pathSet(s, matches, pred, includeKeys);
            case ISeq s -> pathSeq(s, matches, pred, includeKeys);
            case Sequential s -> pathSeq(RT.seq(s), matches, pred, includeKeys);
            default -> pathScalar(obj, matches, pred);
        };
    }

    /**
     * Scan a data structure for values matching the predicate (find-only, no navigation).
     *
     * @param obj         the data structure to scan
     * @param matches     ArrayList to accumulate matching values (mutated)
     * @param pred        predicate function (Clojure IFn)
     * @param tf          transform function to apply to matches
     * @param includeKeys if true, also match map keys
     */
    public static void findWhen(Object obj, ArrayList<Object> matches, IFn pred, IFn tf, boolean includeKeys) {
        switch (obj) {
            case null -> findScalar(null, matches, pred, tf);
            case IPersistentMap m -> findMap(m, matches, pred, tf, includeKeys);
            case IPersistentVector v -> findVector(v, matches, pred, tf, includeKeys);
            case IPersistentSet s -> findSet(s, matches, pred, tf, includeKeys);
            case ISeq s -> findSeq(s, matches, pred, tf, includeKeys);
            case Sequential s -> findSeq(RT.seq(s), matches, pred, tf, includeKeys);
            default -> findScalar(obj, matches, pred, tf);
        }
    }

    // ========================================================================
    // Map scanning
    // ========================================================================

    private static Object pathMap(IPersistentMap m, ArrayList<Object> matches, IFn pred, boolean includeKeys) {
        // Use keyIterator if available (avoids MapEntry allocation)
        Iterator<?> iter = (m instanceof IMapIterable mi)
            ? mi.keyIterator()
            : RT.iter(RT.keys(m));

        ArrayList<Nav.KeyNav> childNavs = null;

        while (iter.hasNext()) {
            Object k = iter.next();
            Nav.Updatable nav = (Nav.Updatable) pathWhen(m.valAt(k), matches, pred, includeKeys);
            boolean termK = includeKeys && RT.booleanCast(pred.invoke(k));

            if (termK) {
                matches.add(k);
            }

            if (nav != null || termK) {
                if (childNavs == null) childNavs = new ArrayList<>();
                if (nav != null && termK) {
                    childNavs.add(new Nav.KeyVal(k, nav));
                } else if (nav != null) {
                    childNavs.add(new Nav.Val(k, nav));
                } else {
                    childNavs.add(new Nav.Key(k));
                }
            }
        }

        boolean predRes = RT.booleanCast(pred.invoke(m));
        if (predRes) matches.add(m);

        if (childNavs != null || predRes) {
            return (m instanceof IEditableCollection)
                ? new Nav.MapEditable(childNavs, predRes)
                : new Nav.MapPersistent(childNavs, predRes);
        }
        return null;
    }

    private static Object pathMapStruct(PersistentStructMap m, ArrayList<Object> matches, IFn pred, boolean includeKeys) {
        // Struct maps: keys are fixed, never transform keys
        Iterator<?> iter = RT.iter(RT.keys(m));
        ArrayList<Nav.Val> childNavs = null;

        while (iter.hasNext()) {
            Object k = iter.next();
            Object v = m.valAt(k);
            Nav.Updatable nav = (Nav.Updatable) pathWhen(v, matches, pred, includeKeys);

            if (nav != null) {
                if (childNavs == null) childNavs = new ArrayList<>();
                childNavs.add(new Nav.Val(k, nav));
            }
        }

        boolean predRes = RT.booleanCast(pred.invoke(m));
        if (predRes) matches.add(m);

        if (childNavs != null || predRes) {
            return new Nav.MapStruct(childNavs, predRes);
        }
        return null;
    }

    private static void findMap(IPersistentMap m, ArrayList<Object> matches, IFn pred, IFn tf, boolean includeKeys) {
        Iterator<?> iter = (m instanceof IMapIterable mi)
            ? mi.keyIterator()
            : RT.iter(RT.keys(m));

        while (iter.hasNext()) {
            Object k = iter.next();
            Object v = m.valAt(k);
            findWhen(v, matches, pred, tf, includeKeys);
            if (includeKeys && RT.booleanCast(pred.invoke(k))) {
                matches.add(tf.invoke(k));
            }
        }

        if (RT.booleanCast(pred.invoke(m))) {
            matches.add(tf.invoke(m));
        }
    }

    // ========================================================================
    // Vector scanning
    // ========================================================================

    private static Object pathVector(IPersistentVector v, ArrayList<Object> matches, IFn pred, boolean includeKeys) {
        int count = v.count();
        ArrayList<Nav.Pos> childNavs = null;

        for (int i = 0; i < count; i++) {
            Object elem = v.nth(i);
            Nav.Updatable nav = (Nav.Updatable) pathWhen(elem, matches, pred, includeKeys);
            if (nav != null) {
                if (childNavs == null) childNavs = new ArrayList<>();
                childNavs.add(new Nav.Pos(i, nav));
            }
        }

        boolean predRes = RT.booleanCast(pred.invoke(v));
        if (predRes) matches.add(v);

        if (childNavs != null) {
            return (v instanceof IEditableCollection)
                ? new Nav.VecEdit(childNavs, predRes)
                : new Nav.VecPersistent(childNavs, predRes);
        } else if (predRes) {
            return new Nav.VecEdit(null, true);
        }
        return null;
    }

    private static void findVector(IPersistentVector v, ArrayList<Object> matches, IFn pred, IFn tf, boolean includeKeys) {
        int count = v.count();
        for (int i = 0; i < count; i++) {
            findWhen(v.nth(i), matches, pred, tf, includeKeys);
        }
        if (RT.booleanCast(pred.invoke(v))) {
            matches.add(tf.invoke(v));
        }
    }

    // ========================================================================
    // Set scanning
    // ========================================================================

    private static Object pathSet(IPersistentSet s, ArrayList<Object> matches, IFn pred, boolean includeKeys) {
        ISeq seq = RT.seq(s);
        ArrayList<Nav.Mem> childNavs = null;

        while (seq != null) {
            Object elem = seq.first();
            Nav.Updatable nav = (Nav.Updatable) pathWhen(elem, matches, pred, includeKeys);
            if (nav != null) {
                if (childNavs == null) childNavs = new ArrayList<>();
                childNavs.add(new Nav.Mem(elem, nav));
            }
            seq = seq.next();
        }

        boolean predRes = RT.booleanCast(pred.invoke(s));
        if (predRes) matches.add(s);

        if (childNavs != null) {
            return (s instanceof IEditableCollection)
                ? new Nav.SetEdit(childNavs, predRes)
                : new Nav.SetPersistent(childNavs, predRes);
        } else if (predRes) {
            return new Nav.SetEdit(null, true);
        }
        return null;
    }

    private static void findSet(IPersistentSet s, ArrayList<Object> matches, IFn pred, IFn tf, boolean includeKeys) {
        ISeq seq = RT.seq(s);
        while (seq != null) {
            findWhen(seq.first(), matches, pred, tf, includeKeys);
            seq = seq.next();
        }
        if (RT.booleanCast(pred.invoke(s))) {
            matches.add(tf.invoke(s));
        }
    }

    // ========================================================================
    // Sequential scanning (lists, lazy seqs, etc.)
    // ========================================================================

    private static Object pathSeq(ISeq s, ArrayList<Object> matches, IFn pred, boolean includeKeys) {
        if (s == null) return pathScalar(null, matches, pred);

        Object originalColl = s;
        int idx = 0;
        ArrayList<Nav.Pos> childNavs = null;

        while (s != null) {
            Object elem = s.first();
            Nav.Updatable nav = (Nav.Updatable) pathWhen(elem, matches, pred, includeKeys);
            if (nav != null) {
                if (childNavs == null) childNavs = new ArrayList<>();
                childNavs.add(new Nav.Pos(idx, nav));
            }
            idx++;
            s = s.next();
        }

        boolean predRes = RT.booleanCast(pred.invoke(originalColl));
        if (predRes) matches.add(originalColl);

        if (childNavs != null || predRes) {
            return new Nav.SeqNav(childNavs, predRes);
        }
        return null;
    }

    private static void findSeq(ISeq s, ArrayList<Object> matches, IFn pred, IFn tf, boolean includeKeys) {
        if (s == null) {
            findScalar(null, matches, pred, tf);
            return;
        }

        Object originalColl = s;
        while (s != null) {
            findWhen(s.first(), matches, pred, tf, includeKeys);
            s = s.next();
        }
        if (RT.booleanCast(pred.invoke(originalColl))) {
            matches.add(tf.invoke(originalColl));
        }
    }

    // ========================================================================
    // Scalar scanning
    // ========================================================================

    private static Object pathScalar(Object obj, ArrayList<Object> matches, IFn pred) {
        if (RT.booleanCast(pred.invoke(obj))) {
            matches.add(obj);
            return Nav.Scalar.INSTANCE;
        }
        return null;
    }

    private static void findScalar(Object obj, ArrayList<Object> matches, IFn pred, IFn tf) {
        if (RT.booleanCast(pred.invoke(obj))) {
            matches.add(tf.invoke(obj));
        }
    }
}
