Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mupt/interfaces/rdkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''Interfaces between the hierarchical MuPT molecular representation and RDKit Mol objects'''

__author__ = 'Timotej Bernat'
__email__ = 'timotej.bernat@colorado.edu'
__author__ = 'Timotej Bernat, Joseph R. Laforet Jr.'
__email__ = 'timotej.bernat@colorado.edu, jola3134@colorado.edu'


from .selection import (
Expand All @@ -26,7 +26,11 @@
connectors_from_rdkit,
)
from .importers import primitive_from_rdkit
from .exporters import primitive_to_rdkit
from .exporters import primitive_to_rdkit, primitive_to_rdkit_mols
from .strategies import (
RDKitExportStrategy,
AllAtomRDKitExportStrategy,
)
from .depiction import (
set_rdkdraw_size,
show_substruct_highlights,
Expand Down Expand Up @@ -54,4 +58,4 @@
set_rdkdraw_size(400, aspect=3/2)
show_atom_indices()
show_substruct_highlights()
disable_kekulized_drawing()
disable_kekulized_drawing()
188 changes: 184 additions & 4 deletions mupt/interfaces/rdkit/exporters.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
'''Writers which convert the MuPT molecular representation out to RDKit Mols'''

__author__ = 'Timotej Bernat'
__email__ = 'timotej.bernat@colorado.edu'
__author__ = 'Timotej Bernat, Joseph R. Laforet Jr.'
__email__ = 'timotej.bernat@colorado.edu, jola3134@colorado.edu'


from collections.abc import Iterator
from typing import Optional

import numpy as np

from rdkit.Chem.rdchem import (
Atom,
Bond,
Mol,
RWMol,
Conformer,
AtomPDBResidueInfo,
)
from rdkit.Geometry import Point3D

from .rdprops import RDPropType, assign_property_to_rdobj
from .labelling import RDMOL_NAME_WRITE_PROP
Expand All @@ -23,6 +25,7 @@
from ...chemistry.conversion import element_to_rdkit_atom
from ...mupr.connection import Connector
from ...mupr.primitives import Primitive, PrimitiveHandle
from .strategies import AllAtomRDKitExportStrategy, RDKitExportStrategy, RDKitMolData


def rdkit_atom_from_atomic_primitive(atomic_primitive : Primitive) -> Atom:
Expand All @@ -44,7 +47,12 @@ def primitive_to_rdkit(
'''
Convert a Primitive hierarchy to an RDKit Mol
Will return as single Mol instance, even is underlying Primitive represents a collection of multiple disconnected molecules


DEV: This is the legacy flattened exporter. We recommend replacing downstream
workflows with primitive_to_rdkit_mols() and removing this path after reviewer
approval, rather than abstracting shared bond or metadata helpers around code
that is likely to be retired.

Will set spatial positions for each atom ("default_atom_position" if not assigned per atom) to a Conformer bound to the returned Mol
'''
if default_atom_position is None:
Expand Down Expand Up @@ -132,4 +140,176 @@ def primitive_to_rdkit(
mol.SetProp(RDMOL_NAME_WRITE_PROP, primitive.label)

return mol


PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
Copy link
Copy Markdown
Contributor Author

@joelaforet joelaforet May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this block is added so that exported systems have valid PDB Chain identifiers. Problems arise when one creates a topology from a file that has unspecified chain information. By having this "wrap around" logic, we ensure each residue has a unique chain/residue index. One chain per "segment" only allows for 26 unique segments in the topology. This messes with downstream analysis, since there are potentially multiple matches for something like "Chain A Residue 1". The wraparound gives us 26 x 9999 ~= 260,000 unique residues in the system, where a residue is the same concept as a repeat_unit. In the future, we may consider switching to PDBx format exclusively, but this works fine for now.

PDB_MAX_RESIDUE_NUMBER = 9999


def _pdb_chain_and_resid(global_residue_idx: int) -> tuple[str, int]:
"""Return PDB-compliant chain/residue identifiers for a global residue index."""
chain_idx, resid_offset = divmod(global_residue_idx, PDB_MAX_RESIDUE_NUMBER)
if chain_idx >= len(PDB_CHAIN_IDS):
raise ValueError(
"Role-aware RDKit export exceeded PDB chain/residue capacity "
f"({len(PDB_CHAIN_IDS) * PDB_MAX_RESIDUE_NUMBER} residues). "
"Use a topology format with larger identifier fields."
)
return PDB_CHAIN_IDS[chain_idx], resid_offset + 1


def _atom_pdb_name(atom: Primitive, atom_idx_in_residue: int) -> str:
"""Return a PDB-width atom name from element and residue-local index."""
atom_name = f"{atom.element.symbol}{atom_idx_in_residue + 1}"
if len(atom.element.symbol) == 1:
return f" {atom_name:<3}"
return f"{atom_name:<4}"


def _add_rdkit_atoms(
mol: RWMol,
conf: Conformer,
data: RDKitMolData,
segment_idx: int,
first_residue_idx: int,
) -> None:
"""Insert RDKit atoms, positions, and per-atom MuPT metadata."""
residue_atom_counts: dict[int, int] = {}

for atom_idx, atom_prim in enumerate(data.atoms):
rdkit_atom = rdkit_atom_from_atomic_primitive(atom_prim)
local_resid = data.atom_resids[atom_idx]
chain_id, resid = _pdb_chain_and_resid(first_residue_idx + local_resid - 1)
atom_idx_in_residue = residue_atom_counts.get(local_resid, 0)
residue_atom_counts[local_resid] = atom_idx_in_residue + 1

pdb_info = AtomPDBResidueInfo(
atomName=_atom_pdb_name(atom_prim, atom_idx_in_residue),
serialNumber=atom_idx + 1,
residueName=data.atom_resnames[atom_idx],
residueNumber=resid,
chainId=chain_id,
isHeteroAtom=True,
)
if data.atom_insertion_codes[atom_idx]:
pdb_info.SetInsertionCode(data.atom_insertion_codes[atom_idx])
rdkit_atom.SetMonomerInfo(pdb_info)
rdkit_atom.SetProp("residue_name", data.atom_resnames[atom_idx])
rdkit_atom.SetIntProp("residue_id", resid)
rdkit_atom.SetProp("chain_id", chain_id)
rdkit_atom.SetIntProp("mupt_segment_index", segment_idx)
rdkit_atom.SetProp("mupt_segment_label", str(data.segment.label))
rdkit_atom.SetIntProp("mupt_residue_index", local_resid)
rdkit_atom.SetProp("mupt_residue_label", data.atom_residue_labels[atom_idx])
rdkit_atom.SetIntProp("mupt_particle_index", atom_idx)
rdkit_atom.SetProp("mupt_particle_label", data.atom_particle_labels[atom_idx])

idx = mol.AddAtom(rdkit_atom)
pos = data.atom_positions[atom_idx]
conf.SetAtomPosition(idx, Point3D(float(pos[0]), float(pos[1]), float(pos[2])))


def _add_rdkit_bonds(mol: RWMol, data: RDKitMolData) -> None:
"""Insert internal bonds and merged bond metadata."""
for (idx1, idx2), (parent, conn_refs) in zip(data.bonds, data.bond_refs):
conn1 = parent.fetch_connector_on_child(conn_refs[0])
conn2 = parent.fetch_connector_on_child(conn_refs[1])
mol.AddBond(idx1, idx2, order=conn1.bondtype)
bond = mol.GetBondBetweenAtoms(idx1, idx2)
bond_metadata: dict[str, RDPropType] = {
**conn1.metadata,
**conn2.metadata,
}
for bond_key, bond_value in bond_metadata.items():
assign_property_to_rdobj(bond, bond_key, bond_value, preserve_type=True)


def _add_rdkit_linkers(mol: RWMol, conf: Conformer, data: RDKitMolData) -> None:
"""Insert linker atoms and their bond metadata."""
for atom_idx, parent, conn_ref in data.linker_refs:
conn = parent.fetch_connector_on_child(conn_ref)
anchor_atom = mol.GetAtomWithIdx(atom_idx)
linker_idx = mol.AddAtom(Atom(0))
anchor_pdb_info = anchor_atom.GetPDBResidueInfo()
linker_atom = mol.GetAtomWithIdx(linker_idx)
if anchor_pdb_info is not None:
pdb_info = AtomPDBResidueInfo(
atomName=" * ",
serialNumber=linker_idx + 1,
residueName=anchor_pdb_info.GetResidueName(),
residueNumber=anchor_pdb_info.GetResidueNumber(),
chainId=anchor_pdb_info.GetChainId(),
isHeteroAtom=True,
)
if anchor_pdb_info.GetInsertionCode():
pdb_info.SetInsertionCode(anchor_pdb_info.GetInsertionCode())
linker_atom.SetMonomerInfo(pdb_info)
linker_atom.SetProp("residue_name", anchor_atom.GetProp("residue_name"))
linker_atom.SetIntProp("residue_id", anchor_atom.GetIntProp("residue_id"))
linker_atom.SetProp("chain_id", anchor_atom.GetProp("chain_id"))
mol.AddBond(atom_idx, linker_idx, order=conn.bondtype)
conf.SetAtomPosition(
linker_idx,
Point3D(
float(conn.linker.position[0]),
float(conn.linker.position[1]),
float(conn.linker.position[2]),
),
)
bond = mol.GetBondBetweenAtoms(atom_idx, linker_idx)
for bond_key, bond_value in conn.metadata.items():
assign_property_to_rdobj(bond, bond_key, bond_value, preserve_type=True)


def _apply_rdkit_mol_metadata(mol: RWMol, data: RDKitMolData, root_metadata: dict) -> None:
"""Attach root and segment metadata to one RDKit Mol."""
assign_property_to_rdobj(mol, 'origin', TOOLKIT_NAME, preserve_type=True)
if data.segment.label is not None:
mol.SetProp(RDMOL_NAME_WRITE_PROP, str(data.segment.label))
for key, value in root_metadata.items():
assign_property_to_rdobj(mol, key, value, preserve_type=True)
for key, value in data.segment.metadata.items():
assign_property_to_rdobj(mol, key, value, preserve_type=True)


def _mol_from_rdkit_data(
data: RDKitMolData,
segment_idx: int,
first_residue_idx: int,
root_metadata: dict,
) -> Mol:
"""Build an RDKit Mol from collected role-aware topology data."""
mol = RWMol()
conf = Conformer(len(data.atoms) + len(data.linker_refs))
_add_rdkit_atoms(mol, conf, data, segment_idx, first_residue_idx)
_add_rdkit_bonds(mol, data)
_add_rdkit_linkers(mol, conf, data)
_apply_rdkit_mol_metadata(mol, data, root_metadata)

mol.AddConformer(conf, assignId=True)
final_mol = Mol(mol)
final_mol.UpdatePropertyCache(strict=True)
return final_mol


def primitive_to_rdkit_mols(
primitive: Primitive,
resname_map: dict[str, str],
default_atom_position: Optional[np.ndarray[Shape[3], float]] = None,
strategy: Optional[RDKitExportStrategy] = None,
) -> Iterator[Mol]:
"""Yield one RDKit Mol per segment from a role-annotated Primitive hierarchy."""
if strategy is None:
strategy = AllAtomRDKitExportStrategy(default_atom_position=default_atom_position)
strategy.validate(primitive)

first_residue_idx = 0
for segment_idx, data in enumerate(strategy.iter_mol_data(primitive, resname_map=resname_map)):
yield _mol_from_rdkit_data(
data,
segment_idx,
first_residue_idx,
root_metadata=primitive.metadata,
)
first_residue_idx += max(data.atom_resids, default=0)

143 changes: 143 additions & 0 deletions mupt/interfaces/rdkit/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Strategy implementations for MuPT -> RDKit export."""

__author__ = "Joseph R. Laforet Jr."
__email__ = "jola3134@colorado.edu"

from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Optional

import numpy as np

from ...mupr.embedding import ConnectorReference
from ...mupr.primitives import Primitive
from .._shared.topology import (
_pdb_resname,
build_saamr_role_topology_index,
connector_reference_sort_key,
iter_saamr_residue_records,
resolve_to_atom_cached,
)


@dataclass
class RDKitMolData:
"""Container for one segment's RDKit-exportable topology data."""

segment: Primitive
atoms: list[Primitive] = field(default_factory=list)
atom_positions: list[np.ndarray] = field(default_factory=list)
atom_resnames: list[str] = field(default_factory=list)
atom_insertion_codes: list[str] = field(default_factory=list)
atom_residue_labels: list[str] = field(default_factory=list)
atom_particle_labels: list[str] = field(default_factory=list)
atom_resids: list[int] = field(default_factory=list)
bonds: list[tuple[int, int]] = field(default_factory=list)
bond_refs: list[
tuple[Primitive, tuple[ConnectorReference, ConnectorReference]]
] = field(default_factory=list)
linker_refs: list[tuple[int, Primitive, ConnectorReference]] = field(default_factory=list)


class RDKitExportStrategy(ABC):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, added for completeness in case we want to export non-AA systems to RDKit. I don't know if that is needed, and this may be overkill. Happy to refactor to remove the ABC if needed.

"""Abstract strategy for collecting RDKit-exportable topology data."""

@abstractmethod
def validate(self, root: Primitive) -> None:
"""Validate role assignment and hierarchy preconditions for export."""

@abstractmethod
def iter_mol_data(self, root: Primitive, resname_map: dict[str, str]) -> Iterator[RDKitMolData]:
"""Yield one topology dataset per RDKit Mol to build."""

@property
@abstractmethod
def label(self) -> str:
"""Human-readable name for this strategy."""


class AllAtomRDKitExportStrategy(RDKitExportStrategy):
"""Role-aware all-atom RDKit export strategy."""

def __init__(self, default_atom_position: Optional[np.ndarray] = None) -> None:
if default_atom_position is None:
self.default_atom_position = np.array([0.0, 0.0, 0.0], dtype=float)
else:
default_atom_position = np.asarray(default_atom_position, dtype=float)
if default_atom_position.shape != (3,):
raise ValueError('default_atom_position must be a 3-dimensional vector')
self.default_atom_position = default_atom_position

@property
def label(self) -> str:
"""Human-readable strategy name."""
return "All-atom"

def validate(self, root: Primitive) -> None:
"""Validate role assignments needed for all-atom RDKit export."""
build_saamr_role_topology_index(root)

def iter_mol_data(self, root: Primitive, resname_map: dict[str, str]) -> Iterator[RDKitMolData]:
"""Yield one RDKit topology dataset per SEGMENT-role node."""
index = build_saamr_role_topology_index(root)
endpoint_cache: dict[tuple[int, object, object], Primitive] = {}
residue_records_by_segment = {id(segment): [] for segment in index.segments}
for residue_record in iter_saamr_residue_records(index):
residue_records_by_segment[id(residue_record.segment)].append(residue_record)

for segment in index.segments:
data = RDKitMolData(segment=segment)
atom_id_to_local: dict[int, int] = {}

for residue_record in residue_records_by_segment[id(segment)]:
resname = _pdb_resname(residue_record.residue.label, resname_map)
for atom in residue_record.particles:
atom_id_to_local[id(atom)] = len(data.atoms)
data.atoms.append(atom)
if atom.shape is not None:
data.atom_positions.append(atom.shape.centroid)
else:
data.atom_positions.append(self.default_atom_position)
data.atom_resnames.append(resname)
data.atom_insertion_codes.append(str(residue_record.residue.metadata.get("pdb_insertion_code", "")))
data.atom_residue_labels.append(str(residue_record.residue.label))
data.atom_particle_labels.append(str(atom.label))
data.atom_resids.append(residue_record.residue_idx)

bonds_set: set[tuple[int, int]] = set()
for node in index.bond_nodes_by_segment[id(segment)]:
for conn_ref_pair in node.internal_connections:
conn_ref1, conn_ref2 = sorted(
conn_ref_pair,
key=connector_reference_sort_key,
)
atom1 = resolve_to_atom_cached(node, conn_ref1, endpoint_cache)
atom2 = resolve_to_atom_cached(node, conn_ref2, endpoint_cache)
idx1 = atom_id_to_local[id(atom1)]
idx2 = atom_id_to_local[id(atom2)]
bond_pair = tuple(sorted((idx1, idx2)))
if bond_pair in bonds_set:
raise ValueError(
"Multiple MuPT internal connections resolve to the same "
f"RDKit atom pair {bond_pair} in SEGMENT '{segment.label}'. "
"Role-aware export cannot choose which connector metadata "
"to preserve."
)

data.bonds.append(bond_pair)
data.bond_refs.append((node, (conn_ref1, conn_ref2)))
bonds_set.add(bond_pair)

for conn_ref in segment.external_connectors.values():
atom = resolve_to_atom_cached(segment, conn_ref, endpoint_cache)
data.linker_refs.append((atom_id_to_local[id(atom)], segment, conn_ref))

if data.bonds:
sorted_bonds = sorted(
zip(data.bonds, data.bond_refs), key=lambda pair: pair[0]
)
data.bonds = [bond for bond, _ in sorted_bonds]
data.bond_refs = [bond_ref for _, bond_ref in sorted_bonds]

yield data
Loading