Skip to content

Add support for an external synchronous compiler to discrete and hybrid systems #3399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,14 @@ function input_timedomain(x)
throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression"))
end
end

function ZeroCrossing(expr; name = gensym(), up = true, down = true, kwargs...)
return SymbolicContinuousCallback(
[expr ~ 0], up ? ImperativeAffect(Returns(nothing)) : nothing;
affect_neg = down ? ImperativeAffect(Returns(nothing)) : nothing,
kwargs..., zero_crossing_id = name)
end

function SciMLBase.Clocks.EventClock(cb::SymbolicContinuousCallback)
return SciMLBase.Clocks.EventClock(cb.zero_crossing_id)
end
1 change: 1 addition & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ function namespace_expr(
O
end
end

_nonum(@nospecialize x) = x isa Num ? x.val : x

"""
Expand Down
26 changes: 17 additions & 9 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
finalize::Union{Affect, Nothing}
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
reinitializealg::SciMLBase.DAEInitializationAlgorithm
zero_crossing_id::Symbol

function SymbolicContinuousCallback(
conditions::Union{Equation, Vector{Equation}},
Expand All @@ -174,6 +175,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
finalize = nothing,
rootfind = SciMLBase.LeftRootFind,
reinitializealg = nothing,
zero_crossing_id = gensym(),
kwargs...)
conditions = (conditions isa AbstractVector) ? conditions : [conditions]

Expand All @@ -190,7 +192,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
make_affect(affect_neg; kwargs...),
make_affect(initialize; kwargs...), make_affect(
finalize; kwargs...),
rootfind, reinitializealg)
rootfind, reinitializealg, zero_crossing_id)
end # Default affect to nothing
end

Expand Down Expand Up @@ -466,7 +468,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo
affect_neg = namespace_affects(affect_negs(cb), s),
initialize = namespace_affects(initialize_affects(cb), s),
finalize = namespace_affects(finalize_affects(cb), s),
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg)
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg,
zero_crossing_id = cb.zero_crossing_id)
end

function namespace_conditions(condition, s)
Expand All @@ -490,6 +493,8 @@ function Base.hash(cb::AbstractCallback, s::UInt)
s = hash(finalize_affects(cb), s)
!is_discrete(cb) && (s = hash(cb.rootfind, s))
hash(cb.reinitializealg, s)
!is_discrete(cb) && (s = hash(cb.zero_crossing_id, s))
return s
end

###########################
Expand Down Expand Up @@ -524,13 +529,16 @@ function finalize_affects(cbs::Vector{<:AbstractCallback})
end

function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback)
(is_discrete(e1) === is_discrete(e2)) || return false
(isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) &&
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) &&
isequal(e1.reinitializealg, e2.reinitializealg) ||
return false
is_discrete(e1) ||
(isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind))
is_discrete(e1) === is_discrete(e2) || return false
isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) || return false
isequal(e1.initialize, e2.initialize) || return false
isequal(e1.finalize, e2.finalize) || return false
isequal(e1.reinitializealg, e2.reinitializealg) || return false
if !is_discrete(e1)
isequal(e1.affect_neg, e2.affect_neg) || return false
isequal(e1.rootfind, e2.rootfind) || return false
isequal(e1.zero_crossing_id, e2.zero_crossing_id) || return false
end
end

Base.isempty(cb::AbstractCallback) = isempty(cb.conditions)
Expand Down
15 changes: 12 additions & 3 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function infer_clocks!(ci::ClockInference)
c = BitSet(c′)
idxs = intersect(c, inferred)
isempty(idxs) && continue
if !allequal(var_domain[i] for i in idxs)
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
display(fullvars[c′])
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
end
Expand Down Expand Up @@ -155,6 +155,9 @@ function split_system(ci::ClockInference{S}) where {S}
cid_to_var = Vector{Int}[]
# cid_counter = number of clocks
cid_counter = Ref(0)

# populates clock_to_id and id_to_clock
# checks if there is a continuous_id (for some reason? clock to id does this too)
for (i, d) in enumerate(eq_domain)
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
continuous_id = continuous_id
Expand All @@ -174,9 +177,13 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_eq, i, cid)
end
continuous_id = continuous_id[]
# for each clock partition what are the input (indexes/vars)
input_idxs = map(_ -> Int[], 1:cid_counter[])
inputs = map(_ -> Any[], 1:cid_counter[])
# var_domain corresponds to fullvars/all variables in the system
nvv = length(var_domain)
# put variables into the right clock partition
# keep track of inputs to each partition
for i in 1:nvv
d = var_domain[i]
cid = get(clock_to_id, d, 0)
Expand All @@ -190,15 +197,17 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_var, i, cid)
end

# breaks the system up into a continous and 0 or more discrete systems
tss = similar(cid_to_eq, S)
for (id, ieqs) in enumerate(cid_to_eq)
ts_i = system_subset(ts, ieqs)
for (id, (ieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_var))
ts_i = system_subset(ts, ieqs, ivars)
if id != continuous_id
ts_i = shift_discrete_system(ts_i)
@set! ts_i.structure.only_discrete = true
end
tss[id] = ts_i
end
# put the continous system at the back
if continuous_id != 0
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]
Expand Down
4 changes: 3 additions & 1 deletion src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ function compile_functional_affect(
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)

# write the new values back to the integrator
_generated_writeback(integ, upd_funs, upd_vals)
if !isnothing(upd_vals)
_generated_writeback(integ, upd_funs, upd_vals)
end

reset_jumps && reset_aggregated_jumps!(integ)
end
Expand Down
4 changes: 4 additions & 0 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,8 @@ function get_p_constructor(p_constructor, pType::Type, floatT::Type)
end
end

abstract type ProblemConstructionHook end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -1309,6 +1311,8 @@ function process_SciMLProblem(

check_inputmap_keys(sys, op)

op = getmetadata(sys, ProblemConstructionHook, identity)(op)

defs = add_toterms(recursive_unwrap(defaults(sys)); replace = is_discrete_system(sys))
kwargs = NamedTuple(kwargs)

Expand Down
33 changes: 33 additions & 0 deletions src/systems/state_machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,36 @@ entry

When used in a finite state machine, this operator returns `true` if the queried state is active and false otherwise.
""" activeState

function vars!(vars, O::Transition; op = Differential)
vars!(vars, O.from)
vars!(vars, O.to)
vars!(vars, O.cond; op)
return vars
end
function vars!(vars, O::InitialState; op = Differential)
vars!(vars, O.s; op)
return vars
end
function vars!(vars, O::StateMachineOperator; op = Differential)
error("Unhandled state machine operator")
end

function namespace_expr(
O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys))
return Transition(
O.from === nothing ? O.from : renamespace(sys, O.from),
O.to === nothing ? O.to : renamespace(sys, O.to),
O.cond === nothing ? O.cond : namespace_expr(O.cond, sys),
O.immediate, O.reset, O.synchronize, O.priority
)
end

function namespace_expr(
O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys))
return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s))
end

function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...)
error("Unhandled state machine operator")
end
15 changes: 11 additions & 4 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function mtkcompile(
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
newsys′ = __mtkcompile(sys; simplify,
allow_symbolic, allow_parameter, conservative, fully_determined,
inputs, outputs, disturbance_inputs,
inputs, outputs, disturbance_inputs, additional_passes,
kwargs...)
if newsys′ isa Tuple
@assert length(newsys′) == 2
Expand Down Expand Up @@ -75,12 +75,13 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify)
end

sys, statemachines = extract_top_level_statemachines(sys)
sys = expand_connections(sys)
state = TearingState(sys; sort_eqs)
state = TearingState(sys)
append!(state.statemachines, statemachines)

@unpack structure, fullvars = state
@unpack graph, var_to_diff, var_types = structure
eqs = equations(state)
brown_vars = Int[]
new_idxs = zeros(Int, length(var_types))
idx = 0
Expand All @@ -98,7 +99,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
Is = Int[]
Js = Int[]
vals = Num[]
new_eqs = copy(eqs)
make_eqs_zero_equals!(state)
new_eqs = copy(equations(state))
dvar2eq = Dict{Any, Int}()
for (v, dv) in enumerate(var_to_diff)
dv === nothing && continue
Expand Down Expand Up @@ -291,3 +293,8 @@ function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivative

return mapping
end

"""
Mark whether an extra pass `p` can support compiling discrete systems.
"""
discrete_compile_pass(p) = false
Loading