Skip to content

Commit 397ab12

Browse files
authored
Merge pull request #2337 from TorkelE/bif_extension
Improve BifurcationKit extension
2 parents 70c3252 + b7ab44f commit 397ab12

File tree

2 files changed

+207
-19
lines changed

2 files changed

+207
-19
lines changed

ext/MTKBifurcationKitExt.jl

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,76 @@ module MTKBifurcationKitExt
66
using ModelingToolkit, Setfield
77
import BifurcationKit
88

9+
### Observable Plotting Handling ###
10+
11+
# Functor used when the plotting variable is an observable. Keeps track of the required information for computing the observable's value at each point of the bifurcation diagram.
12+
struct ObservableRecordFromSolution{S, T}
13+
# The equations determining the observables values.
14+
obs_eqs::S
15+
# The index of the observable that we wish to plot.
16+
target_obs_idx::Int64
17+
# The final index in subs_vals that contains a state.
18+
state_end_idxs::Int64
19+
# The final index in subs_vals that contains a param.
20+
param_end_idxs::Int64
21+
# The index (in subs_vals) that contain the bifurcation parameter.
22+
bif_par_idx::Int64
23+
# A Vector of pairs (Symbolic => value) with teh default values of all system variables and parameters.
24+
subs_vals::T
25+
26+
function ObservableRecordFromSolution(nsys::NonlinearSystem,
27+
plot_var,
28+
bif_idx,
29+
u0_vals,
30+
p_vals) where {S, T}
31+
obs_eqs = observed(nsys)
32+
target_obs_idx = findfirst(isequal(plot_var, eq.lhs) for eq in observed(nsys))
33+
state_end_idxs = length(states(nsys))
34+
param_end_idxs = state_end_idxs + length(parameters(nsys))
35+
36+
bif_par_idx = state_end_idxs + bif_idx
37+
# Gets the (base) substitution values for states.
38+
subs_vals_states = Pair.(states(nsys), u0_vals)
39+
# Gets the (base) substitution values for parameters.
40+
subs_vals_params = Pair.(parameters(nsys), p_vals)
41+
# Gets the (base) substitution values for observables.
42+
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
43+
[subs_vals_states; subs_vals_params]) for obs in observed(nsys)]
44+
# Sometimes observables depend on other observables, hence we make a second upate to this vector.
45+
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
46+
[subs_vals_states; subs_vals_params; subs_vals_obs]) for obs in observed(nsys)]
47+
# During the bifurcation process, teh value of some states, parameters, and observables may vary (and are calculated in each step). Those that are not are stored in this vector
48+
subs_vals = [subs_vals_states; subs_vals_params; subs_vals_obs]
49+
50+
param_end_idxs = state_end_idxs + length(parameters(nsys))
51+
new{typeof(obs_eqs), typeof(subs_vals)}(obs_eqs,
52+
target_obs_idx,
53+
state_end_idxs,
54+
param_end_idxs,
55+
bif_par_idx,
56+
subs_vals)
57+
end
58+
end
59+
# Functor function that computes the value.
60+
function (orfs::ObservableRecordFromSolution)(x, p)
61+
# Updates the state values (in subs_vals).
62+
for state_idx in 1:(orfs.state_end_idxs)
63+
orfs.subs_vals[state_idx] = orfs.subs_vals[state_idx][1] => x[state_idx]
64+
end
65+
66+
# Updates the bifurcation parameters value (in subs_vals).
67+
orfs.subs_vals[orfs.bif_par_idx] = orfs.subs_vals[orfs.bif_par_idx][1] => p
68+
69+
# Updates the observable values (in subs_vals).
70+
for (obs_idx, obs_eq) in enumerate(orfs.obs_eqs)
71+
orfs.subs_vals[orfs.param_end_idxs + obs_idx] = orfs.subs_vals[orfs.param_end_idxs + obs_idx][1] => substitute(obs_eq.rhs,
72+
orfs.subs_vals)
73+
end
74+
75+
# Substitutes in the value for all states, parameters, and observables into the equation for the designated observable.
76+
return substitute(orfs.obs_eqs[orfs.target_obs_idx].rhs, orfs.subs_vals)
77+
end
78+
979
### Creates BifurcationProblem Overloads ###
1080

1181
# When input is a NonlinearSystem.
@@ -23,20 +93,37 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
2393
F = ofun.f
2494
J = jac ? ofun.jac : nothing
2595

26-
# Computes bifurcation parameter and plot var indexes.
96+
# Converts the input state guess.
97+
u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif,
98+
states(nsys);
99+
defaults = nsys.defaults)
100+
p_vals = ModelingToolkit.varmap_to_vars(ps, parameters(nsys); defaults = nsys.defaults)
101+
102+
# Computes bifurcation parameter and the plotting function.
27103
bif_idx = findfirst(isequal(bif_par), parameters(nsys))
28104
if !isnothing(plot_var)
29-
plot_idx = findfirst(isequal(plot_var), states(nsys))
30-
record_from_solution = (x, p) -> x[plot_idx]
31-
end
105+
# If the plot var is a normal state.
106+
if any(isequal(plot_var, var) for var in states(nsys))
107+
plot_idx = findfirst(isequal(plot_var), states(nsys))
108+
record_from_solution = (x, p) -> x[plot_idx]
32109

33-
# Converts the input state guess.
34-
u0_bif = ModelingToolkit.varmap_to_vars(u0_bif, states(nsys))
35-
ps = ModelingToolkit.varmap_to_vars(ps, parameters(nsys))
110+
# If the plot var is an observed state.
111+
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
112+
record_from_solution = ObservableRecordFromSolution(nsys,
113+
plot_var,
114+
bif_idx,
115+
u0_bif_vals,
116+
p_vals)
117+
118+
# If neither an variable nor observable, throw an error.
119+
else
120+
error("The plot variable ($plot_var) was neither recognised as a system state nor observable.")
121+
end
122+
end
36123

37124
return BifurcationKit.BifurcationProblem(F,
38-
u0_bif,
39-
ps,
125+
u0_bif_vals,
126+
p_vals,
40127
(@lens _[bif_idx]),
41128
args...;
42129
record_from_solution = record_from_solution,

test/extensions/bifurcationkit.jl

Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,136 @@
11
using BifurcationKit, ModelingToolkit, Test
22

3-
# Checks pitchfork diagram and that there are the correct number of branches (a main one and two children)
3+
# Simple pitchfork diagram, compares solution to native BifurcationKit, checks they are identical.
4+
# Checks using `jac=false` option.
45
let
6+
# Creates model.
57
@variables t x(t) y(t)
68
@parameters μ α
79
eqs = [0 ~ μ * x - x^3 + α * y,
810
0 ~ -y]
911
@named nsys = NonlinearSystem(eqs, [x, y], [μ, α])
1012

13+
# Creates BifurcationProblem
1114
bif_par = μ
1215
p_start ==> -1.0, α => 1.0]
1316
u0_guess = [x => 1.0, y => 1.0]
1417
plot_var = x
15-
16-
using BifurcationKit
1718
bprob = BifurcationProblem(nsys,
1819
u0_guess,
1920
p_start,
2021
bif_par;
2122
plot_var = plot_var,
2223
jac = false)
2324

25+
# Conputes bifurcation diagram.
2426
p_span = (-4.0, 6.0)
27+
opts_br = ContinuationPar(max_steps = 500, p_min = p_span[1], p_max = p_span[2])
28+
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
29+
30+
# Computes bifurcation diagram using BifurcationKit directly (without going through MTK).
31+
function f_BK(u, p)
32+
x, y = u
33+
μ, α = p
34+
return* x - x^3 + α * y, -y]
35+
end
36+
bprob_BK = BifurcationProblem(f_BK,
37+
[1.0, 1.0],
38+
[-1.0, 1.0],
39+
(@lens _[1]);
40+
record_from_solution = (x, p) -> x[1])
41+
bif_dia_BK = bifurcationdiagram(bprob_BK,
42+
PALC(),
43+
2,
44+
(args...) -> opts_br;
45+
bothside = true)
46+
47+
# Compares results.
48+
@test getfield.(bif_dia.γ.branch, :x) getfield.(bif_dia_BK.γ.branch, :x)
49+
@test getfield.(bif_dia.γ.branch, :param) getfield.(bif_dia_BK.γ.branch, :param)
50+
@test bif_dia.γ.specialpoint[1].x == bif_dia_BK.γ.specialpoint[1].x
51+
@test bif_dia.γ.specialpoint[1].param == bif_dia_BK.γ.specialpoint[1].param
52+
@test bif_dia.γ.specialpoint[1].type == bif_dia_BK.γ.specialpoint[1].type
53+
end
54+
55+
# Lotka–Volterra model, checks exact position of bifurcation variable and bifurcation points.
56+
# Checks using ODESystem input.
57+
let
58+
# Creates a Lotka–Volterra model.
59+
@parameters α a b
60+
@variables t x(t) y(t) z(t)
61+
D = Differential(t)
62+
eqs = [D(x) ~ -x + a * y + x^2 * y,
63+
D(y) ~ b - a * y - x^2 * y]
64+
@named sys = ODESystem(eqs)
65+
66+
# Creates BifurcationProblem
67+
bprob = BifurcationProblem(sys,
68+
[x => 1.5, y => 1.0],
69+
[a => 0.1, b => 0.5],
70+
b;
71+
plot_var = x)
72+
73+
# Computes bifurcation diagram.
74+
p_span = (0.0, 2.0)
75+
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 2000)
76+
opts_br = ContinuationPar(dsmax = 0.05,
77+
max_steps = 500,
78+
newton_options = opt_newton,
79+
p_min = p_span[1],
80+
p_max = p_span[2],
81+
n_inversion = 4)
82+
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
83+
84+
# Tests that the diagram has the correct values (x = b)
85+
all([b.x b.param for b in bif_dia.γ.branch])
86+
87+
# Tests that we get two Hopf bifurcations at the correct positions.
88+
hopf_points = sort(getfield.(filter(sp -> sp.type == :hopf, bif_dia.γ.specialpoint),
89+
:x);
90+
by = x -> x[1])
91+
@test length(hopf_points) == 2
92+
@test hopf_points[1] [0.41998733080424205, 1.5195495712453098]
93+
@test hopf_points[2] [0.7899715592573977, 1.0910379583813192]
94+
end
95+
96+
# Simple fold bifurcation model, checks exact position of bifurcation variable and bifurcation points.
97+
# Checks that default parameter values are accounted for.
98+
# Checks that observables (that depend on other observables, as in this case) are accounted for.
99+
let
100+
# Creates model, and uses `structural_simplify` to generate observables.
101+
@parameters μ p=2
102+
@variables t x(t) y(t) z(t)
103+
D = Differential(t)
104+
eqs = [0 ~ μ - x^3 + 2x^2,
105+
0 ~ p * μ - y,
106+
0 ~ y - z]
107+
@named nsys = NonlinearSystem(eqs, [x, y, z], [μ, p])
108+
nsys = structural_simplify(nsys)
109+
110+
# Creates BifurcationProblem.
111+
bif_par = μ
112+
p_start ==> 1.0]
113+
u0_guess = [x => 1.0, y => 0.1, z => 0.1]
114+
plot_var = x
115+
bprob = BifurcationProblem(nsys, u0_guess, p_start, bif_par; plot_var = plot_var)
116+
117+
# Computes bifurcation diagram.
118+
p_span = (-4.3, 12.0)
25119
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 20)
26-
opts_br = ContinuationPar(dsmin = 0.001, dsmax = 0.05, ds = 0.01,
27-
max_steps = 100, nev = 2, newton_options = opt_newton,
28-
p_min = p_span[1], p_max = p_span[2],
29-
detect_bifurcation = 3, n_inversion = 4, tol_bisection_eigenvalue = 1e-8,
30-
dsmin_bisection = 1e-9)
120+
opts_br = ContinuationPar(dsmax = 0.05,
121+
max_steps = 500,
122+
newton_options = opt_newton,
123+
p_min = p_span[1],
124+
p_max = p_span[2],
125+
n_inversion = 4)
126+
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
31127

32-
bf = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
128+
# Tests that the diagram has the correct values (x = b)
129+
all([b.x 2 * b.param for b in bif_dia.γ.branch])
33130

34-
@test length(bf.child) == 2
131+
# Tests that we get two fold bifurcations at the correct positions.
132+
fold_points = sort(getfield.(filter(sp -> sp.type == :bp, bif_dia.γ.specialpoint),
133+
:param))
134+
@test length(fold_points) == 2
135+
@test fold_points [-1.1851851706940317, -5.6734983580551894e-6] # test that they occur at the correct parameter values).
35136
end

0 commit comments

Comments
 (0)