Skip to content

Commit 0bc198f

Browse files
feat: add respecialize
1 parent 45afbc8 commit 0bc198f

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

docs/src/API/model_building.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ add_accumulations
227227
noise_to_brownians
228228
convert_system_indepvar
229229
subset_tunables
230+
respecialize
230231
```
231232

232233
## Hybrid systems

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
268268
hasmisc, getmisc, state_priority,
269269
subset_tunables
270270
export liouville_transform, change_independent_variable, substitute_component,
271-
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables
271+
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables,
272+
respecialize
272273
export PDESystem
273274
export Differential, expand_derivatives, @derivatives
274275
export Equation, ConstrainedEquation

src/systems/diffeqs/basic_transformations.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,122 @@ function convert_system_indepvar(sys::System, t; name = nameof(sys))
706706
@set! sys.var_to_name = var_to_name
707707
return sys
708708
end
709+
710+
"""
711+
$(TYPEDSIGNATURES)
712+
713+
Shorthand for `respecialize(sys, []; all = true)`
714+
"""
715+
respecialize(sys::AbstractSystem) = respecialize(sys, []; all = true)
716+
717+
"""
718+
$(TYPEDSIGNATURES)
719+
720+
Specialize nonnumeric parameters in `sys` by changing their symtype to a concrete type.
721+
`mapping` is an iterable, where each element can be a parameter or a pair mapping a parameter
722+
to a value. If the element is a parameter, it must have a default. Each specified parameter
723+
is updated to have the symtype of the value associated with it (either in `mapping` or in
724+
the defaults). This operation can only be performed on nonnumeric, non-array parameters. The
725+
defaults of respecialized parameters are set to the associated values.
726+
727+
This operation can only be performed on `complete`d systems.
728+
729+
# Keyword arguments
730+
731+
- `all`: Specialize all nonnumeric parameters in the system. This will error if any such
732+
parameter does not have a default.
733+
"""
734+
function respecialize(sys::AbstractSystem, mapping; all = false)
735+
if !iscomplete(sys)
736+
error("""
737+
This operation can only be performed on completed systems. Use `complete(sys)` or
738+
`mtkcompile(sys)`.
739+
""")
740+
end
741+
if !is_split(sys)
742+
error("""
743+
This operation can only be performed on split systems. Use `complete(sys)` or
744+
`mtkcompile(sys)` with the `split = true` keyword argument.
745+
""")
746+
end
747+
748+
new_ps = copy(get_ps(sys))
749+
@set! sys.ps = new_ps
750+
751+
extras = []
752+
if all
753+
for x in filter(!is_variable_numeric, get_ps(sys))
754+
if any(y -> isequal(x, y) || y isa Pair && isequal(x, y[1]), mapping) ||
755+
symbolic_type(x) === ArraySymbolic() ||
756+
iscall(x) && operation(x) === getindex
757+
continue
758+
end
759+
push!(extras, x)
760+
end
761+
end
762+
763+
defs = copy(defaults(sys))
764+
@set! sys.defaults = defs
765+
766+
subrules = Dict()
767+
768+
for element in Iterators.flatten((extras, mapping))
769+
if element isa Pair
770+
k, v = element
771+
else
772+
k = element
773+
v = get(defs, k, nothing)
774+
@assert v !== nothing """
775+
Parameter $k needs an associated value to be respecialized.
776+
"""
777+
end
778+
779+
k = unwrap(k)
780+
T = typeof(v)
781+
782+
@assert !is_variable_numeric(k) """
783+
Numeric types cannot be respecialized - tried to respecialize $k.
784+
"""
785+
@assert symbolic_type(k) !== ArraySymbolic() """
786+
Cannot respecialize array symbolics - tried to respecialize $k.
787+
"""
788+
@assert !iscall(k) || operation(k) !== getindex """
789+
Cannot respecialized scalarized array variables - tried to respecialize $k.
790+
"""
791+
idx = findfirst(isequal(k), get_ps(sys))
792+
@assert idx !== nothing """
793+
Parameter $k does not exist in the system.
794+
"""
795+
796+
if iscall(k)
797+
op = operation(k)
798+
args = arguments(k)
799+
new_p = SymbolicUtils.term(op, args...; type = T)
800+
else
801+
new_p = SymbolicUtils.Sym{T}(getname(k))
802+
end
803+
804+
get_ps(sys)[idx] = new_p
805+
defaults(sys)[new_p] = v
806+
subrules[unwrap(k)] = unwrap(new_p)
807+
end
808+
809+
substituter = Base.Fix2(fast_substitute, subrules)
810+
@set! sys.eqs = map(substituter, get_eqs(sys))
811+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
812+
@set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)])
813+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
814+
@set! sys.continuous_events = map(get_continuous_events(sys)) do cev
815+
SymbolicContinuousCallback(
816+
map(substituter, cev.conditions), substituter(cev.affect),
817+
substituter(cev.affect_neg), substituter(cev.initialize),
818+
substituter(cev.finalize), cev.rootfind,
819+
cev.reinitializealg, cev.zero_crossing_id)
820+
end
821+
@set! sys.discrete_events = map(get_discrete_events(sys)) do dev
822+
SymbolicDiscreteCallback(map(substituter, dev.conditions), substituter(dev.affect),
823+
substituter(dev.initialize), substituter(dev.finalize), dev.reinitializealg)
824+
end
825+
sys = complete(sys; split = is_split(sys))
826+
return sys
827+
end

0 commit comments

Comments
 (0)