| Line | Exclusive | Inclusive | Code |
|---|---|---|---|
| 1 | ## Common Interface Solve Functions | ||
| 2 | |||
| 3 | function DiffEqBase.__solve(prob::Union{DiffEqBase.AbstractODEProblem, | ||
| 4 | DiffEqBase.AbstractDAEProblem}, | ||
| 5 | alg::algType, | ||
| 6 | timeseries = [], | ||
| 7 | ts = [], | ||
| 8 | ks = [], | ||
| 9 | recompile::Type{Val{recompile_flag}} = Val{true}; | ||
| 10 | calculate_error = true, | ||
| 11 | kwargs...) where { | ||
| 12 | algType <: Union{SundialsODEAlgorithm, | ||
| 13 | SundialsDAEAlgorithm}, | ||
| 14 | recompile_flag} | ||
| 15 | integrator = DiffEqBase.__init(prob, alg, timeseries, ts, ks; kwargs...) | ||
| 16 | if integrator.sol.retcode == ReturnCode.Default | ||
| 17 | solve!(integrator; early_free = true, calculate_error = calculate_error) | ||
| 18 | end | ||
| 19 | integrator.sol | ||
| 20 | end | ||
| 21 | |||
| 22 | function DiffEqBase.__solve(prob::Union{ | ||
| 23 | DiffEqBase.AbstractSteadyStateProblem{uType, | ||
| 24 | isinplace}, | ||
| 25 | DiffEqBase.AbstractNonlinearProblem{uType, | ||
| 26 | isinplace}}, | ||
| 27 | alg::algType, | ||
| 28 | timeseries = [], | ||
| 29 | ts = [], | ||
| 30 | ks = [], | ||
| 31 | recompile::Type{Val{recompile_flag}} = Val{true}; | ||
| 32 | abstol = eps(Float64) ^ (4 // 5), | ||
| 33 | maxiters = 1000, | ||
| 34 | kwargs...) where {algType <: SundialsNonlinearSolveAlgorithm, | ||
| 35 | recompile_flag, uType, isinplace} | ||
| 36 | if prob.u0 isa Number | ||
| 37 | u0 = [prob.u0] | ||
| 38 | else | ||
| 39 | u0 = deepcopy(prob.u0) | ||
| 40 | end | ||
| 41 | |||
| 42 | p = prob.p | ||
| 43 | userdata = alg.userdata | ||
| 44 | linsolve = linear_solver(alg) | ||
| 45 | jac_upper = alg.jac_upper | ||
| 46 | jac_lower = alg.jac_lower | ||
| 47 | |||
| 48 | ### Fix the more general function to Sundials allowed style | ||
| 49 | if prob.f isa ODEFunction | ||
| 50 | t = Inf | ||
| 51 | if !isinplace && prob.u0 isa Number | ||
| 52 | f! = (du, u) -> (du .= prob.f(first(u), p, t); Cint(0)) | ||
| 53 | elseif !isinplace | ||
| 54 | f! = (du, u) -> (du .= prob.f(u, p, t); Cint(0)) | ||
| 55 | else # Then it's an in-place function on an abstract array | ||
| 56 | f! = (du, u) -> prob.f(du, u, p, t) | ||
| 57 | end | ||
| 58 | elseif prob.f isa NonlinearFunction | ||
| 59 | if !isinplace && prob.u0 isa Number | ||
| 60 | f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) | ||
| 61 | elseif !isinplace | ||
| 62 | f! = (du, u) -> (du .= prob.f(u, p); Cint(0)) | ||
| 63 | else # Then it's an in-place function on an abstract array | ||
| 64 | f! = (du, u) -> prob.f(du, u, p) | ||
| 65 | end | ||
| 66 | end | ||
| 67 | u = zero(u0) | ||
| 68 | resid = similar(u) | ||
| 69 | u,flag = ___kinsol(f!, u0; | ||
| 70 | userdata = userdata, | ||
| 71 | linear_solver = linsolve, | ||
| 72 | jac_upper = jac_upper, | ||
| 73 | jac_lower = jac_lower, | ||
| 74 | abstol, | ||
| 75 | prob.f.jac_prototype, | ||
| 76 | alg.prec_side, | ||
| 77 | alg.krylov_dim, | ||
| 78 | maxiters, | ||
| 79 | strategy = alg.globalization_strategy) | ||
| 80 | |||
| 81 | f!(resid, u) | ||
| 82 | retcode = interpret_sundials_retcode(flag) | ||
| 83 | if prob.u0 isa Number | ||
| 84 | DiffEqBase.build_solution(prob, alg, u[1], resid[1]; retcode = retcode) | ||
| 85 | else | ||
| 86 | DiffEqBase.build_solution(prob, alg, u, resid; retcode = retcode) | ||
| 87 | end | ||
| 88 | end | ||
| 89 | |||
| 90 | function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, isinplace}, | ||
| 91 | alg::SundialsODEAlgorithm{Method, LinearSolver}, | ||
| 92 | timeseries = [], | ||
| 93 | ts = [], | ||
| 94 | ks = []; | ||
| 95 | verbose = true, | ||
| 96 | callback = nothing, | ||
| 97 | abstol = 1 / 10^6, | ||
| 98 | reltol = 1 / 10^3, | ||
| 99 | saveat = Float64[], | ||
| 100 | d_discontinuities = Float64[], | ||
| 101 | tstops = Float64[], | ||
| 102 | maxiters = Int(1e5), | ||
| 103 | dt = nothing, | ||
| 104 | dtmin = 0.0, | ||
| 105 | dtmax = 0.0, | ||
| 106 | timeseries_errors = true, | ||
| 107 | dense_errors = false, | ||
| 108 | save_everystep = isempty(saveat), save_idxs = nothing, | ||
| 109 | save_on = true, | ||
| 110 | save_start = save_everystep || isempty(saveat) || | ||
| 111 | saveat isa Number ? true : | ||
| 112 | prob.tspan[1] in saveat, | ||
| 113 | save_end = save_everystep || isempty(saveat) || | ||
| 114 | saveat isa Number ? true : | ||
| 115 | prob.tspan[2] in saveat, | ||
| 116 | dense = save_everystep && isempty(saveat), | ||
| 117 | progress = false, | ||
| 118 | progress_steps=1000, | ||
| 119 | progress_name = "ODE", | ||
| 120 | progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, | ||
| 121 | progress_id = gensym("Sundials"), | ||
| 122 | save_timeseries = nothing, | ||
| 123 | advance_to_tstop = false, | ||
| 124 | stop_at_next_tstop = false, | ||
| 125 | userdata = nothing, | ||
| 126 | alias_u0 = false, | ||
| 127 | kwargs...) where {uType, tupType, isinplace, Method, LinearSolver | ||
| 128 | } | ||
| 129 | tType = eltype(tupType) | ||
| 130 | |||
| 131 | if verbose | ||
| 132 | warned = !isempty(kwargs) && DiffEqBase.check_keywords(alg, kwargs, warnlist) | ||
| 133 | warned && DiffEqBase.warn_compat() | ||
| 134 | end | ||
| 135 | |||
| 136 | if prob.f.mass_matrix != LinearAlgebra.I | ||
| 137 | error("This solver is not able to use mass matrices.") | ||
| 138 | end | ||
| 139 | |||
| 140 | if reltol isa AbstractArray | ||
| 141 | error("Sundials only allows scalar reltol.") | ||
| 142 | end | ||
| 143 | |||
| 144 | if length(prob.u0) <= 0 | ||
| 145 | error("Sundials requires at least one state variable.") | ||
| 146 | end | ||
| 147 | |||
| 148 | progress && Logging.@logmsg(Logging.LogLevel(-1), progress_name, _id=progress_id, progress=0) | ||
| 149 | |||
| 150 | tstops = vcat(tstops, d_discontinuities) | ||
| 151 | callbacks_internal = DiffEqBase.CallbackSet(callback) | ||
| 152 | |||
| 153 | max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal) | ||
| 154 | if max_len_cb isa VectorContinuousCallback | ||
| 155 | callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64) | ||
| 156 | else | ||
| 157 | callback_cache = nothing | ||
| 158 | end | ||
| 159 | |||
| 160 | tspan = Float64.(prob.tspan) | ||
| 161 | t0 = tspan[1] | ||
| 162 | |||
| 163 | tdir = sign(tspan[2] - tspan[1]) | ||
| 164 | |||
| 165 | tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir, | ||
| 166 | tspan, tType) | ||
| 167 | |||
| 168 | if prob.u0 isa Number | ||
| 169 | u0 = [prob.u0] | ||
| 170 | else | ||
| 171 | if alias_u0 | ||
| 172 | u0 = prob.u0 | ||
| 173 | else | ||
| 174 | u0 = copy(prob.u0) | ||
| 175 | end | ||
| 176 | end | ||
| 177 | |||
| 178 | ### Fix the more general function to Sundials allowed style | ||
| 179 | if !isinplace && prob.u0 isa Number | ||
| 180 | f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0)) | ||
| 181 | elseif !isinplace | ||
| 182 | f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0)) | ||
| 183 | else # Then it's an in-place function on an abstract array | ||
| 184 | f! = prob.f | ||
| 185 | end | ||
| 186 | |||
| 187 | if alg isa CVODE_BDF | ||
| 188 | alg_code = CV_BDF | ||
| 189 | elseif alg isa CVODE_Adams | ||
| 190 | alg_code = CV_ADAMS | ||
| 191 | end | ||
| 192 | |||
| 193 | #if Method == :Newton | ||
| 194 | # method_code = CV_NEWTON | ||
| 195 | #elseif Method == :Functional | ||
| 196 | # method_code = CV_FUNCTIONAL | ||
| 197 | #end | ||
| 198 | |||
| 199 | mem_ptr = CVodeCreate(alg_code) | ||
| 200 | (mem_ptr == C_NULL) && error("Failed to allocate CVODE solver object") | ||
| 201 | mem = Handle(mem_ptr) | ||
| 202 | |||
| 203 | !verbose && CVodeSetErrHandlerFn(mem, | ||
| 204 | @cfunction(null_error_handler, Nothing, | ||
| 205 | (Cint, Char, Char, Ptr{Cvoid})), | ||
| 206 | C_NULL) | ||
| 207 | |||
| 208 | save_start ? ts = [t0] : ts = Float64[] | ||
| 209 | |||
| 210 | out = copy(u0) | ||
| 211 | uvec = vec(u0) # aliases u0 | ||
| 212 | utmp = NVector(uvec) # aliases u0 | ||
| 213 | |||
| 214 | use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) && | ||
| 215 | LinearSolver ∈ SPARSE_SOLVERS) || | ||
| 216 | prob.f.jac_prototype isa AbstractSciMLOperator | ||
| 217 | userfun = FunJac(f!, | ||
| 218 | prob.f.jac, | ||
| 219 | prob.p, | ||
| 220 | nothing, | ||
| 221 | use_jac_prototype ? prob.f.jac_prototype : nothing, | ||
| 222 | alg.prec, | ||
| 223 | alg.psetup, | ||
| 224 | u0, | ||
| 225 | out) | ||
| 226 | |||
| 227 | function getcfunf(::T) where {T} | ||
| 228 | @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) | ||
| 229 | end | ||
| 230 | |||
| 231 | flag = CVodeInit(mem, getcfunf(userfun), t0, utmp) | ||
| 232 | |||
| 233 | dt !== nothing && (flag = CVodeSetInitStep(mem, Float64(dt))) | ||
| 234 | flag = CVodeSetMinStep(mem, Float64(dtmin)) | ||
| 235 | flag = CVodeSetMaxStep(mem, Float64(dtmax)) | ||
| 236 | flag = CVodeSetUserData(mem, userfun) | ||
| 237 | if abstol isa Array | ||
| 238 | flag = CVodeSVtolerances(mem, reltol, abstol) | ||
| 239 | else | ||
| 240 | flag = CVodeSStolerances(mem, reltol, abstol) | ||
| 241 | end | ||
| 242 | flag = CVodeSetMaxNumSteps(mem, maxiters) | ||
| 243 | flag = CVodeSetMaxOrd(mem, alg.max_order) | ||
| 244 | flag = CVodeSetMaxHnilWarns(mem, alg.max_hnil_warns) | ||
| 245 | flag = CVodeSetStabLimDet(mem, alg.stability_limit_detect) | ||
| 246 | flag = CVodeSetMaxErrTestFails(mem, alg.max_error_test_failures) | ||
| 247 | flag = CVodeSetMaxNonlinIters(mem, alg.max_nonlinear_iters) | ||
| 248 | flag = CVodeSetMaxConvFails(mem, alg.max_convergence_failures) | ||
| 249 | |||
| 250 | nojacobian = true | ||
| 251 | |||
| 252 | if Method == :Newton # Only use a linear solver if it's a Newton-based method | ||
| 253 | if LinearSolver in (:Dense, :LapackDense) | ||
| 254 | nojacobian = false | ||
| 255 | A = SUNDenseMatrix(length(uvec), length(uvec)) | ||
| 256 | _A = MatrixHandle(A, DenseMatrix()) | ||
| 257 | if LinearSolver === :Dense | ||
| 258 | LS = SUNLinSol_Dense(uvec, A) | ||
| 259 | _LS = LinSolHandle(LS, Dense()) | ||
| 260 | else | ||
| 261 | LS = SUNLinSol_LapackDense(uvec, A) | ||
| 262 | _LS = LinSolHandle(LS, LapackDense()) | ||
| 263 | end | ||
| 264 | elseif LinearSolver in (:Band, :LapackBand) | ||
| 265 | nojacobian = false | ||
| 266 | A = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower) | ||
| 267 | _A = MatrixHandle(A, BandMatrix()) | ||
| 268 | if LinearSolver === :Band | ||
| 269 | LS = SUNLinSol_Band(uvec, A) | ||
| 270 | _LS = LinSolHandle(LS, Band()) | ||
| 271 | else | ||
| 272 | LS = SUNLinSol_LapackBand(uvec, A) | ||
| 273 | _LS = LinSolHandle(LS, LapackBand()) | ||
| 274 | end | ||
| 275 | elseif LinearSolver == :Diagonal | ||
| 276 | nojacobian = false | ||
| 277 | flag = CVDiag(mem) | ||
| 278 | _A = nothing | ||
| 279 | _LS = nothing | ||
| 280 | elseif LinearSolver == :GMRES | ||
| 281 | LS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.krylov_dim) | ||
| 282 | _A = nothing | ||
| 283 | _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR()) | ||
| 284 | elseif LinearSolver == :FGMRES | ||
| 285 | LS = SUNLinSol_SPFGMR(uvec, alg.prec_side, alg.krylov_dim) | ||
| 286 | _A = nothing | ||
| 287 | _LS = LinSolHandle(LS, SPFGMR()) | ||
| 288 | elseif LinearSolver == :BCG | ||
| 289 | LS = SUNLinSol_SPBCGS(uvec, alg.prec_side, alg.krylov_dim) | ||
| 290 | _A = nothing | ||
| 291 | _LS = LinSolHandle(LS, SPBCGS()) | ||
| 292 | elseif LinearSolver == :PCG | ||
| 293 | LS = SUNLinSol_PCG(uvec, alg.prec_side, alg.krylov_dim) | ||
| 294 | _A = nothing | ||
| 295 | _LS = LinSolHandle(LS, PCG()) | ||
| 296 | elseif LinearSolver == :TFQMR | ||
| 297 | LS = SUNLinSol_SPTFQMR(uvec, alg.prec_side, alg.krylov_dim) | ||
| 298 | _A = nothing | ||
| 299 | _LS = LinSolHandle(LS, PTFQMR()) | ||
| 300 | elseif LinearSolver == :KLU | ||
| 301 | nojacobian = false | ||
| 302 | nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) | ||
| 303 | A = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT) | ||
| 304 | LS = SUNLinSol_KLU(uvec, A) | ||
| 305 | _A = MatrixHandle(A, SparseMatrix()) | ||
| 306 | _LS = LinSolHandle(LS, KLU()) | ||
| 307 | end | ||
| 308 | if LinearSolver !== :Diagonal | ||
| 309 | flag = CVodeSetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) | ||
| 310 | end | ||
| 311 | NLS = SUNNonlinSol_Newton(uvec) | ||
| 312 | else | ||
| 313 | _A = nothing | ||
| 314 | _LS = nothing | ||
| 315 | # TODO: Anderson Acceleration | ||
| 316 | anderson_m = 0 | ||
| 317 | NLS = SUNNonlinSol_FixedPoint(uvec, anderson_m) | ||
| 318 | end | ||
| 319 | CVodeSetNonlinearSolver(mem, NLS) | ||
| 320 | |||
| 321 | if DiffEqBase.has_jac(prob.f) && Method == :Newton | ||
| 322 | function getcfunjac(::T) where {T} | ||
| 323 | @cfunction(cvodejac, | ||
| 324 | Cint, | ||
| 325 | (realtype, | ||
| 326 | N_Vector, | ||
| 327 | N_Vector, | ||
| 328 | SUNMatrix, | ||
| 329 | Ref{T}, | ||
| 330 | N_Vector, | ||
| 331 | N_Vector, | ||
| 332 | N_Vector)) | ||
| 333 | end | ||
| 334 | jac = getcfunjac(userfun) | ||
| 335 | flag = CVodeSetUserData(mem, userfun) | ||
| 336 | nojacobian || (flag = CVodeSetJacFn(mem, jac)) | ||
| 337 | else | ||
| 338 | jac = nothing | ||
| 339 | end | ||
| 340 | |||
| 341 | if prob.f.jac_prototype isa AbstractSciMLOperator | ||
| 342 | "here!!!!" | ||
| 343 | function getcfunjtimes(::T) where {T} | ||
| 344 | @cfunction(jactimes, | ||
| 345 | Cint, | ||
| 346 | (N_Vector, N_Vector, realtype, N_Vector, N_Vector, Ref{T}, N_Vector)) | ||
| 347 | end | ||
| 348 | jtimes = getcfunjtimes(userfun) | ||
| 349 | CVodeSetJacTimes(mem, C_NULL, jtimes) | ||
| 350 | end | ||
| 351 | |||
| 352 | if alg.prec !== nothing | ||
| 353 | function getpercfun(::T) where {T} | ||
| 354 | @cfunction(precsolve, | ||
| 355 | Cint, | ||
| 356 | (Float64, | ||
| 357 | N_Vector, | ||
| 358 | N_Vector, | ||
| 359 | N_Vector, | ||
| 360 | N_Vector, | ||
| 361 | Float64, | ||
| 362 | Float64, | ||
| 363 | Int, | ||
| 364 | Ref{T})) | ||
| 365 | end | ||
| 366 | precfun = getpercfun(userfun) | ||
| 367 | |||
| 368 | function getpsetupfun(::T) where {T} | ||
| 369 | @cfunction(precsetup, | ||
| 370 | Cint, | ||
| 371 | (Float64, N_Vector, N_Vector, Int, Ptr{Int}, Float64, Ref{T})) | ||
| 372 | end | ||
| 373 | psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun) | ||
| 374 | |||
| 375 | CVodeSetPreconditioner(mem, psetupfun, precfun) | ||
| 376 | end | ||
| 377 | |||
| 378 | tmp = isnothing(callbacks_internal) ? u0 : similar(u0) | ||
| 379 | uprev = isnothing(callbacks_internal) ? u0 : similar(u0) | ||
| 380 | tout = [tspan[1]] | ||
| 381 | |||
| 382 | if save_start | ||
| 383 | if save_idxs === nothing | ||
| 384 | ures = Vector{uType}() | ||
| 385 | dures = Vector{uType}() | ||
| 386 | save_value!(ures, u0, uType, save_idxs) | ||
| 387 | if dense | ||
| 388 | f!(out, u0, prob.p, tspan[1]) | ||
| 389 | save_value!(dures, out, uType, save_idxs) | ||
| 390 | end | ||
| 391 | else | ||
| 392 | ures = [u0[save_idxs]] | ||
| 393 | if dense | ||
| 394 | f!(out, u0, prob.p, tspan[1]) | ||
| 395 | dures = [out[save_idxs]] | ||
| 396 | end | ||
| 397 | end | ||
| 398 | else | ||
| 399 | ures = Vector{uType}() | ||
| 400 | dures = Vector{uType}() | ||
| 401 | end | ||
| 402 | |||
| 403 | sol = DiffEqBase.build_solution(prob, | ||
| 404 | alg, | ||
| 405 | ts, | ||
| 406 | ures; | ||
| 407 | dense = dense, | ||
| 408 | interp = dense ? | ||
| 409 | DiffEqBase.HermiteInterpolation(ts, ures, | ||
| 410 | dures) : | ||
| 411 | DiffEqBase.LinearInterpolation(ts, ures), | ||
| 412 | timeseries_errors = timeseries_errors, | ||
| 413 | stats = SciMLBase.DEStats(0), | ||
| 414 | calculate_error = false) | ||
| 415 | opts = DEOptions(saveat_internal, | ||
| 416 | tstops_internal, | ||
| 417 | saveat, tstops, save_start, | ||
| 418 | save_everystep, save_idxs, | ||
| 419 | dense, | ||
| 420 | timeseries_errors, | ||
| 421 | dense_errors, | ||
| 422 | save_on, | ||
| 423 | save_end, | ||
| 424 | callbacks_internal, | ||
| 425 | abstol, | ||
| 426 | reltol, | ||
| 427 | verbose, | ||
| 428 | advance_to_tstop, | ||
| 429 | stop_at_next_tstop, | ||
| 430 | progress, | ||
| 431 | progress_steps, | ||
| 432 | progress_name, | ||
| 433 | progress_message, | ||
| 434 | progress_id, | ||
| 435 | maxiters) | ||
| 436 | integrator = CVODEIntegrator(u0, | ||
| 437 | utmp, | ||
| 438 | prob.p, | ||
| 439 | t0, | ||
| 440 | t0, | ||
| 441 | mem, | ||
| 442 | _LS, | ||
| 443 | _A, | ||
| 444 | sol, | ||
| 445 | alg, | ||
| 446 | f!, | ||
| 447 | userfun, | ||
| 448 | jac, | ||
| 449 | opts, | ||
| 450 | tout, | ||
| 451 | tdir, | ||
| 452 | false, | ||
| 453 | tmp, | ||
| 454 | uprev, | ||
| 455 | Cint(flag), | ||
| 456 | false, | ||
| 457 | 0, | ||
| 458 | 1, | ||
| 459 | callback_cache, | ||
| 460 | 0.0) | ||
| 461 | initialize_callbacks!(integrator) | ||
| 462 | integrator | ||
| 463 | end # function solve | ||
| 464 | |||
| 465 | function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, isinplace}, | ||
| 466 | alg::ARKODE{Method, LinearSolver, MassLinearSolver}, | ||
| 467 | timeseries = [], | ||
| 468 | ts = [], | ||
| 469 | ks = []; | ||
| 470 | verbose = true, | ||
| 471 | callback = nothing, | ||
| 472 | abstol = 1 / 10^6, | ||
| 473 | reltol = 1 / 10^3, | ||
| 474 | saveat = Float64[], | ||
| 475 | tstops = Float64[], | ||
| 476 | d_discontinuities = Float64[], | ||
| 477 | maxiters = Int(1e5), | ||
| 478 | dt = nothing, | ||
| 479 | dtmin = 0.0, | ||
| 480 | dtmax = 0.0, | ||
| 481 | timeseries_errors = true, | ||
| 482 | dense_errors = false, | ||
| 483 | save_everystep = isempty(saveat), save_idxs = nothing, | ||
| 484 | dense = save_everystep, | ||
| 485 | save_on = true, | ||
| 486 | save_start = true, | ||
| 487 | save_end = true, | ||
| 488 | save_timeseries = nothing, | ||
| 489 | progress = false, | ||
| 490 | progress_steps = 1000, | ||
| 491 | progress_name = "ODE", | ||
| 492 | progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, | ||
| 493 | progress_id = gensym("Sundials"), | ||
| 494 | advance_to_tstop = false, | ||
| 495 | stop_at_next_tstop = false, | ||
| 496 | userdata = nothing, | ||
| 497 | alias_u0 = false, | ||
| 498 | kwargs...) where {uType, tupType, isinplace, Method, | ||
| 499 | LinearSolver, | ||
| 500 | MassLinearSolver} | ||
| 501 | tType = eltype(tupType) | ||
| 502 | |||
| 503 | if verbose | ||
| 504 | warned = !isempty(kwargs) && DiffEqBase.check_keywords(alg, kwargs, warnlist) | ||
| 505 | warned && DiffEqBase.warn_compat() | ||
| 506 | end | ||
| 507 | |||
| 508 | if reltol isa AbstractArray | ||
| 509 | error("Sundials only allows scalar reltol.") | ||
| 510 | end | ||
| 511 | |||
| 512 | if length(prob.u0) <= 0 | ||
| 513 | error("Sundials requires at least one state variable.") | ||
| 514 | end | ||
| 515 | |||
| 516 | progress && Logging.@logmsg(Logging.LogLevel(-1), progress_name, _id=progress_id, progress=0) | ||
| 517 | |||
| 518 | tstops = vcat(tstops, d_discontinuities) | ||
| 519 | callbacks_internal = DiffEqBase.CallbackSet(callback) | ||
| 520 | |||
| 521 | max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal) | ||
| 522 | if max_len_cb isa VectorContinuousCallback | ||
| 523 | callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64) | ||
| 524 | else | ||
| 525 | callback_cache = nothing | ||
| 526 | end | ||
| 527 | |||
| 528 | tspan = prob.tspan | ||
| 529 | t0 = tspan[1] | ||
| 530 | |||
| 531 | tdir = sign(tspan[2] - tspan[1]) | ||
| 532 | |||
| 533 | tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir, | ||
| 534 | tspan, tType) | ||
| 535 | |||
| 536 | if prob.u0 isa Number | ||
| 537 | u0 = [prob.u0] | ||
| 538 | else | ||
| 539 | if alias_u0 | ||
| 540 | u0 = prob.u0 | ||
| 541 | else | ||
| 542 | u0 = copy(prob.u0) | ||
| 543 | end | ||
| 544 | end | ||
| 545 | |||
| 546 | save_start ? ts = [t0] : ts = Float64[] | ||
| 547 | out = copy(u0) | ||
| 548 | uvec = vec(u0) | ||
| 549 | utmp = NVector(uvec) | ||
| 550 | |||
| 551 | function arkodemem(; fe = C_NULL, fi = C_NULL, t0 = t0, u0 = utmp) | ||
| 552 | mem_ptr = ARKStepCreate(fe, fi, t0, u0) | ||
| 553 | (mem_ptr == C_NULL) && error("Failed to allocate ARKODE solver object") | ||
| 554 | mem = Handle(mem_ptr) | ||
| 555 | |||
| 556 | !verbose && ARKStepSetErrHandlerFn(mem, | ||
| 557 | @cfunction(null_error_handler, Nothing, | ||
| 558 | (Cint, Char, Char, Ptr{Cvoid})), | ||
| 559 | C_NULL) | ||
| 560 | return mem | ||
| 561 | end | ||
| 562 | |||
| 563 | ### Fix the more general function to Sundials allowed style | ||
| 564 | if !isinplace && prob.u0 isa Number | ||
| 565 | f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0)) | ||
| 566 | elseif !isinplace | ||
| 567 | f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0)) | ||
| 568 | else # Then it's an in-place function on an abstract array | ||
| 569 | f! = prob.f | ||
| 570 | end | ||
| 571 | |||
| 572 | if prob.problem_type isa SplitODEProblem | ||
| 573 | |||
| 574 | ### Fix the more general function to Sundials allowed style | ||
| 575 | if !isinplace && prob.u0 isa Number | ||
| 576 | f1! = (du, u, p, t) -> (du .= prob.f.f1(first(u), p, t); Cint(0)) | ||
| 577 | f2! = (du, u, p, t) -> (du .= prob.f.f2(first(u), p, t); Cint(0)) | ||
| 578 | elseif !isinplace | ||
| 579 | f1! = (du, u, p, t) -> (du .= prob.f.f1(u, p, t); Cint(0)) | ||
| 580 | f2! = (du, u, p, t) -> (du .= prob.f.f2(u, p, t); Cint(0)) | ||
| 581 | else # Then it's an in-place function on an abstract array | ||
| 582 | f1! = prob.f.f1 | ||
| 583 | f2! = prob.f.f2 | ||
| 584 | end | ||
| 585 | |||
| 586 | use_jac_prototype = (isa(prob.f.f1.jac_prototype, SparseArrays.SparseMatrixCSC) && | ||
| 587 | LinearSolver ∈ SPARSE_SOLVERS) | ||
| 588 | userfun = FunJac(f1!, | ||
| 589 | f2!, | ||
| 590 | prob.f.f1.jac, | ||
| 591 | prob.p, | ||
| 592 | prob.f.mass_matrix, | ||
| 593 | use_jac_prototype ? prob.f.f1.jac_prototype : nothing, | ||
| 594 | alg.prec, | ||
| 595 | alg.psetup, | ||
| 596 | u0, | ||
| 597 | out, | ||
| 598 | nothing) | ||
| 599 | |||
| 600 | function getcfunjac(::T) where {T} | ||
| 601 | @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) | ||
| 602 | end | ||
| 603 | function getcfunjac2(::T) where {T} | ||
| 604 | @cfunction(cvodefunjac2, Cint, (realtype, N_Vector, N_Vector, Ref{T})) | ||
| 605 | end | ||
| 606 | cfj1 = getcfunjac(userfun) | ||
| 607 | cfj2 = getcfunjac2(userfun) | ||
| 608 | |||
| 609 | mem = arkodemem(; fi = cfj1, fe = cfj2) | ||
| 610 | else | ||
| 611 | use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) && | ||
| 612 | LinearSolver ∈ SPARSE_SOLVERS) | ||
| 613 | userfun = FunJac(f!, | ||
| 614 | prob.f.jac, | ||
| 615 | prob.p, | ||
| 616 | prob.f.mass_matrix, | ||
| 617 | use_jac_prototype ? prob.f.jac_prototype : nothing, | ||
| 618 | alg.prec, | ||
| 619 | alg.psetup, | ||
| 620 | u0, | ||
| 621 | out) | ||
| 622 | if alg.stiffness == Explicit() | ||
| 623 | function getcfun1(::T) where {T} | ||
| 624 | @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) | ||
| 625 | end | ||
| 626 | cfj1 = getcfun1(userfun) | ||
| 627 | mem = arkodemem(; fe = cfj1) | ||
| 628 | elseif alg.stiffness == Implicit() | ||
| 629 | function getcfun2(::T) where {T} | ||
| 630 | @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) | ||
| 631 | end | ||
| 632 | cfj2 = getcfun2(userfun) | ||
| 633 | mem = arkodemem(; fi = cfj2) | ||
| 634 | end | ||
| 635 | end | ||
| 636 | |||
| 637 | dt !== nothing && (flag = ARKStepSetInitStep(mem, Float64(dt))) | ||
| 638 | flag = ARKStepSetMinStep(mem, Float64(dtmin)) | ||
| 639 | flag = ARKStepSetMaxStep(mem, Float64(dtmax)) | ||
| 640 | flag = ARKStepSetUserData(mem, userfun) | ||
| 641 | if abstol isa Array | ||
| 642 | flag = ARKStepSVtolerances(mem, reltol, abstol) | ||
| 643 | else | ||
| 644 | flag = ARKStepSStolerances(mem, reltol, abstol) | ||
| 645 | end | ||
| 646 | flag = ARKStepSetMaxNumSteps(mem, maxiters) | ||
| 647 | flag = ARKStepSetMaxHnilWarns(mem, alg.max_hnil_warns) | ||
| 648 | flag = ARKStepSetMaxErrTestFails(mem, alg.max_error_test_failures) | ||
| 649 | flag = ARKStepSetMaxConvFails(mem, alg.max_convergence_failures) | ||
| 650 | flag = ARKStepSetPredictorMethod(mem, alg.predictor_method) | ||
| 651 | flag = ARKStepSetNonlinConvCoef(mem, alg.nonlinear_convergence_coefficient) | ||
| 652 | flag = ARKStepSetDenseOrder(mem, alg.dense_order) | ||
| 653 | |||
| 654 | #= | ||
| 655 | Reference from Manual on ARKODE | ||
| 656 | To choose an explicit table, set itable to a negative value. This automatically calls ARKStepSetExplicit(). However, if the problem is posed in explicit form, i.e. 𝑦 ̇ = 𝑓 (𝑡, 𝑦), then we recommend that the ERKStep time- stepper module be used instead of ARKStep. | ||
| 657 | To select an implicit table, set etable to a negative value. This automatically calls ARKStepSetImplicit(). If both itable and etable are non-negative, then these should match an existing implicit/explicit pair, listed in the section Additive Butcher tables. This automatically calls ARKStepSetImEx(). | ||
| 658 | =# | ||
| 659 | if alg.itable === nothing && alg.etable === nothing | ||
| 660 | flag = ARKStepSetOrder(mem, alg.order) | ||
| 661 | elseif alg.itable === nothing && alg.etable !== nothing | ||
| 662 | flag = ARKStepSetTableNum(mem, -1, alg.etable) | ||
| 663 | elseif alg.itable !== nothing && alg.etable === nothing | ||
| 664 | flag = ARKStepSetTableNum(mem, alg.itable, -1) | ||
| 665 | else | ||
| 666 | flag = ARKStepSetTableNum(mem, alg.itable, alg.etable) | ||
| 667 | end | ||
| 668 | |||
| 669 | flag = ARKStepSetNonlinCRDown(mem, alg.crdown) | ||
| 670 | flag = ARKStepSetNonlinRDiv(mem, alg.rdiv) | ||
| 671 | flag = ARKStepSetDeltaGammaMax(mem, alg.dgmax) | ||
| 672 | flag = ARKStepSetMaxStepsBetweenLSet(mem, alg.msbp) | ||
| 673 | #flag = ARKStepSetAdaptivityMethod(mem,alg.adaptivity_method,1,0) | ||
| 674 | |||
| 675 | #flag = ARKStepSetFixedStep(mem,) | ||
| 676 | alg.set_optimal_params && (flag = ARKStepSetOptimalParams(mem)) | ||
| 677 | |||
| 678 | if Method == :Newton && alg.stiffness !== Explicit() # Only use a linear solver if it's a Newton-based method | ||
| 679 | if LinearSolver in (:Dense, :LapackDense) | ||
| 680 | nojacobian = false | ||
| 681 | A = SUNDenseMatrix(length(uvec), length(uvec)) | ||
| 682 | _A = MatrixHandle(A, DenseMatrix()) | ||
| 683 | if LinearSolver === :Dense | ||
| 684 | LS = SUNLinSol_Dense(uvec, A) | ||
| 685 | _LS = LinSolHandle(LS, Dense()) | ||
| 686 | else | ||
| 687 | LS = SUNLinSol_LapackDense(uvec, A) | ||
| 688 | _LS = LinSolHandle(LS, LapackDense()) | ||
| 689 | end | ||
| 690 | elseif LinearSolver in (:Band, :LapackBand) | ||
| 691 | nojacobian = false | ||
| 692 | A = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower) | ||
| 693 | _A = MatrixHandle(A, BandMatrix()) | ||
| 694 | if LinearSolver === :Band | ||
| 695 | LS = SUNLinSol_Band(uvec, A) | ||
| 696 | _LS = LinSolHandle(LS, Band()) | ||
| 697 | else | ||
| 698 | LS = SUNLinSol_LapackBand(uvec, A) | ||
| 699 | _LS = LinSolHandle(LS, LapackBand()) | ||
| 700 | end | ||
| 701 | elseif LinearSolver == :GMRES | ||
| 702 | LS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.krylov_dim) | ||
| 703 | _A = nothing | ||
| 704 | _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR()) | ||
| 705 | elseif LinearSolver == :FGMRES | ||
| 706 | LS = SUNLinSol_SPFGMR(uvec, alg.prec_side, alg.krylov_dim) | ||
| 707 | _A = nothing | ||
| 708 | _LS = LinSolHandle(LS, SPFGMR()) | ||
| 709 | elseif LinearSolver == :BCG | ||
| 710 | LS = SUNLinSol_SPBCGS(uvec, alg.prec_side, alg.krylov_dim) | ||
| 711 | _A = nothing | ||
| 712 | _LS = LinSolHandle(LS, SPBCGS()) | ||
| 713 | elseif LinearSolver == :PCG | ||
| 714 | LS = SUNLinSol_PCG(uvec, alg.prec_side, alg.krylov_dim) | ||
| 715 | _A = nothing | ||
| 716 | _LS = LinSolHandle(LS, PCG()) | ||
| 717 | elseif LinearSolver == :TFQMR | ||
| 718 | LS = SUNLinSol_SPTFQMR(uvec, alg.prec_side, alg.krylov_dim) | ||
| 719 | _A = nothing | ||
| 720 | _LS = LinSolHandle(LS, PTFQMR()) | ||
| 721 | elseif LinearSolver == :KLU | ||
| 722 | nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) | ||
| 723 | A = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT) | ||
| 724 | LS = SUNLinSol_KLU(uvec, A) | ||
| 725 | _A = MatrixHandle(A, SparseMatrix()) | ||
| 726 | _LS = LinSolHandle(LS, KLU()) | ||
| 727 | end | ||
| 728 | flag = ARKStepSetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) | ||
| 729 | flag = ARKStepSetMaxNonlinIters(mem, alg.max_nonlinear_iters) | ||
| 730 | elseif Method == :Functional && alg.stiffness !== Explicit() | ||
| 731 | ARKStepSetFixedPoint(mem, Clong(alg.krylov_dim)) | ||
| 732 | else | ||
| 733 | _A = nothing | ||
| 734 | _LS = nothing | ||
| 735 | end | ||
| 736 | |||
| 737 | if (prob.problem_type isa SplitODEProblem && | ||
| 738 | prob.f.f1.jac_prototype isa AbstractSciMLOperator) || | ||
| 739 | (!(prob.problem_type isa SplitODEProblem) && | ||
| 740 | prob.f.jac_prototype isa AbstractSciMLOperator) && | ||
| 741 | alg.stiffness !== Explicit() | ||
| 742 | function getcfunjtimes(::T) where {T} | ||
| 743 | @cfunction(jactimes, | ||
| 744 | Cint, | ||
| 745 | (N_Vector, N_Vector, realtype, N_Vector, N_Vector, Ref{T}, N_Vector)) | ||
| 746 | end | ||
| 747 | jtimes = getcfunjtimes(userfun) | ||
| 748 | ARKStepSetJacTimes(mem, C_NULL, jtimes) | ||
| 749 | end | ||
| 750 | |||
| 751 | if prob.f.mass_matrix != LinearAlgebra.I && alg.stiffness !== Explicit() | ||
| 752 | if MassLinearSolver in (:Dense, :LapackDense) | ||
| 753 | nojacobian = false | ||
| 754 | M = SUNDenseMatrix(length(uvec), length(uvec)) | ||
| 755 | _M = MatrixHandle(M, DenseMatrix()) | ||
| 756 | if MassLinearSolver === :Dense | ||
| 757 | MLS = SUNLinSol_Dense(uvec, M) | ||
| 758 | _MLS = LinSolHandle(MLS, Dense()) | ||
| 759 | else | ||
| 760 | MLS = SUNLinSol_LapackDense(uvec, M) | ||
| 761 | _MLS = LinSolHandle(MLS, LapackDense()) | ||
| 762 | end | ||
| 763 | elseif MassLinearSolver in (:Band, :LapackBand) | ||
| 764 | nojacobian = false | ||
| 765 | M = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower) | ||
| 766 | _M = MatrixHandle(M, BandMatrix()) | ||
| 767 | if MassLinearSolver === :Band | ||
| 768 | MLS = SUNLinSol_Band(uvec, M) | ||
| 769 | _MLS = LinSolHandle(MLS, Band()) | ||
| 770 | else | ||
| 771 | MLS = SUNLinSol_LapackBand(uvec, M) | ||
| 772 | _MLS = LinSolHandle(MLS, LapackBand()) | ||
| 773 | end | ||
| 774 | elseif MassLinearSolver == :GMRES | ||
| 775 | MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) | ||
| 776 | _M = nothing | ||
| 777 | _MLS = LinSolHandle(MLS, SPGMR()) | ||
| 778 | elseif MassLinearSolver == :FGMRES | ||
| 779 | MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) | ||
| 780 | _M = nothing | ||
| 781 | _MLS = LinSolHandle(MLS, SPFGMR()) | ||
| 782 | elseif MassLinearSolver == :BCG | ||
| 783 | MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) | ||
| 784 | _M = nothing | ||
| 785 | _MLS = LinSolHandle(MLS, SPBCGS()) | ||
| 786 | elseif MassLinearSolver == :PCG | ||
| 787 | MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) | ||
| 788 | _M = nothing | ||
| 789 | _MLS = LinSolHandle(MLS, PCG()) | ||
| 790 | elseif MassLinearSolver == :TFQMR | ||
| 791 | MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) | ||
| 792 | _M = nothing | ||
| 793 | _MLS = LinSolHandle(MLS, PTFQMR()) | ||
| 794 | elseif MassLinearSolver == :KLU | ||
| 795 | nnz = length(SparseArrays.nonzeros(prob.f.mass_matrix)) | ||
| 796 | M = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT) | ||
| 797 | MLS = SUNLinSol_KLU(uvec, M) | ||
| 798 | _M = MatrixHandle(M, SparseMatrix()) | ||
| 799 | _MLS = LinSolHandle(MLS, KLU()) | ||
| 800 | end | ||
| 801 | flag = ARKStepSetMassLinearSolver(mem, MLS, _M === nothing ? C_NULL : M, false) | ||
| 802 | function getmatfun(::T) where {T} | ||
| 803 | @cfunction(massmat, | ||
| 804 | Cint, | ||
| 805 | (realtype, SUNMatrix, Ref{T}, N_Vector, N_Vector, N_Vector)) | ||
| 806 | end | ||
| 807 | matfun = getmatfun(userfun) | ||
| 808 | ARKStepSetMassFn(mem, matfun) | ||
| 809 | else | ||
| 810 | _M = nothing | ||
| 811 | _MLS = nothing | ||
| 812 | end | ||
| 813 | |||
| 814 | if DiffEqBase.has_jac(prob.f) && alg.stiffness !== Explicit() | ||
| 815 | function getfunjac(::T) where {T} | ||
| 816 | @cfunction(cvodejac, | ||
| 817 | Cint, | ||
| 818 | (realtype, | ||
| 819 | N_Vector, | ||
| 820 | N_Vector, | ||
| 821 | SUNMatrix, | ||
| 822 | Ref{T}, | ||
| 823 | N_Vector, | ||
| 824 | N_Vector, | ||
| 825 | N_Vector)) | ||
| 826 | end | ||
| 827 | jac = getfunjac(userfun) | ||
| 828 | flag = ARKStepSetUserData(mem, userfun) | ||
| 829 | flag = ARKStepSetJacFn(mem, jac) | ||
| 830 | else | ||
| 831 | jac = nothing | ||
| 832 | end | ||
| 833 | |||
| 834 | if alg.prec !== nothing && alg.stiffness !== Explicit() | ||
| 835 | function getpercfun(::T) where {T} | ||
| 836 | @cfunction(precsolve, | ||
| 837 | Cint, | ||
| 838 | (Float64, | ||
| 839 | N_Vector, | ||
| 840 | N_Vector, | ||
| 841 | N_Vector, | ||
| 842 | N_Vector, | ||
| 843 | Float64, | ||
| 844 | Float64, | ||
| 845 | Int, | ||
| 846 | Ref{T})) | ||
| 847 | end | ||
| 848 | precfun = getpercfun(userfun) | ||
| 849 | |||
| 850 | function getpsetupfun(::T) where {T} | ||
| 851 | @cfunction(precsetup, | ||
| 852 | Cint, | ||
| 853 | (Float64, N_Vector, N_Vector, Int, Ptr{Int}, Float64, Ref{T})) | ||
| 854 | end | ||
| 855 | psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun) | ||
| 856 | |||
| 857 | ARKStepSetPreconditioner(mem, psetupfun, precfun) | ||
| 858 | end | ||
| 859 | |||
| 860 | tmp = isnothing(callbacks_internal) ? u0 : similar(u0) | ||
| 861 | uprev = isnothing(callbacks_internal) ? u0 : similar(u0) | ||
| 862 | tout = [tspan[1]] | ||
| 863 | |||
| 864 | if save_start | ||
| 865 | if save_idxs === nothing | ||
| 866 | ures = Vector{uType}() | ||
| 867 | dures = Vector{uType}() | ||
| 868 | save_value!(ures, u0, uType, save_idxs) | ||
| 869 | if dense | ||
| 870 | f!(out, u0, prob.p, tspan[1]) | ||
| 871 | save_value!(dures, out, uType, save_idxs) | ||
| 872 | end | ||
| 873 | else | ||
| 874 | ures = [u0[save_idxs]] | ||
| 875 | if dense | ||
| 876 | f!(out, u0, prob.p, tspan[1]) | ||
| 877 | dures = [out[save_idxs]] | ||
| 878 | end | ||
| 879 | end | ||
| 880 | else | ||
| 881 | ures = Vector{uType}() | ||
| 882 | dures = Vector{uType}() | ||
| 883 | end | ||
| 884 | |||
| 885 | sol = DiffEqBase.build_solution(prob, | ||
| 886 | alg, | ||
| 887 | ts, | ||
| 888 | ures; | ||
| 889 | dense = dense, | ||
| 890 | interp = dense ? | ||
| 891 | DiffEqBase.HermiteInterpolation(ts, ures, | ||
| 892 | dures) : | ||
| 893 | DiffEqBase.LinearInterpolation(ts, ures), | ||
| 894 | timeseries_errors = timeseries_errors, | ||
| 895 | stats = SciMLBase.DEStats(0), | ||
| 896 | calculate_error = false) | ||
| 897 | opts = DEOptions(saveat_internal, | ||
| 898 | tstops_internal, | ||
| 899 | saveat, tstops, save_start, | ||
| 900 | save_everystep, save_idxs, | ||
| 901 | dense, | ||
| 902 | timeseries_errors, | ||
| 903 | dense_errors, | ||
| 904 | save_on, | ||
| 905 | save_end, | ||
| 906 | callbacks_internal, | ||
| 907 | abstol, | ||
| 908 | reltol, | ||
| 909 | verbose, | ||
| 910 | advance_to_tstop, | ||
| 911 | stop_at_next_tstop, | ||
| 912 | progress, | ||
| 913 | progress_steps, | ||
| 914 | progress_name, | ||
| 915 | progress_message, | ||
| 916 | progress_id, | ||
| 917 | maxiters) | ||
| 918 | integrator = ARKODEIntegrator(u0, | ||
| 919 | utmp, | ||
| 920 | prob.p, | ||
| 921 | t0, | ||
| 922 | t0, | ||
| 923 | mem, | ||
| 924 | _LS, | ||
| 925 | _A, | ||
| 926 | _MLS, | ||
| 927 | _M, | ||
| 928 | sol, | ||
| 929 | alg, | ||
| 930 | f!, | ||
| 931 | userfun, | ||
| 932 | jac, | ||
| 933 | opts, | ||
| 934 | tout, | ||
| 935 | tdir, | ||
| 936 | false, | ||
| 937 | tmp, | ||
| 938 | uprev, | ||
| 939 | Cint(flag), | ||
| 940 | false, | ||
| 941 | 0, | ||
| 942 | 1, | ||
| 943 | callback_cache, | ||
| 944 | 0.0) | ||
| 945 | |||
| 946 | initialize_callbacks!(integrator) | ||
| 947 | integrator | ||
| 948 | end # function solve | ||
| 949 | |||
| 950 | function tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) | ||
| 951 | tstops_internal = DataStructures.BinaryHeap{tType}(DataStructures.FasterForward()) | ||
| 952 | saveat_internal = DataStructures.BinaryHeap{tType}(DataStructures.FasterForward()) | ||
| 953 | |||
| 954 | t0, tf = tspan | ||
| 955 | tdir_t0 = tdir * t0 | ||
| 956 | tdir_tf = tdir * tf | ||
| 957 | |||
| 958 | for t in tstops | ||
| 959 | tdir_t = tdir * t | ||
| 960 | tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) | ||
| 961 | end | ||
| 962 | push!(tstops_internal, tdir_tf) | ||
| 963 | |||
| 964 | if saveat isa Number | ||
| 965 | saveat = (t0:tdir*abs(saveat):tf)[2:end] | ||
| 966 | end | ||
| 967 | for t in saveat | ||
| 968 | tdir_t = tdir * t | ||
| 969 | tdir_t0 < tdir_t ≤ tdir_tf && push!(saveat_internal, tdir_t) | ||
| 970 | end | ||
| 971 | |||
| 972 | tstops_internal, saveat_internal | ||
| 973 | end | ||
| 974 | |||
| 975 | ## Solve for DAEs uses IDA | ||
| 976 | |||
| 977 | function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tupType, | ||
| 978 | isinplace}, | ||
| 979 | alg::SundialsDAEAlgorithm{LinearSolver}, | ||
| 980 | timeseries = [], | ||
| 981 | ts = [], | ||
| 982 | ks = []; | ||
| 983 | verbose = true, | ||
| 984 | dt = nothing, | ||
| 985 | dtmax = 0.0, | ||
| 986 | save_on = true, | ||
| 987 | save_start = true, | ||
| 988 | callback = nothing, | ||
| 989 | abstol = 1 / 10^6, | ||
| 990 | reltol = 1 / 10^3, | ||
| 991 | saveat = Float64[], | ||
| 992 | tstops = Float64[], | ||
| 993 | d_discontinuities = Float64[], | ||
| 994 | maxiters = Int(1e5), | ||
| 995 | timeseries_errors = true, | ||
| 996 | dense_errors = false, | ||
| 997 | save_everystep = isempty(saveat), save_idxs = nothing, | ||
| 998 | dense = save_everystep, | ||
| 999 | save_timeseries = nothing, | ||
| 1000 | save_end = true, | ||
| 1001 | progress = false, | ||
| 1002 | progress_steps = 1000, | ||
| 1003 | progress_name = "DAE IDA", | ||
| 1004 | progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, | ||
| 1005 | progress_id = gensym("Sundials"), | ||
| 1006 | advance_to_tstop = false, | ||
| 1007 | stop_at_next_tstop = false, | ||
| 1008 | userdata = nothing, | ||
| 1009 | initializealg = IDADefaultInit(), | ||
| 1010 | kwargs...) where {uType, duType, tupType, isinplace, LinearSolver | ||
| 1011 | } | ||
| 1012 | tType = eltype(tupType) | ||
| 1013 | |||
| 1014 | if verbose | ||
| 1015 | warned = !isempty(kwargs) && DiffEqBase.check_keywords(alg, kwargs, warnida) | ||
| 1016 | warned && DiffEqBase.warn_compat() | ||
| 1017 | end | ||
| 1018 | |||
| 1019 | if reltol isa AbstractArray | ||
| 1020 | error("Sundials only allows scalar reltol.") | ||
| 1021 | end | ||
| 1022 | |||
| 1023 | if length(prob.u0) == 0 | ||
| 1024 | error("Sundials requires at least one state variable.") | ||
| 1025 | end | ||
| 1026 | |||
| 1027 | progress && Logging.@logmsg(Logging.LogLevel(-1), progress_name, _id=progress_id, progress=0) | ||
| 1028 | |||
| 1029 | tstops = vcat(tstops, d_discontinuities) | ||
| 1030 | callbacks_internal = DiffEqBase.CallbackSet(callback) | ||
| 1031 | |||
| 1032 | max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal) | ||
| 1033 | if max_len_cb isa VectorContinuousCallback | ||
| 1034 | callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64) | ||
| 1035 | else | ||
| 1036 | callback_cache = nothing | ||
| 1037 | end | ||
| 1038 | |||
| 1039 | tspan = prob.tspan | ||
| 1040 | t0 = tspan[1] | ||
| 1041 | |||
| 1042 | tdir = sign(tspan[2] - tspan[1]) | ||
| 1043 | |||
| 1044 | tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir, | ||
| 1045 | tspan, tType) | ||
| 1046 | @assert size(prob.u0) == size(prob.du0) | ||
| 1047 | if prob.u0 isa Number | ||
| 1048 | u0 = [prob.u0] | ||
| 1049 | du0 = [prob.du0] | ||
| 1050 | else | ||
| 1051 | u0 = copy(prob.u0) | ||
| 1052 | du0 = copy(prob.du0) | ||
| 1053 | end | ||
| 1054 | |||
| 1055 | ### Fix the more general function to Sundials allowed style | ||
| 1056 | if !isinplace && prob.u0 isa Number | ||
| 1057 | f! = (out, du, u, p, t) -> (out .= prob.f(first(du), first(u), p, t); Cint(0)) | ||
| 1058 | elseif !isinplace | ||
| 1059 | f! = (out, du, u, p, t) -> (out .= prob.f(du, u, p, t); Cint(0)) | ||
| 1060 | else # Then it's an in-place function on an abstract array | ||
| 1061 | f! = prob.f | ||
| 1062 | end | ||
| 1063 | |||
| 1064 | mem_ptr = IDACreate() | ||
| 1065 | (mem_ptr == C_NULL) && error("Failed to allocate IDA solver object") | ||
| 1066 | mem = Handle(mem_ptr) | ||
| 1067 | |||
| 1068 | !verbose && IDASetErrHandlerFn(mem, | ||
| 1069 | @cfunction(null_error_handler, Nothing, | ||
| 1070 | (Cint, Char, Char, Ptr{Cvoid})), | ||
| 1071 | C_NULL) | ||
| 1072 | |||
| 1073 | ts = [t0] | ||
| 1074 | |||
| 1075 | # vec shares memory | ||
| 1076 | utmp = NVector(vec(u0)) | ||
| 1077 | dutmp = NVector(vec(du0)) | ||
| 1078 | rtest = zeros(size(u0)) | ||
| 1079 | |||
| 1080 | use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) && | ||
| 1081 | LinearSolver ∈ SPARSE_SOLVERS) | ||
| 1082 | userfun = FunJac(f!, | ||
| 1083 | prob.f.jac, | ||
| 1084 | prob.p, | ||
| 1085 | nothing, | ||
| 1086 | use_jac_prototype ? prob.f.jac_prototype : nothing, | ||
| 1087 | alg.prec, | ||
| 1088 | alg.psetup, | ||
| 1089 | u0, | ||
| 1090 | du0, | ||
| 1091 | rtest) | ||
| 1092 | |||
| 1093 | function getcfun(::T) where {T} | ||
| 1094 | @cfunction(idasolfun, Cint, (realtype, N_Vector, N_Vector, N_Vector, Ref{T})) | ||
| 1095 | end | ||
| 1096 | cfun = getcfun(userfun) | ||
| 1097 | flag = IDAInit(mem, cfun, t0, utmp, dutmp) | ||
| 1098 | dt !== nothing && (flag = IDASetInitStep(mem, dt)) | ||
| 1099 | flag = IDASetUserData(mem, userfun) | ||
| 1100 | flag = IDASetMaxStep(mem, dtmax) | ||
| 1101 | if abstol isa Array | ||
| 1102 | flag = IDASVtolerances(mem, reltol, abstol) | ||
| 1103 | else | ||
| 1104 | flag = IDASStolerances(mem, reltol, abstol) | ||
| 1105 | end | ||
| 1106 | flag = IDASetMaxNumSteps(mem, maxiters) | ||
| 1107 | flag = IDASetMaxOrd(mem, alg.max_order) | ||
| 1108 | flag = IDASetMaxErrTestFails(mem, alg.max_error_test_failures) | ||
| 1109 | flag = IDASetNonlinConvCoef(mem, alg.nonlinear_convergence_coefficient) | ||
| 1110 | flag = IDASetMaxNonlinIters(mem, alg.max_nonlinear_iters) | ||
| 1111 | flag = IDASetMaxConvFails(mem, alg.max_convergence_failures) | ||
| 1112 | flag = IDASetNonlinConvCoefIC(mem, alg.nonlinear_convergence_coefficient_ic) | ||
| 1113 | flag = IDASetMaxNumStepsIC(mem, alg.max_num_steps_ic) | ||
| 1114 | flag = IDASetMaxNumJacsIC(mem, alg.max_num_jacs_ic) | ||
| 1115 | flag = IDASetMaxNumItersIC(mem, alg.max_num_iters_ic) | ||
| 1116 | #flag = IDASetMaxBacksIC(mem,alg.max_num_backs_ic) # Needs newer version? | ||
| 1117 | flag = IDASetLineSearchOffIC(mem, alg.use_linesearch_ic) | ||
| 1118 | |||
| 1119 | prec_side = isnothing(alg.prec) ? 0 : 1 # IDA only supports left preconditioning (prec_side = 1) | ||
| 1120 | if LinearSolver in (:Dense, :LapackDense) | ||
| 1121 | nojacobian = false | ||
| 1122 | A = SUNDenseMatrix(length(u0), length(u0)) | ||
| 1123 | _A = MatrixHandle(A, DenseMatrix()) | ||
| 1124 | if LinearSolver === :Dense | ||
| 1125 | LS = SUNLinSol_Dense(utmp, A) | ||
| 1126 | _LS = LinSolHandle(LS, Dense()) | ||
| 1127 | else | ||
| 1128 | LS = SUNLinSol_LapackDense(u0, A) | ||
| 1129 | _LS = LinSolHandle(LS, LapackDense()) | ||
| 1130 | end | ||
| 1131 | elseif LinearSolver in (:Band, :LapackBand) | ||
| 1132 | nojacobian = false | ||
| 1133 | A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) | ||
| 1134 | _A = MatrixHandle(A, BandMatrix()) | ||
| 1135 | if LinearSolver === :Band | ||
| 1136 | LS = SUNLinSol_Band(utmp, A) | ||
| 1137 | _LS = LinSolHandle(LS, Band()) | ||
| 1138 | else | ||
| 1139 | LS = SUNLinSol_LapackBand(utmp, A) | ||
| 1140 | _LS = LinSolHandle(LS, LapackBand()) | ||
| 1141 | end | ||
| 1142 | elseif LinearSolver == :GMRES | ||
| 1143 | LS = SUNLinSol_SPGMR(utmp, prec_side, alg.krylov_dim) | ||
| 1144 | _A = nothing | ||
| 1145 | _LS = LinSolHandle(LS, SPGMR()) | ||
| 1146 | elseif LinearSolver == :FGMRES | ||
| 1147 | LS = SUNLinSol_SPFGMR(utmp, prec_side, alg.krylov_dim) | ||
| 1148 | _A = nothing | ||
| 1149 | _LS = LinSolHandle(LS, SPFGMR()) | ||
| 1150 | elseif LinearSolver == :BCG | ||
| 1151 | LS = SUNLinSol_SPBCGS(utmp, prec_side, alg.krylov_dim) | ||
| 1152 | _A = nothing | ||
| 1153 | _LS = LinSolHandle(LS, SPBCGS()) | ||
| 1154 | elseif LinearSolver == :PCG | ||
| 1155 | LS = SUNLinSol_PCG(utmp, prec_side, alg.krylov_dim) | ||
| 1156 | _A = nothing | ||
| 1157 | _LS = LinSolHandle(LS, PCG()) | ||
| 1158 | elseif LinearSolver == :TFQMR | ||
| 1159 | LS = SUNLinSol_SPTFQMR(utmp, prec_side, alg.krylov_dim) | ||
| 1160 | _A = nothing | ||
| 1161 | _LS = LinSolHandle(LS, PTFQMR()) | ||
| 1162 | elseif LinearSolver == :KLU | ||
| 1163 | nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) | ||
| 1164 | A = SUNSparseMatrix(length(u0), length(u0), nnz, Sundials.CSC_MAT) | ||
| 1165 | LS = SUNLinSol_KLU(utmp, A) | ||
| 1166 | _A = MatrixHandle(A, SparseMatrix()) | ||
| 1167 | _LS = LinSolHandle(LS, KLU()) | ||
| 1168 | end | ||
| 1169 | flag = IDASetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) | ||
| 1170 | |||
| 1171 | if prob.f.jac_prototype isa AbstractSciMLOperator | ||
| 1172 | function getcfunjtimes(::T) where {T} | ||
| 1173 | @cfunction(idajactimes, | ||
| 1174 | Cint, | ||
| 1175 | (realtype, | ||
| 1176 | N_Vector, | ||
| 1177 | N_Vector, | ||
| 1178 | N_Vector, | ||
| 1179 | N_Vector, | ||
| 1180 | N_Vector, | ||
| 1181 | realtype, | ||
| 1182 | Ref{T}, | ||
| 1183 | N_Vector, | ||
| 1184 | N_Vector)) | ||
| 1185 | end | ||
| 1186 | jtimes = getcfunjtimes(userfun) | ||
| 1187 | IDASetJacTimes(mem, C_NULL, jtimes) | ||
| 1188 | end | ||
| 1189 | |||
| 1190 | if alg.prec !== nothing | ||
| 1191 | function getprecfun(::T) where {T} | ||
| 1192 | @cfunction(idaprecsolve, | ||
| 1193 | Cint, | ||
| 1194 | (Float64, | ||
| 1195 | N_Vector, | ||
| 1196 | N_Vector, | ||
| 1197 | N_Vector, | ||
| 1198 | N_Vector, | ||
| 1199 | N_Vector, | ||
| 1200 | Float64, | ||
| 1201 | Float64, | ||
| 1202 | Ref{T})) | ||
| 1203 | end | ||
| 1204 | precfun = getprecfun(userfun) | ||
| 1205 | |||
| 1206 | function getpsetupfun(::T) where {T} | ||
| 1207 | @cfunction(idaprecsetup, | ||
| 1208 | Cint, | ||
| 1209 | (Float64, N_Vector, N_Vector, N_Vector, Float64, Ref{T})) | ||
| 1210 | end | ||
| 1211 | psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun) | ||
| 1212 | |||
| 1213 | IDASetPreconditioner(mem, psetupfun, precfun) | ||
| 1214 | end | ||
| 1215 | |||
| 1216 | if DiffEqBase.has_jac(prob.f) | ||
| 1217 | function getcfunjacc(::T) where {T} | ||
| 1218 | @cfunction(idajac, | ||
| 1219 | Cint, | ||
| 1220 | (realtype, | ||
| 1221 | realtype, | ||
| 1222 | N_Vector, | ||
| 1223 | N_Vector, | ||
| 1224 | N_Vector, | ||
| 1225 | SUNMatrix, | ||
| 1226 | Ref{T}, | ||
| 1227 | N_Vector, | ||
| 1228 | N_Vector, | ||
| 1229 | N_Vector)) | ||
| 1230 | end | ||
| 1231 | jac = getcfunjacc(userfun) | ||
| 1232 | flag = IDASetUserData(mem, userfun) | ||
| 1233 | flag = IDASetJacFn(mem, jac) | ||
| 1234 | else | ||
| 1235 | jac = nothing | ||
| 1236 | end | ||
| 1237 | |||
| 1238 | tout = Float64[first(tspan)] | ||
| 1239 | if save_idxs isa Integer | ||
| 1240 | ures = Vector{eltype(uType)}() | ||
| 1241 | dures = Vector{eltype(uType)}() | ||
| 1242 | else | ||
| 1243 | ures = Vector{uType}() | ||
| 1244 | dures = Vector{uType}() | ||
| 1245 | end | ||
| 1246 | tmp = isnothing(callbacks_internal) ? u0 : similar(u0) | ||
| 1247 | uprev = isnothing(callbacks_internal) ? u0 : similar(u0) | ||
| 1248 | retcode = flag >= 0 ? ReturnCode.Default : ReturnCode.InitialFailure | ||
| 1249 | sol = DiffEqBase.build_solution(prob, | ||
| 1250 | alg, | ||
| 1251 | ts, | ||
| 1252 | ures, dense ? dures : nothing; | ||
| 1253 | dense = dense, | ||
| 1254 | calculate_error = false, | ||
| 1255 | timeseries_errors = timeseries_errors, | ||
| 1256 | retcode = retcode, | ||
| 1257 | stats = SciMLBase.DEStats(0), | ||
| 1258 | dense_errors = dense_errors) | ||
| 1259 | |||
| 1260 | opts = DEOptions(saveat_internal, | ||
| 1261 | tstops_internal, | ||
| 1262 | saveat, tstops, save_start, | ||
| 1263 | save_everystep, save_idxs, | ||
| 1264 | dense, | ||
| 1265 | timeseries_errors, | ||
| 1266 | dense_errors, | ||
| 1267 | save_on, | ||
| 1268 | save_end, | ||
| 1269 | callbacks_internal, | ||
| 1270 | abstol, | ||
| 1271 | reltol, | ||
| 1272 | verbose, | ||
| 1273 | advance_to_tstop, | ||
| 1274 | stop_at_next_tstop, | ||
| 1275 | progress, | ||
| 1276 | progress_steps, | ||
| 1277 | progress_name, | ||
| 1278 | progress_message, | ||
| 1279 | progress_id, | ||
| 1280 | maxiters) | ||
| 1281 | |||
| 1282 | integrator = IDAIntegrator(u0, | ||
| 1283 | du0, | ||
| 1284 | prob.p, | ||
| 1285 | t0, | ||
| 1286 | t0, | ||
| 1287 | mem, | ||
| 1288 | _LS, | ||
| 1289 | _A, | ||
| 1290 | sol, | ||
| 1291 | alg, | ||
| 1292 | f!, | ||
| 1293 | userfun, | ||
| 1294 | jac, | ||
| 1295 | opts, | ||
| 1296 | tout, | ||
| 1297 | tdir, | ||
| 1298 | false, | ||
| 1299 | tmp, | ||
| 1300 | uprev, | ||
| 1301 | Cint(flag), | ||
| 1302 | 0, | ||
| 1303 | false, | ||
| 1304 | 0, | ||
| 1305 | 1, | ||
| 1306 | callback_cache, | ||
| 1307 | 0.0, | ||
| 1308 | utmp, | ||
| 1309 | dutmp, | ||
| 1310 | initializealg) | ||
| 1311 | |||
| 1312 | DiffEqBase.initialize_dae!(integrator, initializealg) | ||
| 1313 | integrator.u_modified && IDAReinit!(integrator) | ||
| 1314 | |||
| 1315 | if save_start | ||
| 1316 | save_value!(ures, integrator.u, uType, save_idxs) | ||
| 1317 | save_value!(dures, integrator.du, duType, save_idxs) | ||
| 1318 | end | ||
| 1319 | |||
| 1320 | initialize_callbacks!(integrator) | ||
| 1321 | integrator | ||
| 1322 | end # function solve | ||
| 1323 | |||
| 1324 | ## Common calls | ||
| 1325 | |||
| 1326 | function interpret_sundials_retcode(flag) | ||
| 1327 | flag >= 0 && return ReturnCode.Success | ||
| 1328 | flag == -1 && return ReturnCode.MaxIters | ||
| 1329 | (flag == -2 || flag == -3) && return ReturnCode.Unstable | ||
| 1330 | flag == -4 && return ReturnCode.ConvergenceFailure | ||
| 1331 | return ReturnCode.Failure | ||
| 1332 | end | ||
| 1333 | |||
| 1334 | function solver_step(integrator::CVODEIntegrator, tstop) | ||
| 1335 | integrator.flag = CVode(integrator.mem, tstop, integrator.u_nvec, integrator.tout, | ||
| 1336 | CV_ONE_STEP) | ||
| 1337 | if integrator.opts.progress | ||
| 1338 | Logging.@logmsg(Logging.LogLevel(-1), | ||
| 1339 | integrator.opts.progress_name, | ||
| 1340 | _id=integrator.opts.progress_id, | ||
| 1341 | message=integrator.opts.progress_message(integrator.dt, | ||
| 1342 | integrator.u, | ||
| 1343 | integrator.p, | ||
| 1344 | integrator.t), | ||
| 1345 | progress=integrator.t / integrator.sol.prob.tspan[2]) | ||
| 1346 | end | ||
| 1347 | end | ||
| 1348 | function solver_step(integrator::ARKODEIntegrator, tstop) | ||
| 1349 | integrator.flag = ARKStepEvolve(integrator.mem, tstop, integrator.u_nvec, | ||
| 1350 | integrator.tout, ARK_ONE_STEP) | ||
| 1351 | if integrator.opts.progress | ||
| 1352 | Logging.@logmsg(Logging.LogLevel(-1), | ||
| 1353 | integrator.opts.progress_name, | ||
| 1354 | _id=integrator.opts.progress_id, | ||
| 1355 | message=integrator.opts.progress_message(integrator.dt, | ||
| 1356 | integrator.u_nvec, | ||
| 1357 | integrator.p, | ||
| 1358 | integrator.t), | ||
| 1359 | progress=integrator.t / integrator.sol.prob.tspan[2]) | ||
| 1360 | end | ||
| 1361 | end | ||
| 1362 | function solver_step(integrator::IDAIntegrator, tstop) | ||
| 1363 | 58 (100 %) |
58 (100 %)
samples spent calling
IDASolve
integrator.flag = IDASolve(integrator.mem,
|
|
| 1364 | tstop, | ||
| 1365 | integrator.tout, | ||
| 1366 | integrator.u_nvec, | ||
| 1367 | integrator.du_nvec, | ||
| 1368 | IDA_ONE_STEP) | ||
| 1369 | integrator.iter += 1 | ||
| 1370 | if integrator.opts.progress && integrator.iter % integrator.opts.progress_steps == 0 | ||
| 1371 | Logging.@logmsg(Logging.LogLevel(-1), | ||
| 1372 | integrator.opts.progress_name, | ||
| 1373 | _id=integrator.opts.progress_id, | ||
| 1374 | message=integrator.opts.progress_message(integrator.dt, | ||
| 1375 | integrator.u, | ||
| 1376 | integrator.p, | ||
| 1377 | integrator.t), | ||
| 1378 | progress=integrator.t / integrator.sol.prob.tspan[2]) | ||
| 1379 | end | ||
| 1380 | end | ||
| 1381 | |||
| 1382 | function set_stop_time(integrator::CVODEIntegrator, tstop) | ||
| 1383 | CVodeSetStopTime(integrator.mem, tstop) | ||
| 1384 | end | ||
| 1385 | function set_stop_time(integrator::ARKODEIntegrator, tstop) | ||
| 1386 | ARKStepSetStopTime(integrator.mem, tstop) | ||
| 1387 | end | ||
| 1388 | function set_stop_time(integrator::IDAIntegrator, tstop) | ||
| 1389 | IDASetStopTime(integrator.mem, tstop) | ||
| 1390 | end | ||
| 1391 | |||
| 1392 | function get_iters!(integrator::CVODEIntegrator, iters) | ||
| 1393 | CVodeGetNumSteps(integrator.mem, iters) | ||
| 1394 | end | ||
| 1395 | function get_iters!(integrator::ARKODEIntegrator, iters) | ||
| 1396 | ARKStepGetNumSteps(integrator.mem, iters) | ||
| 1397 | end | ||
| 1398 | function get_iters!(integrator::IDAIntegrator, iters) | ||
| 1399 | IDAGetNumSteps(integrator.mem, iters) | ||
| 1400 | end | ||
| 1401 | |||
| 1402 | function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = false, | ||
| 1403 | calculate_error = false) | ||
| 1404 | uType = eltype(integrator.sol.u) | ||
| 1405 | iters = Ref(Clong(-1)) | ||
| 1406 | while !isempty(integrator.opts.tstops) | ||
| 1407 | # The call to first is an overload of Base.first implemented in DataStructures | ||
| 1408 | while integrator.tdir * integrator.t < first(integrator.opts.tstops) | ||
| 1409 | tstop = integrator.tdir * first(integrator.opts.tstops) | ||
| 1410 | set_stop_time(integrator, tstop) | ||
| 1411 | integrator.tprev = integrator.t | ||
| 1412 | if !(integrator.opts.callback.continuous_callbacks isa Tuple{}) | ||
| 1413 | integrator.uprev .= integrator.u | ||
| 1414 | end | ||
| 1415 | integrator.userfun.p = integrator.p | ||
| 1416 | solver_step(integrator, tstop) | ||
| 1417 | integrator.t = first(integrator.tout) | ||
| 1418 | # NB: CVode, ARKode may warn and then recover if integrator.t == integrator.tprev so don't flag this as an error | ||
| 1419 | integrator.flag < 0 && break | ||
| 1420 | handle_callbacks!(integrator) # this also updates the interpolation | ||
| 1421 | integrator.flag < 0 && break | ||
| 1422 | if isempty(integrator.opts.tstops) | ||
| 1423 | break | ||
| 1424 | end | ||
| 1425 | get_iters!(integrator, iters) | ||
| 1426 | if iters[] + 1 > integrator.opts.maxiters | ||
| 1427 | integrator.flag = -1 # MaxIters: -1 | ||
| 1428 | break | ||
| 1429 | end | ||
| 1430 | end | ||
| 1431 | integrator.flag < 0 && break | ||
| 1432 | handle_tstop!(integrator) | ||
| 1433 | end | ||
| 1434 | |||
| 1435 | DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator) | ||
| 1436 | tend = integrator.t | ||
| 1437 | if integrator.opts.save_end && | ||
| 1438 | (isempty(integrator.sol.t) || integrator.sol.t[end] != tend) | ||
| 1439 | save_value!(integrator.sol.u, integrator.u, uType, | ||
| 1440 | integrator.opts.save_idxs) | ||
| 1441 | push!(integrator.sol.t, tend) | ||
| 1442 | if integrator.opts.dense | ||
| 1443 | save_value!(integrator.sol.interp.du, get_du(integrator), uType, | ||
| 1444 | integrator.opts.save_idxs) | ||
| 1445 | end | ||
| 1446 | end | ||
| 1447 | |||
| 1448 | if integrator.opts.progress | ||
| 1449 | Logging.@logmsg(Logging.LogLevel(-1), | ||
| 1450 | integrator.opts.progress_name, | ||
| 1451 | _id=integrator.opts.progress_id, | ||
| 1452 | message=integrator.opts.progress_message(integrator.dt, | ||
| 1453 | integrator.u, | ||
| 1454 | integrator.p, | ||
| 1455 | integrator.t), | ||
| 1456 | progress="done") | ||
| 1457 | end | ||
| 1458 | |||
| 1459 | fill_stats!(integrator) | ||
| 1460 | |||
| 1461 | if early_free | ||
| 1462 | empty!(integrator.mem) | ||
| 1463 | integrator.A !== nothing && empty!(integrator.A) | ||
| 1464 | integrator.LS !== nothing && empty!(integrator.LS) | ||
| 1465 | end | ||
| 1466 | |||
| 1467 | if DiffEqBase.has_analytic(integrator.sol.prob.f) && calculate_error | ||
| 1468 | DiffEqBase.calculate_solution_errors!(integrator.sol; | ||
| 1469 | timeseries_errors = integrator.opts.timeseries_errors, | ||
| 1470 | dense_errors = integrator.opts.dense_errors) | ||
| 1471 | end | ||
| 1472 | |||
| 1473 | if integrator.sol.retcode == ReturnCode.Default | ||
| 1474 | integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, | ||
| 1475 | interpret_sundials_retcode(integrator.flag)) | ||
| 1476 | end | ||
| 1477 | |||
| 1478 | return integrator.sol | ||
| 1479 | end | ||
| 1480 | |||
| 1481 | function handle_tstop!(integrator::AbstractSundialsIntegrator) | ||
| 1482 | tstops = integrator.opts.tstops | ||
| 1483 | if !isempty(tstops) && integrator.tdir * integrator.t >= first(tstops) | ||
| 1484 | pop!(tstops) | ||
| 1485 | # If we passed multiple tstops at once (possible if Sundials ignores us or we had redundant tstops) | ||
| 1486 | while !isempty(tstops) && integrator.tdir * integrator.t >= first(tstops) | ||
| 1487 | pop!(tstops) | ||
| 1488 | end | ||
| 1489 | integrator.just_hit_tstop = true | ||
| 1490 | end | ||
| 1491 | end | ||
| 1492 | |||
| 1493 | function fill_stats!(integrator::AbstractSundialsIntegrator) end | ||
| 1494 | |||
| 1495 | function fill_stats!(integrator::CVODEIntegrator) | ||
| 1496 | stats = integrator.sol.stats | ||
| 1497 | mem = integrator.mem | ||
| 1498 | tmp = Ref(Clong(-1)) | ||
| 1499 | CVodeGetNumRhsEvals(mem, tmp) | ||
| 1500 | stats.nf = tmp[] | ||
| 1501 | CVodeGetNumLinSolvSetups(mem, tmp) | ||
| 1502 | stats.nw = tmp[] | ||
| 1503 | CVodeGetNumErrTestFails(mem, tmp) | ||
| 1504 | stats.nreject = tmp[] | ||
| 1505 | CVodeGetNumSteps(mem, tmp) | ||
| 1506 | stats.naccept = tmp[] - stats.nreject | ||
| 1507 | CVodeGetNumNonlinSolvIters(mem, tmp) | ||
| 1508 | stats.nnonliniter = tmp[] | ||
| 1509 | CVodeGetNumNonlinSolvConvFails(mem, tmp) | ||
| 1510 | stats.nnonlinconvfail = tmp[] | ||
| 1511 | if method_choice(integrator.alg) == :Newton | ||
| 1512 | CVodeGetNumJacEvals(mem, tmp) | ||
| 1513 | stats.njacs = tmp[] | ||
| 1514 | end | ||
| 1515 | end | ||
| 1516 | |||
| 1517 | function fill_stats!(integrator::ARKODEIntegrator) | ||
| 1518 | stats = integrator.sol.stats | ||
| 1519 | mem = integrator.mem | ||
| 1520 | tmp = Ref(Clong(-1)) | ||
| 1521 | tmp2 = Ref(Clong(-1)) | ||
| 1522 | ARKStepGetNumRhsEvals(mem, tmp, tmp2) | ||
| 1523 | stats.nf = tmp[] | ||
| 1524 | stats.nf2 = tmp2[] | ||
| 1525 | integrator.alg.stiffness !== Explicit() && ARKStepGetNumLinSolvSetups(mem, tmp) | ||
| 1526 | stats.nw = tmp[] | ||
| 1527 | ARKStepGetNumErrTestFails(mem, tmp) | ||
| 1528 | stats.nreject = tmp[] | ||
| 1529 | ARKStepGetNumSteps(mem, tmp) | ||
| 1530 | stats.naccept = tmp[] - stats.nreject | ||
| 1531 | integrator.alg.stiffness !== Explicit() && ARKStepGetNumNonlinSolvIters(mem, tmp) | ||
| 1532 | stats.nnonliniter = tmp[] | ||
| 1533 | integrator.alg.stiffness !== Explicit() && ARKStepGetNumNonlinSolvConvFails(mem, tmp) | ||
| 1534 | stats.nnonlinconvfail = tmp[] | ||
| 1535 | if integrator.alg.stiffness !== Explicit() && method_choice(integrator.alg) == :Newton | ||
| 1536 | ARKStepGetNumJacEvals(mem, tmp) | ||
| 1537 | stats.njacs = tmp[] | ||
| 1538 | end | ||
| 1539 | end | ||
| 1540 | |||
| 1541 | function fill_stats!(integrator::IDAIntegrator) | ||
| 1542 | stats = integrator.sol.stats | ||
| 1543 | mem = integrator.mem | ||
| 1544 | tmp = Ref(Clong(-1)) | ||
| 1545 | IDAGetNumResEvals(mem, tmp) | ||
| 1546 | stats.nf = tmp[] | ||
| 1547 | IDAGetNumLinSolvSetups(mem, tmp) | ||
| 1548 | stats.nw = tmp[] | ||
| 1549 | IDAGetNumErrTestFails(mem, tmp) | ||
| 1550 | stats.nreject = tmp[] | ||
| 1551 | IDAGetNumSteps(mem, tmp) | ||
| 1552 | stats.naccept = tmp[] - stats.nreject | ||
| 1553 | IDAGetNumNonlinSolvIters(mem, tmp) | ||
| 1554 | stats.nnonliniter = tmp[] | ||
| 1555 | IDAGetNumNonlinSolvConvFails(mem, tmp) | ||
| 1556 | stats.nnonlinconvfail = tmp[] | ||
| 1557 | if method_choice(integrator.alg) == :Newton | ||
| 1558 | IDAGetNumJacEvals(mem, tmp) | ||
| 1559 | stats.njacs = tmp[] | ||
| 1560 | end | ||
| 1561 | end | ||
| 1562 | |||
| 1563 | function initialize_callbacks!(integrator, initialize_save = true) | ||
| 1564 | t = integrator.t | ||
| 1565 | u = integrator.u | ||
| 1566 | callbacks = integrator.opts.callback | ||
| 1567 | integrator.u_modified = true | ||
| 1568 | |||
| 1569 | u_modified = initialize!(callbacks, u, t, integrator) | ||
| 1570 | |||
| 1571 | # if the user modifies u, we need to fix current values | ||
| 1572 | if u_modified | ||
| 1573 | handle_callback_modifiers!(integrator) | ||
| 1574 | |||
| 1575 | if initialize_save && | ||
| 1576 | (any((c) -> c.save_positions[2], callbacks.discrete_callbacks) || | ||
| 1577 | any((c) -> c.save_positions[2], callbacks.continuous_callbacks)) | ||
| 1578 | savevalues!(integrator, true) | ||
| 1579 | end | ||
| 1580 | end | ||
| 1581 | |||
| 1582 | # reset this as it is now handled so the integrators should proceed as normal | ||
| 1583 | integrator.u_modified = false | ||
| 1584 | end |