Skip to content

Commit d7fd3e4

Browse files
feat: allow adding constant and nonnumeric parameters to IndexCache after construction
1 parent 0656a12 commit d7fd3e4

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

src/systems/index_cache.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ struct IndexCache
6363
symbol_to_variable::Dict{Symbol, SymbolicParam}
6464
end
6565

66+
function Base.copy(ic::IndexCache)
67+
IndexCache(copy(ic.unknown_idx), copy(ic.discrete_idx), copy(ic.callback_to_clocks),
68+
copy(ic.tunable_idx), copy(ic.initials_idx), copy(ic.constant_idx),
69+
copy(ic.nonnumeric_idx), copy(ic.observed_syms_to_timeseries),
70+
copy(ic.dependent_pars_to_timeseries), copy(ic.discrete_buffer_sizes),
71+
ic.tunable_buffer_size, ic.initials_buffer_size,
72+
copy(ic.constant_buffer_sizes), copy(ic.nonnumeric_buffer_sizes),
73+
copy(ic.symbol_to_variable))
74+
end
75+
6676
function IndexCache(sys::AbstractSystem)
6777
unks = unknowns(sys)
6878
unk_idxs = UnknownIndexMap()
@@ -716,3 +726,55 @@ function subset_unknowns_observed(
716726
@set! ic.observed_syms_to_timeseries = observed_syms_to_timeseries
717727
return ic
718728
end
729+
730+
function with_additional_constant_parameter(sys::AbstractSystem, par)
731+
par = unwrap(par)
732+
ps = copy(get_ps(sys))
733+
push!(ps, par)
734+
@set! sys.ps = ps
735+
is_split(sys) || return sys
736+
737+
ic = copy(get_index_cache(sys))
738+
T = symtype(par)
739+
bufidx = findfirst(buft -> buft.type == T, ic.constant_buffer_sizes)
740+
if bufidx === nothing
741+
push!(ic.constant_buffer_sizes, BufferTemplate(T, 1))
742+
bufidx = length(ic.constant_buffer_sizes)
743+
idx_in_buf = 1
744+
else
745+
buft = ic.constant_buffer_sizes[bufidx]
746+
ic.constant_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1)
747+
idx_in_buf = buft.length + 1
748+
end
749+
750+
ic.constant_idx[par] = ic.constant_idx[renamespace(sys, par)] = (bufidx, idx_in_buf)
751+
@set! sys.index_cache = ic
752+
753+
return sys
754+
end
755+
756+
function with_additional_nonnumeric_parameter(sys::AbstractSystem, par)
757+
par = unwrap(par)
758+
ps = copy(get_ps(sys))
759+
push!(ps, par)
760+
@set! sys.ps = ps
761+
is_split(sys) || return sys
762+
763+
ic = copy(get_index_cache(sys))
764+
T = symtype(par)
765+
bufidx = findfirst(buft -> buft.type == T, ic.nonnumeric_buffer_sizes)
766+
if bufidx === nothing
767+
push!(ic.nonnumeric_buffer_sizes, BufferTemplate(T, 1))
768+
bufidx = length(ic.nonnumeric_buffer_sizes)
769+
idx_in_buf = 1
770+
else
771+
buft = ic.nonnumeric_buffer_sizes[bufidx]
772+
ic.nonnumeric_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1)
773+
idx_in_buf = buft.length + 1
774+
end
775+
776+
ic.nonnumeric_idx[par] = ic.nonnumeric_idx[renamespace(sys, par)] = (bufidx, idx_in_buf)
777+
@set! sys.index_cache = ic
778+
779+
return sys
780+
end

0 commit comments

Comments
 (0)