Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
vyudu committed Sep 19, 2024
1 parent 154bb6f commit d86bf72
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 49 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
<!-- [![Coverage Status](https://coveralls.io/repos/github/SciML/JumpProcesses.jl/badge.svg?branch=master)](https://coveralls.io/github/SciML/JumpProcesses.jl?branch=master)
[![codecov](https://codecov.io/gh/SciML/JumpProcesses.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/SciML/JumpProcesses.jl) -->
<!-- [![Join the chat at https://julialang.zulipchat.com #sciml-bridged](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/279055-sciml-bridged) -->

[![Build Status](https://github.com/SciML/JumpProcesses.jl/workflows/CI/badge.svg)](https://github.com/SciML/JumpProcesses.jl/actions?query=workflow%3ACI)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)
Expand Down
4 changes: 2 additions & 2 deletions src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ evolution, Journal of Machine Learning Research 18(1), 1305–1353 (2017). doi:
struct Coevolve <: AbstractAggregatorAlgorithm end

"""
A constant-complexity NRM method. Stores next reaction times in a table with a specified bin width.
A constant-complexity NRM method. Stores next reaction times in a table with a specified bin width.
Kevin R. Sanft and Hans G. Othmer, Constant-complexity stochastic simulation
algorithm with optimal binning, Journal of Chemical Physics 143, 074108
Expand Down Expand Up @@ -173,7 +173,7 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108
struct DirectCRDirect <: AbstractAggregatorAlgorithm end

const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(),
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), CCNRM())
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), CCNRM())

# For JumpProblem construction without an aggregator
struct NullAggregator <: AbstractAggregatorAlgorithm end
Expand Down
12 changes: 7 additions & 5 deletions src/aggregators/ccnrm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,
add_self_dependencies!(dg)
end

ptt = PriorityTimeTable(zeros(T, length(crs)), 0., 1.) # We will re-initialize this in initialize!()
ptt = PriorityTimeTable(zeros(T, length(crs)), 0.0, 1.0) # We will re-initialize this in initialize!()

affecttype = F2 <: Tuple ? F2 : Any
CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}(nj, nj, njt, et,
CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}(
nj, nj, njt, et,
crs, sr, maj,
rs, affs!, sps,
rng, dg, ptt)
Expand Down Expand Up @@ -104,7 +105,8 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t)

@inbounds for rx in dep_rxs
oldrate = cur_rates[rx]
times = ptt.times; oldtime = times[rx]
times = ptt.times
oldtime = times[rx]

# update the jump rate
@inbounds cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u,
Expand All @@ -115,13 +117,13 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t)
if cur_rates[rx] > zero(eltype(cur_rates))
update!(ptt, rx, oldtime, t + oldrate / cur_rates[rx] * (times[rx] - t))
else
update!(ptt, rx, oldtime, 2*end_time)
update!(ptt, rx, oldtime, 2 * end_time)
end
else
if cur_rates[rx] > zero(eltype(cur_rates))
update!(ptt, rx, oldtime, t + randexp(p.rng) / cur_rates[rx])
else
update!(ptt, rx, oldtime, 2*end_time)
update!(ptt, rx, oldtime, 2 * end_time)
end
end
end
Expand Down
64 changes: 36 additions & 28 deletions src/aggregators/prioritytable.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
"""
Dynamic table data structure to store and update priorities
Implementation
Expand Down Expand Up @@ -142,7 +142,6 @@ end
pt.gsum
end


"""
Adds extra groups to the table to accommodate a new maxpriority.
"""
Expand Down Expand Up @@ -185,7 +184,7 @@ function insert!(pt::PriorityTable, pid, priority)
push!(pidtogroup, (gid, pididx))
end

gid
nothing
end

function update!(pt::PriorityTable, pid, oldpriority, newpriority)
Expand Down Expand Up @@ -231,7 +230,6 @@ function update!(pt::PriorityTable, pid, oldpriority, newpriority)
nothing
end


function reset!(pt::PriorityTable{F, S, T, U}) where {F, S, T, U}
@unpack groups, gsums, pidtogroup = pt
pt.gsum = zero(F)
Expand Down Expand Up @@ -328,7 +326,7 @@ mutable struct TimeGrouper{T <: Number}
timestep::T
end

function (t::TimeGrouper)(time)
function (t::TimeGrouper)(time)
return floor(Int, (time - t.mintime) / t.timestep) + 1
end

Expand All @@ -346,7 +344,7 @@ end
# Construct the time table with the default optimal bin width and number of bins.
# DEFAULT NUMBINS: 20 * √length(times)
# DEFAULT BINWIDTH: 16 / sum(propensities)
function PriorityTimeTable(times::AbstractVector, mintime, timestep)
function PriorityTimeTable(times::AbstractVector, mintime, timestep)
numbins = floor(Int, 20 * sqrt(length(times)))
maxtime = mintime + numbins * timestep

Expand All @@ -362,7 +360,7 @@ function PriorityTimeTable(times::AbstractVector, mintime, timestep)

# Create the groups, [t_min, t_min + τ), [t_min + τ, t_min + 2τ)...
for i in 1:numbins
push!(groups, PriorityGroup{pidtype}(mintime + i*timestep))
push!(groups, PriorityGroup{pidtype}(mintime + i * timestep))
end

pt = PriorityTable(mintime, maxtime, groups, gsums, gsum, pidtogroup, timetogroup)
Expand All @@ -376,46 +374,53 @@ function PriorityTimeTable(times::AbstractVector, mintime, timestep)
gid = insert!(pt, pid, time)
end

minbin = findfirst(g -> g.numpids >(0), pt.groups); minbin === nothing && (minbin = 0)
minbin = findfirst(g -> g.numpids > (0), pt.groups)
minbin === nothing && (minbin = 0)
ptt = PriorityTimeTable(pt, times, ttgdata, minbin, 0)
end

# Rebuild the table when there are no more reaction times within the current
# time window.
function rebuild!(ptt::PriorityTimeTable, mintime, timestep)
@unpack pt, times, timegrouper = ptt
reset!(pt); groups = pt.groups
reset!(pt)
groups = pt.groups

numbins = floor(Int, 20 * sqrt(length(times)))
pt.minpriority = mintime; pt.maxpriority = mintime + numbins*timestep
timegrouper.mintime = mintime; timegrouper.timestep = timestep
pt.minpriority = mintime
pt.maxpriority = mintime + numbins * timestep
timegrouper.mintime = mintime
timegrouper.timestep = timestep
for i in 1:numbins
groups[i].maxpriority = mintime + i*timestep
groups[i].maxpriority = mintime + i * timestep
end

# Reinsert the times into the groups.
for (id, time) in enumerate(times)
time > pt.maxpriority && continue
insert!(pt, id, time)
end
ptt.minbin = findfirst(g -> g.numpids >(0), pt.groups); ptt.minbin === nothing && (ptt.minbin = 0)
ptt.minbin = findfirst(g -> g.numpids > (0), pt.groups)
ptt.minbin === nothing && (ptt.minbin = 0)
ptt.steps = 0
end

# Get the reaction with the earliest timestep.
function getfirst(ptt::PriorityTimeTable)
@unpack pt, times, minbin = ptt
minbin == 0 && return (nothing, nothing)
groups = pt.groups; gsums = pt.gsums
groups = pt.groups
gsums = pt.gsums

while groups[minbin].numpids == 0
minbin += 1
if minbin > length(groups)
return (nothing, nothing)
return (nothing, nothing)
end
end

ptt.minbin = minbin; ptt.steps += 1

ptt.minbin = minbin
ptt.steps += 1
pids = ids(groups[minbin])
min_time, min_idx = findmin(times[pids])
return pids[min_idx], min_time
Expand All @@ -424,26 +429,28 @@ end
# Update the priority table when a reaction time gets updated. We only shift
# between bins if the new time is within the current time window; otherwise
# we remove the reaction and wait until rebuild.
function update!(ptt::PriorityTimeTable, pid, oldtime, newtime)
function update!(ptt::PriorityTimeTable, pid, oldtime, newtime)
@unpack pt, times, timegrouper = ptt
maxtime = pt.maxpriority; pidtogroup = pt.pidtogroup; groups = pt.groups
maxtime = pt.maxpriority
pidtogroup = pt.pidtogroup
groups = pt.groups

times[pid] = newtime
if oldtime >= maxtime
if oldtime >= maxtime
# If a reaction comes back into the time window, insert it.
newtime < maxtime ? insert!(pt, pid, newtime) : return
elseif newtime >= maxtime
# If the new time lands outside of current window, remove it.
pop!(pt, pid, oldtime)
pop!(pt, pid, oldtime)
else
# Move bins if the reaction was already inside.
move_bins!(pt, pid, oldtime, newtime)
move_bins!(pt, pid, oldtime, newtime)
end
end

function pop!(pt::PriorityTable, pid, oldtime)
function pop!(pt::PriorityTable, pid, oldtime)
@unpack pidtogroup, groups = pt
@inbounds begin
@inbounds begin
gid, pidx = pidtogroup[pid]
movedpid = remove!(groups[gid], pidx)
pidtogroup[movedpid] = (gid, pidx)
Expand All @@ -452,10 +459,11 @@ function pop!(pt::PriorityTable, pid, oldtime)
gid
end

function move_bins!(pt::PriorityTable, pid, oldtime, newtime)
@unpack pidtogroup, groups, priortogid = pt
oldgid = priortogid(oldtime); newgid = priortogid(newtime)
oldgid == newgid && return
function move_bins!(pt::PriorityTable, pid, oldtime, newtime)
@unpack pidtogroup, groups, priortogid = pt
oldgid = priortogid(oldtime)
newgid = priortogid(newtime)
oldgid == newgid && return
@inbounds begin
rank = pidtogroup[pid][2]
movedpid = remove!(groups[oldgid], rank)
Expand Down
2 changes: 1 addition & 1 deletion test/bimolerx_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ doprintmeans = false
# SSAs to test
SSAalgs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator())

Nsims = 32000
Nsims = 32000
tf = 0.01
u0 = [200, 100, 150]
expected_avg = 84.876015624999994
Expand Down
24 changes: 12 additions & 12 deletions test/table_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,36 +62,36 @@ for i in 1:Nsamps
end
@test abs(cnt // Nsamps - 0.008968535978248484) / 0.008968535978248484 < 0.05



##### PRIORITY TIME TABLE TESTS FOR CCNRM
mintime = 0.; maxtime = 100.; timestep = 1.5
times = [2., 8., 13., 15., 74.]
mintime = 0.0;
maxtime = 100.0;
timestep = 1.5;
times = [2.0, 8.0, 13.0, 15.0, 74.0]

ptt = DJ.PriorityTimeTable(times, mintime, timestep)
@test DJ.getfirst(ptt) == (1, 2.)
@test DJ.getfirst(ptt) == (1, 2.0)
@test ptt.pt.pidtogroup[5] == (0, 0) # Should not store the last one, outside the window.

# Test update
DJ.update!(ptt, 1, times[1], 10*times[1]) # 2. -> 20., group 2 to group 14
DJ.update!(ptt, 1, times[1], 10 * times[1]) # 2. -> 20., group 2 to group 14
@test ptt.pt.groups[14].numpids == 1
@test DJ.getfirst(ptt) == (2, 8.)
@test DJ.getfirst(ptt) == (2, 8.0)
# Updating beyond the time window should not change the max priority.
DJ.update!(ptt, 1, times[1], 70.) # 20. -> 70.
DJ.update!(ptt, 1, times[1], 70.0) # 20. -> 70.
@test ptt.pt.groups[14].numpids == 0
@test ptt.pt.maxpriority == 66.
@test ptt.pt.maxpriority == 66.0
@test ptt.pt.pidtogroup[1] == (0, 0)

# Test rebuild
for i in 1:4
DJ.update!(ptt, i, times[i], times[i] + 66.)
DJ.update!(ptt, i, times[i], times[i] + 66.0)
end
@test DJ.getfirst(ptt) === (nothing, nothing) # No more left.

mintime = 66.; timestep = .75
mintime = 66.0;
timestep = 0.75;
DJ.rebuild!(ptt, mintime, timestep)
@test ptt.pt.groups[11].numpids == 2 # 73.5-74.25
@test ptt.pt.groups[18].numpids == 1
@test ptt.pt.groups[21].numpids == 1
@test ptt.pt.pidtogroup[1] == (0, 0)

0 comments on commit d86bf72

Please sign in to comment.