Skip to content

Flatten() + expand() operations in the RDKit exporter scale poorly for large polymer systems #62

@joelaforet

Description

@joelaforet

Problem

The current MuPT -> RDKit -> SDF workflow flattens the MuPT hierarchy, exports one monolithic RDKit molecule, writes one SDF record, reloads it, and then calls Chem.GetMolFrags(..., asMols=True) to recover the original chains.

There are two slow operations in this workflow. First, Primitive.flattened() copies the tree and repeatedly calls expand(...); each expand(...) ends with full consistency validation via check_self_consistent() (which is expensive for large trees). Second, the OpenFF handoff workflow has to rediscover chains after flattening by calling Chem.GetMolFrags(..., asMols=True).

def flatten(self) -> None:
    while (target_handles := self.expandable_children):
        for child_handle in target_handles:
            self.expand(child_handle)

def flattened(self) -> 'Primitive':
    clone_primitive = self.copy()
    clone_primitive.flatten()
    return clone_primitive
# Convert RDKit mols to OpenFF Molecules, parameterize, and bundle into Interchange
# ref: https://github.com/MuPT-hub/mupt-examples/blob/main/examples_system/mdfiles_with_openff.ipynb
from rich.progress import track
import logging
# logging.basicConfig(level=logging.DEBUG, force=True)

from functools import reduce


ff = ForceField('openff-2.2.1.offxml')

mol_incs : list[Interchange] = []
mols = Chem.GetMolFrags(mol, asMols=True) #<----- THIS STEP IS SLOW!!!!
for chain in track(mols, description='Adding molecules to OpenFF Topology', auto_refresh=False):
    offmol = Molecule.from_rdkit(chain, allow_undefined_stereo=True, hydrogens_are_explicit=True)
    offmol.assign_partial_charges(partial_charge_method='openff-gnn-am1bcc-1.0.0.pt') # partial_charge_method='gasteiger')
    mol_incs.append(
        ff.create_interchange(offmol.to_topology(), charge_from_molecules=[offmol])
    )

logging.warning('Combining Interchanges per-molecule into master Interchange')
inc : Interchange = reduce(Interchange.combine, mol_incs) # NOTE: found significantly faster to parameterize individually and combine, rather than parameterizing the full system at once
logging.warning('Successfully created Interchange for entire chemical system')

On the same polymer used in multiscale_chain_placement.ipynb, this legacy path scales superlinearly and becomes dominated by repeated MuPT tree expansion/validation plus RDKit fragment generation after chain identity has been discarded.

Proposed Solution

The role-aware SDF writer avoids both costs by traversing the unflattened SAAMR hierarchy and streaming one SDF record per chain/segment; the implementation is the streaming write_primitive_to_sdf(...) path, which also preserves residue/chain metadata in SDF atom-property lists.

The benchmark below isolates the MuPT/RDKit/SDF interface, excluding OpenFF conversion because both paths produce the same number of molecules and atoms for OpenFF. The largest monodisperse PES/PSU case tested here has 32 chains, 576 residues, and 16,182 atoms. The legacy interface takes 468.7 s and peaks at 122.8 MB traced Python allocation, while the role-aware path takes 3.15 s and peaks at 6.54 MB. A log-log fit gives legacy_flattened: T ~= O(N^1.8096) and role_aware: T ~= O(N^0.9672), consistent with the legacy path doing repeated global work while the new path is approximately linear in atom count.

Benchmark code

Full benchmark script: benchmark_rdkit_export_workflows.py

NOTE: The above script should be dropped into a directory one above your mupt install. I assumed that one would test this in their clone of the mupt-examples repo.

git clone https://github.com/MuPT-hub/mupt-examples # this repo
cd mupt-examples

<place-benchmark-script-in-this-folder>

mamba env create -f conda-envs/release-env.yml
mamba activate mupt-env
git clone https://github.com/MuPT-hub/mupt # the toolkit 
cd mupt
pip install .
git checkout issue-51/muptio-sdf-export

If you already have mupt-examples on your machine, and the mupt-env set up, just checkout the muptio-sdf-export branch.

mamba run -n mupt-env python benchmark_rdkit_export_workflows.py \
  --sizes all \
  --outdir examples_system/rdkit_export_benchmark_issue

This writes:

rdkit_export_benchmark_issue/rdkit_export_benchmark.csv
rdkit_export_benchmark_issue/rdkit_export_benchmark.png

rdkit_export_benchmark.csv

Image

Runtime data

case (chains x monomers) atoms legacy_s role_aware_s speedup legacy_peak_mb role_aware_peak_mb memory_reduction
2x5 288 0.647 0.060 10.8x 2.30 0.30 7.6x
4x8 888 2.359 0.187 12.6x 6.92 0.61 11.4x
8x12 2,736 10.554 0.540 19.5x 21.18 1.62 13.0x
16x15 6,774 56.020 1.256 44.6x 52.57 3.09 17.0x
32x18 16,182 468.653 3.147 148.9x 122.81 6.54 18.8x

Power-law fits, discarding the first point as a small-system flat-cost outlier:

legacy_flattened: T ~= O(N^(1.8096)) (prefactor ~= 8.467e-06)
role_aware:      T ~= O(N^(0.9672)) (prefactor ~= 0.0002583)

The 800 35-mer ionomer I was testing has roughly 330,000 atoms. If we use the above fits to estimate how long it would take we get the following:

legacy_flattened: T $\approx$ 8.467e-06 * (330,000)^1.8096 (seconds)
yields T $\approx$ 82041.3 seconds -> 22.8 hours

vs

role_aware: T $\approx$ 0.0002583 * (330,000)^0.9672
yields T $\approx$ 56 seconds

Profiling code

The same script can dump cProfile output and CSV call tables for individual stages:

mamba run -n mupt-env python examples_system/benchmark_rdkit_export_workflows.py \
  --profile-stage legacy-flatten \
  --profile-n-chains 8 \
  --profile-chain-len 12 \
  --outdir examples_system/rdkit_export_profile_issue

mamba run -n mupt-env python examples_system/benchmark_rdkit_export_workflows.py \
  --profile-stage legacy-split \
  --profile-n-chains 8 \
  --profile-chain-len 12 \
  --outdir examples_system/rdkit_export_profile_issue

This writes .prof, _flat.csv, and _callgraph.csv files. The flattened-path hot stack is:

Primitive.flattened
  -> Primitive.flatten
    -> Primitive.expand
      -> Primitive.check_self_consistent
        -> Primitive.check_connectors
          -> Primitive.check_internal_connection_references_valid
            -> Primitive.check_internally_connectable
              -> Primitive.fetch_connector_on_child

For the 8x12 profile, Primitive.expand is called 104 times and check_self_consistent is called 208 times. That validation cascade produces 292,644 calls to check_internally_connectable and 591,656 calls to fetch_connector_on_child. The second legacy bottleneck is Chem.GetMolFrags(..., asMols=True), which is a native RDKit call that must rediscover and copy chains from the monolithic molecule; its isolated runtime grows from 0.004 s at 288 atoms to 251.6 s at 16,182 atoms. The role-aware writer does not call flatten, expand, or GetMolFrags; it streams segment records directly via write_primitive_to_mupt_sdf(...).

Metadata

Metadata

Assignees

No one assigned

    Labels

    optimizationEnhancement which speeds up some existing functionalitypriority:medium

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions