Skip to content

Commit a8d00b1

Browse files
committed
Add support for None bond anchors
Add test for single topology end-states
1 parent bb07be7 commit a8d00b1

File tree

4 files changed

+135
-14
lines changed

4 files changed

+135
-14
lines changed

tests/test_dummy.py

+63-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import pytest
22
from rdkit import Chem
33

4-
from timemachine.fe.dummy import generate_dummy_group_assignments
4+
from timemachine.fe.dummy import (
5+
ZeroBondAnchorWarning,
6+
generate_anchored_dummy_group_assignments,
7+
generate_dummy_group_assignments,
8+
)
59
from timemachine.graph_utils import convert_to_nx
610

711
# These tests check the various utilities used to turn off interactions
@@ -92,4 +96,61 @@ def test_generate_dummy_group_assignments():
9296
def test_generate_dummy_group_assignments_empty_core():
9397
g = convert_to_nx(Chem.MolFromSmiles("OC1COO1"))
9498
core = []
95-
assert list(generate_dummy_group_assignments(g, core)) == []
99+
100+
with pytest.warns(ZeroBondAnchorWarning):
101+
dgas = list(generate_dummy_group_assignments(g, core))
102+
103+
assert equivalent_assignment(dgas, [{None: {0, 1, 2, 3, 4}}])
104+
105+
106+
def test_generate_dummy_group_assignments_full_core():
107+
g = convert_to_nx(Chem.MolFromSmiles("OC1COO1"))
108+
core = [0, 1, 2, 3, 4]
109+
dgas = list(generate_dummy_group_assignments(g, core))
110+
assert equivalent_assignment(dgas, [{}])
111+
112+
113+
def test_generate_angle_anchor_dummy_group_assignments():
114+
# Test that if we break a core-core bond, we only have one valid
115+
# choice of the angle anchor
116+
#
117+
# O0 O0
118+
# | |
119+
# C1 C1
120+
# / \ \
121+
# O4 C2 -> O4 C2
122+
# \ / \ /
123+
# O3 O3
124+
g_a = convert_to_nx(Chem.MolFromSmiles("OC1COO1"))
125+
g_b = convert_to_nx(Chem.MolFromSmiles("OCCOO"))
126+
core_a = [1, 2, 3, 4]
127+
core_b = [1, 2, 3, 4]
128+
129+
dgas = list(generate_dummy_group_assignments(g_a, core_a))
130+
expected_dga = {1: {0}}
131+
132+
assert equivalent_assignment(dgas, [expected_dga])
133+
134+
# forward direction
135+
anchored_dummy_group_assignments = generate_anchored_dummy_group_assignments(expected_dga, g_a, g_b, core_a, core_b)
136+
137+
anchored_dummy_group_assignments = list(anchored_dummy_group_assignments)
138+
assert len(anchored_dummy_group_assignments) == 1
139+
assert anchored_dummy_group_assignments[0] == {1: (2, frozenset({0}))}
140+
141+
# reverse direction
142+
anchored_dummy_group_assignments = generate_anchored_dummy_group_assignments(expected_dga, g_b, g_a, core_b, core_a)
143+
144+
anchored_dummy_group_assignments = list(anchored_dummy_group_assignments)
145+
assert len(anchored_dummy_group_assignments) == 1
146+
assert anchored_dummy_group_assignments[0] == {1: (2, frozenset({0}))}
147+
148+
# Test that providing an empty core results in None values for both the bond anchor
149+
# and the angle anchor
150+
anchored_dummy_group_assignments = generate_anchored_dummy_group_assignments(
151+
{None: {0, 1, 2, 3, 4}}, g_a, g_b, core_a, core_b
152+
)
153+
154+
anchored_dummy_group_assignments = list(anchored_dummy_group_assignments)
155+
156+
assert anchored_dummy_group_assignments == [{None: (None, frozenset({0, 1, 2, 3, 4}))}]

tests/test_single_topology.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
NBParamIdx,
2121
)
2222
from timemachine.fe import atom_mapping, single_topology
23-
from timemachine.fe.dummy import MultipleAnchorWarning, canonicalize_bond
23+
from timemachine.fe.dummy import MultipleBondAnchorWarning, canonicalize_bond
2424
from timemachine.fe.interpolate import align_nonbonded_idxs_and_params, linear_interpolation
2525
from timemachine.fe.single_topology import (
2626
AtomMapMixin,
@@ -36,6 +36,7 @@
3636
setup_dummy_interactions_from_ff,
3737
)
3838
from timemachine.fe.system import minimize_scipy, simulate_system
39+
from timemachine.fe.topology import BaseTopology
3940
from timemachine.fe.utils import get_mol_name, get_romol_conf, read_sdf, read_sdf_mols_by_name, set_mol_name
4041
from timemachine.ff import Forcefield
4142
from timemachine.md import minimizer
@@ -330,7 +331,7 @@ def test_find_dummy_groups_and_multiple_anchors():
330331

331332
core_pairs = np.array([[1, 1], [2, 2]])
332333

333-
with pytest.warns(MultipleAnchorWarning):
334+
with pytest.warns(MultipleBondAnchorWarning):
334335
dgs = single_topology.find_dummy_groups_and_anchors(mol_a, mol_b, core_pairs[:, 0], core_pairs[:, 1])
335336
assert dgs == {1: (2, {0})} or dgs == {2: (1, {0})}
336337

@@ -353,7 +354,7 @@ def test_find_dummy_groups_and_multiple_anchors():
353354
core_a = [0, 1, 2, 3]
354355
core_b = [2, 1, 4, 3]
355356

356-
with pytest.warns(MultipleAnchorWarning):
357+
with pytest.warns(MultipleBondAnchorWarning):
357358
dgs = single_topology.find_dummy_groups_and_anchors(mol_a, mol_b, core_a, core_b)
358359
assert dgs == {1: (2, {0})}
359360

@@ -1743,3 +1744,48 @@ def test_hif2a_end_state_symmetry_nightly_test():
17431744
print("testing", mol_a.GetProp("_Name"), "->", mol_b.GetProp("_Name"))
17441745
core = atom_mapping.get_cores(mol_a, mol_b, **DEFAULT_ATOM_MAPPING_KWARGS)[0]
17451746
assert_symmetric_interpolation(mol_a, mol_b, core)
1747+
1748+
1749+
def test_empty_core():
1750+
# test that an empty core results in an end-state with identical energies and forces to independent molecules
1751+
with path_to_internal_file("timemachine.testsystems.data", "ligands_40.sdf") as path_to_ligand:
1752+
mols = read_sdf(path_to_ligand)
1753+
1754+
mol_a = mols[0]
1755+
mol_b = mols[1]
1756+
1757+
x_a = get_romol_conf(mol_a)
1758+
x_b = get_romol_conf(mol_b)
1759+
core = np.zeros((0, 2))
1760+
1761+
ff = Forcefield.load_default()
1762+
st = SingleTopology(mol_a, mol_b, core, ff)
1763+
lhs = st.setup_intermediate_state(0.0)
1764+
rhs = st.setup_intermediate_state(1.0)
1765+
1766+
x_0 = st.combine_confs(x_a, x_b, 0.0)
1767+
x_1 = st.combine_confs(x_a, x_b, 1.0)
1768+
1769+
np.testing.assert_array_equal(x_0, x_1)
1770+
1771+
# lhs and rhs do not have identical total energies since dummy molecule is in a softened state
1772+
lhs_U_fn = lhs.get_U_fn()
1773+
rhs_U_fn = rhs.get_U_fn()
1774+
1775+
assert lhs_U_fn(x_0) != rhs_U_fn(x_0)
1776+
1777+
# however, the bond, angle, improper energies should be bitwise identical
1778+
assert lhs.bond(x_0, None) == rhs.bond(x_0, None)
1779+
assert lhs.angle(x_0, None) == rhs.angle(x_0, None)
1780+
assert lhs.improper(x_0, None) == rhs.improper(x_0, None)
1781+
1782+
# we should also be consistent with vanilla base topologies
1783+
bt_a = BaseTopology(mol_a, ff)
1784+
bt_b = BaseTopology(mol_b, ff)
1785+
1786+
ref_a = bt_a.setup_end_state()
1787+
ref_b = bt_b.setup_end_state()
1788+
1789+
np.testing.assert_almost_equal(lhs.bond(x_0, None), ref_a.bond(x_a, None) + ref_b.bond(x_b, None))
1790+
np.testing.assert_almost_equal(lhs.angle(x_0, None), ref_a.angle(x_a, None) + ref_b.angle(x_b, None))
1791+
np.testing.assert_almost_equal(lhs.improper(x_0, None), ref_a.improper(x_a, None) + ref_b.improper(x_b, None))

timemachine/fe/dummy.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@
77
import networkx as nx
88

99

10-
class MultipleAnchorWarning(UserWarning):
10+
class MultipleBondAnchorWarning(UserWarning):
11+
pass
12+
13+
14+
class ZeroBondAnchorWarning(UserWarning):
1115
pass
1216

1317

1418
def generate_dummy_group_assignments(
1519
bond_graph: nx.Graph, core_atoms: Collection[int]
1620
) -> Iterator[dict[int, frozenset[int]]]:
17-
"""Returns an iterator over dummy group assignments (i.e., candidate partitionings of dummy atoms with each
21+
"""Returns an iterator over all possible dummy group assignments (i.e., candidate partitionings of dummy atoms with each
1822
partition assigned a bond anchor atom) for a given molecule (represented as a bond graph) and set of core atoms.
1923
2024
A dummy group is a set of dummy atoms that are inserted or deleted in alchemical free energy calculations. The
@@ -72,10 +76,16 @@ def generate_dummy_group_assignments(
7276

7377
def get_bond_anchors(dummy_group):
7478
bond_anchors = {n for dummy_atom in dummy_group for n in bond_graph.neighbors(dummy_atom) if n in core_atoms}
75-
if len(bond_anchors) > 1:
79+
if len(bond_anchors) == 0:
80+
warnings.warn(
81+
f"No bond anchors found for dummy group: {dummy_group}",
82+
ZeroBondAnchorWarning,
83+
)
84+
bond_anchors = set([None])
85+
elif len(bond_anchors) > 1:
7686
warnings.warn(
7787
f"Multiple bond anchors {bond_anchors} found for dummy group: {dummy_group}",
78-
MultipleAnchorWarning,
88+
MultipleBondAnchorWarning,
7989
)
8090
return bond_anchors
8191

@@ -138,11 +148,16 @@ def generate_anchored_dummy_group_assignments(
138148
each element is a mapping from bond anchor atom to the pair (angle anchor atom, dummy group)
139149
"""
140150

151+
assert len(core_atoms_a) == len(core_atoms_b)
152+
141153
core_bonds_c = get_core_bonds(bond_graph_a.edges(), bond_graph_b.edges(), core_atoms_a, core_atoms_b)
142154
c_to_b = {c: b for c, b in enumerate(core_atoms_b)}
143155
core_bonds_b = frozenset(translate_bonds(core_bonds_c, c_to_b))
144156

145157
def get_angle_anchors(bond_anchor):
158+
if bond_anchor is None:
159+
return [None]
160+
146161
valid_angle_anchors = [
147162
angle_anchor
148163
for angle_anchor in [n for n in bond_graph_b.neighbors(bond_anchor) if n in core_atoms_b]

timemachine/fe/single_topology.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def setup_dummy_bond_and_chiral_interactions(
144144
root_anchor_atom: int,
145145
core_atoms: NDArray,
146146
):
147-
assert root_anchor_atom in core_atoms
147+
assert root_anchor_atom is None or root_anchor_atom in core_atoms
148148

149149
dummy_group_arr = np.array(list(dummy_group))
150150

@@ -286,7 +286,7 @@ def setup_dummy_interactions(
286286
(bonded_idxs, bonded_params)
287287
Returns bonds, angles, and improper idxs and parameters.
288288
"""
289-
assert root_anchor_atom in core_atoms
289+
assert root_anchor_atom is None or root_anchor_atom in core_atoms
290290

291291
dummy_angle_idxs = []
292292
dummy_angle_params = []
@@ -664,9 +664,8 @@ def concatenate(arrays, empty_shape, empty_dtype):
664664
)
665665

666666
num_atoms = mol_a.GetNumAtoms() + mol_b.GetNumAtoms() - len(core)
667-
assert get_num_connected_components(num_atoms, bond_potential.potential.idxs) == 1, (
668-
"hybrid molecule has multiple connected components"
669-
)
667+
if get_num_connected_components(num_atoms, bond_potential.potential.idxs) > 1:
668+
warnings.warn("Hybrid molecule has multiple connected components")
670669

671670
return GuestSystem(
672671
bond=bond_potential,

0 commit comments

Comments
 (0)