Skip to content

Commit 4082506

Browse files
feat: add MTKBase ext for OptimizationBase, disable MTKExt post v11
1 parent 2e65fd7 commit 4082506

File tree

3 files changed

+413
-195
lines changed

3 files changed

+413
-195
lines changed

lib/OptimizationBase/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2424
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
2525
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2626
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
27+
ModelingToolkitBase = "7771a370-6774-4173-bd38-47e70ca0b839"
2728
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2829
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
2930
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -35,6 +36,7 @@ OptimizationForwardDiffExt = "ForwardDiff"
3536
OptimizationMLDataDevicesExt = "MLDataDevices"
3637
OptimizationMLUtilsExt = "MLUtils"
3738
OptimizationMTKExt = "ModelingToolkit"
39+
OptimizationMTKBaseExt = "ModelingToolkitBase"
3840
OptimizationReverseDiffExt = "ReverseDiff"
3941
OptimizationSymbolicAnalysisExt = "SymbolicAnalysis"
4042
OptimizationZygoteExt = "Zygote"
@@ -52,6 +54,7 @@ LinearAlgebra = "1.9, 1.10"
5254
MLDataDevices = "1"
5355
MLUtils = "0.4"
5456
ModelingToolkit = "10.23"
57+
ModelingToolkitBase = "1"
5558
PDMats = "0.11"
5659
Reexport = "1.2"
5760
ReverseDiff = "1.14"
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
module OptimizationMTKBaseExt
2+
3+
import OptimizationBase, OptimizationBase.ArrayInterface
4+
import SciMLBase
5+
import SciMLBase: OptimizationFunction
6+
import OptimizationBase.ADTypes: AutoSymbolics, AutoSparse
7+
using ModelingToolkitBase
8+
9+
function OptimizationBase.instantiate_function(
10+
f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics}, p,
11+
num_cons = 0;
12+
g = false, h = false, hv = false, fg = false, fgh = false,
13+
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
14+
lag_h = false)
15+
p = isnothing(p) ? SciMLBase.NullParameters() : p
16+
17+
sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, x, p;
18+
lcons = fill(0.0,
19+
num_cons),
20+
ucons = fill(0.0,
21+
num_cons))))
22+
#sys = ModelingToolkit.structural_simplify(sys)
23+
# don't need to pass `x` or `p` since they're defaults now
24+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
25+
sparse = true, cons_j = cons_j, cons_h = cons_h,
26+
cons_sparse = true)
27+
f = mtkprob.f
28+
29+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
30+
31+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
32+
33+
hv = function (H, θ, v, args...)
34+
res = similar(f.hess_prototype, eltype(θ))
35+
hess(res, θ, args...)
36+
H .= res * v
37+
end
38+
39+
if !isnothing(f.cons)
40+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
41+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
42+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
43+
else
44+
cons = nothing
45+
cons_j = nothing
46+
cons_h = nothing
47+
end
48+
49+
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
50+
cons = cons, cons_j = cons_j, cons_h = cons_h,
51+
hess_prototype = f.hess_prototype,
52+
cons_jac_prototype = f.cons_jac_prototype,
53+
cons_hess_prototype = f.cons_hess_prototype,
54+
expr = OptimizationBase.symbolify(f.expr),
55+
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
56+
sys = sys,
57+
observed = f.observed)
58+
end
59+
60+
function OptimizationBase.instantiate_function(
61+
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
62+
adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0;
63+
g = false, h = false, hv = false, fg = false, fgh = false,
64+
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
65+
lag_h = false)
66+
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p
67+
68+
sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, cache.u0,
69+
cache.p;
70+
lcons = fill(0.0,
71+
num_cons),
72+
ucons = fill(0.0,
73+
num_cons))))
74+
#sys = ModelingToolkit.structural_simplify(sys)
75+
# don't need to pass `x` or `p` since they're defaults now
76+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
77+
sparse = true, cons_j = cons_j, cons_h = cons_h,
78+
cons_sparse = true)
79+
f = mtkprob.f
80+
81+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
82+
83+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
84+
85+
hv = function (H, θ, v, args...)
86+
res = similar(f.hess_prototype, eltype(θ))
87+
hess(res, θ, args...)
88+
H .= res * v
89+
end
90+
if !isnothing(f.cons)
91+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
92+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
93+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
94+
else
95+
cons = nothing
96+
cons_j = nothing
97+
cons_h = nothing
98+
end
99+
100+
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
101+
cons = cons, cons_j = cons_j, cons_h = cons_h,
102+
hess_prototype = f.hess_prototype,
103+
cons_jac_prototype = f.cons_jac_prototype,
104+
cons_hess_prototype = f.cons_hess_prototype,
105+
expr = OptimizationBase.symbolify(f.expr),
106+
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
107+
sys = sys,
108+
observed = f.observed)
109+
end
110+
111+
function OptimizationBase.instantiate_function(
112+
f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p,
113+
num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false,
114+
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
115+
lag_h = false)
116+
p = isnothing(p) ? SciMLBase.NullParameters() : p
117+
118+
sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, x, p;
119+
lcons = fill(0.0,
120+
num_cons),
121+
ucons = fill(0.0,
122+
num_cons))))
123+
#sys = ModelingToolkit.structural_simplify(sys)
124+
# don't need to pass `x` or `p` since they're defaults now
125+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
126+
sparse = false, cons_j = cons_j, cons_h = cons_h,
127+
cons_sparse = false)
128+
f = mtkprob.f
129+
130+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
131+
132+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
133+
134+
hv = function (H, θ, v, args...)
135+
res = ArrayInterface.zeromatrix(θ)
136+
hess(res, θ, args...)
137+
H .= res * v
138+
end
139+
140+
if !isnothing(f.cons)
141+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
142+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
143+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
144+
else
145+
cons = nothing
146+
cons_j = nothing
147+
cons_h = nothing
148+
end
149+
150+
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
151+
cons = cons, cons_j = cons_j, cons_h = cons_h,
152+
hess_prototype = f.hess_prototype,
153+
cons_jac_prototype = f.cons_jac_prototype,
154+
cons_hess_prototype = f.cons_hess_prototype,
155+
expr = OptimizationBase.symbolify(f.expr),
156+
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
157+
sys = sys,
158+
observed = f.observed)
159+
end
160+
161+
function OptimizationBase.instantiate_function(
162+
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
163+
adtype::AutoSymbolics, num_cons = 0;
164+
g = false, h = false, hv = false, fg = false, fgh = false,
165+
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
166+
lag_h = false)
167+
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p
168+
169+
sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, cache.u0,
170+
cache.p;
171+
lcons = fill(0.0,
172+
num_cons),
173+
ucons = fill(0.0,
174+
num_cons))))
175+
#sys = ModelingToolkit.structural_simplify(sys)
176+
# don't need to pass `x` or `p` since they're defaults now
177+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
178+
sparse = false, cons_j = cons_j, cons_h = cons_h,
179+
cons_sparse = false)
180+
f = mtkprob.f
181+
182+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
183+
184+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
185+
186+
hv = function (H, θ, v, args...)
187+
res = ArrayInterface.zeromatrix(θ)
188+
hess(res, θ, args...)
189+
H .= res * v
190+
end
191+
192+
if !isnothing(f.cons)
193+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
194+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
195+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
196+
else
197+
cons = nothing
198+
cons_j = nothing
199+
cons_h = nothing
200+
end
201+
202+
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
203+
cons = cons, cons_j = cons_j, cons_h = cons_h,
204+
hess_prototype = f.hess_prototype,
205+
cons_jac_prototype = f.cons_jac_prototype,
206+
cons_hess_prototype = f.cons_hess_prototype,
207+
expr = OptimizationBase.symbolify(f.expr),
208+
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
209+
sys = sys,
210+
observed = f.observed)
211+
end
212+
213+
end

0 commit comments

Comments
 (0)