#=
reproduce_zop_off_by_one.jl

Demonstrates the ZeroOrderPulse off-by-one bug.

The real-world trigger: after an NLP solve, a NamedTrajectory's times are
recovered via `get_times(traj) = cumsum([0.0, Δt...])`. When a ZeroOrderPulse
is reconstructed from this trajectory, its stored `.t` are cumsum-based.
Later, `_sample_times(traj, N)` recomputes times via `range(0, T, length=N)`,
producing different floats (LinRange's compensated formula vs cumulative addition).

At ConstantInterpolation's discontinuities (knot times), a difference of O(1e-17)
can cause `pulse(t_k - ε) → u[k-1]` instead of `u[k]`.

Run:
    JULIA_DEPOT_PATH=/home/agent/content/workspace/depot/ \
    JULIA_LOAD_PATH=/home/agent/content/workspace/env/:@:@stdlib \
    /home/agent/content/workspace/usr/bin/julia scratch/reproduce_zop_off_by_one.jl
=#

using Piccolo

function main()
    N = 51
    T = 10.0

    # Controls u[k] = k — makes any index shift visually obvious
    controls = Float64.(reshape(1:N, 1, N))

    # ====================================================================== #
    # 1. Build times via two different methods
    # ====================================================================== #

    # cumsum path: what get_times(traj) produces for variable-Δt trajectories
    Δt = T / (N - 1)
    times_cumsum = cumsum([0.0; fill(Δt, N - 1)])

    # range path: what _sample_times(traj, N) produces
    times_range = collect(range(0.0, T, length=N))

    # ====================================================================== #
    # 2. Show the float disagreement
    # ====================================================================== #

    println("="^72)
    println("ZeroOrderPulse off-by-one reproduction")
    println("="^72)
    println()
    println("N = $N,  T = $T μs,  controls u[k] = k")
    println()

    time_diffs = times_range .- times_cumsum
    n_nonzero = count(!iszero, time_diffs)
    max_diff = maximum(abs, time_diffs)

    println("Time vector comparison:")
    println("  cumsum([0; fill(Δt, N-1)])  vs  collect(range(0, T, N))")
    println("  Non-zero differences: $n_nonzero / $N")
    println("  Max |Δt|:            $max_diff")
    println()

    if n_nonzero > 0
        println("  First 5 non-zero diffs:")
        shown = 0
        for k in 1:N
            if time_diffs[k] != 0
                shown += 1
                println("    k=$(lpad(k,2)):  cumsum=$(times_cumsum[k])  " *
                        "range=$(times_range[k])  Δ=$(time_diffs[k])")
                shown >= 5 && break
            end
        end
        println()
    end

    # ====================================================================== #
    # 3. Construct ZeroOrderPulse with cumsum times, evaluate at range times
    # ====================================================================== #

    pulse = ZeroOrderPulse(controls, times_cumsum)
    stored_u = Matrix(pulse.controls.u)
    sampled = hcat([pulse(t) for t in times_range]...)

    mismatches = [k for k in 1:N if stored_u[1, k] != sampled[1, k]]

    println("Control value comparison:")
    println("  stored u[:,k]  vs  pulse(range_time[k])")
    println("  Mismatched knots: $(length(mismatches)) / $N")
    println()

    if !isempty(mismatches)
        println("  Off-by-one knots (first 10):")
        for k in mismatches[1:min(10, end)]
            Δ = times_range[k] - times_cumsum[k]
            direction = Δ < 0 ? "range < cumsum" : "range > cumsum"
            println("    k=$(lpad(k,2)): stored_u=$(Int(stored_u[1,k]))  " *
                    "sampled=$(Int(sampled[1,k]))  " *
                    "shift=$(Int(stored_u[1,k] - sampled[1,k]))  " *
                    "($(direction), Δ=$(Δ))")
        end
        println()

        # Verify: mismatched knots should return u[k-1]
        all_shifted = all(k -> k > 1 && sampled[1, k] == stored_u[1, k - 1], mismatches)
        println("  All mismatches are exactly u[k-1]? $all_shifted")
        println()
    else
        println("  (No mismatches — times may not trigger the discontinuity edge)")
        println()
    end

    # ====================================================================== #
    # 4. Verify sample(pulse, N) has the same issue
    # ====================================================================== #

    println("-"^72)
    println("Downstream: Piccolo.sample(pulse, N)")
    println("-"^72)
    println()

    sampled_api, _ = sample(pulse, N)
    api_mismatches = count(k -> sampled_api[1, k] != stored_u[1, k], 1:N)
    println("  sample(pulse, $N) mismatches: $api_mismatches / $N")
    println()

    # ====================================================================== #
    # 5. Physical scale: alternating ±600 rad/μs
    # ====================================================================== #

    println("-"^72)
    println("Physical scale: alternating ±600 rad/μs controls")
    println("-"^72)
    println()

    phys_controls = reshape([k % 2 == 0 ? 600.0 : -600.0 for k in 1:N], 1, N)
    phys_pulse = ZeroOrderPulse(phys_controls, times_cumsum)
    phys_sampled = hcat([phys_pulse(t) for t in times_range]...)
    phys_stored = Matrix(phys_pulse.controls.u)

    max_err = maximum(abs, phys_sampled .- phys_stored)
    n_wrong = count(k -> phys_sampled[1, k] != phys_stored[1, k], 1:N)
    println("  Max control error: $max_err rad/μs")
    println("  Knots with wrong control: $n_wrong / $N")
    if max_err > 0
        println("  (sign flips of 1200 rad/μs at each affected knot)")
    end
    println()

    # ====================================================================== #
    # 6. Full round-trip: NamedTrajectory → ZeroOrderPulse → resample
    # ====================================================================== #

    println("-"^72)
    println("Full round-trip: NamedTrajectory ↔ ZeroOrderPulse")
    println("-"^72)
    println()

    # Build NamedTrajectory with variable timestep (the standard post-NLP case)
    Δt_vec = fill(Δt, N)
    data = (u = controls, Δt = reshape(Δt_vec, 1, N))
    traj = NamedTrajectory(data; timestep=:Δt, controls=(:Δt, :u))

    recovered_times = get_times(traj)
    println("  get_times(traj) matches cumsum?  $(recovered_times ≈ times_cumsum)")
    println("  get_times(traj) == range times?  $(recovered_times == times_range)")
    println()

    # Reconstruct ZeroOrderPulse from trajectory
    pulse_rt = ZeroOrderPulse(traj; drive_name=:u)

    # Resample at range-based times (what _sample_times would produce)
    rt_range = collect(range(0.0, duration(pulse_rt), length=N))
    rt_sampled = hcat([pulse_rt(t) for t in rt_range]...)
    rt_stored = Matrix(pulse_rt.controls.u)

    rt_mismatches = count(k -> rt_sampled[1, k] != rt_stored[1, k], 1:N)
    println("  Round-trip mismatches: $rt_mismatches / $N")

    if rt_mismatches > 0
        rt_err = maximum(abs, rt_sampled .- rt_stored)
        println("  Max |u_sampled - u_stored|: $rt_err")
    end

    println()
    println("="^72)
    println("Done.")
end

main()
