StatProfilerHTML.jl report
Generated on Mon, 01 Apr 2024 21:01:18
File source code
Line Exclusive Inclusive Code
1 ## Common Interface Solve Functions
2
3 function DiffEqBase.__solve(prob::Union{DiffEqBase.AbstractODEProblem,
4 DiffEqBase.AbstractDAEProblem},
5 alg::algType,
6 timeseries = [],
7 ts = [],
8 ks = [],
9 recompile::Type{Val{recompile_flag}} = Val{true};
10 calculate_error = true,
11 kwargs...) where {
12 algType <: Union{SundialsODEAlgorithm,
13 SundialsDAEAlgorithm},
14 recompile_flag}
15 integrator = DiffEqBase.__init(prob, alg, timeseries, ts, ks; kwargs...)
16 if integrator.sol.retcode == ReturnCode.Default
17 solve!(integrator; early_free = true, calculate_error = calculate_error)
18 end
19 integrator.sol
20 end
21
22 function DiffEqBase.__solve(prob::Union{
23 DiffEqBase.AbstractSteadyStateProblem{uType,
24 isinplace},
25 DiffEqBase.AbstractNonlinearProblem{uType,
26 isinplace}},
27 alg::algType,
28 timeseries = [],
29 ts = [],
30 ks = [],
31 recompile::Type{Val{recompile_flag}} = Val{true};
32 abstol = eps(Float64) ^ (4 // 5),
33 maxiters = 1000,
34 kwargs...) where {algType <: SundialsNonlinearSolveAlgorithm,
35 recompile_flag, uType, isinplace}
36 if prob.u0 isa Number
37 u0 = [prob.u0]
38 else
39 u0 = deepcopy(prob.u0)
40 end
41
42 p = prob.p
43 userdata = alg.userdata
44 linsolve = linear_solver(alg)
45 jac_upper = alg.jac_upper
46 jac_lower = alg.jac_lower
47
48 ### Fix the more general function to Sundials allowed style
49 if prob.f isa ODEFunction
50 t = Inf
51 if !isinplace && prob.u0 isa Number
52 f! = (du, u) -> (du .= prob.f(first(u), p, t); Cint(0))
53 elseif !isinplace
54 f! = (du, u) -> (du .= prob.f(u, p, t); Cint(0))
55 else # Then it's an in-place function on an abstract array
56 f! = (du, u) -> prob.f(du, u, p, t)
57 end
58 elseif prob.f isa NonlinearFunction
59 if !isinplace && prob.u0 isa Number
60 f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
61 elseif !isinplace
62 f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
63 else # Then it's an in-place function on an abstract array
64 f! = (du, u) -> prob.f(du, u, p)
65 end
66 end
67 u = zero(u0)
68 resid = similar(u)
69 u,flag = ___kinsol(f!, u0;
70 userdata = userdata,
71 linear_solver = linsolve,
72 jac_upper = jac_upper,
73 jac_lower = jac_lower,
74 abstol,
75 prob.f.jac_prototype,
76 alg.prec_side,
77 alg.krylov_dim,
78 maxiters,
79 strategy = alg.globalization_strategy)
80
81 f!(resid, u)
82 retcode = interpret_sundials_retcode(flag)
83 if prob.u0 isa Number
84 DiffEqBase.build_solution(prob, alg, u[1], resid[1]; retcode = retcode)
85 else
86 DiffEqBase.build_solution(prob, alg, u, resid; retcode = retcode)
87 end
88 end
89
90 function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, isinplace},
91 alg::SundialsODEAlgorithm{Method, LinearSolver},
92 timeseries = [],
93 ts = [],
94 ks = [];
95 verbose = true,
96 callback = nothing,
97 abstol = 1 / 10^6,
98 reltol = 1 / 10^3,
99 saveat = Float64[],
100 d_discontinuities = Float64[],
101 tstops = Float64[],
102 maxiters = Int(1e5),
103 dt = nothing,
104 dtmin = 0.0,
105 dtmax = 0.0,
106 timeseries_errors = true,
107 dense_errors = false,
108 save_everystep = isempty(saveat), save_idxs = nothing,
109 save_on = true,
110 save_start = save_everystep || isempty(saveat) ||
111 saveat isa Number ? true :
112 prob.tspan[1] in saveat,
113 save_end = save_everystep || isempty(saveat) ||
114 saveat isa Number ? true :
115 prob.tspan[2] in saveat,
116 dense = save_everystep && isempty(saveat),
117 progress = false,
118 progress_steps=1000,
119 progress_name = "ODE",
120 progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE,
121 progress_id = gensym("Sundials"),
122 save_timeseries = nothing,
123 advance_to_tstop = false,
124 stop_at_next_tstop = false,
125 userdata = nothing,
126 alias_u0 = false,
127 kwargs...) where {uType, tupType, isinplace, Method, LinearSolver
128 }
129 tType = eltype(tupType)
130
131 if verbose
132 warned = !isempty(kwargs) && DiffEqBase.check_keywords(alg, kwargs, warnlist)
133 warned && DiffEqBase.warn_compat()
134 end
135
136 if prob.f.mass_matrix != LinearAlgebra.I
137 error("This solver is not able to use mass matrices.")
138 end
139
140 if reltol isa AbstractArray
141 error("Sundials only allows scalar reltol.")
142 end
143
144 if length(prob.u0) <= 0
145 error("Sundials requires at least one state variable.")
146 end
147
148 progress && Logging.@logmsg(Logging.LogLevel(-1), progress_name, _id=progress_id, progress=0)
149
150 tstops = vcat(tstops, d_discontinuities)
151 callbacks_internal = DiffEqBase.CallbackSet(callback)
152
153 max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal)
154 if max_len_cb isa VectorContinuousCallback
155 callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64)
156 else
157 callback_cache = nothing
158 end
159
160 tspan = Float64.(prob.tspan)
161 t0 = tspan[1]
162
163 tdir = sign(tspan[2] - tspan[1])
164
165 tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir,
166 tspan, tType)
167
168 if prob.u0 isa Number
169 u0 = [prob.u0]
170 else
171 if alias_u0
172 u0 = prob.u0
173 else
174 u0 = copy(prob.u0)
175 end
176 end
177
178 ### Fix the more general function to Sundials allowed style
179 if !isinplace && prob.u0 isa Number
180 f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0))
181 elseif !isinplace
182 f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0))
183 else # Then it's an in-place function on an abstract array
184 f! = prob.f
185 end
186
187 if alg isa CVODE_BDF
188 alg_code = CV_BDF
189 elseif alg isa CVODE_Adams
190 alg_code = CV_ADAMS
191 end
192
193 #if Method == :Newton
194 # method_code = CV_NEWTON
195 #elseif Method == :Functional
196 # method_code = CV_FUNCTIONAL
197 #end
198
199 mem_ptr = CVodeCreate(alg_code)
200 (mem_ptr == C_NULL) && error("Failed to allocate CVODE solver object")
201 mem = Handle(mem_ptr)
202
203 !verbose && CVodeSetErrHandlerFn(mem,
204 @cfunction(null_error_handler, Nothing,
205 (Cint, Char, Char, Ptr{Cvoid})),
206 C_NULL)
207
208 save_start ? ts = [t0] : ts = Float64[]
209
210 out = copy(u0)
211 uvec = vec(u0) # aliases u0
212 utmp = NVector(uvec) # aliases u0
213
214 use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) &&
215 LinearSolver ∈ SPARSE_SOLVERS) ||
216 prob.f.jac_prototype isa AbstractSciMLOperator
217 userfun = FunJac(f!,
218 prob.f.jac,
219 prob.p,
220 nothing,
221 use_jac_prototype ? prob.f.jac_prototype : nothing,
222 alg.prec,
223 alg.psetup,
224 u0,
225 out)
226
227 function getcfunf(::T) where {T}
228 @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T}))
229 end
230
231 flag = CVodeInit(mem, getcfunf(userfun), t0, utmp)
232
233 dt !== nothing && (flag = CVodeSetInitStep(mem, Float64(dt)))
234 flag = CVodeSetMinStep(mem, Float64(dtmin))
235 flag = CVodeSetMaxStep(mem, Float64(dtmax))
236 flag = CVodeSetUserData(mem, userfun)
237 if abstol isa Array
238 flag = CVodeSVtolerances(mem, reltol, abstol)
239 else
240 flag = CVodeSStolerances(mem, reltol, abstol)
241 end
242 flag = CVodeSetMaxNumSteps(mem, maxiters)
243 flag = CVodeSetMaxOrd(mem, alg.max_order)
244 flag = CVodeSetMaxHnilWarns(mem, alg.max_hnil_warns)
245 flag = CVodeSetStabLimDet(mem, alg.stability_limit_detect)
246 flag = CVodeSetMaxErrTestFails(mem, alg.max_error_test_failures)
247 flag = CVodeSetMaxNonlinIters(mem, alg.max_nonlinear_iters)
248 flag = CVodeSetMaxConvFails(mem, alg.max_convergence_failures)
249
250 nojacobian = true
251
252 if Method == :Newton # Only use a linear solver if it's a Newton-based method
253 if LinearSolver in (:Dense, :LapackDense)
254 nojacobian = false
255 A = SUNDenseMatrix(length(uvec), length(uvec))
256 _A = MatrixHandle(A, DenseMatrix())
257 if LinearSolver === :Dense
258 LS = SUNLinSol_Dense(uvec, A)
259 _LS = LinSolHandle(LS, Dense())
260 else
261 LS = SUNLinSol_LapackDense(uvec, A)
262 _LS = LinSolHandle(LS, LapackDense())
263 end
264 elseif LinearSolver in (:Band, :LapackBand)
265 nojacobian = false
266 A = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower)
267 _A = MatrixHandle(A, BandMatrix())
268 if LinearSolver === :Band
269 LS = SUNLinSol_Band(uvec, A)
270 _LS = LinSolHandle(LS, Band())
271 else
272 LS = SUNLinSol_LapackBand(uvec, A)
273 _LS = LinSolHandle(LS, LapackBand())
274 end
275 elseif LinearSolver == :Diagonal
276 nojacobian = false
277 flag = CVDiag(mem)
278 _A = nothing
279 _LS = nothing
280 elseif LinearSolver == :GMRES
281 LS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.krylov_dim)
282 _A = nothing
283 _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR())
284 elseif LinearSolver == :FGMRES
285 LS = SUNLinSol_SPFGMR(uvec, alg.prec_side, alg.krylov_dim)
286 _A = nothing
287 _LS = LinSolHandle(LS, SPFGMR())
288 elseif LinearSolver == :BCG
289 LS = SUNLinSol_SPBCGS(uvec, alg.prec_side, alg.krylov_dim)
290 _A = nothing
291 _LS = LinSolHandle(LS, SPBCGS())
292 elseif LinearSolver == :PCG
293 LS = SUNLinSol_PCG(uvec, alg.prec_side, alg.krylov_dim)
294 _A = nothing
295 _LS = LinSolHandle(LS, PCG())
296 elseif LinearSolver == :TFQMR
297 LS = SUNLinSol_SPTFQMR(uvec, alg.prec_side, alg.krylov_dim)
298 _A = nothing
299 _LS = LinSolHandle(LS, PTFQMR())
300 elseif LinearSolver == :KLU
301 nojacobian = false
302 nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype))
303 A = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT)
304 LS = SUNLinSol_KLU(uvec, A)
305 _A = MatrixHandle(A, SparseMatrix())
306 _LS = LinSolHandle(LS, KLU())
307 end
308 if LinearSolver !== :Diagonal
309 flag = CVodeSetLinearSolver(mem, LS, _A === nothing ? C_NULL : A)
310 end
311 NLS = SUNNonlinSol_Newton(uvec)
312 else
313 _A = nothing
314 _LS = nothing
315 # TODO: Anderson Acceleration
316 anderson_m = 0
317 NLS = SUNNonlinSol_FixedPoint(uvec, anderson_m)
318 end
319 CVodeSetNonlinearSolver(mem, NLS)
320
321 if DiffEqBase.has_jac(prob.f) && Method == :Newton
322 function getcfunjac(::T) where {T}
323 @cfunction(cvodejac,
324 Cint,
325 (realtype,
326 N_Vector,
327 N_Vector,
328 SUNMatrix,
329 Ref{T},
330 N_Vector,
331 N_Vector,
332 N_Vector))
333 end
334 jac = getcfunjac(userfun)
335 flag = CVodeSetUserData(mem, userfun)
336 nojacobian || (flag = CVodeSetJacFn(mem, jac))
337 else
338 jac = nothing
339 end
340
341 if prob.f.jac_prototype isa AbstractSciMLOperator
342 "here!!!!"
343 function getcfunjtimes(::T) where {T}
344 @cfunction(jactimes,
345 Cint,
346 (N_Vector, N_Vector, realtype, N_Vector, N_Vector, Ref{T}, N_Vector))
347 end
348 jtimes = getcfunjtimes(userfun)
349 CVodeSetJacTimes(mem, C_NULL, jtimes)
350 end
351
352 if alg.prec !== nothing
353 function getpercfun(::T) where {T}
354 @cfunction(precsolve,
355 Cint,
356 (Float64,
357 N_Vector,
358 N_Vector,
359 N_Vector,
360 N_Vector,
361 Float64,
362 Float64,
363 Int,
364 Ref{T}))
365 end
366 precfun = getpercfun(userfun)
367
368 function getpsetupfun(::T) where {T}
369 @cfunction(precsetup,
370 Cint,
371 (Float64, N_Vector, N_Vector, Int, Ptr{Int}, Float64, Ref{T}))
372 end
373 psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun)
374
375 CVodeSetPreconditioner(mem, psetupfun, precfun)
376 end
377
378 tmp = isnothing(callbacks_internal) ? u0 : similar(u0)
379 uprev = isnothing(callbacks_internal) ? u0 : similar(u0)
380 tout = [tspan[1]]
381
382 if save_start
383 if save_idxs === nothing
384 ures = Vector{uType}()
385 dures = Vector{uType}()
386 save_value!(ures, u0, uType, save_idxs)
387 if dense
388 f!(out, u0, prob.p, tspan[1])
389 save_value!(dures, out, uType, save_idxs)
390 end
391 else
392 ures = [u0[save_idxs]]
393 if dense
394 f!(out, u0, prob.p, tspan[1])
395 dures = [out[save_idxs]]
396 end
397 end
398 else
399 ures = Vector{uType}()
400 dures = Vector{uType}()
401 end
402
403 sol = DiffEqBase.build_solution(prob,
404 alg,
405 ts,
406 ures;
407 dense = dense,
408 interp = dense ?
409 DiffEqBase.HermiteInterpolation(ts, ures,
410 dures) :
411 DiffEqBase.LinearInterpolation(ts, ures),
412 timeseries_errors = timeseries_errors,
413 stats = SciMLBase.DEStats(0),
414 calculate_error = false)
415 opts = DEOptions(saveat_internal,
416 tstops_internal,
417 saveat, tstops, save_start,
418 save_everystep, save_idxs,
419 dense,
420 timeseries_errors,
421 dense_errors,
422 save_on,
423 save_end,
424 callbacks_internal,
425 abstol,
426 reltol,
427 verbose,
428 advance_to_tstop,
429 stop_at_next_tstop,
430 progress,
431 progress_steps,
432 progress_name,
433 progress_message,
434 progress_id,
435 maxiters)
436 integrator = CVODEIntegrator(u0,
437 utmp,
438 prob.p,
439 t0,
440 t0,
441 mem,
442 _LS,
443 _A,
444 sol,
445 alg,
446 f!,
447 userfun,
448 jac,
449 opts,
450 tout,
451 tdir,
452 false,
453 tmp,
454 uprev,
455 Cint(flag),
456 false,
457 0,
458 1,
459 callback_cache,
460 0.0)
461 initialize_callbacks!(integrator)
462 integrator
463 end # function solve
464
465 function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, isinplace},
466 alg::ARKODE{Method, LinearSolver, MassLinearSolver},
467 timeseries = [],
468 ts = [],
469 ks = [];
470 verbose = true,
471 callback = nothing,
472 abstol = 1 / 10^6,
473 reltol = 1 / 10^3,
474 saveat = Float64[],
475 tstops = Float64[],
476 d_discontinuities = Float64[],
477 maxiters = Int(1e5),
478 dt = nothing,
479 dtmin = 0.0,
480 dtmax = 0.0,
481 timeseries_errors = true,
482 dense_errors = false,
483 save_everystep = isempty(saveat), save_idxs = nothing,
484 dense = save_everystep,
485 save_on = true,
486 save_start = true,
487 save_end = true,
488 save_timeseries = nothing,
489 progress = false,
490 progress_steps = 1000,
491 progress_name = "ODE",
492 progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE,
493 progress_id = gensym("Sundials"),
494 advance_to_tstop = false,
495 stop_at_next_tstop = false,
496 userdata = nothing,
497 alias_u0 = false,
498 kwargs...) where {uType, tupType, isinplace, Method,
499 LinearSolver,
500 MassLinearSolver}
501 tType = eltype(tupType)
502
503 if verbose
504 warned = !isempty(kwargs) && DiffEqBase.check_keywords(alg, kwargs, warnlist)
505 warned && DiffEqBase.warn_compat()
506 end
507
508 if reltol isa AbstractArray
509 error("Sundials only allows scalar reltol.")
510 end
511
512 if length(prob.u0) <= 0
513 error("Sundials requires at least one state variable.")
514 end
515
516 progress && Logging.@logmsg(Logging.LogLevel(-1), progress_name, _id=progress_id, progress=0)
517
518 tstops = vcat(tstops, d_discontinuities)
519 callbacks_internal = DiffEqBase.CallbackSet(callback)
520
521 max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal)
522 if max_len_cb isa VectorContinuousCallback
523 callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64)
524 else
525 callback_cache = nothing
526 end
527
528 tspan = prob.tspan
529 t0 = tspan[1]
530
531 tdir = sign(tspan[2] - tspan[1])
532
533 tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir,
534 tspan, tType)
535
536 if prob.u0 isa Number
537 u0 = [prob.u0]
538 else
539 if alias_u0
540 u0 = prob.u0
541 else
542 u0 = copy(prob.u0)
543 end
544 end
545
546 save_start ? ts = [t0] : ts = Float64[]
547 out = copy(u0)
548 uvec = vec(u0)
549 utmp = NVector(uvec)
550
551 function arkodemem(; fe = C_NULL, fi = C_NULL, t0 = t0, u0 = utmp)
552 mem_ptr = ARKStepCreate(fe, fi, t0, u0)
553 (mem_ptr == C_NULL) && error("Failed to allocate ARKODE solver object")
554 mem = Handle(mem_ptr)
555
556 !verbose && ARKStepSetErrHandlerFn(mem,
557 @cfunction(null_error_handler, Nothing,
558 (Cint, Char, Char, Ptr{Cvoid})),
559 C_NULL)
560 return mem
561 end
562
563 ### Fix the more general function to Sundials allowed style
564 if !isinplace && prob.u0 isa Number
565 f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0))
566 elseif !isinplace
567 f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0))
568 else # Then it's an in-place function on an abstract array
569 f! = prob.f
570 end
571
572 if prob.problem_type isa SplitODEProblem
573
574 ### Fix the more general function to Sundials allowed style
575 if !isinplace && prob.u0 isa Number
576 f1! = (du, u, p, t) -> (du .= prob.f.f1(first(u), p, t); Cint(0))
577 f2! = (du, u, p, t) -> (du .= prob.f.f2(first(u), p, t); Cint(0))
578 elseif !isinplace
579 f1! = (du, u, p, t) -> (du .= prob.f.f1(u, p, t); Cint(0))
580 f2! = (du, u, p, t) -> (du .= prob.f.f2(u, p, t); Cint(0))
581 else # Then it's an in-place function on an abstract array
582 f1! = prob.f.f1
583 f2! = prob.f.f2
584 end
585
586 use_jac_prototype = (isa(prob.f.f1.jac_prototype, SparseArrays.SparseMatrixCSC) &&
587 LinearSolver ∈ SPARSE_SOLVERS)
588 userfun = FunJac(f1!,
589 f2!,
590 prob.f.f1.jac,
591 prob.p,
592 prob.f.mass_matrix,
593 use_jac_prototype ? prob.f.f1.jac_prototype : nothing,
594 alg.prec,
595 alg.psetup,
596 u0,
597 out,
598 nothing)
599
600 function getcfunjac(::T) where {T}
601 @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T}))
602 end
603 function getcfunjac2(::T) where {T}
604 @cfunction(cvodefunjac2, Cint, (realtype, N_Vector, N_Vector, Ref{T}))
605 end
606 cfj1 = getcfunjac(userfun)
607 cfj2 = getcfunjac2(userfun)
608
609 mem = arkodemem(; fi = cfj1, fe = cfj2)
610 else
611 use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) &&
612 LinearSolver ∈ SPARSE_SOLVERS)
613 userfun = FunJac(f!,
614 prob.f.jac,
615 prob.p,
616 prob.f.mass_matrix,
617 use_jac_prototype ? prob.f.jac_prototype : nothing,
618 alg.prec,
619 alg.psetup,
620 u0,
621 out)
622 if alg.stiffness == Explicit()
623 function getcfun1(::T) where {T}
624 @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T}))
625 end
626 cfj1 = getcfun1(userfun)
627 mem = arkodemem(; fe = cfj1)
628 elseif alg.stiffness == Implicit()
629 function getcfun2(::T) where {T}
630 @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T}))
631 end
632 cfj2 = getcfun2(userfun)
633 mem = arkodemem(; fi = cfj2)
634 end
635 end
636
637 dt !== nothing && (flag = ARKStepSetInitStep(mem, Float64(dt)))
638 flag = ARKStepSetMinStep(mem, Float64(dtmin))
639 flag = ARKStepSetMaxStep(mem, Float64(dtmax))
640 flag = ARKStepSetUserData(mem, userfun)
641 if abstol isa Array
642 flag = ARKStepSVtolerances(mem, reltol, abstol)
643 else
644 flag = ARKStepSStolerances(mem, reltol, abstol)
645 end
646 flag = ARKStepSetMaxNumSteps(mem, maxiters)
647 flag = ARKStepSetMaxHnilWarns(mem, alg.max_hnil_warns)
648 flag = ARKStepSetMaxErrTestFails(mem, alg.max_error_test_failures)
649 flag = ARKStepSetMaxConvFails(mem, alg.max_convergence_failures)
650 flag = ARKStepSetPredictorMethod(mem, alg.predictor_method)
651 flag = ARKStepSetNonlinConvCoef(mem, alg.nonlinear_convergence_coefficient)
652 flag = ARKStepSetDenseOrder(mem, alg.dense_order)
653
654 #=
655 Reference from Manual on ARKODE
656 To choose an explicit table, set itable to a negative value. This automatically calls ARKStepSetExplicit(). However, if the problem is posed in explicit form, i.e. 𝑦 ̇ = 𝑓 (𝑡, 𝑦), then we recommend that the ERKStep time- stepper module be used instead of ARKStep.
657 To select an implicit table, set etable to a negative value. This automatically calls ARKStepSetImplicit(). If both itable and etable are non-negative, then these should match an existing implicit/explicit pair, listed in the section Additive Butcher tables. This automatically calls ARKStepSetImEx().
658 =#
659 if alg.itable === nothing && alg.etable === nothing
660 flag = ARKStepSetOrder(mem, alg.order)
661 elseif alg.itable === nothing && alg.etable !== nothing
662 flag = ARKStepSetTableNum(mem, -1, alg.etable)
663 elseif alg.itable !== nothing && alg.etable === nothing
664 flag = ARKStepSetTableNum(mem, alg.itable, -1)
665 else
666 flag = ARKStepSetTableNum(mem, alg.itable, alg.etable)
667 end
668
669 flag = ARKStepSetNonlinCRDown(mem, alg.crdown)
670 flag = ARKStepSetNonlinRDiv(mem, alg.rdiv)
671 flag = ARKStepSetDeltaGammaMax(mem, alg.dgmax)
672 flag = ARKStepSetMaxStepsBetweenLSet(mem, alg.msbp)
673 #flag = ARKStepSetAdaptivityMethod(mem,alg.adaptivity_method,1,0)
674
675 #flag = ARKStepSetFixedStep(mem,)
676 alg.set_optimal_params && (flag = ARKStepSetOptimalParams(mem))
677
678 if Method == :Newton && alg.stiffness !== Explicit() # Only use a linear solver if it's a Newton-based method
679 if LinearSolver in (:Dense, :LapackDense)
680 nojacobian = false
681 A = SUNDenseMatrix(length(uvec), length(uvec))
682 _A = MatrixHandle(A, DenseMatrix())
683 if LinearSolver === :Dense
684 LS = SUNLinSol_Dense(uvec, A)
685 _LS = LinSolHandle(LS, Dense())
686 else
687 LS = SUNLinSol_LapackDense(uvec, A)
688 _LS = LinSolHandle(LS, LapackDense())
689 end
690 elseif LinearSolver in (:Band, :LapackBand)
691 nojacobian = false
692 A = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower)
693 _A = MatrixHandle(A, BandMatrix())
694 if LinearSolver === :Band
695 LS = SUNLinSol_Band(uvec, A)
696 _LS = LinSolHandle(LS, Band())
697 else
698 LS = SUNLinSol_LapackBand(uvec, A)
699 _LS = LinSolHandle(LS, LapackBand())
700 end
701 elseif LinearSolver == :GMRES
702 LS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.krylov_dim)
703 _A = nothing
704 _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR())
705 elseif LinearSolver == :FGMRES
706 LS = SUNLinSol_SPFGMR(uvec, alg.prec_side, alg.krylov_dim)
707 _A = nothing
708 _LS = LinSolHandle(LS, SPFGMR())
709 elseif LinearSolver == :BCG
710 LS = SUNLinSol_SPBCGS(uvec, alg.prec_side, alg.krylov_dim)
711 _A = nothing
712 _LS = LinSolHandle(LS, SPBCGS())
713 elseif LinearSolver == :PCG
714 LS = SUNLinSol_PCG(uvec, alg.prec_side, alg.krylov_dim)
715 _A = nothing
716 _LS = LinSolHandle(LS, PCG())
717 elseif LinearSolver == :TFQMR
718 LS = SUNLinSol_SPTFQMR(uvec, alg.prec_side, alg.krylov_dim)
719 _A = nothing
720 _LS = LinSolHandle(LS, PTFQMR())
721 elseif LinearSolver == :KLU
722 nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype))
723 A = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT)
724 LS = SUNLinSol_KLU(uvec, A)
725 _A = MatrixHandle(A, SparseMatrix())
726 _LS = LinSolHandle(LS, KLU())
727 end
728 flag = ARKStepSetLinearSolver(mem, LS, _A === nothing ? C_NULL : A)
729 flag = ARKStepSetMaxNonlinIters(mem, alg.max_nonlinear_iters)
730 elseif Method == :Functional && alg.stiffness !== Explicit()
731 ARKStepSetFixedPoint(mem, Clong(alg.krylov_dim))
732 else
733 _A = nothing
734 _LS = nothing
735 end
736
737 if (prob.problem_type isa SplitODEProblem &&
738 prob.f.f1.jac_prototype isa AbstractSciMLOperator) ||
739 (!(prob.problem_type isa SplitODEProblem) &&
740 prob.f.jac_prototype isa AbstractSciMLOperator) &&
741 alg.stiffness !== Explicit()
742 function getcfunjtimes(::T) where {T}
743 @cfunction(jactimes,
744 Cint,
745 (N_Vector, N_Vector, realtype, N_Vector, N_Vector, Ref{T}, N_Vector))
746 end
747 jtimes = getcfunjtimes(userfun)
748 ARKStepSetJacTimes(mem, C_NULL, jtimes)
749 end
750
751 if prob.f.mass_matrix != LinearAlgebra.I && alg.stiffness !== Explicit()
752 if MassLinearSolver in (:Dense, :LapackDense)
753 nojacobian = false
754 M = SUNDenseMatrix(length(uvec), length(uvec))
755 _M = MatrixHandle(M, DenseMatrix())
756 if MassLinearSolver === :Dense
757 MLS = SUNLinSol_Dense(uvec, M)
758 _MLS = LinSolHandle(MLS, Dense())
759 else
760 MLS = SUNLinSol_LapackDense(uvec, M)
761 _MLS = LinSolHandle(MLS, LapackDense())
762 end
763 elseif MassLinearSolver in (:Band, :LapackBand)
764 nojacobian = false
765 M = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower)
766 _M = MatrixHandle(M, BandMatrix())
767 if MassLinearSolver === :Band
768 MLS = SUNLinSol_Band(uvec, M)
769 _MLS = LinSolHandle(MLS, Band())
770 else
771 MLS = SUNLinSol_LapackBand(uvec, M)
772 _MLS = LinSolHandle(MLS, LapackBand())
773 end
774 elseif MassLinearSolver == :GMRES
775 MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim)
776 _M = nothing
777 _MLS = LinSolHandle(MLS, SPGMR())
778 elseif MassLinearSolver == :FGMRES
779 MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim)
780 _M = nothing
781 _MLS = LinSolHandle(MLS, SPFGMR())
782 elseif MassLinearSolver == :BCG
783 MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim)
784 _M = nothing
785 _MLS = LinSolHandle(MLS, SPBCGS())
786 elseif MassLinearSolver == :PCG
787 MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim)
788 _M = nothing
789 _MLS = LinSolHandle(MLS, PCG())
790 elseif MassLinearSolver == :TFQMR
791 MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim)
792 _M = nothing
793 _MLS = LinSolHandle(MLS, PTFQMR())
794 elseif MassLinearSolver == :KLU
795 nnz = length(SparseArrays.nonzeros(prob.f.mass_matrix))
796 M = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT)
797 MLS = SUNLinSol_KLU(uvec, M)
798 _M = MatrixHandle(M, SparseMatrix())
799 _MLS = LinSolHandle(MLS, KLU())
800 end
801 flag = ARKStepSetMassLinearSolver(mem, MLS, _M === nothing ? C_NULL : M, false)
802 function getmatfun(::T) where {T}
803 @cfunction(massmat,
804 Cint,
805 (realtype, SUNMatrix, Ref{T}, N_Vector, N_Vector, N_Vector))
806 end
807 matfun = getmatfun(userfun)
808 ARKStepSetMassFn(mem, matfun)
809 else
810 _M = nothing
811 _MLS = nothing
812 end
813
814 if DiffEqBase.has_jac(prob.f) && alg.stiffness !== Explicit()
815 function getfunjac(::T) where {T}
816 @cfunction(cvodejac,
817 Cint,
818 (realtype,
819 N_Vector,
820 N_Vector,
821 SUNMatrix,
822 Ref{T},
823 N_Vector,
824 N_Vector,
825 N_Vector))
826 end
827 jac = getfunjac(userfun)
828 flag = ARKStepSetUserData(mem, userfun)
829 flag = ARKStepSetJacFn(mem, jac)
830 else
831 jac = nothing
832 end
833
834 if alg.prec !== nothing && alg.stiffness !== Explicit()
835 function getpercfun(::T) where {T}
836 @cfunction(precsolve,
837 Cint,
838 (Float64,
839 N_Vector,
840 N_Vector,
841 N_Vector,
842 N_Vector,
843 Float64,
844 Float64,
845 Int,
846 Ref{T}))
847 end
848 precfun = getpercfun(userfun)
849
850 function getpsetupfun(::T) where {T}
851 @cfunction(precsetup,
852 Cint,
853 (Float64, N_Vector, N_Vector, Int, Ptr{Int}, Float64, Ref{T}))
854 end
855 psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun)
856
857 ARKStepSetPreconditioner(mem, psetupfun, precfun)
858 end
859
860 tmp = isnothing(callbacks_internal) ? u0 : similar(u0)
861 uprev = isnothing(callbacks_internal) ? u0 : similar(u0)
862 tout = [tspan[1]]
863
864 if save_start
865 if save_idxs === nothing
866 ures = Vector{uType}()
867 dures = Vector{uType}()
868 save_value!(ures, u0, uType, save_idxs)
869 if dense
870 f!(out, u0, prob.p, tspan[1])
871 save_value!(dures, out, uType, save_idxs)
872 end
873 else
874 ures = [u0[save_idxs]]
875 if dense
876 f!(out, u0, prob.p, tspan[1])
877 dures = [out[save_idxs]]
878 end
879 end
880 else
881 ures = Vector{uType}()
882 dures = Vector{uType}()
883 end
884
885 sol = DiffEqBase.build_solution(prob,
886 alg,
887 ts,
888 ures;
889 dense = dense,
890 interp = dense ?
891 DiffEqBase.HermiteInterpolation(ts, ures,
892 dures) :
893 DiffEqBase.LinearInterpolation(ts, ures),
894 timeseries_errors = timeseries_errors,
895 stats = SciMLBase.DEStats(0),
896 calculate_error = false)
897 opts = DEOptions(saveat_internal,
898 tstops_internal,
899 saveat, tstops, save_start,
900 save_everystep, save_idxs,
901 dense,
902 timeseries_errors,
903 dense_errors,
904 save_on,
905 save_end,
906 callbacks_internal,
907 abstol,
908 reltol,
909 verbose,
910 advance_to_tstop,
911 stop_at_next_tstop,
912 progress,
913 progress_steps,
914 progress_name,
915 progress_message,
916 progress_id,
917 maxiters)
918 integrator = ARKODEIntegrator(u0,
919 utmp,
920 prob.p,
921 t0,
922 t0,
923 mem,
924 _LS,
925 _A,
926 _MLS,
927 _M,
928 sol,
929 alg,
930 f!,
931 userfun,
932 jac,
933 opts,
934 tout,
935 tdir,
936 false,
937 tmp,
938 uprev,
939 Cint(flag),
940 false,
941 0,
942 1,
943 callback_cache,
944 0.0)
945
946 initialize_callbacks!(integrator)
947 integrator
948 end # function solve
949
950 function tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType)
951 tstops_internal = DataStructures.BinaryHeap{tType}(DataStructures.FasterForward())
952 saveat_internal = DataStructures.BinaryHeap{tType}(DataStructures.FasterForward())
953
954 t0, tf = tspan
955 tdir_t0 = tdir * t0
956 tdir_tf = tdir * tf
957
958 for t in tstops
959 tdir_t = tdir * t
960 tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t)
961 end
962 push!(tstops_internal, tdir_tf)
963
964 if saveat isa Number
965 saveat = (t0:tdir*abs(saveat):tf)[2:end]
966 end
967 for t in saveat
968 tdir_t = tdir * t
969 tdir_t0 < tdir_t ≤ tdir_tf && push!(saveat_internal, tdir_t)
970 end
971
972 tstops_internal, saveat_internal
973 end
974
975 ## Solve for DAEs uses IDA
976
977 function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tupType,
978 isinplace},
979 alg::SundialsDAEAlgorithm{LinearSolver},
980 timeseries = [],
981 ts = [],
982 ks = [];
983 verbose = true,
984 dt = nothing,
985 dtmax = 0.0,
986 save_on = true,
987 save_start = true,
988 callback = nothing,
989 abstol = 1 / 10^6,
990 reltol = 1 / 10^3,
991 saveat = Float64[],
992 tstops = Float64[],
993 d_discontinuities = Float64[],
994 maxiters = Int(1e5),
995 timeseries_errors = true,
996 dense_errors = false,
997 save_everystep = isempty(saveat), save_idxs = nothing,
998 dense = save_everystep,
999 save_timeseries = nothing,
1000 save_end = true,
1001 progress = false,
1002 progress_steps = 1000,
1003 progress_name = "DAE IDA",
1004 progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE,
1005 progress_id = gensym("Sundials"),
1006 advance_to_tstop = false,
1007 stop_at_next_tstop = false,
1008 userdata = nothing,
1009 initializealg = IDADefaultInit(),
1010 kwargs...) where {uType, duType, tupType, isinplace, LinearSolver
1011 }
1012 tType = eltype(tupType)
1013
1014 if verbose
1015 warned = !isempty(kwargs) && DiffEqBase.check_keywords(alg, kwargs, warnida)
1016 warned && DiffEqBase.warn_compat()
1017 end
1018
1019 if reltol isa AbstractArray
1020 error("Sundials only allows scalar reltol.")
1021 end
1022
1023 if length(prob.u0) == 0
1024 error("Sundials requires at least one state variable.")
1025 end
1026
1027 progress && Logging.@logmsg(Logging.LogLevel(-1), progress_name, _id=progress_id, progress=0)
1028
1029 tstops = vcat(tstops, d_discontinuities)
1030 callbacks_internal = DiffEqBase.CallbackSet(callback)
1031
1032 max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal)
1033 if max_len_cb isa VectorContinuousCallback
1034 callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64)
1035 else
1036 callback_cache = nothing
1037 end
1038
1039 tspan = prob.tspan
1040 t0 = tspan[1]
1041
1042 tdir = sign(tspan[2] - tspan[1])
1043
1044 tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir,
1045 tspan, tType)
1046 @assert size(prob.u0) == size(prob.du0)
1047 if prob.u0 isa Number
1048 u0 = [prob.u0]
1049 du0 = [prob.du0]
1050 else
1051 u0 = copy(prob.u0)
1052 du0 = copy(prob.du0)
1053 end
1054
1055 ### Fix the more general function to Sundials allowed style
1056 if !isinplace && prob.u0 isa Number
1057 f! = (out, du, u, p, t) -> (out .= prob.f(first(du), first(u), p, t); Cint(0))
1058 elseif !isinplace
1059 f! = (out, du, u, p, t) -> (out .= prob.f(du, u, p, t); Cint(0))
1060 else # Then it's an in-place function on an abstract array
1061 f! = prob.f
1062 end
1063
1064 mem_ptr = IDACreate()
1065 (mem_ptr == C_NULL) && error("Failed to allocate IDA solver object")
1066 mem = Handle(mem_ptr)
1067
1068 !verbose && IDASetErrHandlerFn(mem,
1069 @cfunction(null_error_handler, Nothing,
1070 (Cint, Char, Char, Ptr{Cvoid})),
1071 C_NULL)
1072
1073 ts = [t0]
1074
1075 # vec shares memory
1076 utmp = NVector(vec(u0))
1077 dutmp = NVector(vec(du0))
1078 rtest = zeros(size(u0))
1079
1080 use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) &&
1081 LinearSolver ∈ SPARSE_SOLVERS)
1082 userfun = FunJac(f!,
1083 prob.f.jac,
1084 prob.p,
1085 nothing,
1086 use_jac_prototype ? prob.f.jac_prototype : nothing,
1087 alg.prec,
1088 alg.psetup,
1089 u0,
1090 du0,
1091 rtest)
1092
1093 function getcfun(::T) where {T}
1094 @cfunction(idasolfun, Cint, (realtype, N_Vector, N_Vector, N_Vector, Ref{T}))
1095 end
1096 cfun = getcfun(userfun)
1097 flag = IDAInit(mem, cfun, t0, utmp, dutmp)
1098 dt !== nothing && (flag = IDASetInitStep(mem, dt))
1099 flag = IDASetUserData(mem, userfun)
1100 flag = IDASetMaxStep(mem, dtmax)
1101 if abstol isa Array
1102 flag = IDASVtolerances(mem, reltol, abstol)
1103 else
1104 flag = IDASStolerances(mem, reltol, abstol)
1105 end
1106 flag = IDASetMaxNumSteps(mem, maxiters)
1107 flag = IDASetMaxOrd(mem, alg.max_order)
1108 flag = IDASetMaxErrTestFails(mem, alg.max_error_test_failures)
1109 flag = IDASetNonlinConvCoef(mem, alg.nonlinear_convergence_coefficient)
1110 flag = IDASetMaxNonlinIters(mem, alg.max_nonlinear_iters)
1111 flag = IDASetMaxConvFails(mem, alg.max_convergence_failures)
1112 flag = IDASetNonlinConvCoefIC(mem, alg.nonlinear_convergence_coefficient_ic)
1113 flag = IDASetMaxNumStepsIC(mem, alg.max_num_steps_ic)
1114 flag = IDASetMaxNumJacsIC(mem, alg.max_num_jacs_ic)
1115 flag = IDASetMaxNumItersIC(mem, alg.max_num_iters_ic)
1116 #flag = IDASetMaxBacksIC(mem,alg.max_num_backs_ic) # Needs newer version?
1117 flag = IDASetLineSearchOffIC(mem, alg.use_linesearch_ic)
1118
1119 prec_side = isnothing(alg.prec) ? 0 : 1 # IDA only supports left preconditioning (prec_side = 1)
1120 if LinearSolver in (:Dense, :LapackDense)
1121 nojacobian = false
1122 A = SUNDenseMatrix(length(u0), length(u0))
1123 _A = MatrixHandle(A, DenseMatrix())
1124 if LinearSolver === :Dense
1125 LS = SUNLinSol_Dense(utmp, A)
1126 _LS = LinSolHandle(LS, Dense())
1127 else
1128 LS = SUNLinSol_LapackDense(u0, A)
1129 _LS = LinSolHandle(LS, LapackDense())
1130 end
1131 elseif LinearSolver in (:Band, :LapackBand)
1132 nojacobian = false
1133 A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower)
1134 _A = MatrixHandle(A, BandMatrix())
1135 if LinearSolver === :Band
1136 LS = SUNLinSol_Band(utmp, A)
1137 _LS = LinSolHandle(LS, Band())
1138 else
1139 LS = SUNLinSol_LapackBand(utmp, A)
1140 _LS = LinSolHandle(LS, LapackBand())
1141 end
1142 elseif LinearSolver == :GMRES
1143 LS = SUNLinSol_SPGMR(utmp, prec_side, alg.krylov_dim)
1144 _A = nothing
1145 _LS = LinSolHandle(LS, SPGMR())
1146 elseif LinearSolver == :FGMRES
1147 LS = SUNLinSol_SPFGMR(utmp, prec_side, alg.krylov_dim)
1148 _A = nothing
1149 _LS = LinSolHandle(LS, SPFGMR())
1150 elseif LinearSolver == :BCG
1151 LS = SUNLinSol_SPBCGS(utmp, prec_side, alg.krylov_dim)
1152 _A = nothing
1153 _LS = LinSolHandle(LS, SPBCGS())
1154 elseif LinearSolver == :PCG
1155 LS = SUNLinSol_PCG(utmp, prec_side, alg.krylov_dim)
1156 _A = nothing
1157 _LS = LinSolHandle(LS, PCG())
1158 elseif LinearSolver == :TFQMR
1159 LS = SUNLinSol_SPTFQMR(utmp, prec_side, alg.krylov_dim)
1160 _A = nothing
1161 _LS = LinSolHandle(LS, PTFQMR())
1162 elseif LinearSolver == :KLU
1163 nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype))
1164 A = SUNSparseMatrix(length(u0), length(u0), nnz, Sundials.CSC_MAT)
1165 LS = SUNLinSol_KLU(utmp, A)
1166 _A = MatrixHandle(A, SparseMatrix())
1167 _LS = LinSolHandle(LS, KLU())
1168 end
1169 flag = IDASetLinearSolver(mem, LS, _A === nothing ? C_NULL : A)
1170
1171 if prob.f.jac_prototype isa AbstractSciMLOperator
1172 function getcfunjtimes(::T) where {T}
1173 @cfunction(idajactimes,
1174 Cint,
1175 (realtype,
1176 N_Vector,
1177 N_Vector,
1178 N_Vector,
1179 N_Vector,
1180 N_Vector,
1181 realtype,
1182 Ref{T},
1183 N_Vector,
1184 N_Vector))
1185 end
1186 jtimes = getcfunjtimes(userfun)
1187 IDASetJacTimes(mem, C_NULL, jtimes)
1188 end
1189
1190 if alg.prec !== nothing
1191 function getprecfun(::T) where {T}
1192 @cfunction(idaprecsolve,
1193 Cint,
1194 (Float64,
1195 N_Vector,
1196 N_Vector,
1197 N_Vector,
1198 N_Vector,
1199 N_Vector,
1200 Float64,
1201 Float64,
1202 Ref{T}))
1203 end
1204 precfun = getprecfun(userfun)
1205
1206 function getpsetupfun(::T) where {T}
1207 @cfunction(idaprecsetup,
1208 Cint,
1209 (Float64, N_Vector, N_Vector, N_Vector, Float64, Ref{T}))
1210 end
1211 psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun)
1212
1213 IDASetPreconditioner(mem, psetupfun, precfun)
1214 end
1215
1216 if DiffEqBase.has_jac(prob.f)
1217 function getcfunjacc(::T) where {T}
1218 @cfunction(idajac,
1219 Cint,
1220 (realtype,
1221 realtype,
1222 N_Vector,
1223 N_Vector,
1224 N_Vector,
1225 SUNMatrix,
1226 Ref{T},
1227 N_Vector,
1228 N_Vector,
1229 N_Vector))
1230 end
1231 jac = getcfunjacc(userfun)
1232 flag = IDASetUserData(mem, userfun)
1233 flag = IDASetJacFn(mem, jac)
1234 else
1235 jac = nothing
1236 end
1237
1238 tout = Float64[first(tspan)]
1239 if save_idxs isa Integer
1240 ures = Vector{eltype(uType)}()
1241 dures = Vector{eltype(uType)}()
1242 else
1243 ures = Vector{uType}()
1244 dures = Vector{uType}()
1245 end
1246 tmp = isnothing(callbacks_internal) ? u0 : similar(u0)
1247 uprev = isnothing(callbacks_internal) ? u0 : similar(u0)
1248 retcode = flag >= 0 ? ReturnCode.Default : ReturnCode.InitialFailure
1249 sol = DiffEqBase.build_solution(prob,
1250 alg,
1251 ts,
1252 ures, dense ? dures : nothing;
1253 dense = dense,
1254 calculate_error = false,
1255 timeseries_errors = timeseries_errors,
1256 retcode = retcode,
1257 stats = SciMLBase.DEStats(0),
1258 dense_errors = dense_errors)
1259
1260 opts = DEOptions(saveat_internal,
1261 tstops_internal,
1262 saveat, tstops, save_start,
1263 save_everystep, save_idxs,
1264 dense,
1265 timeseries_errors,
1266 dense_errors,
1267 save_on,
1268 save_end,
1269 callbacks_internal,
1270 abstol,
1271 reltol,
1272 verbose,
1273 advance_to_tstop,
1274 stop_at_next_tstop,
1275 progress,
1276 progress_steps,
1277 progress_name,
1278 progress_message,
1279 progress_id,
1280 maxiters)
1281
1282 integrator = IDAIntegrator(u0,
1283 du0,
1284 prob.p,
1285 t0,
1286 t0,
1287 mem,
1288 _LS,
1289 _A,
1290 sol,
1291 alg,
1292 f!,
1293 userfun,
1294 jac,
1295 opts,
1296 tout,
1297 tdir,
1298 false,
1299 tmp,
1300 uprev,
1301 Cint(flag),
1302 0,
1303 false,
1304 0,
1305 1,
1306 callback_cache,
1307 0.0,
1308 utmp,
1309 dutmp,
1310 initializealg)
1311
1312 DiffEqBase.initialize_dae!(integrator, initializealg)
1313 integrator.u_modified && IDAReinit!(integrator)
1314
1315 if save_start
1316 save_value!(ures, integrator.u, uType, save_idxs)
1317 save_value!(dures, integrator.du, duType, save_idxs)
1318 end
1319
1320 initialize_callbacks!(integrator)
1321 integrator
1322 end # function solve
1323
1324 ## Common calls
1325
1326 function interpret_sundials_retcode(flag)
1327 flag >= 0 && return ReturnCode.Success
1328 flag == -1 && return ReturnCode.MaxIters
1329 (flag == -2 || flag == -3) && return ReturnCode.Unstable
1330 flag == -4 && return ReturnCode.ConvergenceFailure
1331 return ReturnCode.Failure
1332 end
1333
1334 function solver_step(integrator::CVODEIntegrator, tstop)
1335 integrator.flag = CVode(integrator.mem, tstop, integrator.u_nvec, integrator.tout,
1336 CV_ONE_STEP)
1337 if integrator.opts.progress
1338 Logging.@logmsg(Logging.LogLevel(-1),
1339 integrator.opts.progress_name,
1340 _id=integrator.opts.progress_id,
1341 message=integrator.opts.progress_message(integrator.dt,
1342 integrator.u,
1343 integrator.p,
1344 integrator.t),
1345 progress=integrator.t / integrator.sol.prob.tspan[2])
1346 end
1347 end
1348 function solver_step(integrator::ARKODEIntegrator, tstop)
1349 integrator.flag = ARKStepEvolve(integrator.mem, tstop, integrator.u_nvec,
1350 integrator.tout, ARK_ONE_STEP)
1351 if integrator.opts.progress
1352 Logging.@logmsg(Logging.LogLevel(-1),
1353 integrator.opts.progress_name,
1354 _id=integrator.opts.progress_id,
1355 message=integrator.opts.progress_message(integrator.dt,
1356 integrator.u_nvec,
1357 integrator.p,
1358 integrator.t),
1359 progress=integrator.t / integrator.sol.prob.tspan[2])
1360 end
1361 end
1362
58 (100 %) samples spent in solver_step
58 (100 %) (incl.) when called from step! line 247
function solver_step(integrator::IDAIntegrator, tstop)
1363 58 (100 %)
58 (100 %) samples spent calling IDASolve
integrator.flag = IDASolve(integrator.mem,
1364 tstop,
1365 integrator.tout,
1366 integrator.u_nvec,
1367 integrator.du_nvec,
1368 IDA_ONE_STEP)
1369 integrator.iter += 1
1370 if integrator.opts.progress && integrator.iter % integrator.opts.progress_steps == 0
1371 Logging.@logmsg(Logging.LogLevel(-1),
1372 integrator.opts.progress_name,
1373 _id=integrator.opts.progress_id,
1374 message=integrator.opts.progress_message(integrator.dt,
1375 integrator.u,
1376 integrator.p,
1377 integrator.t),
1378 progress=integrator.t / integrator.sol.prob.tspan[2])
1379 end
1380 end
1381
1382 function set_stop_time(integrator::CVODEIntegrator, tstop)
1383 CVodeSetStopTime(integrator.mem, tstop)
1384 end
1385 function set_stop_time(integrator::ARKODEIntegrator, tstop)
1386 ARKStepSetStopTime(integrator.mem, tstop)
1387 end
1388 function set_stop_time(integrator::IDAIntegrator, tstop)
1389 IDASetStopTime(integrator.mem, tstop)
1390 end
1391
1392 function get_iters!(integrator::CVODEIntegrator, iters)
1393 CVodeGetNumSteps(integrator.mem, iters)
1394 end
1395 function get_iters!(integrator::ARKODEIntegrator, iters)
1396 ARKStepGetNumSteps(integrator.mem, iters)
1397 end
1398 function get_iters!(integrator::IDAIntegrator, iters)
1399 IDAGetNumSteps(integrator.mem, iters)
1400 end
1401
1402 function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = false,
1403 calculate_error = false)
1404 uType = eltype(integrator.sol.u)
1405 iters = Ref(Clong(-1))
1406 while !isempty(integrator.opts.tstops)
1407 # The call to first is an overload of Base.first implemented in DataStructures
1408 while integrator.tdir * integrator.t < first(integrator.opts.tstops)
1409 tstop = integrator.tdir * first(integrator.opts.tstops)
1410 set_stop_time(integrator, tstop)
1411 integrator.tprev = integrator.t
1412 if !(integrator.opts.callback.continuous_callbacks isa Tuple{})
1413 integrator.uprev .= integrator.u
1414 end
1415 integrator.userfun.p = integrator.p
1416 solver_step(integrator, tstop)
1417 integrator.t = first(integrator.tout)
1418 # NB: CVode, ARKode may warn and then recover if integrator.t == integrator.tprev so don't flag this as an error
1419 integrator.flag < 0 && break
1420 handle_callbacks!(integrator) # this also updates the interpolation
1421 integrator.flag < 0 && break
1422 if isempty(integrator.opts.tstops)
1423 break
1424 end
1425 get_iters!(integrator, iters)
1426 if iters[] + 1 > integrator.opts.maxiters
1427 integrator.flag = -1 # MaxIters: -1
1428 break
1429 end
1430 end
1431 integrator.flag < 0 && break
1432 handle_tstop!(integrator)
1433 end
1434
1435 DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator)
1436 tend = integrator.t
1437 if integrator.opts.save_end &&
1438 (isempty(integrator.sol.t) || integrator.sol.t[end] != tend)
1439 save_value!(integrator.sol.u, integrator.u, uType,
1440 integrator.opts.save_idxs)
1441 push!(integrator.sol.t, tend)
1442 if integrator.opts.dense
1443 save_value!(integrator.sol.interp.du, get_du(integrator), uType,
1444 integrator.opts.save_idxs)
1445 end
1446 end
1447
1448 if integrator.opts.progress
1449 Logging.@logmsg(Logging.LogLevel(-1),
1450 integrator.opts.progress_name,
1451 _id=integrator.opts.progress_id,
1452 message=integrator.opts.progress_message(integrator.dt,
1453 integrator.u,
1454 integrator.p,
1455 integrator.t),
1456 progress="done")
1457 end
1458
1459 fill_stats!(integrator)
1460
1461 if early_free
1462 empty!(integrator.mem)
1463 integrator.A !== nothing && empty!(integrator.A)
1464 integrator.LS !== nothing && empty!(integrator.LS)
1465 end
1466
1467 if DiffEqBase.has_analytic(integrator.sol.prob.f) && calculate_error
1468 DiffEqBase.calculate_solution_errors!(integrator.sol;
1469 timeseries_errors = integrator.opts.timeseries_errors,
1470 dense_errors = integrator.opts.dense_errors)
1471 end
1472
1473 if integrator.sol.retcode == ReturnCode.Default
1474 integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol,
1475 interpret_sundials_retcode(integrator.flag))
1476 end
1477
1478 return integrator.sol
1479 end
1480
1481 function handle_tstop!(integrator::AbstractSundialsIntegrator)
1482 tstops = integrator.opts.tstops
1483 if !isempty(tstops) && integrator.tdir * integrator.t >= first(tstops)
1484 pop!(tstops)
1485 # If we passed multiple tstops at once (possible if Sundials ignores us or we had redundant tstops)
1486 while !isempty(tstops) && integrator.tdir * integrator.t >= first(tstops)
1487 pop!(tstops)
1488 end
1489 integrator.just_hit_tstop = true
1490 end
1491 end
1492
1493 function fill_stats!(integrator::AbstractSundialsIntegrator) end
1494
1495 function fill_stats!(integrator::CVODEIntegrator)
1496 stats = integrator.sol.stats
1497 mem = integrator.mem
1498 tmp = Ref(Clong(-1))
1499 CVodeGetNumRhsEvals(mem, tmp)
1500 stats.nf = tmp[]
1501 CVodeGetNumLinSolvSetups(mem, tmp)
1502 stats.nw = tmp[]
1503 CVodeGetNumErrTestFails(mem, tmp)
1504 stats.nreject = tmp[]
1505 CVodeGetNumSteps(mem, tmp)
1506 stats.naccept = tmp[] - stats.nreject
1507 CVodeGetNumNonlinSolvIters(mem, tmp)
1508 stats.nnonliniter = tmp[]
1509 CVodeGetNumNonlinSolvConvFails(mem, tmp)
1510 stats.nnonlinconvfail = tmp[]
1511 if method_choice(integrator.alg) == :Newton
1512 CVodeGetNumJacEvals(mem, tmp)
1513 stats.njacs = tmp[]
1514 end
1515 end
1516
1517 function fill_stats!(integrator::ARKODEIntegrator)
1518 stats = integrator.sol.stats
1519 mem = integrator.mem
1520 tmp = Ref(Clong(-1))
1521 tmp2 = Ref(Clong(-1))
1522 ARKStepGetNumRhsEvals(mem, tmp, tmp2)
1523 stats.nf = tmp[]
1524 stats.nf2 = tmp2[]
1525 integrator.alg.stiffness !== Explicit() && ARKStepGetNumLinSolvSetups(mem, tmp)
1526 stats.nw = tmp[]
1527 ARKStepGetNumErrTestFails(mem, tmp)
1528 stats.nreject = tmp[]
1529 ARKStepGetNumSteps(mem, tmp)
1530 stats.naccept = tmp[] - stats.nreject
1531 integrator.alg.stiffness !== Explicit() && ARKStepGetNumNonlinSolvIters(mem, tmp)
1532 stats.nnonliniter = tmp[]
1533 integrator.alg.stiffness !== Explicit() && ARKStepGetNumNonlinSolvConvFails(mem, tmp)
1534 stats.nnonlinconvfail = tmp[]
1535 if integrator.alg.stiffness !== Explicit() && method_choice(integrator.alg) == :Newton
1536 ARKStepGetNumJacEvals(mem, tmp)
1537 stats.njacs = tmp[]
1538 end
1539 end
1540
1541 function fill_stats!(integrator::IDAIntegrator)
1542 stats = integrator.sol.stats
1543 mem = integrator.mem
1544 tmp = Ref(Clong(-1))
1545 IDAGetNumResEvals(mem, tmp)
1546 stats.nf = tmp[]
1547 IDAGetNumLinSolvSetups(mem, tmp)
1548 stats.nw = tmp[]
1549 IDAGetNumErrTestFails(mem, tmp)
1550 stats.nreject = tmp[]
1551 IDAGetNumSteps(mem, tmp)
1552 stats.naccept = tmp[] - stats.nreject
1553 IDAGetNumNonlinSolvIters(mem, tmp)
1554 stats.nnonliniter = tmp[]
1555 IDAGetNumNonlinSolvConvFails(mem, tmp)
1556 stats.nnonlinconvfail = tmp[]
1557 if method_choice(integrator.alg) == :Newton
1558 IDAGetNumJacEvals(mem, tmp)
1559 stats.njacs = tmp[]
1560 end
1561 end
1562
1563 function initialize_callbacks!(integrator, initialize_save = true)
1564 t = integrator.t
1565 u = integrator.u
1566 callbacks = integrator.opts.callback
1567 integrator.u_modified = true
1568
1569 u_modified = initialize!(callbacks, u, t, integrator)
1570
1571 # if the user modifies u, we need to fix current values
1572 if u_modified
1573 handle_callback_modifiers!(integrator)
1574
1575 if initialize_save &&
1576 (any((c) -> c.save_positions[2], callbacks.discrete_callbacks) ||
1577 any((c) -> c.save_positions[2], callbacks.continuous_callbacks))
1578 savevalues!(integrator, true)
1579 end
1580 end
1581
1582 # reset this as it is now handled so the integrators should proceed as normal
1583 integrator.u_modified = false
1584 end