module PlasmaEquilibriumToolkitZygoteExt

using PlasmaEquilibriumToolkit
isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote)

# Deal with the CurvilinearVector representation containing the basis and coefficients
Zygote._gradcopy!(dst::AbstractArray, src::NamedTuple{(:basis, :coeffs), Tuple{MMatrix{3, 3, T, 9}, SizedVector{3, T, Vector{T}}}}) where {T} = Zygote._gradcopy!(dst, src.basis * diagm(src.coeffs))
Zygote._gradcopy!(dst::AbstractArray, src::NamedTuple{(:basis, :coeffs), Tuple{MMatrix{3, 3, T, 9}, Vector{T}}}) where {T} = Zygote._gradcopy!(dst, src.basis * diagm(src.coeffs))

Zygote._gradcopy!(dst::AbstractArray, src::NamedTuple{(:basis, :coeffs), Tuple{Vector{T}, T}}) where {T} = Zygote._gradcopy!(dst, src.basis * src.coeffs)

end
