Conversation
📝 WalkthroughWalkthroughThe pull request adds comprehensive ligand support to a protein dataset pipeline. Dependencies were updated (pymol package variant and rdkit added), dataset parsing extended to extract ligands, the dataset class modified to optionally include ligands in training data, and the encoder updated to process ligand features via learned projections. Tests were expanded with ligand parsing and integration validation. Changes
Sequence DiagramssequenceDiagram
participant User
participant Dataset as ProteinWaterDataset
participant Parser as parse_asu_with_biotite()
participant Encoder as CachedEmbeddingEncoder
participant Output as Graph Data
User->>Dataset: init(include_ligands=True)
Dataset->>Parser: parse_asu_with_biotite(pdb_path)
Parser-->>Dataset: (protein_atoms, water_atoms, ligand_atoms)
alt include_ligands=True & ligands exist
Dataset->>Dataset: Append ligand coordinates to protein nodes
Dataset->>Dataset: Set is_ligand mask & residue_index=-1 for ligands
end
Dataset-->>User: Graph data with optional ligand nodes
User->>Encoder: forward(data)
alt is_ligand mask present & any True
Encoder->>Encoder: ligand_embed(protein.x[ligand_mask])
Encoder->>Encoder: Replace cached embeddings for ligand rows
end
Encoder-->>Output: Embeddings with learned ligand projections
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes The changes introduce new feature logic (ligand extraction and encoding) across multiple interconnected files with heterogeneous modifications. Dataset parsing returns an additional value type, preprocessing conditionally appends ligand data with sentinel markers, and the encoder adds learnable projection logic with conditional execution paths. Comprehensive test coverage adds validation complexity. The variety of changes and logic density across dataset, encoder, and test layers warrants substantial review effort. Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR extends dataset preprocessing to include non-water ligands in the static structure representation, folding ligand atoms into the existing protein node type and tagging them via an is_ligand boolean mask.
Changes:
- Update PDB parsing to return
(protein_atoms, water_atoms, ligand_atoms)and (optionally) append ligands to protein nodes viainclude_ligands. - Add
protein.is_ligandto cached geometry andHeteroDatafor downstream masking/conditioning. - Add integration tests and a new ligand-containing PDB fixture (
4h0b) to validate parsing and dataset behavior.
Reviewed changes
Copilot reviewed 5 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/dataset.py |
Adds ligand parsing/output, include_ligands option, and persists is_ligand in caches and HeteroData. |
src/encoder_base.py |
Adds a learned projection for ligand rows when using cached-embedding encoders (ESM/SLAE). |
tests/test_dataset.py |
Expands parsing tests for 3-way partitioning and adds dataset integration tests for ligand inclusion. |
tests/conftest.py |
Adds a pdb_4h0b fixture for ligand test coverage. |
pyproject.toml |
Adds rdkit and swaps the PyMOL package reference to pymol-open-source-whl. |
uv.lock |
Updates locked dependency set to reflect dependency changes (incl. rdkit/jaxtyping/pymol package rename). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| protein_mask = bts.filter_amino_acids(atoms) | ||
| water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT") | ||
| ligand_mask = ~protein_mask & ~water_mask | ||
|
|
||
| protein_atoms = atoms[protein_mask] | ||
| water_atoms = atoms[water_mask] | ||
| ligand_atoms = atoms[ligand_mask] |
There was a problem hiding this comment.
ligand_mask = ~protein_mask & ~water_mask will classify any non-amino-acid, non-water atoms as ligands, including non-protein polymers recorded as ATOM (e.g., DNA/RNA) and other non-hetero records. This contradicts the docstring claim of “non-water HETATM records” and can silently mix non-ligand atoms into the ligand set. Consider restricting ligands to atoms.hetero (and excluding HOH/WAT), or otherwise explicitly encoding the intended selection criteria.
| def parse_asu_with_biotite( | ||
| path: str, | ||
| ) -> tuple[bts.AtomArray, bts.AtomArray]: | ||
| ) -> tuple[bts.AtomArray, bts.AtomArray, bts.AtomArray]: | ||
| """ |
There was a problem hiding this comment.
parse_asu_with_biotite() now returns a 3-tuple, which is a breaking API change. There are existing call sites (e.g., embedding generation scripts) that still unpack only 2 values and will crash at runtime unless updated to accept/ignore the new ligand_atoms return value.
| # Ligand atoms get residue_index = -1 (sentinel; no residue embedding) | ||
| ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long) |
There was a problem hiding this comment.
Assigning residue_index = -1 for ligand atoms will break any downstream code that uses residue_index for scatter/gather (e.g., residue pooling via torch_scatter), since those ops generally require non-negative indices. Either mask out ligand atoms wherever residue pooling/indexing is used, or map ligands to a valid non-negative index (and update num_residues accordingly) to keep the tensor safe for scatter-based pooling.
| # Ligand atoms get residue_index = -1 (sentinel; no residue embedding) | |
| ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long) | |
| # Assign ligands to valid residue indices after existing residues | |
| max_res_idx = ( | |
| final_protein_res_idx.max().item() | |
| if final_protein_res_idx.numel() > 0 | |
| else -1 | |
| ) | |
| ligand_res_idx = torch.full( | |
| (len(ligand_atoms),), | |
| max_res_idx + 1, | |
| dtype=torch.long, | |
| ) |
| "pymol-open-source-whl>=3.1.0.4", | ||
| "scipy", | ||
| "pandas", | ||
| "numpy", | ||
| "rdkit", | ||
| "matplotlib", |
There was a problem hiding this comment.
rdkit is added as a required runtime dependency, but it isn’t referenced anywhere in the repo (src/scripts/tests). Since RDKit significantly increases install size and can complicate builds on some platforms, consider removing it until it’s actually needed, or making it an optional extra for ligand-specific workflows.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/test_dataset.py (1)
566-697: Add one cached-embedding ligand regression test.These cases only exercise
encoder_type="gvp", but the new ligand-specific behavior lives inCachedEmbeddingEncoder.forward()foresm/slae. A tiny fixture-backed test that asserts ligand rows are replaced would cover the path most likely to regress.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_dataset.py` around lines 566 - 697, Add a small integration test in tests/test_dataset.py that exercises CachedEmbeddingEncoder.forward for encoder_type="esm" (or "slae") to prevent regressions: construct a ProteinWaterDataset with include_ligands=True and preprocess=True using a fixture-backed cached-embedding file (or reuse the existing pdb fixtures), then fetch data = ds[0], locate ligand rows via data["protein"].is_ligand, and assert that those ligand-row embeddings in data["protein"].x match the expected cached embeddings (i.e., are replaced by the cached values) and differ from the original one-hot element encoding; reference CachedEmbeddingEncoder.forward and ProteinWaterDataset to find the encoder flow. Ensure the test is small, uses the tmp_path fixture for any temp files, and only targets the cached-embedding path for esm/slae.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/dataset.py`:
- Around line 73-81: The current ligand selection uses ligand_mask =
~protein_mask & ~water_mask which includes any non-protein/non-water ATOM
records; instead restrict ligands to HETATM entries by using the Biotite flag
atoms.hetero: update ligand_mask to also require atoms.hetero (e.g., ligand_mask
= (~protein_mask & ~water_mask) & atoms.hetero) so ligand_atoms contains only
non-protein, non-water HETATM records; keep protein_mask, water_mask,
protein_atoms, water_atoms assignments unchanged and return protein_atoms,
water_atoms, ligand_atoms.
- Around line 1070-1078: _pool_by_residue will fail if residue_index contains -1
for ligands because scatter operations (scatter_mean, scatter_add, scatter_max)
cannot accept negative indices; before any scatter keyed by residue_index (in
_pool_by_residue and any other pooling guarded by pool_residue), filter out
ligand atoms using the is_ligand mask (e.g., mask = ~is_ligand) and pass atom
embeddings and residue_index masked arrays (atom_embed[mask],
residue_index[mask]) to scatter with dim_size=num_residues; apply the same
masking wherever residue_index is used for pooling to avoid negative-index
errors.
In `@src/encoder_base.py`:
- Around line 182-190: The ligand_embed Linear is being created lazily in
forward(), so its parameters are not registered when optimizers/DDP/FSDP capture
model parameters; move the nn.Linear initialization into __init__ by creating
self.ligand_embed there (using the given embedding_dim) and remove the
forward-time instantiation, or if embedding_dim truly isn’t known at
construction, add an explicit initialize_ligand_embed(embedding_dim) method that
constructs and registers self.ligand_embed before the optimizer is created (call
it in setup code), ensuring parameters are registered with model.parameters();
as a last-resort alternative, if you cannot construct before optimizer, call
optimizer.add_param_group(...) with the new parameters immediately after
creation (note this still may not work with DDP/FSDP), and update references to
ligand_embed in forward() to assume it already exists.
---
Nitpick comments:
In `@tests/test_dataset.py`:
- Around line 566-697: Add a small integration test in tests/test_dataset.py
that exercises CachedEmbeddingEncoder.forward for encoder_type="esm" (or "slae")
to prevent regressions: construct a ProteinWaterDataset with
include_ligands=True and preprocess=True using a fixture-backed cached-embedding
file (or reuse the existing pdb fixtures), then fetch data = ds[0], locate
ligand rows via data["protein"].is_ligand, and assert that those ligand-row
embeddings in data["protein"].x match the expected cached embeddings (i.e., are
replaced by the cached values) and differ from the original one-hot element
encoding; reference CachedEmbeddingEncoder.forward and ProteinWaterDataset to
find the encoder flow. Ensure the test is small, uses the tmp_path fixture for
any temp files, and only targets the cached-embedding path for esm/slae.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: abe0f086-5c49-47da-8543-bf16a9391a11
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (6)
pyproject.tomlsrc/dataset.pysrc/encoder_base.pytests/conftest.pytests/test_dataset.pytests/test_files/4h0b/4h0b_final.pdb
| protein_mask = bts.filter_amino_acids(atoms) | ||
| water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT") | ||
| ligand_mask = ~protein_mask & ~water_mask | ||
|
|
||
| protein_atoms = atoms[protein_mask] | ||
| water_atoms = atoms[water_mask] | ||
| ligand_atoms = atoms[ligand_mask] | ||
|
|
||
| return protein_atoms, water_atoms | ||
| return protein_atoms, water_atoms, ligand_atoms |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In Biotite's PDB parser, does get_structure()preserve ATOM vs HETATM viaAtomArray.hetero, and is atoms.hetero & ~water_mask the correct way to isolate non-water HETATM ligands?
💡 Result:
Yes, Biotite's PDB parser via get_structure preserves the ATOM vs HETATM distinction using the AtomArray.hetero annotation, which is set to True for HETATM records and False for ATOM records. This is evidenced by the Atom constructor accepting a hetero parameter (result 5, result 14), and examples where hetero=True is explicitly set for added HETATM atoms, with the parser maintaining this in loaded structures (results 1,4,7,8). Additionally, Biotite provides struc.filter_solvent(atoms) to create a water_mask (results 18,19,22), typically identifying water molecules (e.g., res_name "HOH"). Thus, atoms.hetero & ~water_mask correctly isolates non-water HETATM atoms (hetero=True and not solvent), which are typically ligands or other hetero compounds.
Citations:
- 1: https://www.biotite-python.org/latest/apidoc/biotite.structure.io.pdb.get_structure.html
- 2: https://www.biotite-python.org/latest/apidoc/biotite.structure.html
- 3: https://www.biotite-python.org/latest/apidoc/biotite.structure.io.pdb.html
- 4: https://stackoverflow.com/questions/58783906/add-heteroatom-to-pdb-file
- 5: https://www.biotite-python.org/latest/tutorial/structure/atoms.html
- 6: https://www.biotite-python.org/latest/tutorial/structure/io.html
- 7: https://www.biotite-python.org/latest/apidoc/biotite.structure.io.pdb.PDBFile.html
- 8: https://www.biotite-python.org/latest/apidoc/biotite.structure.AtomArray.html
- 9: https://www.biotite-python.org/latest/apidoc/biotite.structure.Atom.html
- 10: https://www.biotite-python.org/latest/examples/gallery/structure/modeling/solvation_shells.html
- 11: https://www.biotite-python.org/latest/examples/gallery/structure/modeling/water_exchange_noexec.html
- 12: https://www.biotite-python.org/latest/tutorial/structure/filter.html
🏁 Script executed:
# First, check the actual code in src/dataset.py around lines 73-81
cat -n src/dataset.py | sed -n '48,85p'Repository: diff-use/WaterFlow
Length of output: 1662
🏁 Script executed:
# Also check the function definition and docstring to understand the intended behavior
cat -n src/dataset.py | sed -n '48,100p'Repository: diff-use/WaterFlow
Length of output: 2332
🏁 Script executed:
# Check if there's any documentation about what ligands should be
grep -n "ligand" src/dataset.py | head -20Repository: diff-use/WaterFlow
Length of output: 1458
Use atoms.hetero flag to properly isolate HETATM ligands.
The docstring explicitly states that ligand_atoms should contain "non-protein, non-water HETATM records", but the current logic ~protein_mask & ~water_mask only uses negation and will also include ATOM records like nucleic acids or other polymers that are neither proteins nor water. This silently breaks the documented contract and changes node counts for mixed complexes.
Biotite's PDB parser preserves the ATOM vs HETATM distinction via the atoms.hetero flag, so the fix is straightforward:
Suggested fix
protein_mask = bts.filter_amino_acids(atoms)
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
- ligand_mask = ~protein_mask & ~water_mask
+ hetero_mask = atoms.hetero.astype(bool)
+ ligand_mask = hetero_mask & ~protein_mask & ~water_mask🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/dataset.py` around lines 73 - 81, The current ligand selection uses
ligand_mask = ~protein_mask & ~water_mask which includes any
non-protein/non-water ATOM records; instead restrict ligands to HETATM entries
by using the Biotite flag atoms.hetero: update ligand_mask to also require
atoms.hetero (e.g., ligand_mask = (~protein_mask & ~water_mask) & atoms.hetero)
so ligand_atoms contains only non-protein, non-water HETATM records; keep
protein_mask, water_mask, protein_atoms, water_atoms assignments unchanged and
return protein_atoms, water_atoms, ligand_atoms.
| # Ligand atoms get residue_index = -1 (sentinel; no residue embedding) | ||
| ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long) | ||
| final_protein_res_idx = torch.cat( | ||
| [final_protein_res_idx, ligand_res_idx], dim=0 | ||
| ) | ||
| is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool) | ||
| is_ligand[-len(ligand_atoms) :] = True | ||
| else: | ||
| is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find every downstream read of residue_index and confirm negative values are
# masked before indexing, scattering, or pooling.
rg -n --type=py -C3 '\bresidue_index\b'Repository: diff-use/WaterFlow
Length of output: 7577
🏁 Script executed:
# Check how _pool_by_residue is called and whether ligands are masked
rg -n -A 20 'def forward' src/gvp_encoder.py | head -100Repository: diff-use/WaterFlow
Length of output: 1902
🏁 Script executed:
# Check if is_ligand is used anywhere to mask before pooling
rg -n 'is_ligand' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Look at the exact lines 1199-1224 in dataset.py mentioned in "Also applies to"
sed -n '1199,1224p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 1126
🏁 Script executed:
# Check PyTorch scatter behavior with negative indices - look for any defensive code
rg -n -B 5 -A 5 'scatter_mean\|scatter_add\|scatter_max' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check make_gvp_encoder_data function - does it exclude ligands?
rg -n -A 30 'def make_gvp_encoder_data' src/Repository: diff-use/WaterFlow
Length of output: 1662
🏁 Script executed:
# Look at the actual _pool_by_residue implementation more carefully
sed -n '260,290p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1208
🏁 Script executed:
# Check if there's any validation or masking of negative indices in scatter operations
rg -n -B 10 -A 10 'def _pool_by_residue' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 862
🏁 Script executed:
# Test what happens with negative indices in scatter operations - check if there's error handling
python3 << 'EOF'
import torch
from torch_scatter import scatter_mean
# Test if scatter_mean handles -1 indices
try:
atom_embed = torch.randn(5, 16)
residue_index = torch.tensor([0, 1, -1, 1, -1], dtype=torch.long)
result = scatter_mean(atom_embed, residue_index, dim=0, dim_size=2)
print("scatter_mean with -1 indices succeeded")
print(f"Result shape: {result.shape}")
print(f"Result:\n{result}")
except Exception as e:
print(f"scatter_mean with -1 indices failed: {e}")
EOFRepository: diff-use/WaterFlow
Length of output: 179
🏁 Script executed:
# Check the complete make_gvp_encoder_data to see if it excludes ligands
sed -n '60,110p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1473
🏁 Script executed:
# Check the forward method to see when _pool_by_residue is called and if ligands are excluded
sed -n '304,360p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 2596
🏁 Script executed:
# Search for any filtering of ligands before pooling
rg -n 'is_ligand' src/gvp_encoder.py -B 5 -A 5Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check dataset.py to see where pool_residue is set and what assumptions it makes
rg -n 'pool_residue' src/Repository: diff-use/WaterFlow
Length of output: 331
🏁 Script executed:
# Look for any scatter operation masking or negative index handling in the codebase
rg -n 'scatter_mean|scatter_add|scatter_max' src/ -B 3 -A 3Repository: diff-use/WaterFlow
Length of output: 4617
🏁 Script executed:
# Check if residue_index is passed to make_gvp_encoder_data or the encoder
sed -n '60,100p' src/gvp_encoder.py | grep -n residue_indexRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check the call to make_gvp_encoder_data to see if residue_index is preserved
rg -n 'make_gvp_encoder_data' src/ -B 2 -A 5Repository: diff-use/WaterFlow
Length of output: 895
🏁 Script executed:
# Check the ProteinGVPEncoder forward method to see full implementation
sed -n '488,520p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1216
🏁 Script executed:
# Let's check the full flow - where pool_residue is actually called
rg -n 'pool_residue=True' src/ -B 3 -A 3Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Look more carefully at make_gvp_encoder_data - does it copy residue_index?
sed -n '60,105p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1394
🏁 Script executed:
# Check where ProteinGVPEncoder is instantiated - does it ever set pool_residue=True?
rg -n 'ProteinGVPEncoder\|pool_residue' src/ -B 2 -A 2 | head -60Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check if residue_index is ever added to enc_data after make_gvp_encoder_data
rg -n 'enc_data.residue_index\|enc_data.num_residues' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check the actual encoder instantiation in the GVPEncoder (wrapper) class
sed -n '430,480p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1600
🏁 Script executed:
# Verify pool_residue is never set to True anywhere
rg -n 'pool_residue.*=.*True\|pool_residue=True' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check what the default pool_residue value is and in what context
rg -n 'pool_residue' src/ -B 1 -A 1Repository: diff-use/WaterFlow
Length of output: 930
🏁 Script executed:
# Now let's look at the lines around 1070 and 1199-1224 to understand the full context
sed -n '1060,1080p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 1367
🏁 Script executed:
# Check lines 1199-1224 as mentioned in "Also applies to"
sed -n '1190,1230p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 1726
🏁 Script executed:
# Check if residue_index is used anywhere else with negative value handling
rg -n '\-1.*residue|residue.*\-1' src/Repository: diff-use/WaterFlow
Length of output: 164
🏁 Script executed:
# Check if there are any tests or code that actually enable pool_residue
git log --all --oneline -S 'pool_residue=True' 2>/dev/null || echo "No git log available"
rg -r 'pool_residue\s*=\s*True' src/ tests/ 2>/dev/null || echo "Not found"Repository: diff-use/WaterFlow
Length of output: 148
🏁 Script executed:
# Check what happens with max() of tensor with -1 values
python3 << 'EOF'
import torch
# Simulate residue_index with -1 values for ligands
res_idx = torch.tensor([0, 1, 2, 1, 0, -1, -1])
max_val = res_idx.max().item()
num_residues = int(max_val + 1)
print(f"residue_index: {res_idx}")
print(f"max value: {max_val}")
print(f"num_residues: {num_residues}")
# Edge case: all ligands
res_idx_ligand_only = torch.tensor([-1, -1, -1])
max_val2 = res_idx_ligand_only.max().item()
num_residues2 = int(max_val2 + 1)
print(f"\nAll ligands case:")
print(f"residue_index: {res_idx_ligand_only}")
print(f"max value: {max_val2}")
print(f"num_residues: {num_residues2}")
EOFRepository: diff-use/WaterFlow
Length of output: 179
🏁 Script executed:
# Check if residue_index is used anywhere for indexing operations besides scatter
rg -n 'residue_index\[' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Look for any masking of is_ligand before residue operations
rg -n 'is_ligand' src/ -A 3 -B 3 | grep -A 5 -B 5 residueRepository: diff-use/WaterFlow
Length of output: 712
🏁 Script executed:
# Check all residue_index usages more comprehensively
rg -n 'residue_index' src/ tests/ --type py -C 2Repository: diff-use/WaterFlow
Length of output: 6025
🏁 Script executed:
# Specifically check if residue_index is ever used as an index (array indexing)
rg -n '\[.*residue_index\]|\[residue_index' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check if there are any comments or TODOs about handling -1 in residue_index
rg -n 'sentinel|-1.*residue|residue.*-1' src/ tests/Repository: diff-use/WaterFlow
Length of output: 444
🏁 Script executed:
# Check the test at lines 2395-2398 that checks residue indices
sed -n '2390,2410p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check if there's any edge case test with ligands
rg -n 'ligand' tests/test_dataset.py | head -20Repository: diff-use/WaterFlow
Length of output: 1444
🏁 Script executed:
# Final check: are there any other scatter operations that might use residue_index indirectly?
rg -n 'scatter_' src/ -B 2 -A 2Repository: diff-use/WaterFlow
Length of output: 3766
Verify that _pool_by_residue masks out ligand atoms before scatter operations.
This change introduces -1 sentinel values in residue_index for ligand atoms, but the scatter operations in _pool_by_residue (lines 276–282 in gvp_encoder.py) cannot handle negative indices. Although pool_residue is currently disabled (pool_residue=False by default), if this feature is enabled in the future, the code will fail when calling scatter_mean, scatter_add, or scatter_max with a mixed residue_index containing -1 values.
Add a mask using is_ligand before the scatter operations:
mask = ~is_ligand # or retrieve from data
atom_embed_masked = atom_embed[mask]
residue_index_masked = residue_index[mask]
res_embed = scatter_mean(atom_embed_masked, residue_index_masked, dim=0, dim_size=num_residues)Also verify the same masking is applied at any other scatter/pooling operations keyed by residue_index if pool_residue becomes enabled.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/dataset.py` around lines 1070 - 1078, _pool_by_residue will fail if
residue_index contains -1 for ligands because scatter operations (scatter_mean,
scatter_add, scatter_max) cannot accept negative indices; before any scatter
keyed by residue_index (in _pool_by_residue and any other pooling guarded by
pool_residue), filter out ligand atoms using the is_ligand mask (e.g., mask =
~is_ligand) and pass atom embeddings and residue_index masked arrays
(atom_embed[mask], residue_index[mask]) to scatter with dim_size=num_residues;
apply the same masking wherever residue_index is used for pooling to avoid
negative-index errors.
| # Learnable projection for ligand atoms (element one-hot -> embedding space). | ||
| # Ligands have no ESM/SLAE embeddings; this replaces zero-padding with a | ||
| # learned representation parameterized only by element type. | ||
| # Lazily initialized on first forward when embedding_dim becomes known. | ||
| self.ligand_embed: nn.Linear | None = ( | ||
| nn.Linear(NODE_FEATURE_DIM, embedding_dim, bias=False) | ||
| if embedding_dim is not None | ||
| else None | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In PyTorch, if an nn.Modulecreates a new submodule or Parameter insideforward() after the optimizer or DDP/FSDP wrapper has already been constructed, will that parameter be optimized and replicated?
💡 Result:
No, parameters or submodules created inside forward after the optimizer and DDP/FSDP wrapper have been constructed will not be optimized or replicated. PyTorch optimizers capture the parameters at construction time via model.parameters, which yields only parameters registered at that moment (those assigned as nn.Parameter attributes before wrapper/optimizer creation). New parameters registered dynamically in forward are not included in this iterator, so the optimizer does not track them and step skips them. DDP and FSDP similarly build their internal state (buckets, hooks, sharding) based on parameters present at construction; new parameters added later are not hooked for gradient synchronization or sharding/replication. Official docs confirm optimizers use param_groups from the initial parameters call, with state keyed by parameter IDs captured then. Forum discussions and GitHub issues (e.g., dynamic parametrizations in FSDP requiring manual ignored_params + DDP) show post-construction parameters need explicit handling (e.g., optimizer.add_param_group), but even then replication may fail due to missed hooks. Creating parameters in forward is discouraged; use init instead.
Citations:
- 1: http://docs.pytorch.org/docs/main/notes/modules.html
- 2: https://pytorch.org/docs/2.1/generated/torch.optim.Optimizer.state_dict.html
- 3: https://stackoverflow.com/questions/63047762/correct-way-to-register-a-parameter-for-model-in-pytorch
- 4: https://stackoverflow.com/questions/59234238/how-to-add-parameters-in-module-class-in-pytorch-custom-model
- 5: https://docs.pytorch.ac.cn/docs/2.8/generated/torch.nn.utils.parametrize.register_parametrization.html
- 6: https://docs.pytorch.org/docs/2.5/notes/ddp.html
- 7: https://docs.pytorch.org/docs/main/distributed.fsdp.fully%5Fshard.html
🏁 Script executed:
# Check the file exists and inspect the relevant lines
if [ -f "src/encoder_base.py" ]; then
echo "=== Lines 182-190 (declaration) ==="
sed -n '182,190p' src/encoder_base.py
echo ""
echo "=== Lines 237-242 (instantiation) ==="
sed -n '237,242p' src/encoder_base.py
echo ""
echo "=== Context around line 237 to understand forward() method ==="
sed -n '220,250p' src/encoder_base.py
else
echo "File not found"
fiRepository: diff-use/WaterFlow
Length of output: 2363
Move ligand_embed initialization to __init__.
The module is created inside forward() after the optimizer has already captured model parameters. Since PyTorch optimizers iterate over model.parameters() at construction time, any submodules registered later in forward() are not included in param_groups and will not be updated. Similarly, DDP and FSDP will not replicate or synchronize these parameters. Move the nn.Linear initialization to __init__, or call optimizer.add_param_group() with the new parameters before training (though full DDP replication may still fail without explicit hooks).
Affects lines 182–190 (declaration) and 237–242 (forward instantiation).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/encoder_base.py` around lines 182 - 190, The ligand_embed Linear is being
created lazily in forward(), so its parameters are not registered when
optimizers/DDP/FSDP capture model parameters; move the nn.Linear initialization
into __init__ by creating self.ligand_embed there (using the given
embedding_dim) and remove the forward-time instantiation, or if embedding_dim
truly isn’t known at construction, add an explicit
initialize_ligand_embed(embedding_dim) method that constructs and registers
self.ligand_embed before the optimizer is created (call it in setup code),
ensuring parameters are registered with model.parameters(); as a last-resort
alternative, if you cannot construct before optimizer, call
optimizer.add_param_group(...) with the new parameters immediately after
creation (note this still may not work with DDP/FSDP), and update references to
ligand_embed in forward() to assume it already exists.
PR to include ligands in the static structure when encoding proteins. This is a v1 change so tried to keep it as simple of a change as possible.
proteinnode types with ais_ligandflag.Summary by CodeRabbit
New Features
Dependencies