Skip to content

Commit 08c1753

Browse files
829 Optimization of TestingStrategy (#830)
Performance optimization of TestingStrategies - vector instead of unordered_map for testing schemes): ~25% decrease of run time - switch of ifs in testing strategy: additional ~8 % decrease of run time - CustomIndexArray for GoToWork/GoToSchool paramater: no measureable change - bitset for agegroups in TestingCriteria with fixed number of age groups: ~7% decrease of run time Co-authored-by: DavidKerkmann <[email protected]>
1 parent 8ed8364 commit 08c1753

10 files changed

+165
-62
lines changed

cpp/examples/abm_minimal.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ int main()
4141
world.parameters.get<mio::abm::IncubationPeriod>() = 4.;
4242

4343
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
44-
world.parameters.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
44+
world.parameters.get<mio::abm::AgeGroupGotoSchool>() = false;
45+
world.parameters.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
4546
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59)
46-
world.parameters.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
47+
world.parameters.get<mio::abm::AgeGroupGotoWork>() = false;
48+
world.parameters.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
49+
world.parameters.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
4750

4851
// Check if the parameters satisfy their contraints.
4952
world.parameters.check_constraints();
@@ -169,4 +172,4 @@ int main()
169172
std::cout << "Results written to abm_minimal.txt" << std::endl;
170173

171174
return 0;
172-
}
175+
}

cpp/models/abm/config.h

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (C) 2020-2024 MEmilio
3+
*
4+
* Authors: Daniel Abele
5+
*
6+
* Contact: Martin J. Kuehn <[email protected]>
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
#ifndef MIO_ABM_CONFIG_H
21+
#define MIO_ABM_CONFIG_H
22+
23+
namespace mio
24+
{
25+
namespace abm
26+
{
27+
28+
/**
29+
* Maximum number of age groups allowed in the model.
30+
*/
31+
const constexpr int MAX_NUM_AGE_GROUPS = 64;
32+
33+
}
34+
} // namespace mio
35+
36+
#endif

cpp/models/abm/migration_rules.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ LocationType go_to_school(Person::RandomNumberGenerator& /*rng*/, const Person&
5656
if (current_loc == LocationType::Home && t < params.get<LockdownDate>() && t.day_of_week() < 5 &&
5757
person.get_go_to_school_time(params) >= t.time_since_midnight() &&
5858
person.get_go_to_school_time(params) < t.time_since_midnight() + dt &&
59-
params.get<mio::abm::AgeGroupGotoSchool>().count(person.get_age()) && person.goes_to_school(t, params) &&
59+
params.get<mio::abm::AgeGroupGotoSchool>()[person.get_age()] && person.goes_to_school(t, params) &&
6060
!person.is_in_quarantine()) {
6161
return LocationType::School;
6262
}
@@ -73,7 +73,7 @@ LocationType go_to_work(Person::RandomNumberGenerator& /*rng*/, const Person& pe
7373
auto current_loc = person.get_location().get_type();
7474

7575
if (current_loc == LocationType::Home && t < params.get<LockdownDate>() &&
76-
params.get<mio::abm::AgeGroupGotoWork>().count(person.get_age()) && t.day_of_week() < 5 &&
76+
params.get<mio::abm::AgeGroupGotoWork>()[person.get_age()] && t.day_of_week() < 5 &&
7777
t.time_since_midnight() + dt > person.get_go_to_work_time(params) &&
7878
t.time_since_midnight() <= person.get_go_to_work_time(params) && person.goes_to_work(t, params) &&
7979
!person.is_in_quarantine()) {

cpp/models/abm/parameters.h

+11-6
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,12 @@ struct GotoSchoolTimeMaximum {
497497
* @brief The set of AgeGroups that can go to school.
498498
*/
499499
struct AgeGroupGotoSchool {
500-
using Type = std::set<AgeGroup>;
501-
static Type get_default(AgeGroup /*size*/)
500+
using Type = CustomIndexArray<bool, AgeGroup>;
501+
static Type get_default(AgeGroup num_agegroups)
502502
{
503-
return std::set<AgeGroup>{AgeGroup(1)};
503+
auto a = Type(num_agegroups, false);
504+
a[AgeGroup(1)] = true;
505+
return a;
504506
}
505507
static std::string name()
506508
{
@@ -512,10 +514,13 @@ struct AgeGroupGotoSchool {
512514
* @brief The set of AgeGroups that can go to work.
513515
*/
514516
struct AgeGroupGotoWork {
515-
using Type = std::set<AgeGroup>;
516-
static Type get_default(AgeGroup /*size*/)
517+
using Type = CustomIndexArray<bool, AgeGroup>;
518+
static Type get_default(AgeGroup num_agegroups)
517519
{
518-
return std::set<AgeGroup>{AgeGroup(2), AgeGroup(3)};
520+
auto a = Type(num_agegroups, false);
521+
a[AgeGroup(2)] = true;
522+
a[AgeGroup(3)] = true;
523+
return a;
519524
}
520525
static std::string name()
521526
{

cpp/models/abm/testing_strategy.cpp

+52-25
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace abm
2929
TestingCriteria::TestingCriteria(const std::vector<AgeGroup>& ages, const std::vector<InfectionState>& infection_states)
3030
{
3131
for (auto age : ages) {
32-
m_ages.insert(static_cast<size_t>(age));
32+
m_ages.set(static_cast<size_t>(age), true);
3333
}
3434
for (auto infection_state : infection_states) {
3535
m_infection_states.set(static_cast<size_t>(infection_state), true);
@@ -43,13 +43,12 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const
4343

4444
void TestingCriteria::add_age_group(const AgeGroup age_group)
4545
{
46-
47-
m_ages.insert(static_cast<size_t>(age_group));
46+
m_ages.set(static_cast<size_t>(age_group), true);
4847
}
4948

5049
void TestingCriteria::remove_age_group(const AgeGroup age_group)
5150
{
52-
m_ages.erase(static_cast<size_t>(age_group));
51+
m_ages.set(static_cast<size_t>(age_group), false);
5352
}
5453

5554
void TestingCriteria::add_infection_state(const InfectionState infection_state)
@@ -65,7 +64,7 @@ void TestingCriteria::remove_infection_state(const InfectionState infection_stat
6564
bool TestingCriteria::evaluate(const Person& p, TimePoint t) const
6665
{
6766
// An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set.
68-
return (m_ages.empty() || m_ages.count(static_cast<size_t>(p.get_age()))) &&
67+
return (m_ages.none() || m_ages[static_cast<size_t>(p.get_age())]) &&
6968
(m_infection_states.none() || m_infection_states[static_cast<size_t>(p.get_infection_state(t))]);
7069
}
7170

@@ -104,9 +103,9 @@ void TestingScheme::update_activity_status(TimePoint t)
104103
bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& person, TimePoint t) const
105104
{
106105
if (person.get_time_since_negative_test() > m_minimal_time_since_last_test) {
107-
double random = UniformDistribution<double>::get_instance()(rng);
108-
if (random < m_probability) {
109-
if (m_testing_criteria.evaluate(person, t)) {
106+
if (m_testing_criteria.evaluate(person, t)) {
107+
double random = UniformDistribution<double>::get_instance()(rng);
108+
if (random < m_probability) {
110109
return !person.get_tested(rng, t, m_test_type.get_default());
111110
}
112111
}
@@ -116,23 +115,45 @@ bool TestingScheme::run_scheme(Person::RandomNumberGenerator& rng, Person& perso
116115

117116
TestingStrategy::TestingStrategy(
118117
const std::unordered_map<LocationId, std::vector<TestingScheme>>& location_to_schemes_map)
119-
: m_location_to_schemes_map(location_to_schemes_map)
118+
: m_location_to_schemes_map(location_to_schemes_map.begin(), location_to_schemes_map.end())
120119
{
121120
}
122121

123122
void TestingStrategy::add_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme)
124123
{
125-
auto& schemes_vector = m_location_to_schemes_map[loc_id];
126-
if (std::find(schemes_vector.begin(), schemes_vector.end(), scheme) == schemes_vector.end()) {
127-
schemes_vector.emplace_back(scheme);
124+
auto iter_schemes =
125+
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) {
126+
return p.first == loc_id;
127+
});
128+
if (iter_schemes == m_location_to_schemes_map.end()) {
129+
//no schemes for this location yet, add a new list with one scheme
130+
m_location_to_schemes_map.emplace_back(loc_id, std::vector<TestingScheme>(1, scheme));
131+
}
132+
else {
133+
//add scheme to existing vector if the scheme doesn't exist yet
134+
auto& schemes = iter_schemes->second;
135+
if (std::find(schemes.begin(), schemes.end(), scheme) == schemes.end()) {
136+
schemes.push_back(scheme);
137+
}
128138
}
129139
}
130140

131141
void TestingStrategy::remove_testing_scheme(const LocationId& loc_id, const TestingScheme& scheme)
132142
{
133-
auto& schemes_vector = m_location_to_schemes_map[loc_id];
134-
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
135-
schemes_vector.erase(last, schemes_vector.end());
143+
auto iter_schemes =
144+
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_id](auto& p) {
145+
return p.first == loc_id;
146+
});
147+
if (iter_schemes != m_location_to_schemes_map.end()) {
148+
//remove the scheme from the list
149+
auto& schemes_vector = iter_schemes->second;
150+
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
151+
schemes_vector.erase(last, schemes_vector.end());
152+
//delete the list of schemes for this location if no schemes left
153+
if (schemes_vector.empty()) {
154+
m_location_to_schemes_map.erase(iter_schemes);
155+
}
156+
}
136157
}
137158

138159
void TestingStrategy::update_activity_status(TimePoint t)
@@ -152,16 +173,22 @@ bool TestingStrategy::run_strategy(Person::RandomNumberGenerator& rng, Person& p
152173
return true;
153174
}
154175

155-
// Combine two vectors of schemes at corresponding location and location stype
156-
std::vector<TestingScheme>* schemes_vector[] = {
157-
&m_location_to_schemes_map[LocationId{location.get_index(), location.get_type()}],
158-
&m_location_to_schemes_map[LocationId{INVALID_LOCATION_INDEX, location.get_type()}]};
159-
160-
for (auto vec_ptr : schemes_vector) {
161-
if (!std::all_of(vec_ptr->begin(), vec_ptr->end(), [&rng, &person, t](TestingScheme& ts) {
162-
return !ts.is_active() || ts.run_scheme(rng, person, t);
163-
})) {
164-
return false;
176+
//lookup schemes for this specific location as well as the location type
177+
//lookup in std::vector instead of std::map should be much faster unless for large numbers of schemes
178+
for (auto loc_key : {LocationId{location.get_index(), location.get_type()},
179+
LocationId{INVALID_LOCATION_INDEX, location.get_type()}}) {
180+
auto iter_schemes =
181+
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [loc_key](auto& p) {
182+
return p.first == loc_key;
183+
});
184+
if (iter_schemes != m_location_to_schemes_map.end()) {
185+
//apply all testing schemes that are found
186+
auto& schemes = iter_schemes->second;
187+
if (!std::all_of(schemes.begin(), schemes.end(), [&rng, &person, t](TestingScheme& ts) {
188+
return !ts.is_active() || ts.run_scheme(rng, person, t);
189+
})) {
190+
return false;
191+
}
165192
}
166193
}
167194
return true;

cpp/models/abm/testing_strategy.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#ifndef EPI_ABM_TESTING_SCHEME_H
2121
#define EPI_ABM_TESTING_SCHEME_H
2222

23+
#include "abm/config.h"
2324
#include "abm/parameters.h"
2425
#include "abm/person.h"
2526
#include "abm/location.h"
@@ -91,7 +92,7 @@ class TestingCriteria
9192
bool evaluate(const Person& p, TimePoint t) const;
9293

9394
private:
94-
std::unordered_set<size_t> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
95+
std::bitset<MAX_NUM_AGE_GROUPS> m_ages; ///< Set of #AgeGroup%s that are either allowed or required to be tested.
9596
std::bitset<(size_t)InfectionState::Count>
9697
m_infection_states; /**< BitSet of #InfectionState%s that are either allowed or required to
9798
be tested.*/
@@ -221,7 +222,7 @@ class TestingStrategy
221222
bool run_strategy(Person::RandomNumberGenerator& rng, Person& person, const Location& location, TimePoint t);
222223

223224
private:
224-
std::unordered_map<LocationId, std::vector<TestingScheme>>
225+
std::vector<std::pair<LocationId, std::vector<TestingScheme>>>
225226
m_location_to_schemes_map; ///< Set of schemes that are checked for testing.
226227
};
227228

cpp/models/abm/world.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#ifndef EPI_ABM_WORLD_H
2121
#define EPI_ABM_WORLD_H
2222

23+
#include "abm/config.h"
2324
#include "abm/location_type.h"
2425
#include "abm/parameters.h"
2526
#include "abm/location.h"
@@ -55,14 +56,15 @@ class World
5556

5657
/**
5758
* @brief Create a World.
58-
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World.
59+
* @param[in] num_agegroups The number of AgeGroup%s in the simulated World. Must be less than MAX_NUM_AGE_GROUPS.
5960
*/
6061
World(size_t num_agegroups)
6162
: parameters(num_agegroups)
6263
, m_trip_list()
6364
, m_use_migration_rules(true)
6465
, m_cemetery_id(add_location(LocationType::Cemetery))
6566
{
67+
assert(num_agegroups < MAX_NUM_AGE_GROUPS && "MAX_NUM_AGE_GROUPS exceeded.");
6668
}
6769

6870
/**

cpp/tests/test_abm_lockdown_rules.cpp

+20-8
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ TEST(TestLockdownRules, school_closure)
5353
p2.set_assigned_location(school);
5454
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
5555
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
56-
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
56+
params.get<mio::abm::AgeGroupGotoSchool>() = false;
57+
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
5758
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
58-
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
59+
params.get<mio::abm::AgeGroupGotoWork>() = false;
60+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
61+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
5962
mio::abm::set_school_closure(t, 0.7, params);
6063

6164
auto p1_rng = mio::abm::Person::RandomNumberGenerator(rng, p1);
@@ -88,9 +91,12 @@ TEST(TestLockdownRules, school_opening)
8891
p.set_assigned_location(school);
8992
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
9093
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
91-
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
94+
params.get<mio::abm::AgeGroupGotoSchool>() = false;
95+
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
9296
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
93-
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
97+
params.get<mio::abm::AgeGroupGotoWork>() = false;
98+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
99+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
94100
mio::abm::set_school_closure(t_closing, 1., params);
95101
mio::abm::set_school_closure(t_opening, 0., params);
96102

@@ -110,9 +116,12 @@ TEST(TestLockdownRules, home_office)
110116
mio::abm::Parameters params(num_age_groups);
111117

112118
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
113-
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
119+
params.get<mio::abm::AgeGroupGotoSchool>() = false;
120+
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
114121
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
115-
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
122+
params.get<mio::abm::AgeGroupGotoWork>() = false;
123+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
124+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
116125

117126
mio::abm::set_home_office(t, 0.4, params);
118127

@@ -164,9 +173,12 @@ TEST(TestLockdownRules, no_home_office)
164173
p.set_assigned_location(work);
165174
mio::abm::Parameters params = mio::abm::Parameters(num_age_groups);
166175
// Set the age group the can go to school is AgeGroup(1) (i.e. 5-14)
167-
params.get<mio::abm::AgeGroupGotoSchool>() = {age_group_5_to_14};
176+
params.get<mio::abm::AgeGroupGotoSchool>() = false;
177+
params.get<mio::abm::AgeGroupGotoSchool>()[age_group_5_to_14] = true;
168178
// Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 or 35-59)
169-
params.get<mio::abm::AgeGroupGotoWork>() = {age_group_15_to_34, age_group_35_to_59};
179+
params.get<mio::abm::AgeGroupGotoWork>() = false;
180+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_15_to_34] = true;
181+
params.get<mio::abm::AgeGroupGotoWork>()[age_group_35_to_59] = true;
170182

171183
mio::abm::set_home_office(t_closing, 0.5, params);
172184
mio::abm::set_home_office(t_opening, 0., params);

0 commit comments

Comments
 (0)