Skip to content

Adding ligands to dataset processing#78

Open
vratins wants to merge 5 commits intomainfrom
dev_ligands
Open

Adding ligands to dataset processing#78
vratins wants to merge 5 commits intomainfrom
dev_ligands

Conversation

@vratins
Copy link
Copy Markdown
Contributor

@vratins vratins commented Mar 30, 2026

PR to include ligands in the static structure when encoding proteins. This is a v1 change so tried to keep it as simple of a change as possible.

  • Ligands are included in protein node types with a is_ligand flag.
  • Tried to minimize major codebase refactors in this PR so no new edge types or node-types.
  • Encoded with 1-hot element type
  • Added tests, as well as a new test pdb file which contains ligands (the PDB file is ~3.2k lines hence the huge diff)

Summary by CodeRabbit

  • New Features

    • Added support for ligand molecule extraction and processing in datasets.
    • Introduced optional parameter to include ligand atoms when loading protein structures.
    • Added learnable ligand atom embeddings.
  • Dependencies

    • Updated pymol package with minimum version constraint.
    • Added rdkit dependency.

Copilot AI review requested due to automatic review settings March 30, 2026 20:14
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Mar 30, 2026

📝 Walkthrough

Walkthrough

The pull request adds comprehensive ligand support to a protein dataset pipeline. Dependencies were updated (pymol package variant and rdkit added), dataset parsing extended to extract ligands, the dataset class modified to optionally include ligands in training data, and the encoder updated to process ligand features via learned projections. Tests were expanded with ligand parsing and integration validation.

Changes

Cohort / File(s) Summary
Dependency Updates
pyproject.toml
Replaced pymol-open-source with pymol-open-source-whl>=3.1.0.4 and added rdkit dependency.
Dataset Ligand Support
src/dataset.py
Extended parse_asu_with_biotite() to return ligand atoms; added include_ligands parameter to ProteinWaterDataset; preprocessing now appends ligand coordinates and features to protein node tensors when enabled; introduced is_ligand boolean node attribute with residue index sentinel (-1) for ligand nodes.
Encoder Ligand Handling
src/encoder_base.py
Added ligand_embed learnable projection to CachedEmbeddingEncoder; updated forward() to replace cached embeddings for ligand nodes with learned projections when is_ligand mask is present.
Test Fixtures & Infrastructure
tests/conftest.py
Added pdb_4h0b pytest fixture for test data resolution.
Test Coverage
tests/test_dataset.py
Updated existing tests to handle 3-tuple return from parse_asu_with_biotite(); added TestLigandParsing suite validating ligand extraction semantics; added TestLigandNodeIntegration suite validating dataset behavior with ligand inclusion (include_ligands parameter).

Sequence Diagrams

sequenceDiagram
    participant User
    participant Dataset as ProteinWaterDataset
    participant Parser as parse_asu_with_biotite()
    participant Encoder as CachedEmbeddingEncoder
    participant Output as Graph Data

    User->>Dataset: init(include_ligands=True)
    Dataset->>Parser: parse_asu_with_biotite(pdb_path)
    Parser-->>Dataset: (protein_atoms, water_atoms, ligand_atoms)
    
    alt include_ligands=True & ligands exist
        Dataset->>Dataset: Append ligand coordinates to protein nodes
        Dataset->>Dataset: Set is_ligand mask & residue_index=-1 for ligands
    end
    
    Dataset-->>User: Graph data with optional ligand nodes
    
    User->>Encoder: forward(data)
    
    alt is_ligand mask present & any True
        Encoder->>Encoder: ligand_embed(protein.x[ligand_mask])
        Encoder->>Encoder: Replace cached embeddings for ligand rows
    end
    
    Encoder-->>Output: Embeddings with learned ligand projections
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

The changes introduce new feature logic (ligand extraction and encoding) across multiple interconnected files with heterogeneous modifications. Dataset parsing returns an additional value type, preprocessing conditionally appends ligand data with sentinel markers, and the encoder adds learnable projection logic with conditional execution paths. Comprehensive test coverage adds validation complexity. The variety of changes and logic density across dataset, encoder, and test layers warrants substantial review effort.

Poem

A ligand hops in, tail held high,
No longer lost, no longer shy,
With is_ligand marks, bright and true,
The encoder learns what ligands do—
One hop, one prep, one learned projection,
Our protein graph finds new direction! 🐰✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Adding ligands to dataset processing' directly and clearly describes the main change: ligand support is being added to the dataset processing pipeline.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch dev_ligands

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@vratins vratins linked an issue Mar 30, 2026 that may be closed by this pull request
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends dataset preprocessing to include non-water ligands in the static structure representation, folding ligand atoms into the existing protein node type and tagging them via an is_ligand boolean mask.

Changes:

  • Update PDB parsing to return (protein_atoms, water_atoms, ligand_atoms) and (optionally) append ligands to protein nodes via include_ligands.
  • Add protein.is_ligand to cached geometry and HeteroData for downstream masking/conditioning.
  • Add integration tests and a new ligand-containing PDB fixture (4h0b) to validate parsing and dataset behavior.

Reviewed changes

Copilot reviewed 5 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/dataset.py Adds ligand parsing/output, include_ligands option, and persists is_ligand in caches and HeteroData.
src/encoder_base.py Adds a learned projection for ligand rows when using cached-embedding encoders (ESM/SLAE).
tests/test_dataset.py Expands parsing tests for 3-way partitioning and adds dataset integration tests for ligand inclusion.
tests/conftest.py Adds a pdb_4h0b fixture for ligand test coverage.
pyproject.toml Adds rdkit and swaps the PyMOL package reference to pymol-open-source-whl.
uv.lock Updates locked dependency set to reflect dependency changes (incl. rdkit/jaxtyping/pymol package rename).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/dataset.py
Comment on lines 73 to +79
protein_mask = bts.filter_amino_acids(atoms)
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
ligand_mask = ~protein_mask & ~water_mask

protein_atoms = atoms[protein_mask]
water_atoms = atoms[water_mask]
ligand_atoms = atoms[ligand_mask]
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ligand_mask = ~protein_mask & ~water_mask will classify any non-amino-acid, non-water atoms as ligands, including non-protein polymers recorded as ATOM (e.g., DNA/RNA) and other non-hetero records. This contradicts the docstring claim of “non-water HETATM records” and can silently mix non-ligand atoms into the ligand set. Consider restricting ligands to atoms.hetero (and excluding HOH/WAT), or otherwise explicitly encoding the intended selection criteria.

Copilot uses AI. Check for mistakes.
Comment thread src/dataset.py
Comment on lines 48 to 51
def parse_asu_with_biotite(
path: str,
) -> tuple[bts.AtomArray, bts.AtomArray]:
) -> tuple[bts.AtomArray, bts.AtomArray, bts.AtomArray]:
"""
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_asu_with_biotite() now returns a 3-tuple, which is a breaking API change. There are existing call sites (e.g., embedding generation scripts) that still unpack only 2 values and will crash at runtime unless updated to accept/ignore the new ligand_atoms return value.

Copilot uses AI. Check for mistakes.
Comment thread src/dataset.py
Comment on lines +1070 to +1071
# Ligand atoms get residue_index = -1 (sentinel; no residue embedding)
ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long)
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning residue_index = -1 for ligand atoms will break any downstream code that uses residue_index for scatter/gather (e.g., residue pooling via torch_scatter), since those ops generally require non-negative indices. Either mask out ligand atoms wherever residue pooling/indexing is used, or map ligands to a valid non-negative index (and update num_residues accordingly) to keep the tensor safe for scatter-based pooling.

Suggested change
# Ligand atoms get residue_index = -1 (sentinel; no residue embedding)
ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long)
# Assign ligands to valid residue indices after existing residues
max_res_idx = (
final_protein_res_idx.max().item()
if final_protein_res_idx.numel() > 0
else -1
)
ligand_res_idx = torch.full(
(len(ligand_atoms),),
max_res_idx + 1,
dtype=torch.long,
)

Copilot uses AI. Check for mistakes.
Comment thread pyproject.toml
Comment on lines +14 to 19
"pymol-open-source-whl>=3.1.0.4",
"scipy",
"pandas",
"numpy",
"rdkit",
"matplotlib",
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rdkit is added as a required runtime dependency, but it isn’t referenced anywhere in the repo (src/scripts/tests). Since RDKit significantly increases install size and can complicate builds on some platforms, consider removing it until it’s actually needed, or making it an optional extra for ligand-specific workflows.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/test_dataset.py (1)

566-697: Add one cached-embedding ligand regression test.

These cases only exercise encoder_type="gvp", but the new ligand-specific behavior lives in CachedEmbeddingEncoder.forward() for esm/slae. A tiny fixture-backed test that asserts ligand rows are replaced would cover the path most likely to regress.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_dataset.py` around lines 566 - 697, Add a small integration test
in tests/test_dataset.py that exercises CachedEmbeddingEncoder.forward for
encoder_type="esm" (or "slae") to prevent regressions: construct a
ProteinWaterDataset with include_ligands=True and preprocess=True using a
fixture-backed cached-embedding file (or reuse the existing pdb fixtures), then
fetch data = ds[0], locate ligand rows via data["protein"].is_ligand, and assert
that those ligand-row embeddings in data["protein"].x match the expected cached
embeddings (i.e., are replaced by the cached values) and differ from the
original one-hot element encoding; reference CachedEmbeddingEncoder.forward and
ProteinWaterDataset to find the encoder flow. Ensure the test is small, uses the
tmp_path fixture for any temp files, and only targets the cached-embedding path
for esm/slae.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/dataset.py`:
- Around line 73-81: The current ligand selection uses ligand_mask =
~protein_mask & ~water_mask which includes any non-protein/non-water ATOM
records; instead restrict ligands to HETATM entries by using the Biotite flag
atoms.hetero: update ligand_mask to also require atoms.hetero (e.g., ligand_mask
= (~protein_mask & ~water_mask) & atoms.hetero) so ligand_atoms contains only
non-protein, non-water HETATM records; keep protein_mask, water_mask,
protein_atoms, water_atoms assignments unchanged and return protein_atoms,
water_atoms, ligand_atoms.
- Around line 1070-1078: _pool_by_residue will fail if residue_index contains -1
for ligands because scatter operations (scatter_mean, scatter_add, scatter_max)
cannot accept negative indices; before any scatter keyed by residue_index (in
_pool_by_residue and any other pooling guarded by pool_residue), filter out
ligand atoms using the is_ligand mask (e.g., mask = ~is_ligand) and pass atom
embeddings and residue_index masked arrays (atom_embed[mask],
residue_index[mask]) to scatter with dim_size=num_residues; apply the same
masking wherever residue_index is used for pooling to avoid negative-index
errors.

In `@src/encoder_base.py`:
- Around line 182-190: The ligand_embed Linear is being created lazily in
forward(), so its parameters are not registered when optimizers/DDP/FSDP capture
model parameters; move the nn.Linear initialization into __init__ by creating
self.ligand_embed there (using the given embedding_dim) and remove the
forward-time instantiation, or if embedding_dim truly isn’t known at
construction, add an explicit initialize_ligand_embed(embedding_dim) method that
constructs and registers self.ligand_embed before the optimizer is created (call
it in setup code), ensuring parameters are registered with model.parameters();
as a last-resort alternative, if you cannot construct before optimizer, call
optimizer.add_param_group(...) with the new parameters immediately after
creation (note this still may not work with DDP/FSDP), and update references to
ligand_embed in forward() to assume it already exists.

---

Nitpick comments:
In `@tests/test_dataset.py`:
- Around line 566-697: Add a small integration test in tests/test_dataset.py
that exercises CachedEmbeddingEncoder.forward for encoder_type="esm" (or "slae")
to prevent regressions: construct a ProteinWaterDataset with
include_ligands=True and preprocess=True using a fixture-backed cached-embedding
file (or reuse the existing pdb fixtures), then fetch data = ds[0], locate
ligand rows via data["protein"].is_ligand, and assert that those ligand-row
embeddings in data["protein"].x match the expected cached embeddings (i.e., are
replaced by the cached values) and differ from the original one-hot element
encoding; reference CachedEmbeddingEncoder.forward and ProteinWaterDataset to
find the encoder flow. Ensure the test is small, uses the tmp_path fixture for
any temp files, and only targets the cached-embedding path for esm/slae.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: abe0f086-5c49-47da-8543-bf16a9391a11

📥 Commits

Reviewing files that changed from the base of the PR and between b9c0ce9 and bebf742.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (6)
  • pyproject.toml
  • src/dataset.py
  • src/encoder_base.py
  • tests/conftest.py
  • tests/test_dataset.py
  • tests/test_files/4h0b/4h0b_final.pdb

Comment thread src/dataset.py
Comment on lines 73 to +81
protein_mask = bts.filter_amino_acids(atoms)
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
ligand_mask = ~protein_mask & ~water_mask

protein_atoms = atoms[protein_mask]
water_atoms = atoms[water_mask]
ligand_atoms = atoms[ligand_mask]

return protein_atoms, water_atoms
return protein_atoms, water_atoms, ligand_atoms
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In Biotite's PDB parser, does get_structure()preserve ATOM vs HETATM viaAtomArray.hetero, and is atoms.hetero & ~water_mask the correct way to isolate non-water HETATM ligands?

💡 Result:

Yes, Biotite's PDB parser via get_structure preserves the ATOM vs HETATM distinction using the AtomArray.hetero annotation, which is set to True for HETATM records and False for ATOM records. This is evidenced by the Atom constructor accepting a hetero parameter (result 5, result 14), and examples where hetero=True is explicitly set for added HETATM atoms, with the parser maintaining this in loaded structures (results 1,4,7,8). Additionally, Biotite provides struc.filter_solvent(atoms) to create a water_mask (results 18,19,22), typically identifying water molecules (e.g., res_name "HOH"). Thus, atoms.hetero & ~water_mask correctly isolates non-water HETATM atoms (hetero=True and not solvent), which are typically ligands or other hetero compounds.

Citations:


🏁 Script executed:

# First, check the actual code in src/dataset.py around lines 73-81
cat -n src/dataset.py | sed -n '48,85p'

Repository: diff-use/WaterFlow

Length of output: 1662


🏁 Script executed:

# Also check the function definition and docstring to understand the intended behavior
cat -n src/dataset.py | sed -n '48,100p'

Repository: diff-use/WaterFlow

Length of output: 2332


🏁 Script executed:

# Check if there's any documentation about what ligands should be
grep -n "ligand" src/dataset.py | head -20

Repository: diff-use/WaterFlow

Length of output: 1458


Use atoms.hetero flag to properly isolate HETATM ligands.

The docstring explicitly states that ligand_atoms should contain "non-protein, non-water HETATM records", but the current logic ~protein_mask & ~water_mask only uses negation and will also include ATOM records like nucleic acids or other polymers that are neither proteins nor water. This silently breaks the documented contract and changes node counts for mixed complexes.

Biotite's PDB parser preserves the ATOM vs HETATM distinction via the atoms.hetero flag, so the fix is straightforward:

Suggested fix
     protein_mask = bts.filter_amino_acids(atoms)
     water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
-    ligand_mask = ~protein_mask & ~water_mask
+    hetero_mask = atoms.hetero.astype(bool)
+    ligand_mask = hetero_mask & ~protein_mask & ~water_mask
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` around lines 73 - 81, The current ligand selection uses
ligand_mask = ~protein_mask & ~water_mask which includes any
non-protein/non-water ATOM records; instead restrict ligands to HETATM entries
by using the Biotite flag atoms.hetero: update ligand_mask to also require
atoms.hetero (e.g., ligand_mask = (~protein_mask & ~water_mask) & atoms.hetero)
so ligand_atoms contains only non-protein, non-water HETATM records; keep
protein_mask, water_mask, protein_atoms, water_atoms assignments unchanged and
return protein_atoms, water_atoms, ligand_atoms.

Comment thread src/dataset.py
Comment on lines +1070 to +1078
# Ligand atoms get residue_index = -1 (sentinel; no residue embedding)
ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long)
final_protein_res_idx = torch.cat(
[final_protein_res_idx, ligand_res_idx], dim=0
)
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)
is_ligand[-len(ligand_atoms) :] = True
else:
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find every downstream read of residue_index and confirm negative values are
# masked before indexing, scattering, or pooling.
rg -n --type=py -C3 '\bresidue_index\b'

Repository: diff-use/WaterFlow

Length of output: 7577


🏁 Script executed:

# Check how _pool_by_residue is called and whether ligands are masked
rg -n -A 20 'def forward' src/gvp_encoder.py | head -100

Repository: diff-use/WaterFlow

Length of output: 1902


🏁 Script executed:

# Check if is_ligand is used anywhere to mask before pooling
rg -n 'is_ligand' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Look at the exact lines 1199-1224 in dataset.py mentioned in "Also applies to"
sed -n '1199,1224p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 1126


🏁 Script executed:

# Check PyTorch scatter behavior with negative indices - look for any defensive code
rg -n -B 5 -A 5 'scatter_mean\|scatter_add\|scatter_max' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check make_gvp_encoder_data function - does it exclude ligands?
rg -n -A 30 'def make_gvp_encoder_data' src/

Repository: diff-use/WaterFlow

Length of output: 1662


🏁 Script executed:

# Look at the actual _pool_by_residue implementation more carefully
sed -n '260,290p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1208


🏁 Script executed:

# Check if there's any validation or masking of negative indices in scatter operations
rg -n -B 10 -A 10 'def _pool_by_residue' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 862


🏁 Script executed:

# Test what happens with negative indices in scatter operations - check if there's error handling
python3 << 'EOF'
import torch
from torch_scatter import scatter_mean

# Test if scatter_mean handles -1 indices
try:
    atom_embed = torch.randn(5, 16)
    residue_index = torch.tensor([0, 1, -1, 1, -1], dtype=torch.long)
    result = scatter_mean(atom_embed, residue_index, dim=0, dim_size=2)
    print("scatter_mean with -1 indices succeeded")
    print(f"Result shape: {result.shape}")
    print(f"Result:\n{result}")
except Exception as e:
    print(f"scatter_mean with -1 indices failed: {e}")
EOF

Repository: diff-use/WaterFlow

Length of output: 179


🏁 Script executed:

# Check the complete make_gvp_encoder_data to see if it excludes ligands
sed -n '60,110p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1473


🏁 Script executed:

# Check the forward method to see when _pool_by_residue is called and if ligands are excluded
sed -n '304,360p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 2596


🏁 Script executed:

# Search for any filtering of ligands before pooling
rg -n 'is_ligand' src/gvp_encoder.py -B 5 -A 5

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check dataset.py to see where pool_residue is set and what assumptions it makes
rg -n 'pool_residue' src/

Repository: diff-use/WaterFlow

Length of output: 331


🏁 Script executed:

# Look for any scatter operation masking or negative index handling in the codebase
rg -n 'scatter_mean|scatter_add|scatter_max' src/ -B 3 -A 3

Repository: diff-use/WaterFlow

Length of output: 4617


🏁 Script executed:

# Check if residue_index is passed to make_gvp_encoder_data or the encoder
sed -n '60,100p' src/gvp_encoder.py | grep -n residue_index

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check the call to make_gvp_encoder_data to see if residue_index is preserved
rg -n 'make_gvp_encoder_data' src/ -B 2 -A 5

Repository: diff-use/WaterFlow

Length of output: 895


🏁 Script executed:

# Check the ProteinGVPEncoder forward method to see full implementation
sed -n '488,520p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1216


🏁 Script executed:

# Let's check the full flow - where pool_residue is actually called
rg -n 'pool_residue=True' src/ -B 3 -A 3

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Look more carefully at make_gvp_encoder_data - does it copy residue_index?
sed -n '60,105p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1394


🏁 Script executed:

# Check where ProteinGVPEncoder is instantiated - does it ever set pool_residue=True?
rg -n 'ProteinGVPEncoder\|pool_residue' src/ -B 2 -A 2 | head -60

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check if residue_index is ever added to enc_data after make_gvp_encoder_data
rg -n 'enc_data.residue_index\|enc_data.num_residues' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check the actual encoder instantiation in the GVPEncoder (wrapper) class
sed -n '430,480p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1600


🏁 Script executed:

# Verify pool_residue is never set to True anywhere
rg -n 'pool_residue.*=.*True\|pool_residue=True' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check what the default pool_residue value is and in what context
rg -n 'pool_residue' src/ -B 1 -A 1

Repository: diff-use/WaterFlow

Length of output: 930


🏁 Script executed:

# Now let's look at the lines around 1070 and 1199-1224 to understand the full context
sed -n '1060,1080p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 1367


🏁 Script executed:

# Check lines 1199-1224 as mentioned in "Also applies to"
sed -n '1190,1230p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 1726


🏁 Script executed:

# Check if residue_index is used anywhere else with negative value handling
rg -n '\-1.*residue|residue.*\-1' src/

Repository: diff-use/WaterFlow

Length of output: 164


🏁 Script executed:

# Check if there are any tests or code that actually enable pool_residue
git log --all --oneline -S 'pool_residue=True' 2>/dev/null || echo "No git log available"
rg -r 'pool_residue\s*=\s*True' src/ tests/ 2>/dev/null || echo "Not found"

Repository: diff-use/WaterFlow

Length of output: 148


🏁 Script executed:

# Check what happens with max() of tensor with -1 values
python3 << 'EOF'
import torch

# Simulate residue_index with -1 values for ligands
res_idx = torch.tensor([0, 1, 2, 1, 0, -1, -1])
max_val = res_idx.max().item()
num_residues = int(max_val + 1)
print(f"residue_index: {res_idx}")
print(f"max value: {max_val}")
print(f"num_residues: {num_residues}")

# Edge case: all ligands
res_idx_ligand_only = torch.tensor([-1, -1, -1])
max_val2 = res_idx_ligand_only.max().item()
num_residues2 = int(max_val2 + 1)
print(f"\nAll ligands case:")
print(f"residue_index: {res_idx_ligand_only}")
print(f"max value: {max_val2}")
print(f"num_residues: {num_residues2}")
EOF

Repository: diff-use/WaterFlow

Length of output: 179


🏁 Script executed:

# Check if residue_index is used anywhere for indexing operations besides scatter
rg -n 'residue_index\[' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Look for any masking of is_ligand before residue operations
rg -n 'is_ligand' src/ -A 3 -B 3 | grep -A 5 -B 5 residue

Repository: diff-use/WaterFlow

Length of output: 712


🏁 Script executed:

# Check all residue_index usages more comprehensively
rg -n 'residue_index' src/ tests/ --type py -C 2

Repository: diff-use/WaterFlow

Length of output: 6025


🏁 Script executed:

# Specifically check if residue_index is ever used as an index (array indexing)
rg -n '\[.*residue_index\]|\[residue_index' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check if there are any comments or TODOs about handling -1 in residue_index
rg -n 'sentinel|-1.*residue|residue.*-1' src/ tests/

Repository: diff-use/WaterFlow

Length of output: 444


🏁 Script executed:

# Check the test at lines 2395-2398 that checks residue indices
sed -n '2390,2410p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check if there's any edge case test with ligands
rg -n 'ligand' tests/test_dataset.py | head -20

Repository: diff-use/WaterFlow

Length of output: 1444


🏁 Script executed:

# Final check: are there any other scatter operations that might use residue_index indirectly?
rg -n 'scatter_' src/ -B 2 -A 2

Repository: diff-use/WaterFlow

Length of output: 3766


Verify that _pool_by_residue masks out ligand atoms before scatter operations.

This change introduces -1 sentinel values in residue_index for ligand atoms, but the scatter operations in _pool_by_residue (lines 276–282 in gvp_encoder.py) cannot handle negative indices. Although pool_residue is currently disabled (pool_residue=False by default), if this feature is enabled in the future, the code will fail when calling scatter_mean, scatter_add, or scatter_max with a mixed residue_index containing -1 values.

Add a mask using is_ligand before the scatter operations:

mask = ~is_ligand  # or retrieve from data
atom_embed_masked = atom_embed[mask]
residue_index_masked = residue_index[mask]
res_embed = scatter_mean(atom_embed_masked, residue_index_masked, dim=0, dim_size=num_residues)

Also verify the same masking is applied at any other scatter/pooling operations keyed by residue_index if pool_residue becomes enabled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` around lines 1070 - 1078, _pool_by_residue will fail if
residue_index contains -1 for ligands because scatter operations (scatter_mean,
scatter_add, scatter_max) cannot accept negative indices; before any scatter
keyed by residue_index (in _pool_by_residue and any other pooling guarded by
pool_residue), filter out ligand atoms using the is_ligand mask (e.g., mask =
~is_ligand) and pass atom embeddings and residue_index masked arrays
(atom_embed[mask], residue_index[mask]) to scatter with dim_size=num_residues;
apply the same masking wherever residue_index is used for pooling to avoid
negative-index errors.

Comment thread src/encoder_base.py
Comment on lines +182 to +190
# Learnable projection for ligand atoms (element one-hot -> embedding space).
# Ligands have no ESM/SLAE embeddings; this replaces zero-padding with a
# learned representation parameterized only by element type.
# Lazily initialized on first forward when embedding_dim becomes known.
self.ligand_embed: nn.Linear | None = (
nn.Linear(NODE_FEATURE_DIM, embedding_dim, bias=False)
if embedding_dim is not None
else None
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In PyTorch, if an nn.Modulecreates a new submodule or Parameter insideforward() after the optimizer or DDP/FSDP wrapper has already been constructed, will that parameter be optimized and replicated?

💡 Result:

No, parameters or submodules created inside forward after the optimizer and DDP/FSDP wrapper have been constructed will not be optimized or replicated. PyTorch optimizers capture the parameters at construction time via model.parameters, which yields only parameters registered at that moment (those assigned as nn.Parameter attributes before wrapper/optimizer creation). New parameters registered dynamically in forward are not included in this iterator, so the optimizer does not track them and step skips them. DDP and FSDP similarly build their internal state (buckets, hooks, sharding) based on parameters present at construction; new parameters added later are not hooked for gradient synchronization or sharding/replication. Official docs confirm optimizers use param_groups from the initial parameters call, with state keyed by parameter IDs captured then. Forum discussions and GitHub issues (e.g., dynamic parametrizations in FSDP requiring manual ignored_params + DDP) show post-construction parameters need explicit handling (e.g., optimizer.add_param_group), but even then replication may fail due to missed hooks. Creating parameters in forward is discouraged; use init instead.

Citations:


🏁 Script executed:

# Check the file exists and inspect the relevant lines
if [ -f "src/encoder_base.py" ]; then
  echo "=== Lines 182-190 (declaration) ==="
  sed -n '182,190p' src/encoder_base.py
  echo ""
  echo "=== Lines 237-242 (instantiation) ==="
  sed -n '237,242p' src/encoder_base.py
  echo ""
  echo "=== Context around line 237 to understand forward() method ==="
  sed -n '220,250p' src/encoder_base.py
else
  echo "File not found"
fi

Repository: diff-use/WaterFlow

Length of output: 2363


Move ligand_embed initialization to __init__.

The module is created inside forward() after the optimizer has already captured model parameters. Since PyTorch optimizers iterate over model.parameters() at construction time, any submodules registered later in forward() are not included in param_groups and will not be updated. Similarly, DDP and FSDP will not replicate or synchronize these parameters. Move the nn.Linear initialization to __init__, or call optimizer.add_param_group() with the new parameters before training (though full DDP replication may still fail without explicit hooks).

Affects lines 182–190 (declaration) and 237–242 (forward instantiation).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/encoder_base.py` around lines 182 - 190, The ligand_embed Linear is being
created lazily in forward(), so its parameters are not registered when
optimizers/DDP/FSDP capture model parameters; move the nn.Linear initialization
into __init__ by creating self.ligand_embed there (using the given
embedding_dim) and remove the forward-time instantiation, or if embedding_dim
truly isn’t known at construction, add an explicit
initialize_ligand_embed(embedding_dim) method that constructs and registers
self.ligand_embed before the optimizer is created (call it in setup code),
ensuring parameters are registered with model.parameters(); as a last-resort
alternative, if you cannot construct before optimizer, call
optimizer.add_param_group(...) with the new parameters immediately after
creation (note this still may not work with DDP/FSDP), and update references to
ligand_embed in forward() to assume it already exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Ligand Encoding

2 participants