Skip to content
87 changes: 82 additions & 5 deletions ext/MTKFMIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,21 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
FMI3CSFunctor(state_value_references, output_value_references)
end
@parameters (functor::(typeof(_functor)))(..)[1:(length(__mtk_internal_u) + length(__mtk_internal_o))] = _functor
# for co-simulation, we need to ensure the output buffer is solved for
# during initialization

diffeqs = Equation[]
for (i, x) in enumerate(collect(__mtk_internal_o))
# for co-simulation, we need to ensure the output buffer is solved for
# during initialization
push!(initialization_eqs,
x ~ functor(
wrapper, __mtk_internal_u, __mtk_internal_x, __mtk_internal_p, t)[i])
end
wrapper, (__mtk_internal_u), __mtk_internal_x, __mtk_internal_p, t)[i])

diffeqs = Equation[]
# also add equations for output derivatives
push!(diffeqs,
D(x) ~ term(
getOutputDerivative, functor, wrapper, i, 1, collect(__mtk_internal_u),
__mtk_internal_x, __mtk_internal_p, t; type = Real))
end

# use `ImperativeAffect` for instance management here
cb_observed = (; inputs = __mtk_internal_x, params = copy(params),
Expand Down Expand Up @@ -739,6 +745,15 @@ struct FMI2CSFunctor
The value references of output variables in the FMU.
"""
output_value_references::Vector{FMI.fmi2ValueReference}
"""
Simply a buffer to store the order of output derivative required from
`getRealOutputderivatives` and avoid unnecessary allocations.
"""
output_derivative_order_buffer::Vector{FMI.fmi2Integer}
end

function FMI2CSFunctor(svref, ovref)
FMI2CSFunctor(svref, ovref, FMI.fmi2Integer[1])
end

function (fn::FMI2CSFunctor)(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
Expand All @@ -764,6 +779,41 @@ end
ndims = 1
end

"""
$(TYPEDSIGNATURES)

Calculate the `order` order derivative of the `var`th output of the FMU.
"""
function getOutputDerivative(fn::FMI2CSFunctor, wrapper::FMI2InstanceWrapper, var::Int,
order::FMI.fmi2Integer, states, inputs, params, t)
states = states isa SubArray ? copy(states) : states
inputs = inputs isa SubArray ? copy(inputs) : inputs
params = params isa SubArray ? copy(params) : params
instance = get_instance_CS!(wrapper, states, inputs, params, t)
fn.output_derivative_order_buffer[] = order
return FMI.fmi2GetRealOutputDerivatives(
instance, fn.output_value_references[var], fn.output_derivative_order_buffer)
end

# @register_symbolic getOutputDerivative(fn::FMI2CSFunctor, w::FMI2InstanceWrapper, var::Int, order::FMI.fmi2Integer, states::Vector{<:Real}, inputs::Vector{<:Real}, params::Vector{<:Real}, t::Real)

# HACK-ish for allowing higher order output derivatives
# The first `D(output)` will result in a `getOutputDerivatives` expression.
# Subsequent differentiations of this expression will expand to
# `Σ_{i = 1:8} Differential(args[i])(getOutputDerivative(args...)) * D(args[i])`
# using the chain rule. `i = 1:4` are not time-dependent (or real). We define
# the derivatives for `i = 5:7` to be zero, and the derivative for `i = 8` (w.r.t `t`)
# to be the same `getOutputDerivative` call but with the order increased. This results
# in `D(output) = getOutputDerivative(fn, w, var, order + 1, states, inputs, params, t) * 1`
# which is exactly what we want.
for i in 1:7
@eval Symbolics.derivative(::typeof(getOutputDerivative), args::NTuple{8, Any}, ::Val{$i}) = 0
end
function Symbolics.derivative(::typeof(getOutputDerivative), args::NTuple{8, Any}, ::Val{8})
term(getOutputDerivative, args[1], args[2], args[3],
args[4] + 1, args[5], args[6], args[7], args[8])
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -848,6 +898,15 @@ struct FMI3CSFunctor
The value references of output variables in the FMU.
"""
output_value_references::Vector{FMI.fmi3ValueReference}
"""
Simply a buffer to store the order of output derivative required from
`getRealOutputderivatives` and avoid unnecessary allocations.
"""
output_derivative_order_buffer::Vector{FMI.fmi3Int32}
end

function FMI3CSFunctor(svref, ovref)
FMI3CSFunctor(svref, ovref, FMI.fmi3Int32[1])
end

function (fn::FMI3CSFunctor)(wrapper::FMI3InstanceWrapper, states, inputs, params, t)
Expand All @@ -871,6 +930,24 @@ end
ndims = 1
end

"""
$(TYPEDSIGNATURES)

Calculate the `order` order derivative of the `var`th output of the FMU.
"""
function getOutputDerivative(fn::FMI3CSFunctor, wrapper::FMI3InstanceWrapper, var::Int,
order::FMI.fmi3Int32, states, inputs, params, t)
states = states isa SubArray ? copy(states) : states
inputs = inputs isa SubArray ? copy(inputs) : inputs
params = params isa SubArray ? copy(params) : params
instance = get_instance_CS!(wrapper, states, inputs, params, t)
fn.output_derivative_order_buffer[] = order
return FMI.fmi3GetOutputDerivatives(
instance, fn.output_value_references[var], fn.output_derivative_order_buffer)
end

# @register_symbolic getOutputDerivative(fn::FMI3CSFunctor, w::FMI3InstanceWrapper, var::Int, order::FMI.fmi3Int32, states::Vector{<:Real}, inputs::Vector{<:Real}, params::Vector{<:Real}, t::Real) false

"""
$(TYPEDSIGNATURES)
"""
Expand Down
1 change: 1 addition & 0 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
D(eq.lhs)
end
rhs = ModelingToolkit.expand_derivatives(D(eq.rhs))
rhs = fast_substitute(rhs, state.param_derivative_map)
substitution_dict = Dict(x.lhs => x.rhs
for x in out_eqs if x !== nothing && x.lhs isa Symbolic)
sub_rhs = substitute(rhs, substitution_dict)
Expand Down
18 changes: 17 additions & 1 deletion src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,23 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int; kwargs...)

sys = ts.sys
eq = equations(ts)[ieq]
eq = 0 ~ Symbolics.derivative(eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true)
eq = 0 ~ fast_substitute(
ModelingToolkit.derivative(
eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map)

vs = ModelingToolkit.vars(eq.rhs)
for v in vs
# parameters with unknown derivatives have a value of `nothing` in the map,
# so use `missing` as the default.
get(ts.param_derivative_map, v, missing) === nothing || continue
_original_eq = equations(ts)[ieq]
error("""
Encountered derivative of discrete variable `$(only(arguments(v)))` when \
differentiating equation `$(_original_eq)`. This may indicate a model error or a \
missing equation of the form `$v ~ ...` that defines this derivative.
""")
end

push!(equations(ts), eq)
# Analyze the new equation and update the graph/solvable_graph
# First, copy the previous incidence and add the derivative terms.
Expand Down
25 changes: 24 additions & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
fullvars::Vector
structure::SystemStructure
extra_eqs::Vector
param_derivative_map::Dict{BasicSymbolic, Any}
end

TransformationState(sys::AbstractSystem) = TearingState(sys)
Expand Down Expand Up @@ -253,6 +254,12 @@ function Base.push!(ev::EquationsView, eq)
push!(ev.ts.extra_eqs, eq)
end

function is_time_dependent_parameter(p, iv)
return iv !== nothing && isparameter(p) && iscall(p) &&
(operation(p) === getindex && is_time_dependent_parameter(arguments(p)[1], iv) ||
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv))
end

function TearingState(sys; quick_cancel = false, check = true)
sys = flatten(sys)
ivs = independent_variables(sys)
Expand All @@ -264,6 +271,7 @@ function TearingState(sys; quick_cancel = false, check = true)
var2idx = Dict{Any, Int}()
symbolic_incidence = []
fullvars = []
param_derivative_map = Dict{BasicSymbolic, Any}()
var_counter = Ref(0)
var_types = VariableType[]
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
Expand All @@ -276,11 +284,17 @@ function TearingState(sys; quick_cancel = false, check = true)

vars = OrderedSet()
varsvec = []
eqs_to_retain = trues(length(eqs))
for (i, eq′) in enumerate(eqs)
if eq′.lhs isa Connection
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
return nothing
end
if iscall(eq′.lhs) && (op = operation(eq′.lhs)) isa Differential &&
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq′.lhs)), iv)
param_derivative_map[eq′.lhs] = eq′.rhs
eqs_to_retain[i] = false
end
if _iszero(eq′.lhs)
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
eq = eq′
Expand All @@ -295,6 +309,12 @@ function TearingState(sys; quick_cancel = false, check = true)
any(isequal(_var), ivs) && continue
if isparameter(_var) ||
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
if is_time_dependent_parameter(_var, iv) &&
!haskey(param_derivative_map, Differential(iv)(_var))
# default to `nothing` since it is ignored during substitution,
# so `D(_var)` is retained in the expression.
param_derivative_map[Differential(iv)(_var)] = nothing
end
continue
end
v = scalarize(v)
Expand Down Expand Up @@ -351,6 +371,9 @@ function TearingState(sys; quick_cancel = false, check = true)
eqs[i] = eqs[i].lhs ~ rhs
end
end
eqs = eqs[eqs_to_retain]
neqs = length(eqs)
symbolic_incidence = symbolic_incidence[eqs_to_retain]

### Handle discrete variables
lowest_shift = Dict()
Expand Down Expand Up @@ -438,7 +461,7 @@ function TearingState(sys; quick_cancel = false, check = true)
ts = TearingState(sys, fullvars,
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
Any[])
Any[], param_derivative_map)
if sys isa DiscreteSystem
ts = shift_discrete_system(ts)
end
Expand Down
2 changes: 1 addition & 1 deletion test/state_selection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ModelingToolkit, OrdinaryDiffEq, Test
using ModelingToolkit: t_nounits as t, D_nounits as D

sts = @variables x1(t) x2(t) x3(t) x4(t)
params = @parameters u1(t) u2(t) u3(t) u4(t)
params = @parameters u1 u2 u3 u4
eqs = [x1 + x2 + u1 ~ 0
x1 + x2 + x3 + u2 ~ 0
x1 + D(x3) + x4 + u3 ~ 0
Expand Down
84 changes: 84 additions & 0 deletions test/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SparseArrays
using UnPack
using ModelingToolkit: t_nounits as t, D_nounits as D, default_toterm
using Symbolics: unwrap
using DataInterpolations
const ST = StructuralTransformations

# Define some variables
Expand Down Expand Up @@ -282,3 +283,86 @@ end
@test length(mapping) == 3
end
end

@testset "Issue#3480: Derivatives of time-dependent parameters" begin
@component function FilteredInput(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t) = x0
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@component function FilteredInputFix(; name, x0 = 0, T = 0.1)
params = @parameters begin
k(t) = x0
T = T
end
vars = @variables begin
x(t) = k
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k - x) / T
D(k) ~ 0]
return ODESystem(eqs, t, vars, params; systems, name)
end

@named sys = FilteredInput()
@test_throws ["derivative of discrete variable", "k(t)"] structural_simplify(sys)

@mtkbuild sys = FilteredInputFix()
vs = Set()
for eq in equations(sys)
ModelingToolkit.vars!(vs, eq)
end
for eq in observed(sys)
ModelingToolkit.vars!(vs, eq)
end

@test !(D(sys.k) in vs)

@testset "Called parameter still has derivative" begin
@component function FilteredInput2(; name, x0 = 0, T = 0.1)
ts = collect(0.0:0.1:10.0)
spline = LinearInterpolation(ts .^ 2, ts)
params = @parameters begin
(k::LinearInterpolation)(..) = spline
T = T
end
vars = @variables begin
x(t) = k(t)
dx(t) = 0
ddx(t)
end
systems = []
eqs = [D(x) ~ dx
D(dx) ~ ddx
dx ~ (k(t) - x) / T]
return ODESystem(eqs, t, vars, params; systems, name)
end

@mtkbuild sys = FilteredInput2()
vs = Set()
for eq in equations(sys)
ModelingToolkit.vars!(vs, eq)
end
for eq in observed(sys)
ModelingToolkit.vars!(vs, eq)
end

@test D(sys.k(t)) in vs
end
end
Loading