Skip to content

Commit f891e4c

Browse files
committed
Permit handling observables and defaults
1 parent 55941b5 commit f891e4c

File tree

1 file changed

+78
-16
lines changed

1 file changed

+78
-16
lines changed

ext/MTKBifurcationKitExt.jl

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,64 @@ module MTKBifurcationKitExt
66
using ModelingToolkit, Setfield
77
import BifurcationKit
88

9+
### Observable Plotting Handling ###
10+
11+
# Functor used when the plotting variable is an observable. Keeps track of the required information for computing the observable's value at each point of the bifurcation diagram.
12+
struct ObservableRecordFromSolution{S,T}
13+
# The equations determining the observables values.
14+
obs_eqs::S
15+
# The index of the observable that we wish to plot.
16+
target_obs_idx::Int64
17+
# The final index in subs_vals that contains a state.
18+
state_end_idxs::Int64
19+
# The final index in subs_vals that contains a param.
20+
param_end_idxs::Int64
21+
# The index (in subs_vals) that contain the bifurcation parameter.
22+
bif_par_idx::Int64
23+
# A Vector of pairs (Symbolic => value) with teh default values of all system variables and parameters.
24+
subs_vals::T
25+
26+
function ObservableRecordFromSolution(nsys::NonlinearSystem, plot_var, bif_idx, u0_vals, p_vals) where {S,T}
27+
obs_eqs = observed(nsys)
28+
target_obs_idx = findfirst(isequal(plot_var, eq.lhs) for eq in observed(nsys))
29+
state_end_idxs = length(states(nsys))
30+
param_end_idxs = state_end_idxs + length(parameters(nsys))
31+
32+
bif_par_idx = state_end_idxs + bif_idx
33+
# Gets the (base) substitution values for states.
34+
subs_vals_states = Pair.(states(nsys),u0_vals)
35+
# Gets the (base) substitution values for parameters.
36+
subs_vals_params = Pair.(parameters(nsys),p_vals)
37+
# Gets the (base) substitution values for observables.
38+
subs_vals_obs = [obs.lhs => substitute(obs.rhs, [subs_vals_states; subs_vals_params]) for obs in observed(nsys)]
39+
# Sometimes observables depend on other observables, hence we make a second upate to this vector.
40+
subs_vals_obs = [obs.lhs => substitute(obs.rhs, [subs_vals_states; subs_vals_params; subs_vals_obs]) for obs in observed(nsys)]
41+
# During the bifurcation process, teh value of some states, parameters, and observables may vary (and are calculated in each step). Those that are not are stored in this vector
42+
subs_vals = [subs_vals_states; subs_vals_params; subs_vals_obs]
43+
44+
param_end_idxs = state_end_idxs + length(parameters(nsys))
45+
new{typeof(obs_eqs),typeof(subs_vals)}(obs_eqs, target_obs_idx, state_end_idxs, param_end_idxs, bif_par_idx, subs_vals)
46+
end
47+
end
48+
# Functor function that computes the value.
49+
function (orfs::ObservableRecordFromSolution)(x, p)
50+
# Updates the state values (in subs_vals).
51+
for state_idx in 1:orfs.state_end_idxs
52+
orfs.subs_vals[state_idx] = orfs.subs_vals[state_idx][1] => x[state_idx]
53+
end
54+
55+
# Updates the bifurcation parameters value (in subs_vals).
56+
orfs.subs_vals[orfs.bif_par_idx] = orfs.subs_vals[orfs.bif_par_idx][1] => p
57+
58+
# Updates the observable values (in subs_vals).
59+
for (obs_idx, obs_eq) in enumerate(orfs.obs_eqs)
60+
orfs.subs_vals[orfs.param_end_idxs+obs_idx] = orfs.subs_vals[orfs.param_end_idxs+obs_idx][1] => substitute(obs_eq.rhs, orfs.subs_vals)
61+
end
62+
63+
# Substitutes in the value for all states, parameters, and observables into the equation for the designated observable.
64+
return substitute(orfs.obs_eqs[orfs.target_obs_idx].rhs, orfs.subs_vals)
65+
end
66+
967
### Creates BifurcationProblem Overloads ###
1068

1169
# When input is a NonlinearSystem.
@@ -23,25 +81,29 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
2381
F = ofun.f
2482
J = jac ? ofun.jac : nothing
2583

26-
# Computes bifurcation parameter and plot var indexes.
84+
# Converts the input state guess.
85+
u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif, states(nsys); defaults=nsys.defaults)
86+
p_vals = ModelingToolkit.varmap_to_vars(ps, parameters(nsys); defaults=nsys.defaults)
87+
88+
# Computes bifurcation parameter and the plotting function.
2789
bif_idx = findfirst(isequal(bif_par), parameters(nsys))
28-
if !isnothing(plot_var)
29-
plot_idx = findfirst(isequal(plot_var), states(nsys))
30-
record_from_solution = (x, p) -> x[plot_idx]
31-
end
90+
if !isnothing(plot_var)
91+
# If the plot var is a normal state.
92+
if any(isequal(plot_var, var) for var in states(nsys))
93+
plot_idx = findfirst(isequal(plot_var), states(nsys))
94+
record_from_solution = (x, p) -> x[plot_idx]
3295

33-
# Converts the input state guess.
34-
u0_bif = ModelingToolkit.varmap_to_vars(u0_bif, states(nsys))
35-
ps = ModelingToolkit.varmap_to_vars(ps, parameters(nsys))
96+
# If the plot var is an observed state.
97+
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
98+
record_from_solution = ObservableRecordFromSolution(nsys, plot_var, bif_idx, u0_bif_vals, p_vals)
99+
100+
# If neither an variable nor observable, throw an error.
101+
else
102+
error("The plot variable ($plot_var) was neither recognised as a system state nor observable.")
103+
end
104+
end
36105

37-
return BifurcationKit.BifurcationProblem(F,
38-
u0_bif,
39-
ps,
40-
(@lens _[bif_idx]),
41-
args...;
42-
record_from_solution = record_from_solution,
43-
J = J,
44-
kwargs...)
106+
return BifurcationKit.BifurcationProblem(F, u0_bif_vals, p_vals, (@lens _[bif_idx]), args...; record_from_solution = record_from_solution, J = J, kwargs...)
45107
end
46108

47109
# When input is a ODESystem.

0 commit comments

Comments
 (0)