Skip to content

Commit 179aca2

Browse files
committed
Fixed some bugs in how the levers were printing to the screen, plus fixed a bug
in the new functionality where a default lever for a force was not being picked up (i.e. `ghost/ghost::*` was not taking precedence over, e.g. `*::charge`) I've updated the unit tests to catch these edge cases
1 parent 5867df0 commit 179aca2

File tree

2 files changed

+115
-18
lines changed

2 files changed

+115
-18
lines changed

corelib/src/libs/SireCAS/lambdaschedule.cpp

+57-18
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,12 @@ QString LambdaSchedule::toString() const
238238
auto keys = this->stage_equations[i].keys();
239239
std::sort(keys.begin(), keys.end());
240240

241-
for (auto lever : keys)
241+
for (const auto &lever : keys)
242242
{
243+
auto output_name = lever;
244+
output_name.replace("*::", "");
243245
lines.append(QString(" %1: %2")
244-
.arg(lever.replace("*::", ""))
246+
.arg(output_name)
245247
.arg(this->stage_equations[i][lever].toOpenMMString()));
246248
}
247249
}
@@ -956,20 +958,35 @@ SireCAS::Expression LambdaSchedule::_getEquation(int stage,
956958
CODELOC);
957959

958960
const auto default_lever = _get_lever_name("*", lever);
961+
const auto default_force = _get_lever_name(force, "*");
962+
const auto lever_name = _get_lever_name(force, lever);
959963

960-
if (force == "*")
964+
const auto equations = this->stage_equations[stage];
965+
966+
// search from most specific to least specific
967+
auto it = equations.find(lever_name);
968+
969+
if (it != equations.end())
961970
{
962-
return this->stage_equations[stage].value(
963-
default_lever, this->default_equations[stage]);
971+
return it.value();
964972
}
965-
else
973+
974+
it = equations.find(default_force);
975+
976+
if (it != equations.end())
977+
{
978+
return it.value();
979+
}
980+
981+
it = equations.find(default_lever);
982+
983+
if (it != equations.end())
966984
{
967-
return this->stage_equations[stage].value(
968-
_get_lever_name(force, lever),
969-
this->stage_equations[stage].value(
970-
default_lever,
971-
this->default_equations[stage]));
985+
return it.value();
972986
}
987+
988+
// we don't have any match, so return the default equation for this stage
989+
return this->default_equations[stage];
973990
}
974991

975992
/** Return the equation used to control the specified 'lever'
@@ -1142,15 +1159,33 @@ QHash<QString, QVector<double>> LambdaSchedule::getLeverValues(
11421159
QVector<double> values(lambda_values.count(), NAN);
11431160

11441161
QHash<QString, QVector<double>> lever_values;
1145-
lever_values.reserve(this->lever_names.count() + 1);
1162+
1163+
// get all of the lever / force combinations in use
1164+
QSet<QString> all_levers;
1165+
1166+
for (const auto &equations : this->stage_equations)
1167+
{
1168+
for (const auto &lever : equations.keys())
1169+
{
1170+
all_levers.insert(lever);
1171+
}
1172+
}
1173+
1174+
QStringList levers = all_levers.values();
1175+
std::sort(levers.begin(), levers.end());
1176+
1177+
lever_values.reserve(levers.count() + 2);
11461178

11471179
lever_values.insert("λ", values);
11481180

11491181
lever_values.insert("default", values);
11501182

1151-
for (const auto &lever_name : this->lever_names)
1183+
for (const auto &lever : levers)
11521184
{
1153-
lever_values.insert(lever_name, values);
1185+
if (lever.startsWith("*::"))
1186+
lever_values.insert(lever.mid(3), values);
1187+
else
1188+
lever_values.insert(lever, values);
11541189
}
11551190

11561191
if (this->nStages() == 0)
@@ -1174,12 +1209,16 @@ QHash<QString, QVector<double>> LambdaSchedule::getLeverValues(
11741209
const auto equation = this->default_equations[stage];
11751210
lever_values["default"][i] = equation(input_values);
11761211

1177-
for (const auto &lever_name : lever_names)
1212+
for (const auto &lever : levers)
11781213
{
1179-
const auto equation = this->stage_equations[stage].value(
1180-
lever_name, this->default_equations[stage]);
1214+
auto parts = lever.split("::");
1215+
1216+
const auto equation = this->_getEquation(stage, parts[0], parts[1]);
11811217

1182-
lever_values[lever_name][i] = equation(input_values);
1218+
if (lever.startsWith("*::"))
1219+
lever_values[lever.mid(3)][i] = equation(input_values);
1220+
else
1221+
lever_values[lever][i] = equation(input_values);
11831222
}
11841223
}
11851224

tests/cas/test_lambdaschedule.py

+58
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,64 @@ def _assert_same_equation(x, eq1, eq2):
1010
assert eq1.evaluate(val) == pytest.approx(eq2.evaluate(val), 1e-5)
1111

1212

13+
def test_charge_scale():
14+
l = sr.cas.LambdaSchedule.standard_morph()
15+
16+
morph_equation = l.get_equation(stage="morph")
17+
18+
l.add_charge_scale_stages()
19+
20+
_assert_same_equation(l.lam(), l.get_equation(stage="morph"), morph_equation)
21+
22+
assert l.get_stages() == ["decharge", "morph", "recharge"]
23+
24+
assert l.get_levers() == ["charge"]
25+
26+
assert l.get_constant("γ") == 0.2
27+
28+
gamma = l.get_constant_symbol("γ")
29+
30+
scaled_morph = gamma * morph_equation
31+
32+
_assert_same_equation(
33+
l.lam(), l.get_equation(stage="morph", lever="charge"), scaled_morph
34+
)
35+
36+
_assert_same_equation(
37+
gamma, l.get_equation(stage="morph", lever="charge"), scaled_morph
38+
)
39+
40+
l.set_equation(stage="recharge", force="ghost/ghost", lever="charge", equation=0.5)
41+
42+
l.set_equation(stage="*", force="ghost/ghost", lever="*", equation=1.5)
43+
44+
_assert_same_equation(
45+
l.lam(),
46+
l.get_equation(stage="recharge", force="ghost/ghost", lever="charge"),
47+
sr.cas.Expression(0.5),
48+
)
49+
50+
_assert_same_equation(
51+
l.lam(),
52+
l.get_equation(stage="recharge", lever="LJ"),
53+
(1.0 - ((1.0 - gamma) * (1.0 - l.lam()))) * l.final(),
54+
)
55+
56+
_assert_same_equation(l.lam(), l.get_equation(stage="recharge"), l.final())
57+
58+
_assert_same_equation(
59+
l.lam(),
60+
l.get_equation(stage="recharge", force="ghost/non-ghost", lever="LJ"),
61+
l.final(),
62+
)
63+
64+
_assert_same_equation(
65+
l.lam(),
66+
l.get_equation(stage="recharge", force="ghost/ghost", lever="LJ"),
67+
sr.cas.Expression(1.5),
68+
)
69+
70+
1371
def test_lambdaschedule():
1472
l = sr.cas.LambdaSchedule.standard_morph()
1573

0 commit comments

Comments
 (0)