Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BVProblem with constraints #3323

Merged
merged 60 commits into from
Feb 22, 2025
Merged
Changes from 56 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
9733460
init
vyudu Nov 22, 2024
d95e4a7
Merge remote-tracking branch 'origin/master' into MTK
vyudu Nov 22, 2024
b3da813
up
vyudu Dec 1, 2024
86c82ce
Merge remote-tracking branch 'origin/master' into MTK
vyudu Dec 1, 2024
4affeac
up
vyudu Dec 1, 2024
a3429ea
up
vyudu Dec 1, 2024
f751fbb
up
vyudu Dec 2, 2024
a9fdfd6
up
vyudu Dec 3, 2024
a9f2106
up
vyudu Dec 4, 2024
18fdd5f
up
vyudu Dec 13, 2024
9d65a33
fixing create_array
vyudu Dec 16, 2024
999ec30
revert Project.toml
vyudu Dec 16, 2024
9226ad6
Up
vyudu Dec 16, 2024
0cb4893
Merge remote-tracking branch 'origin/master' into MTK
vyudu Dec 16, 2024
67d8164
formatting
vyudu Dec 16, 2024
25988f3
up
vyudu Dec 17, 2024
bb28d4f
up
vyudu Dec 17, 2024
b2bf7c0
fix
vyudu Dec 17, 2024
3751c2a
up
vyudu Dec 20, 2024
ef1f089
up
vyudu Jan 8, 2025
d23d6f7
Merge remote-tracking branch 'origin/master' into MTK
vyudu Jan 8, 2025
2a25200
extend BVProblem for constraint equations
vyudu Jan 9, 2025
50504ab
adding tests
vyudu Jan 11, 2025
5d082ab
up
vyudu Jan 11, 2025
b83e003
refactor the bc creation function
vyudu Jan 14, 2025
db5eb66
up
vyudu Jan 14, 2025
e802946
test update
vyudu Jan 15, 2025
e74e047
fix
vyudu Jan 15, 2025
86d4144
test more solvers:
vyudu Jan 17, 2025
ec386fe
Refactor constraints
vyudu Jan 28, 2025
90ce80d
refactor tests
vyudu Jan 28, 2025
a15c670
fix sym validation
vyudu Jan 28, 2025
c6ef04a
remove file
vyudu Jan 28, 2025
7878225
up
vyudu Jan 28, 2025
5bcfdff
up
vyudu Jan 28, 2025
0493b5d
remove lines
vyudu Jan 28, 2025
1d32b6e
up
vyudu Jan 28, 2025
2b3ca96
up
vyudu Jan 28, 2025
0324522
fix typo
vyudu Jan 28, 2025
2a079be
Fix setter
vyudu Jan 28, 2025
d70a470
fix
vyudu Jan 28, 2025
37092f1
lower tol
vyudu Jan 29, 2025
e5eb8bd
fix Project.toml
vyudu Jan 29, 2025
2ae79ae
revert to OrdinaryDiffEq
vyudu Jan 30, 2025
8ae2803
merge master
vyudu Feb 3, 2025
13a242c
update to use updated codegen
vyudu Feb 3, 2025
2fcb9c9
up
vyudu Feb 3, 2025
25b56d7
working codegen
vyudu Feb 4, 2025
c35b797
revert to OrdinaryDiffEqDefault
vyudu Feb 4, 2025
25e84db
use MIRK
vyudu Feb 4, 2025
e6a6932
up
vyudu Feb 4, 2025
5e5c24c
revert to OrdinaryDiffEq
vyudu Feb 4, 2025
5338d4f
tests passing
vyudu Feb 4, 2025
810d4fa
remove problematic tests, codegen assumes MTKParameters
vyudu Feb 4, 2025
6740b8c
test fix
vyudu Feb 4, 2025
603c894
Update src/systems/diffeqs/odesystem.jl
ChrisRackauckas Feb 10, 2025
b10a4a6
Merge branch 'master' into BVP-with-constraints
ChrisRackauckas Feb 10, 2025
9b492cd
Merge remote-tracking branch 'vyudu/BVP-with-constraints' into BVP-wi…
vyudu Feb 11, 2025
dce19c7
Merge remote-tracking branch 'origin' into BVP-with-constraints
vyudu Feb 17, 2025
3642e1b
Merge branch 'master' into BVP-with-constraints
vyudu Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -83,6 +83,8 @@ AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
BoundaryValueDiffEqAscher = "1.1.0"
BoundaryValueDiffEqMIRK = "1.4.0"
ChainRulesCore = "1"
Combinatorics = "1"
CommonSolve = "0.2.4"
@@ -104,11 +106,11 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
EnumX = "1.0.4"
ExprTools = "0.1.10"
Expronicon = "0.8"
FMI = "0.14"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappers = "1.1"
FunctionWrappersWrappers = "0.1"
FMI = "0.14"
Graphs = "1.5.2"
HomotopyContinuation = "2.11"
InfiniteOpt = "0.5"
@@ -143,7 +145,6 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
SymbolicIndexingInterface = "0.3.37"
SymbolicUtils = "3.10.1"
@@ -156,6 +157,8 @@ julia = "1.9"
[extras]
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
@@ -187,4 +190,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
8 changes: 4 additions & 4 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
@@ -153,6 +153,10 @@ include("systems/callbacks.jl")
include("systems/codegen_utils.jl")
include("systems/problem_utils.jl")

include("systems/optimization/constraints_system.jl")
include("systems/optimization/optimizationsystem.jl")
include("systems/optimization/modelingtoolkitize.jl")

include("systems/nonlinear/nonlinearsystem.jl")
include("systems/nonlinear/homotopy_continuation.jl")
include("systems/diffeqs/odesystem.jl")
@@ -168,10 +172,6 @@ include("systems/discrete_system/discrete_system.jl")

include("systems/jumps/jumpsystem.jl")

include("systems/optimization/constraints_system.jl")
include("systems/optimization/optimizationsystem.jl")
include("systems/optimization/modelingtoolkitize.jl")

include("systems/pde/pdesystem.jl")

include("systems/sparsematrixclil.jl")
1 change: 1 addition & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
@@ -964,6 +964,7 @@ for prop in [:eqs
:structure
:op
:constraints
:constraintsystem
:controls
:loss
:bcs
164 changes: 164 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
@@ -748,6 +748,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end

if !isnothing(get_constraintsystem(sys))
error("An ODESystem with constraints cannot be used to construct a regular ODEProblem.
Consider a BVProblem instead.")
end

f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
@@ -770,6 +776,164 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

"""
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
constraints = nothing, guesses = nothing,
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```
Create a boundary value problem from the [`ODESystem`](@ref).
`u0map` is used to specify fixed initial values for the states. Every variable
must have either an initial guess supplied using `guesses` or a fixed initial
value specified using `u0map`.
Boundary value conditions are supplied to ODESystems
in the form of a ConstraintsSystem. These equations
should specify values that state variables should
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
specified as one of the equations used to build the `ODESystem`.
If an ODESystem without `constraints` is specified, it will be treated as an initial value problem.
```julia
@parameters g t_c = 0.5
@variables x(..) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x(t))) ~ λ * x(t)
D(D(y)) ~ λ * y - g
x(t)^2 + y^2 ~ 1]
cstr = [x(0.5) ~ 1]
@named cstrs = ConstraintsSystem(cstr, t)
@mtkbuild pend = ODESystem(eqs, t)
tspan = (0.0, 1.5)
u0map = [x(t) => 0.6, y => 0.8]
parammap = [g => 1]
guesses = [λ => 1]
constraints = [x(0.5) ~ 1]
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
```
If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
guesses = Dict(),
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end
!isnothing(callback) && error("BVP solvers do not support callbacks.")

has_alg_eqs(sys) && error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.

sts = unknowns(sys)
ps = parameters(sys)
constraintsys = get_constraintsystem(sys)

if !isnothing(constraintsys)
(length(constraints(constraintsys)) + length(u0map) > length(sts)) &&
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
end

# ODESystems without algebraic equations should use both fixed values + guesses
# for initialization.
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan, guesses,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]

fns = generate_function_bc(sys, u0, u0_idxs, tspan)
bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
bc(sol, p, t) = bc_oop(sol, p, t)
bc(resid, u, p, t) = bc_iip(resid, u, p, t)

return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

"""
generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan)
Given an ODESystem with constraints, generate the boundary condition function to pass to boundary value problem solvers.
Expression uses the constraints and the provided initial conditions.
"""
function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
iv = get_iv(sys)
sts = unknowns(sys)
ps = parameters(sys)
np = length(ps)
ns = length(sts)
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])

@variables sol(..)[1:ns]

conssys = get_constraintsystem(sys)
cons = Any[]
if !isnothing(conssys)
cons = [con.lhs - con.rhs for con in constraints(conssys)]

for st in get_unknowns(conssys)
x = operation(st)
t = only(arguments(st))
idx = stidxmap[x(iv)]

cons = map(c -> Symbolics.substitute(c, Dict(x(t) => sol(t)[idx])), cons)
end
end

init_conds = Any[]
for i in u0_idxs
expr = sol(tspan[1])[i] - u0[i]
push!(init_conds, expr)
end

exprs = vcat(init_conds, cons)
_p = reorder_parameters(sys, ps)

build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
end

"""
```julia
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
76 changes: 70 additions & 6 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
@@ -49,6 +49,8 @@ struct ODESystem <: AbstractODESystem
ctrls::Vector
"""Observed variables."""
observed::Vector{Equation}
"""System of constraints that must be satisfied by the solution to the system."""
constraintsystem::Union{Nothing, ConstraintsSystem}
"""
Time-derivative matrix. Note: this field will not be defined until
[`calculate_tgrad`](@ref) is called on the system.
@@ -186,7 +188,7 @@ struct ODESystem <: AbstractODESystem
"""
parent::Any

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
torn_matching, initializesystem, initialization_eqs, schedule,
connector_type, preface, cevents,
@@ -207,7 +209,7 @@ struct ODESystem <: AbstractODESystem
u = __get_unit_type(dvs, ps, iv)
check_units(u, deqs)
end
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, metadata,
@@ -219,6 +221,7 @@ end
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Equation[],
constraintsystem = nothing,
systems = ODESystem[],
tspan = nothing,
name = nothing,
@@ -286,16 +289,27 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
if is_dde === nothing
is_dde = _check_if_dde(deqs, iv′, systems)
end

if !isempty(systems) && !isnothing(constraintsystem)
conssystems = ConstraintsSystem[]
for sys in systems
cons = get_constraintsystem(sys)
cons !== nothing && push!(conssystems, cons)
end
@show conssystems
@set! constraintsystem.systems = conssystems
end

ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, description, systems,
defaults, guesses, nothing, initializesystem,
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
disc_callbacks, parameter_dependencies,
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

function ODESystem(eqs, iv; kwargs...)
function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -327,8 +341,22 @@ function ODESystem(eqs, iv; kwargs...)
end
algevars = setdiff(allunknowns, diffvars)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
collect(new_ps); kwargs...)
consvars = OrderedSet()
constraintsystem = nothing
if !isempty(constraints)
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
for st in get_unknowns(constraintsystem)
iscall(st) ?
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
!in(st, allunknowns) && push!(consvars, st)
end
for p in parameters(constraintsystem)
!in(p, new_ps) && push!(new_ps, p)
end
end

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
collect(new_ps); constraintsystem, kwargs...)
end

# NOTE: equality does not check cached Jacobian
@@ -649,3 +677,39 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,

return nothing
end

# Validate that all the variables in the BVP constraints are well-formed states or parameters.
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
function process_constraint_system(constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
isempty(constraints) && return nothing

constraintsts = OrderedSet()
constraintps = OrderedSet()

for cons in constraints
collect_vars!(constraintsts, constraintps, cons, iv)
end

# Validate the states.
for var in constraintsts
if !iscall(var)
occursin(iv, var) && (var sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
elseif length(arguments(var)) > 1
throw(ArgumentError("Too many arguments for variable $var."))
elseif length(arguments(var)) == 1
arg = only(arguments(var))
operation(var)(iv) sts ||
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))

isequal(arg, iv) || isparameter(arg) || arg isa Integer || arg isa AbstractFloat ||
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))

isparameter(arg) && push!(constraintps, arg)
else
var sts && @warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
end
end

ConstraintsSystem(constraints, collect(constraintsts), collect(constraintps); name = consname)
end
2 changes: 1 addition & 1 deletion src/systems/optimization/constraints_system.jl
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ function ConstraintsSystem(constraints, unknowns, ps;
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))

cstr = value.(Symbolics.canonical_form.(scalarize(constraints)))
cstr = value.(Symbolics.canonical_form.(vcat(scalarize(constraints)...)))
unknowns′ = value.(scalarize(unknowns))
ps′ = value.(ps)

265 changes: 265 additions & 0 deletions test/bvproblem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
### TODO: update when BoundaryValueDiffEqAscher is updated to use the normal boundary condition conventions

using OrdinaryDiffEq
using BoundaryValueDiffEqMIRK, BoundaryValueDiffEqAscher
using BenchmarkTools
using ModelingToolkit
using SciMLBase
using ModelingToolkit: t_nounits as t, D_nounits as D

### Test Collocation solvers on simple problems
solvers = [MIRK4]
daesolvers = [Ascher2, Ascher4, Ascher6]

let
@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
@variables x(t)=1.0 y(t)=2.0

eqs = [D(x) ~ α * x - β * x * y,
D(y) ~ -γ * y + δ * x * y]

u0map = [x => 1.0, y => 2.0]
parammap ==> 7.5, β => 4, γ => 8.0, δ => 5.0]
tspan = (0.0, 10.0)

@mtkbuild lotkavolterra = ODESystem(eqs, t)
op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
osol = solve(op, Vern9())

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)

for solver in solvers
sol = solve(bvp, solver(), dt = 0.01)
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1.0, 2.0]
end

# Test out of place
bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap)

for solver in solvers
sol = solve(bvp2, solver(), dt = 0.01)
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1.0, 2.0]
end
end

### Testing on pendulum
let
@parameters g=9.81 L=1.0
@variables θ(t) = π / 2 θ_t(t)

eqs = [D(θ) ~ θ_t
D(θ_t) ~ -(g / L) * sin(θ)]

@mtkbuild pend = ODESystem(eqs, t)

u0map ==> π / 2, θ_t => π / 2]
parammap = [:L => 1.0, :g => 9.81]
tspan = (0.0, 6.0)

op = ODEProblem(pend, u0map, tspan, parammap)
osol = solve(op, Vern9())

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap)
for solver in solvers
sol = solve(bvp, solver(), dt = 0.01)
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] ==/ 2, π / 2]
end

# Test out-of-place
bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)

for solver in solvers
sol = solve(bvp2, solver(), dt = 0.01)
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] ==/ 2, π / 2]
end
end

##################################################################
### ODESystem with constraint equations, DAEs with constraints ###
##################################################################

# Test generation of boundary condition function using `generate_function_bc`. Compare solutions to manually written boundary conditions
let
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
@variables x(..) y(..)
eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]

tspan = (0., 1.)
@mtkbuild lksys = ODESystem(eqs, t)

function lotkavolterra!(du, u, p, t)
du[1] = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = -p[4]*u[2] + p[3]*u[1]*u[2]
end

function lotkavolterra(u, p, t)
[p[1]*u[1] - p[2]*u[1]*u[2], -p[4]*u[2] + p[3]*u[1]*u[2]]
end

# Test with a constraint.
constr = [y(0.5) ~ 2.]
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)

function bc!(resid, u, p, t)
resid[1] = u(0.0)[1] - 1.
resid[2] = u(0.5)[2] - 2.
end
function bc(u, p, t)
[u(0.0)[1] - 1., u(0.5)[2] - 2.]
end

u0 = [1., 1.]
tspan = (0., 1.)
p = [1.5, 1., 1., 3.]
bvpi1 = SciMLBase.BVProblem(lotkavolterra!, bc!, u0, tspan, p)
bvpi2 = SciMLBase.BVProblem(lotkavolterra, bc, u0, tspan, p)
bvpi3 = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])
bvpi4 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, [x(t) => 1.], tspan; guesses = [y(t) => 1.])

sol1 = @btime solve($bvpi1, MIRK4(), dt = 0.01)
sol2 = @btime solve($bvpi2, MIRK4(), dt = 0.01)
sol3 = @btime solve($bvpi3, MIRK4(), dt = 0.01)
sol4 = @btime solve($bvpi4, MIRK4(), dt = 0.01)
@test sol1 sol2 sol3 sol4 # don't get true equality here, not sure why
end

function test_solvers(solvers, prob, u0map, constraints, equations = []; dt = 0.05, atol = 1e-2)
for solver in solvers
println("Solver: $solver")
sol = @btime solve($prob, $solver(), dt = $dt, abstol = $atol)
@test SciMLBase.successful_retcode(sol.retcode)
p = prob.p; t = sol.t; bc = prob.f.bc
ns = length(prob.u0)
if isinplace(prob.f)
resid = zeros(ns)
bc(resid, sol, p, t)
@test isapprox(zeros(ns), resid; atol)
@show resid
else
@test isapprox(zeros(ns), bc(sol, p, t); atol)
@show bc(sol, p, t)
end

for (k, v) in u0map
@test sol[k][1] == v
end

# for cons in constraints
# @test sol[cons.rhs - cons.lhs] ≈ 0
# end

for eq in equations
@test sol[eq] 0
end
end
end

# Simple ODESystem with BVP constraints.
let
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
@variables x(..) y(..)

eqs = [D(x(t)) ~ α * x(t) - β * x(t) * y(t),
D(y(t)) ~ -γ * y(t) + δ * x(t) * y(t)]

u0map = []
tspan = (0.0, 1.0)
guess = [x(t) => 4.0, y(t) => 2.0]
constr = [x(.6) ~ 3.5, x(.3) ~ 7.]
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
test_solvers(solvers, bvp, u0map, constr; dt = 0.05)

# Testing that more complicated constraints give correct solutions.
constr = [y(.2) + x(.8) ~ 3., y(.3) ~ 2.]
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
bvp = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(lksys, u0map, tspan; guesses = guess)
test_solvers(solvers, bvp, u0map, constr; dt = 0.05)

constr =* β - x(.6) ~ 0.0, y(.2) ~ 3.]
@mtkbuild lksys = ODESystem(eqs, t; constraints = constr)
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lksys, u0map, tspan; guesses = guess)
test_solvers(solvers, bvp, u0map, constr)
end

# Cartesian pendulum from the docs.
# DAE IVP solved using BoundaryValueDiffEq solvers.
# let
# @parameters g
Comment on lines +191 to +194
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with these ones?

Copy link
Member Author

@vyudu vyudu Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface for specifying boundary conditions is different for the BVDAE solvers at the moment (you specify the time points as a separate argument), was gonna wait until they get updated to be the same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On it, Ascher methods need lots of refactorizations

# @variables x(t) y(t) [state_priority = 10] λ(t)
# eqs = [D(D(x)) ~ λ * x
# D(D(y)) ~ λ * y - g
# x^2 + y^2 ~ 1]
# @mtkbuild pend = ODESystem(eqs, t)
#
# tspan = (0.0, 1.5)
# u0map = [x => 1, y => 0]
# pmap = [g => 1]
# guess = [λ => 1]
#
# prob = ODEProblem(pend, u0map, tspan, pmap; guesses = guess)
# osol = solve(prob, Rodas5P())
#
# zeta = [0., 0., 0., 0., 0.]
# bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses = guess)
#
# for solver in solvers
# sol = solve(bvp, solver(zeta), dt = 0.001)
# @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
# conditions = getfield.(equations(pend)[3:end], :rhs)
# @test isapprox([sol[conditions][1]; sol[x][1] - 1; sol[y][1]], zeros(5), atol = 0.001)
# end
#
# bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
# for solver in solvers
# sol = solve(bvp, solver(zeta), dt = 0.01)
# @test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
# conditions = getfield.(equations(pend)[3:end], :rhs)
# @test [sol[conditions][1]; sol[x][1] - 1; sol[y][1]] ≈ 0
# end
# end

# Adding a midpoint boundary constraint.
# Solve using BVDAE solvers.
# let
# @parameters g
# @variables x(..) y(t) [state_priority = 10] λ(t)
# eqs = [D(D(x(t))) ~ λ * x(t)
# D(D(y)) ~ λ * y - g
# x(t)^2 + y^2 ~ 1]
# constr = [x(0.5) ~ 1]
# @mtkbuild pend = ODESystem(eqs, t; constr)
#
# tspan = (0.0, 1.5)
# u0map = [x(t) => 0.6, y => 0.8]
# parammap = [g => 1]
# guesses = [λ => 1]
#
# bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
# test_solvers(daesolvers, bvp, u0map, constr)
#
# bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)
# test_solvers(daesolvers, bvp2, u0map, constr, get_alg_eqs(pend))
#
# # More complicated constr.
# u0map = [x(t) => 0.6]
# guesses = [λ => 1, y(t) => 0.8]
#
# constr = [x(0.5) ~ 1,
# x(0.3)^3 + y(0.6)^2 ~ 0.5]
# @mtkbuild pend = ODESystem(eqs, t; constr)
# bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
# test_solvers(daesolvers, bvp, u0map, constr, get_alg_eqs(pend))
#
# constr = [x(0.4) * g ~ y(0.2),
# y(0.7) ~ 0.3]
# @mtkbuild pend = ODESystem(eqs, t; constr)
# bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
# test_solvers(daesolvers, bvp, u0map, constr, get_alg_eqs(pend))
# end
47 changes: 47 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
@@ -1636,3 +1636,50 @@ end
prob = ODEProblem{false}(lowered_dae_sys; u0_constructor = x -> SVector(x...))
@test prob.u0 isa SVector
end

@testset "Constraint system construction" begin
@variables x(..) y(..) z(..)
@parameters a b c d e
eqs = [D(x(t)) ~ 3*a*y(t), D(y(t)) ~ x(t) - z(t), D(z(t)) ~ e*x(t)^2]
cons = [x(0.3) ~ c*d, y(0.7) ~ 3]

# Test variables + parameters infer correctly.
@mtkbuild sys = ODESystem(eqs, t; constraints = cons)
@test issetequal(parameters(sys), [a, c, d, e])
@test issetequal(unknowns(sys), [x(t), y(t), z(t)])

@parameters t_c
cons = [x(t_c) ~ 3]
@mtkbuild sys = ODESystem(eqs, t; constraints = cons)
@test issetequal(parameters(sys), [a, e, t_c])

@parameters g(..) h i
cons = [g(h, i) * x(3) ~ c]
@mtkbuild sys = ODESystem(eqs, t; constraints = cons)
@test issetequal(parameters(sys), [g, h, i, a, e, c])

# Test that bad constraints throw errors.
cons = [x(3, 4) ~ 3] # unknowns cannot have multiple args.
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)

cons = [x(y(t)) ~ 2] # unknown arg must be parameter, value, or t
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)

@variables u(t) v
cons = [x(t) * u ~ 3]
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons)
cons = [x(t) * v ~ 3]
@test_throws ArgumentError @mtkbuild sys = ODESystem(eqs, t; constraints = cons) # Need time argument.

# Test array variables
@variables x(..)[1:5]
mat = [1 2 0 3 2
0 0 3 2 0
0 1 3 0 4
2 0 0 2 1
0 0 2 0 5]
eqs = D(x(t)) ~ mat * x(t)
cons = [x(3) ~ [2,3,3,5,4]]
@mtkbuild ode = ODESystem(D(x(t)) ~ mat * x(t), t; constraints = cons)
@test length(constraints(ModelingToolkit.get_constraintsystem(ode))) == 5
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ end
@safetestset "SCCNonlinearProblem Test" include("scc_nonlinear_problem.jl")
@safetestset "PDE Construction Test" include("pde.jl")
@safetestset "JumpSystem Test" include("jumpsystem.jl")
@safetestset "BVProblem Test" include("bvproblem.jl")
@safetestset "print_tree" include("print_tree.jl")
@safetestset "Constraints Test" include("constraints.jl")
@safetestset "IfLifting Test" include("if_lifting.jl")