Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine selection of atoms in REST region using simple heuristic #1524

Merged
merged 23 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions tests/rest/test_single_topology_rest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence
from dataclasses import replace
from functools import cache

import jax
Expand All @@ -11,14 +13,20 @@
from timemachine.fe import atom_mapping
from timemachine.fe.free_energy import HostConfig
from timemachine.fe.rbfe import Host, setup_optimized_host
from timemachine.fe.rest.interpolation import Exponential, Linear, Quadratic, Symmetric, plot_interpolation_fxn
from timemachine.fe.rest.single_topology import InterpolationFxnName, SingleTopologyREST
from timemachine.fe.rest.interpolation import (
Exponential,
InterpolationFxnName,
Linear,
Quadratic,
Symmetric,
plot_interpolation_fxn,
)
from timemachine.fe.rest.single_topology import SingleTopologyREST
from timemachine.fe.single_topology import SingleTopology
from timemachine.fe.system import GuestSystem
from timemachine.fe.utils import get_romol_conf, read_sdf_mols_by_name
from timemachine.ff import Forcefield
from timemachine.md import builders
from timemachine.potentials import PeriodicTorsion
from timemachine.utils import path_to_internal_file

with path_to_internal_file("timemachine.testsystems.fep_benchmark.hif2a", "ligands.sdf") as ligands_path:
Expand Down Expand Up @@ -71,9 +79,14 @@ def test_single_topology_rest_vacuum(mol_pair, temperature_scale_interpolation_f
st = get_single_topology(mol_a, mol_b, core)
st_rest = get_single_topology_rest(mol_a, mol_b, core, 2.0, temperature_scale_interpolation_fxn)

# NOTE: This assertion is not guaranteed to hold in general (i.e. the REST region may be empty, or the whole
# combined ligand), but it does hold for typical cases, including the edges tested here. The stronger assertion here
# ensures that later assertions (e.g. that we only soften interactions in the REST region) are not trivially true.
assert 0 < len(st_rest.rest_region_atom_idxs) < st_rest.get_num_atoms()

state = st_rest.setup_intermediate_state(lamb)
state_ref = st.setup_intermediate_state(lamb)
assert len(st_rest.target_propers) < len(state_ref.proper.potential.idxs)
assert len(st_rest.candidate_propers) < len(state_ref.proper.potential.idxs)

ligand_conf = st.combine_confs(get_romol_conf(mol_a), get_romol_conf(mol_b))

Expand All @@ -99,23 +112,33 @@ def test_single_topology_rest_vacuum(mol_pair, temperature_scale_interpolation_f
assert energy_scale < 1.0

if has_rotatable_bonds or has_aliphatic_rings:
assert 0 < len(st_rest.target_propers)
assert 0 < len(st_rest.candidate_propers)

if energy_scale < 1.0:
assert U_proper < U_proper_ref

def get_proper_subset_energy(state: GuestSystem, ixn_idxs):
def compute_proper_energy(state: GuestSystem, ixn_idxs: Sequence[int]):
assert state.proper
idxs = state.proper.potential.idxs[ixn_idxs, :]
params = state.proper.params[ixn_idxs, :]
potential = PeriodicTorsion(idxs).bind(params)
return potential(ligand_conf, None)

U_proper_subset = get_proper_subset_energy(state, st_rest.target_proper_idxs)
U_proper_subset_ref = get_proper_subset_energy(state_ref, st_rest.target_proper_idxs)
np.testing.assert_allclose(U_proper_subset, energy_scale * U_proper_subset_ref)

np.testing.assert_allclose(U_nonbonded, energy_scale * U_nonbonded_ref)
proper = replace(
state.proper,
potential=replace(state.proper.potential, idxs=state.proper.potential.idxs[ixn_idxs, :]),
params=state.proper.params[ixn_idxs, :],
)
return proper(ligand_conf, None)

# check that propers in the REST region are scaled appropriately
rest_proper_idxs = st_rest.target_proper_idxs
U_proper_rest = compute_proper_energy(state, rest_proper_idxs)
U_proper_rest_ref = compute_proper_energy(state_ref, rest_proper_idxs)
np.testing.assert_allclose(U_proper_rest, energy_scale * U_proper_rest_ref)

# check that propers outside of the REST region are not scaled
num_propers = len(st_rest.propers)
rest_complement_proper_idxs = set(range(num_propers)) - set(rest_proper_idxs)
rest_complement_proper_idxs = list(rest_complement_proper_idxs)
U_proper_complement = compute_proper_energy(state, rest_complement_proper_idxs)
U_proper_complement_ref = compute_proper_energy(state_ref, rest_complement_proper_idxs)
np.testing.assert_array_equal(U_proper_complement, U_proper_complement_ref)


@cache
Expand Down Expand Up @@ -143,16 +166,28 @@ def test_single_topology_rest_solvent(mol_pair, temperature_scale_interpolation_
ligand_conf = st.combine_confs(get_romol_conf(mol_a), get_romol_conf(mol_b))
conf = np.concatenate([host.conf, ligand_conf])

def get_nonbonded_host_guest_ixn_energy(st: SingleTopology):
def compute_host_guest_ixn_energy(st: SingleTopology, ligand_idxs: set[int]):
hgs = st.combine_with_host(host.system, lamb, host_config.num_water_atoms, st.ff, host_config.omm_topology)
return hgs.nonbonded_ixn_group(conf, host_config.box)

U = get_nonbonded_host_guest_ixn_energy(st_rest)
U_ref = get_nonbonded_host_guest_ixn_energy(st)

num_atoms_host = host.system.nonbonded_all_pairs.potential.num_atoms
ligand_idxs_ = np.array(list(ligand_idxs), dtype=np.int32) + num_atoms_host
nonbonded_ixn_group = replace(
hgs.nonbonded_ixn_group,
potential=replace(hgs.nonbonded_ixn_group.potential, row_atom_idxs=ligand_idxs_),
)
return nonbonded_ixn_group(conf, host_config.box)

# check that interactions involving atoms in the REST region are scaled appropriately
U = compute_host_guest_ixn_energy(st_rest, st_rest.rest_region_atom_idxs)
U_ref = compute_host_guest_ixn_energy(st, st_rest.rest_region_atom_idxs)
energy_scale = st_rest.get_energy_scale_factor(lamb)
np.testing.assert_allclose(U, energy_scale * U_ref, rtol=1e-5)

# check that interactions involving atoms outside of the REST region are not scaled
rest_complement_atom_idxs = set(range(st_rest.get_num_atoms())) - st_rest.rest_region_atom_idxs
U_complement = compute_host_guest_ixn_energy(st_rest, rest_complement_atom_idxs)
U_complement_ref = compute_host_guest_ixn_energy(st, rest_complement_atom_idxs)
np.testing.assert_array_equal(U_complement, U_complement_ref)


def get_mol(smiles: str):
mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
Expand All @@ -170,17 +205,17 @@ def test_single_topology_rest_propers():
# benzene: no propers are scaled
benzene = get_mol("c1ccccc1")
st = get_identity_transformation(benzene)
assert st.target_propers == set()
assert len(st.candidate_propers) == 0

# cyclohexane: all 9 * 6 ring propers are scaled (|{H1, H2, C1}-C2-C3-{C4, H3, H4}| = 9 propers per C-C bond)
cyclohexane = get_mol("C1CCCCC1")
st = get_identity_transformation(cyclohexane)
assert len(st.target_propers) == 9 * 6
assert len(set(st.candidate_propers.values())) == 9 * 6

# phenylcyclohexane: all 9 * 6 cyclohexane ring propers and 6 rotatable bond propers are scaled
phenylcyclohexane = get_mol("c1ccc(C2CCCCC2)cc1")
st = get_identity_transformation(phenylcyclohexane)
assert len(st.target_propers) == 9 * 6 + 6
assert len(set(st.candidate_propers.values())) == 9 * 6 + 6


@pytest.mark.parametrize(
Expand Down
77 changes: 63 additions & 14 deletions timemachine/fe/rest/single_topology.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from dataclasses import replace
from dataclasses import astuple, replace
from functools import cached_property

import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray
from openmm import app
from rdkit import Chem

from timemachine.constants import NBParamIdx
from timemachine.fe.single_topology import SingleTopology
from timemachine.fe.single_topology import AlignedPotential, SingleTopology
from timemachine.fe.system import GuestSystem, HostGuestSystem, HostSystem
from timemachine.ff import Forcefield

Expand Down Expand Up @@ -79,37 +80,82 @@ def __init__(
max_temperature_scale, temperature_scale_interpolation
)

@cached_property
def rest_region_atom_idxs(self) -> set[int]:
"""Returns the set of indices of atoms in the combined ligand that are in the REST region.

Here the REST region is defined to include combined ligand atoms involved in bond, angle, or improper torsion
interactions that differ in the end states. Note that proper torsions are omitted from this heuristic as this
tends to result in larger REST regions than seem desirable.
"""

aligned_potentials: list[AlignedPotential] = [
self.aligned_bond,
self.aligned_angle,
self.aligned_improper,
]

idxs = {
int(idx)
for aligned in aligned_potentials
for idxs, params_a, params_b in zip(aligned.idxs, aligned.src_params, aligned.dst_params)
if not np.all(params_a == params_b)
for idx in idxs # type: ignore[attr-defined]
}

# Ensure all dummy atoms are included in the REST region
idxs |= self.get_dummy_atoms_a()
idxs |= self.get_dummy_atoms_b()

return idxs

@cached_property
def aliphatic_ring_bonds(self) -> set[CanonicalBond]:
"""Returns the set of aliphatic ring bonds in the combined ligand."""
ring_bonds_a = {bond.translate(self.a_to_c) for bond in get_aliphatic_ring_bonds(self.mol_a)}
ring_bonds_b = {bond.translate(self.b_to_c) for bond in get_aliphatic_ring_bonds(self.mol_b)}
ring_bonds_c = ring_bonds_a | ring_bonds_b
return ring_bonds_c

@cached_property
def rotatable_bonds(self) -> set[CanonicalBond]:
"""Returns the set of rotatable bonds in the combined ligand."""
rotatable_bonds_a = {bond.translate(self.a_to_c) for bond in get_rotatable_bonds(self.mol_a)}
rotatable_bonds_b = {bond.translate(self.b_to_c) for bond in get_rotatable_bonds(self.mol_b)}
rotatable_bonds_c = rotatable_bonds_a | rotatable_bonds_b
return rotatable_bonds_c

@cached_property
def propers(self) -> list[CanonicalProper]:
"""Returns a list of proper torsions in the combined ligand."""
# TODO: refactor SingleTopology to compute src and dst alignment at initialization
return [mkproper(*idxs) for idxs in super().setup_intermediate_state(0.0).proper.potential.idxs]

@cached_property
def target_proper_idxs(self) -> list[int]:
return [
idx
def candidate_propers(self) -> dict[int, CanonicalProper]:
"""Returns a dict of propers in the combined ligand, keyed on index, that are candidates for softening."""
return {
idx: proper
for idx, proper in enumerate(self.propers)
for bond in [mkbond(proper.j, proper.k)]
if bond in self.rotatable_bonds or bond in self.aliphatic_ring_bonds
]
}

@cached_property
def target_propers(self) -> set[CanonicalProper]:
return {self.propers[i] for i in self.target_proper_idxs}
def target_propers(self) -> dict[int, CanonicalProper]:
"""Returns a dict of propers in the combined ligand, keyed on index, that are candidates for softening and
involve an atom in the REST region."""
return {
idx: proper
for (idx, proper) in self.candidate_propers.items()
if any(idx in self.rest_region_atom_idxs for idx in astuple(proper))
}

@cached_property
def target_proper_idxs(self) -> list[int]:
"""Returns a list of indices of propers in the combined ligand that are candidates for softening and involve an
atom in the REST region."""
return list(self.target_propers.keys())

def get_energy_scale_factor(self, lamb: float) -> float:
temperature_factor = float(self._temperature_scale_interpolation_fxn(lamb))
Expand Down Expand Up @@ -146,25 +192,28 @@ def combine_with_host(
) -> HostGuestSystem:
ref_state = super().combine_with_host(host_system, lamb, num_water_atoms, ff, omm_topology)

num_host_atoms = host_system.nonbonded_all_pairs.params.shape[0]
# compute indices corresponding to REST-region ligand atoms in the host-guest interaction potential
num_atoms_host = host_system.nonbonded_all_pairs.potential.num_atoms
rest_region_atom_idxs = np.array(sorted(self.rest_region_atom_idxs)) + num_atoms_host

# NOTE: the following methods of scaling the ligand-environment interaction energy are all equivalent:
#
# 1. scaling ligand charges and LJ epsilons by energy_scale
# 2. scaling environment charges and LJ epsilons by energy_scale
# 3. scaling all charges and LJ epsilons by sqrt(energy_scale)
#
# Here, we choose (1) because water sampling infers parameters from the NonbondedInteractionGroup.Changing
# the environment parameters prevents easy construction of equivalent parameters for water sampling, which
# leads to incorrect sampling.
# However, (2) and (3) are incompatible with the current water sampling implementation, which assumes that the
# parameters corresponding to water atoms are identical in the host-host all-pairs potential and the host-guest
# interaction group potential. Therefore we choose option (1).

energy_scale = self.get_energy_scale_factor(lamb)

nonbonded_host_guest_ixn = replace(
ref_state.nonbonded_ixn_group,
params=jnp.asarray(ref_state.nonbonded_ixn_group.params)
.at[num_host_atoms:, NBParamIdx.Q_IDX]
.at[rest_region_atom_idxs, NBParamIdx.Q_IDX]
.mul(energy_scale) # scale ligand charges
.at[num_host_atoms:, NBParamIdx.LJ_EPS_IDX]
.at[rest_region_atom_idxs, NBParamIdx.LJ_EPS_IDX]
.mul(energy_scale), # scale ligand epsilons
)

Expand Down
58 changes: 55 additions & 3 deletions timemachine/fe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,13 @@ def plot_atom_mapping_grid(
)


type _Core = Sequence[Sequence[int]] | NDArray


def view_atom_mapping_3d(
mol_a: Chem.rdchem.Mol,
mol_b: Chem.rdchem.Mol,
cores: Sequence[Sequence[Sequence[int]]] | NDArray = (),
cores: Sequence[_Core] | NDArray = (),
colors: Sequence[str] = (
# https://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12
"#a6cee3",
Expand Down Expand Up @@ -335,8 +338,12 @@ def view_atom_mapping_3d(
for core in cores:
assert np.asarray(core).ndim == 2, "expect a list of cores"

make_style = lambda props: {"stick": props}
atom_style = lambda color: make_style({"color": color})
def make_style(props):
return {"stick": props}

def atom_style(color):
return make_style({"color": color})

dummy_style = atom_style("white")

num_rows = 1 + len(cores)
Expand Down Expand Up @@ -382,6 +389,51 @@ def add_mol(mol, viewer):
return view


def view_rest_region_3d(
mol_a: Chem.rdchem.Mol,
mol_b: Chem.rdchem.Mol,
rest_region_atom_idxs_a: Sequence[int],
rest_region_atom_idxs_b: Sequence[int],
show_atom_idx_labels: bool = False,
):
try:
import py3Dmol
except ImportError as e:
raise RuntimeError("requires py3Dmol to be installed") from e

def make_style(props):
return {"stick": props}

def atom_style(color):
return make_style({"color": color})

view = py3Dmol.view(viewergrid=(2, 2))

def add_mol(mol, viewer):
view.addModel(Chem.MolToMolBlock(mol), "mol", viewer=viewer)

add_mol(mol_a, (0, 0))
add_mol(mol_b, (0, 1))
view.setStyle(make_style({}))

add_mol(mol_a, (1, 0))
view.setStyle(atom_style("white"), viewer=(1, 0))
for idx in rest_region_atom_idxs_a:
view.setStyle({"serial": idx}, {"stick": {"color": "red"}}, viewer=(1, 0))

add_mol(mol_b, (1, 1))
view.setStyle(atom_style("white"), viewer=(1, 1))
for idx in rest_region_atom_idxs_b:
view.setStyle({"serial": idx}, atom_style("red"), viewer=(1, 1))

view.zoomTo()

if show_atom_idx_labels:
view.addPropertyLabels("serial", "", {"alignment": "center", "fontSize": 10})

return view


def get_romol_bonds(mol):
"""
Return bond idxs given a mol. These are not canonicalized.
Expand Down