Skip to content

Commit

Permalink
Merge pull request #507 from sathvikbhagavan/sb/refactor
Browse files Browse the repository at this point in the history
refactor: use SurrogatesBase for all surrogates
  • Loading branch information
ChrisRackauckas authored Feb 5, 2025
2 parents 208d501 + d296c35 commit bec1cd9
Show file tree
Hide file tree
Showing 41 changed files with 432 additions and 825 deletions.
5 changes: 4 additions & 1 deletion .typos.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[default.extend-words]
ND = "ND"
ND = "ND"
abl = "abl"
eis = "eis"
EIN = "EIN"
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SurrogatesBase = "89f642e6-4179-4274-8202-c11f4bd9a72c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand All @@ -21,21 +22,21 @@ ChainRulesCore = "1.19.1"
Cubature = "1.5"
Distributions = "0.25.71"
ExtendableSparse = "1"
Flux = "0.13, 0.14"
Flux = "0.15, 0.16"
ForwardDiff = "0.10.19"
GLM = "1.5"
IterativeSolvers = "0.9"
LinearAlgebra = "1.10"
Pkg = "1"
Pkg = "1.10"
PolyChaos = "0.2.5"
QuadGK = "2.4"
QuasiMonteCarlo = "0.3.1"
SafeTestsets = "0.1"
SparseArrays = "1.10"
Statistics = "1.10"
Test = "1"
Tracker = "0.2.18"
Zygote = "0.6.62, 0.7"
SurrogatesBase = "1.1.0"
Test = "1.10"
Zygote = "0.7"
julia = "1.10"

[extras]
Expand All @@ -48,7 +49,6 @@ PolyChaos = "8d666b04-775d-5f6e-b778-5ac7c70f65a3"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["Aqua", "Cubature", "SafeTestsets", "Flux", "ForwardDiff", "PolyChaos", "QuadGK", "Test", "Tracker", "Pkg"]
test = ["Aqua", "Cubature", "SafeTestsets", "Flux", "ForwardDiff", "PolyChaos", "QuadGK", "Test", "Pkg"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
AbstractGPs = "0.5.13"
Documenter = "1"
Flux = "0.13.7, 0.14, 0.15, 0.16"
Flux = "0.15, 0.16"
Plots = "1.36.2"
QuadGK = "2.6.0"
SurrogatesAbstractGPs = "0.1.0"
Expand Down
2 changes: 1 addition & 1 deletion lib/SurrogatesAbstractGPs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ AbstractGPs = "0.5"
SafeTestsets = "0.1"
Surrogates = "6.9"
SurrogatesBase = "1.1"
Zygote = "0.6"
Zygote = "0.7"
julia = "1.10"

[extras]
Expand Down
8 changes: 7 additions & 1 deletion lib/SurrogatesFlux/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@ version = "0.1.1"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Surrogates = "6fc51010-71bc-11e9-0e15-a3fcc6593c49"
SurrogatesBase = "89f642e6-4179-4274-8202-c11f4bd9a72c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Flux = "0.13, 0.14, 0.15, 0.16"
Flux = "0.15, 0.16"
Optimisers = "0.4.4"
Surrogates = "6"
SurrogatesBase = "1.1.0"
Zygote = "0.7"
julia = "1.10"

[extras]
Expand Down
119 changes: 81 additions & 38 deletions lib/SurrogatesFlux/src/SurrogatesFlux.jl
Original file line number Diff line number Diff line change
@@ -1,66 +1,109 @@
module SurrogatesFlux

import Surrogates: add_point!, AbstractSurrogate, _check_dimension
export NeuralSurrogate

using SurrogatesBase
using Optimisers
using Flux

mutable struct NeuralSurrogate{X, Y, M, L, O, P, N, A, U} <: AbstractSurrogate
export NeuralSurrogate, update!

mutable struct NeuralSurrogate{X, Y, M, L, O, P, N, A, U} <: AbstractDeterministicSurrogate
x::X
y::Y
model::M
loss::L
opt::O
ps::P
n_echos::N
n_epochs::N
lb::A
ub::U
end

"""
NeuralSurrogate(x,y,lb,ub,model,loss,opt,n_echos)
NeuralSurrogate(x, y, lb, ub; model = Chain(Dense(length(x[1]), 1), first),
loss = (x, y) -> Flux.mse(model(x), y),
opt = Optimisers.Adam(1e-3),
n_epochs = 10)
## Arguments
- model: Flux layers
- loss: loss function
- opt: optimization function
- `x`: Input data points.
- `y`: Output data points.
- `lb`: Lower bound of input data points.
- `ub`: Upper bound of output data points.
# Keyword Arguments
- `model`: Flux Chain
- `loss`: loss function from minimization
- `opt`: Optimiser defined using Optimisers.jl
- `n_epochs`: number of epochs for training
"""
function NeuralSurrogate(x, y, lb, ub; model = Chain(Dense(length(x[1]), 1), first),
loss = (x, y) -> Flux.mse(model(x), y), opt = Descent(0.01),
n_echos::Int = 1)
X = vec.(collect.(x))
data = zip(X, y)
ps = Flux.params(model)
for epoch in 1:n_echos
Flux.train!(loss, ps, data, opt)
function NeuralSurrogate(x, y, lb, ub; model = Chain(Dense(length(x[1]), 1)),
loss = Flux.mse, opt = Optimisers.Adam(1e-3),
n_epochs::Int = 10)
if x isa Tuple
x = reduce(hcat, x)'
elseif x isa Vector{<:Tuple}
x = reduce(hcat, collect.(x))
elseif x isa Vector
if size(x) == (1,) && size(x[1]) == ()
x = hcat(x)
else
x = reduce(hcat, x)
end
end
y = reduce(hcat, y)
opt_state = Flux.setup(opt, model)
for _ in 1:n_epochs
grads = Flux.gradient(model) do m
result = m(x)
loss(result, y)
end
Flux.update!(opt_state, model, grads[1])
end
return NeuralSurrogate(x, y, model, loss, opt, ps, n_echos, lb, ub)
ps = Flux.trainable(model)
return NeuralSurrogate(x, y, model, loss, opt, ps, n_epochs, lb, ub)
end

function (my_neural::NeuralSurrogate)(val)
# Check to make sure dimensions of input matches expected dimension of surrogate
_check_dimension(my_neural, val)
v = [val...]
out = my_neural.model(v)
if length(out) == 1
return out[1]
else
return out
end
out = my_neural.model(val)
return out
end

function (my_neural::NeuralSurrogate)(val::Tuple)
out = my_neural.model(collect(val))
return out
end

function (my_neural::NeuralSurrogate)(val::Number)
out = my_neural(reduce(hcat, [[val]]))
return out
end

function add_point!(my_n::NeuralSurrogate, x_new, y_new)
if eltype(x_new) == eltype(my_n.x)
append!(my_n.x, x_new)
append!(my_n.y, y_new)
else
push!(my_n.x, x_new)
push!(my_n.y, y_new)
function SurrogatesBase.update!(my_n::NeuralSurrogate, x_new, y_new)
if x_new isa Tuple
x_new = reduce(hcat, x_new)'
elseif x_new isa Vector{<:Tuple}
x_new = reduce(hcat, collect.(x_new))
elseif x_new isa Vector
if size(x_new) == (1,) && size(x_new[1]) == ()
x_new = hcat(x_new)
else
x_new = reduce(hcat, x_new)
end
end
X = vec.(collect.(my_n.x))
data = zip(X, my_n.y)
for epoch in 1:(my_n.n_echos)
Flux.train!(my_n.loss, my_n.ps, data, my_n.opt)
y_new = reduce(hcat, y_new)
opt_state = Flux.setup(my_n.opt, my_n.model)
for _ in 1:(my_n.n_epochs)
grads = Flux.gradient(my_n.model) do m
result = m(x_new)
my_n.loss(result, y_new)
end
Flux.update!(opt_state, my_n.model, grads[1])
end
my_n.ps = Flux.trainable(my_n.model)
my_n.x = hcat(my_n.x, x_new)
my_n.y = hcat(my_n.y, y_new)
nothing
end

Expand Down
Loading

0 comments on commit bec1cd9

Please sign in to comment.