#=
test_snap_to_knots.jl — Regression test for the ZeroOrderPulse snap_to_knots fix.

Run:
    JULIA_DEPOT_PATH=/home/agent/content/workspace/depot/ \
    JULIA_LOAD_PATH=/home/agent/content/workspace/env/:@:@stdlib \
    /home/agent/content/workspace/usr/bin/julia scratch/test_snap_to_knots.jl
=#

using Piccolo
using Test

function test_snap_to_knots()
    N = 51
    T = 10.0

    # Controls u[k] = k — makes any index shift visually obvious
    controls = Float64.(reshape(1:N, 1, N))

    # cumsum path (what get_times(traj) produces for variable-Δt trajectories)
    Δt = T / (N - 1)
    times_cumsum = cumsum([0.0; fill(Δt, N - 1)])

    # range path (what _sample_times uses) — different floats
    times_range = collect(range(0.0, T, length=N))

    @testset "Time vectors differ" begin
        @test times_cumsum != times_range
        @test count(!iszero, times_range .- times_cumsum) > 0
    end

    @testset "snap_to_knots=true (default)" begin
        pulse = ZeroOrderPulse(controls, times_cumsum)
        @test pulse.snap_to_knots == true

        stored = Matrix(pulse.controls.u)

        # Pointwise evaluation at range times — should match stored
        sampled = hcat([pulse(t) for t in times_range]...)
        @test sampled == stored

        # Batch sample(pulse, N) — should also match
        sampled_batch, _ = sample(pulse, N)
        @test sampled_batch == stored
    end

    @testset "Full round-trip: NamedTrajectory ↔ ZeroOrderPulse" begin
        Δt_vec = fill(Δt, N)
        data = (u = controls, Δt = reshape(Δt_vec, 1, N))
        traj = NamedTrajectory(data; timestep=:Δt, controls=(:Δt, :u))

        pulse_rt = ZeroOrderPulse(traj; drive_name=:u)
        @test pulse_rt.snap_to_knots == true

        rt_range = collect(range(0.0, duration(pulse_rt), length=N))
        rt_sampled = hcat([pulse_rt(t) for t in rt_range]...)
        @test rt_sampled == Matrix(pulse_rt.controls.u)
    end

    @testset "snap_to_knots=false reproduces off-by-one" begin
        pulse_raw = ZeroOrderPulse(controls, times_cumsum; snap_to_knots=false)
        @test pulse_raw.snap_to_knots == false

        stored = Matrix(pulse_raw.controls.u)
        sampled_raw = hcat([pulse_raw(t) for t in times_range]...)

        # Off-by-one should be present without snapping
        n_mismatches = count(k -> sampled_raw[1, k] != stored[1, k], 1:N)
        @test n_mismatches > 0

        # Every mismatch should be exactly u[k-1]
        for k in 1:N
            if sampled_raw[1, k] != stored[1, k]
                @test k > 1
                @test sampled_raw[1, k] == stored[1, k - 1]
            end
        end
    end

    @testset "Physical-scale controls" begin
        phys_controls = reshape([k % 2 == 0 ? 600.0 : -600.0 for k in 1:N], 1, N)

        # With snapping (default): no errors
        pulse_snap = ZeroOrderPulse(phys_controls, times_cumsum)
        snap_sampled = hcat([pulse_snap(t) for t in times_range]...)
        @test snap_sampled == Matrix(pulse_snap.controls.u)

        # Without snapping: sign-flip errors
        pulse_nosnap = ZeroOrderPulse(phys_controls, times_cumsum; snap_to_knots=false)
        nosnap_sampled = hcat([pulse_nosnap(t) for t in times_range]...)
        max_err = maximum(abs, nosnap_sampled .- Matrix(pulse_nosnap.controls.u))
        @test max_err == 1200.0  # full sign flip
    end

    @testset "Existing API unchanged" begin
        # Basic construction and evaluation (from existing @testitem "ZeroOrderPulse")
        ctrl = [0.0 1.0 0.5 0.0; 0.0 -1.0 -0.5 0.0]
        ts = [0.0, 0.25, 0.5, 1.0]
        pulse = ZeroOrderPulse(ctrl, ts)

        @test duration(pulse) == 1.0
        @test n_drives(pulse) == 2
        @test drive_name(pulse) == :u
        @test pulse(0.0) ≈ [0.0, 0.0]
        @test pulse(0.1) ≈ [0.0, 0.0]
        @test pulse(0.3) ≈ [1.0, -1.0]
        @test pulse(1.0) ≈ [0.0, 0.0]

        sampled, sample_ts = sample(pulse, 5)
        @test size(sampled) == (2, 5)
        @test length(sample_ts) == 5

        pulse_custom = ZeroOrderPulse(ctrl, ts; drive_name=:Ω)
        @test drive_name(pulse_custom) == :Ω
    end
end

test_snap_to_knots()
println("\nAll tests passed.")
