Skip to content

Commit e93ae20

Browse files
Merge pull request #3900 from AayushSabharwal/as/respecialize
feat: add `respecialize`
2 parents 2d4a674 + 9ed9aa7 commit e93ae20

File tree

9 files changed

+257
-2
lines changed

9 files changed

+257
-2
lines changed

docs/src/API/model_building.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ add_accumulations
227227
noise_to_brownians
228228
convert_system_indepvar
229229
subset_tunables
230+
respecialize
230231
```
231232

232233
## Hybrid systems

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
268268
hasmisc, getmisc, state_priority,
269269
subset_tunables
270270
export liouville_transform, change_independent_variable, substitute_component,
271-
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables
271+
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables,
272+
respecialize
272273
export PDESystem
273274
export Differential, expand_derivatives, @derivatives
274275
export Equation, ConstrainedEquation

src/systems/callbacks.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...)
2525
end
2626
SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...)
2727

28+
function Symbolics.fast_substitute(aff::SymbolicAffect, rules)
29+
substituter = Base.Fix2(fast_substitute, rules)
30+
SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs),
31+
map(substituter, aff.discrete_parameters))
32+
end
33+
2834
struct AffectSystem
2935
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
3036
system::AbstractSystem
@@ -36,6 +42,19 @@ struct AffectSystem
3642
discretes::Vector
3743
end
3844

45+
function Symbolics.fast_substitute(aff::AffectSystem, rules)
46+
substituter = Base.Fix2(fast_substitute, rules)
47+
sys = aff.system
48+
@set! sys.eqs = map(substituter, get_eqs(sys))
49+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
50+
@set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)])
51+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
52+
@set! sys.unknowns = map(substituter, get_unknowns(sys))
53+
@set! sys.ps = map(substituter, get_ps(sys))
54+
AffectSystem(sys, map(substituter, aff.unknowns),
55+
map(substituter, aff.parameters), map(substituter, aff.discretes))
56+
end
57+
3958
function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
4059
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
4160
discrete_parameters = spec.discrete_parameters, kwargs...)

src/systems/diffeqs/basic_transformations.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,143 @@ function convert_system_indepvar(sys::System, t; name = nameof(sys))
706706
@set! sys.var_to_name = var_to_name
707707
return sys
708708
end
709+
710+
"""
711+
$(TYPEDSIGNATURES)
712+
713+
Shorthand for `respecialize(sys, []; all = true)`
714+
"""
715+
respecialize(sys::AbstractSystem) = respecialize(sys, []; all = true)
716+
717+
"""
718+
$(TYPEDSIGNATURES)
719+
720+
Specialize nonnumeric parameters in `sys` by changing their symtype to a concrete type.
721+
`mapping` is an iterable, where each element can be a parameter or a pair mapping a parameter
722+
to a value. If the element is a parameter, it must have a default. Each specified parameter
723+
is updated to have the symtype of the value associated with it (either in `mapping` or in
724+
the defaults). This operation can only be performed on nonnumeric, non-array parameters. The
725+
defaults of respecialized parameters are set to the associated values.
726+
727+
This operation can only be performed on `complete`d systems.
728+
729+
# Keyword arguments
730+
731+
- `all`: Specialize all nonnumeric parameters in the system. This will error if any such
732+
parameter does not have a default.
733+
"""
734+
function respecialize(sys::AbstractSystem, mapping; all = false)
735+
if !iscomplete(sys)
736+
error("""
737+
This operation can only be performed on completed systems. Use `complete(sys)` or
738+
`mtkcompile(sys)`.
739+
""")
740+
end
741+
if !is_split(sys)
742+
error("""
743+
This operation can only be performed on split systems. Use `complete(sys)` or
744+
`mtkcompile(sys)` with the `split = true` keyword argument.
745+
""")
746+
end
747+
748+
new_ps = copy(get_ps(sys))
749+
@set! sys.ps = new_ps
750+
751+
extras = []
752+
if all
753+
for x in filter(!is_variable_numeric, get_ps(sys))
754+
if any(y -> isequal(x, y) || y isa Pair && isequal(x, y[1]), mapping) ||
755+
symbolic_type(x) === ArraySymbolic() ||
756+
iscall(x) && operation(x) === getindex
757+
continue
758+
end
759+
push!(extras, x)
760+
end
761+
end
762+
ps_to_specialize = Iterators.flatten((extras, mapping))
763+
764+
defs = copy(defaults(sys))
765+
@set! sys.defaults = defs
766+
final_defs = copy(defs)
767+
evaluate_varmap!(final_defs, ps_to_specialize)
768+
769+
subrules = Dict()
770+
771+
for element in ps_to_specialize
772+
if element isa Pair
773+
k, v = element
774+
else
775+
k = element
776+
v = get(final_defs, k, nothing)
777+
@assert v !== nothing """
778+
Parameter $k needs an associated value to be respecialized.
779+
"""
780+
@assert symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) """
781+
Parameter $k needs an associated value to be respecialized. Found symbolic \
782+
default $v.
783+
"""
784+
end
785+
786+
k = unwrap(k)
787+
T = typeof(v)
788+
789+
@assert !is_variable_numeric(k) """
790+
Numeric types cannot be respecialized - tried to respecialize $k.
791+
"""
792+
@assert symbolic_type(k) !== ArraySymbolic() """
793+
Cannot respecialize array symbolics - tried to respecialize $k.
794+
"""
795+
@assert !iscall(k) || operation(k) !== getindex """
796+
Cannot respecialized scalarized array variables - tried to respecialize $k.
797+
"""
798+
idx = findfirst(isequal(k), get_ps(sys))
799+
@assert idx !== nothing """
800+
Parameter $k does not exist in the system.
801+
"""
802+
803+
if iscall(k)
804+
op = operation(k)
805+
args = arguments(k)
806+
new_p = SymbolicUtils.term(op, args...; type = T)
807+
else
808+
new_p = SymbolicUtils.Sym{T}(getname(k))
809+
end
810+
811+
get_ps(sys)[idx] = new_p
812+
defaults(sys)[new_p] = v
813+
subrules[unwrap(k)] = unwrap(new_p)
814+
end
815+
816+
substituter = Base.Fix2(fast_substitute, subrules)
817+
@set! sys.eqs = map(substituter, get_eqs(sys))
818+
@set! sys.observed = map(substituter, get_observed(sys))
819+
@set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys))
820+
if get_noise_eqs(sys) !== nothing
821+
@set! sys.noise_eqs = map(substituter, get_noise_eqs(sys))
822+
end
823+
@set! sys.assertions = Dict([substituter(k) => v for (k, v) in assertions(sys)])
824+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
825+
@set! sys.defaults = Dict([substituter(k) => substituter(v) for (k, v) in defaults(sys)])
826+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
827+
@set! sys.continuous_events = map(get_continuous_events(sys)) do cev
828+
SymbolicContinuousCallback(
829+
map(substituter, cev.conditions), substituter(cev.affect),
830+
substituter(cev.affect_neg), substituter(cev.initialize),
831+
substituter(cev.finalize), cev.rootfind,
832+
cev.reinitializealg, cev.zero_crossing_id)
833+
end
834+
@set! sys.discrete_events = map(get_discrete_events(sys)) do dev
835+
SymbolicDiscreteCallback(map(substituter, dev.conditions), substituter(dev.affect),
836+
substituter(dev.initialize), substituter(dev.finalize), dev.reinitializealg)
837+
end
838+
if get_schedule(sys) !== nothing
839+
sched = get_schedule(sys)
840+
@set! sys.schedule = Schedule(
841+
sched.var_sccs, AnyDict(k => substituter(v) for (k, v) in sched.dummy_sub))
842+
end
843+
@set! sys.constraints = map(substituter, get_constraints(sys))
844+
@set! sys.tstops = map(substituter, get_tstops(sys))
845+
@set! sys.costs = Vector{Union{Real, BasicSymbolic}}(map(substituter, get_costs(sys)))
846+
sys = complete(sys; split = is_split(sys))
847+
return sys
848+
end

src/systems/imperative_affect.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ function ImperativeAffect(; f, kwargs...)
6767
ImperativeAffect(f; kwargs...)
6868
end
6969

70+
function Symbolics.fast_substitute(aff::ImperativeAffect, rules)
71+
substituter = Base.Fix2(fast_substitute, rules)
72+
ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms,
73+
map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks)
74+
end
75+
7076
function Base.show(io::IO, mfa::ImperativeAffect)
7177
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
7278
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")

src/systems/problem_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,7 @@ with a constant value.
12191219
"""
12201220
function float_type_from_varmap(varmap, floatT = Bool)
12211221
for (k, v) in varmap
1222+
is_variable_floatingpoint(k) || continue
12221223
symbolic_type(v) == NotSymbolic() || continue
12231224
is_array_of_symbolics(v) && continue
12241225

src/utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,27 @@ function is_floatingpoint_symtype(T::Type)
838838
T <: AbstractArray && is_floatingpoint_symtype(eltype(T))
839839
end
840840

841+
"""
842+
$(TYPEDSIGNATURES)
843+
844+
Check if `sym` represents a symbolic number or array of numbers.
845+
"""
846+
function is_variable_numeric(sym)
847+
sym = unwrap(sym)
848+
T = symtype(sym)
849+
is_numeric_symtype(T)
850+
end
851+
852+
"""
853+
$(TYPEDSIGNATURES)
854+
855+
Check if `T` is an appropriate symtype for a symbolic variable representing a number or
856+
array of numbers.
857+
"""
858+
function is_numeric_symtype(T::Type)
859+
return T <: Number || T <: AbstractArray && is_numeric_symtype(eltype(T))
860+
end
861+
841862
"""
842863
$(TYPEDSIGNATURES)
843864

test/basic_transformations.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ModelingToolkit, OrdinaryDiffEq, DataInterpolations, DynamicQuantities, Test
22
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
3+
using SymbolicUtils: symtype
34

45
@independent_variables t
56
D = Differential(t)
@@ -328,3 +329,68 @@ end
328329
D(x) ~ y]
329330
@test issetequal(equations(asys), eqs)
330331
end
332+
333+
abstract type AbstractFoo end
334+
335+
struct Bar <: AbstractFoo end
336+
struct Baz <: AbstractFoo end
337+
338+
foofn(x) = 4
339+
@register_symbolic foofn(x::AbstractFoo)
340+
341+
@testset "`respecialize`" begin
342+
@parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r
343+
rp,
344+
rp2 = let
345+
only(@parameters p::Bar),
346+
SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz)
347+
end
348+
@variables x(t) = 1.0
349+
@named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r])
350+
351+
@test_throws ["completed systems"] respecialize(sys1)
352+
@test_throws ["completed systems"] respecialize(sys1, [])
353+
@test_throws ["split systems"] respecialize(complete(sys1; split = false))
354+
@test_throws ["split systems"] respecialize(complete(sys1; split = false), [])
355+
356+
sys = complete(sys1)
357+
358+
@test_throws ["Parameter p", "associated value"] respecialize(sys)
359+
@test_throws ["Parameter p", "associated value"] respecialize(sys, [p])
360+
361+
@test_throws ["Parameter p2", "symbolic default"] respecialize(sys, [p2])
362+
363+
sys2 = respecialize(sys, [p => Bar()])
364+
@test ModelingToolkit.iscomplete(sys2)
365+
@test ModelingToolkit.is_split(sys2)
366+
ps = ModelingToolkit.get_ps(sys2)
367+
idx = findfirst(isequal(rp), ps)
368+
@test defaults(sys2)[rp] == Bar()
369+
@test symtype(ps[idx]) <: Bar
370+
ic = ModelingToolkit.get_index_cache(sys2)
371+
@test any(x -> x.type == Bar && x.length == 1, ic.nonnumeric_buffer_sizes)
372+
prob = ODEProblem(sys2, [p2 => Bar(), q => [Bar(), Bar()], r => 1], (0.0, 1.0))
373+
@test any(x -> x isa Vector{Bar} && length(x) == 1, prob.p.nonnumeric)
374+
375+
defaults(sys)[p2] = Baz()
376+
sys2 = respecialize(sys, [p => Bar()]; all = true)
377+
@test ModelingToolkit.iscomplete(sys2)
378+
@test ModelingToolkit.is_split(sys2)
379+
ps = ModelingToolkit.get_ps(sys2)
380+
idx = findfirst(isequal(rp2), ps)
381+
@test defaults(sys2)[rp2] == Baz()
382+
@test symtype(ps[idx]) <: Baz
383+
ic = ModelingToolkit.get_index_cache(sys2)
384+
@test any(x -> x.type == Baz && x.length == 1, ic.nonnumeric_buffer_sizes)
385+
delete!(defaults(sys), p2)
386+
prob = ODEProblem(sys2, [q => [Bar(), Bar()], r => 1], (0.0, 1.0))
387+
@test any(x -> x isa Vector{Bar} && length(x) == 1, prob.p.nonnumeric)
388+
@test any(x -> x isa Vector{Baz} && length(x) == 1, prob.p.nonnumeric)
389+
390+
@test_throws ["Numeric types cannot be respecialized"] respecialize(sys, [r => 1])
391+
@test_throws ["array symbolics"] respecialize(sys, [q => Bar[Bar(), Bar()]])
392+
@test_throws ["scalarized array"] respecialize(sys, [q[1] => Bar()])
393+
394+
@parameters foo::AbstractFoo
395+
@test_throws ["does not exist"] respecialize(sys, [foo => Bar()])
396+
end

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,7 @@ end
15681568

15691569
cmd = `$(Base.julia_cmd()) --project=$(@__DIR__) -e $code`
15701570
proc = run(cmd, stdin, stdout, stderr; wait = false)
1571-
sleep(120)
1571+
sleep(180)
15721572
@test !process_running(proc)
15731573
kill(proc, Base.SIGKILL)
15741574
end

0 commit comments

Comments
 (0)