Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ dependencies = [
"e3nn",
"esm",
"biotite",
"pymol-open-source",
"pymol-open-source-whl>=3.1.0.4",
"scipy",
"pandas",
"numpy",
"rdkit",
"matplotlib",
Comment on lines +14 to 19
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rdkit is added as a required runtime dependency, but it isn’t referenced anywhere in the repo (src/scripts/tests). Since RDKit significantly increases install size and can complicate builds on some platforms, consider removing it until it’s actually needed, or making it an optional extra for ligand-specific workflows.

Copilot uses AI. Check for mistakes.
"pillow",
"loguru",
Expand Down
51 changes: 44 additions & 7 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ def element_onehot(symbols: list[str]) -> Tensor:

def parse_asu_with_biotite(
path: str,
) -> tuple[bts.AtomArray, bts.AtomArray]:
) -> tuple[bts.AtomArray, bts.AtomArray, bts.AtomArray]:
"""
Comment on lines 48 to 51
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_asu_with_biotite() now returns a 3-tuple, which is a breaking API change. There are existing call sites (e.g., embedding generation scripts) that still unpack only 2 values and will crash at runtime unless updated to accept/ignore the new ligand_atoms return value.

Copilot uses AI. Check for mistakes.
Parse PDB file and extract protein and water atoms.
Parse PDB file and extract protein, water, and ligand atoms.

Args:
path: Path to PDB file

Returns:
Tuple of (protein_atoms, water_atoms) as biotite AtomArrays.
Hydrogen atoms are excluded.
Tuple of (protein_atoms, water_atoms, ligand_atoms) as biotite AtomArrays.
Hydrogen atoms are excluded. ligand_atoms contains all non-protein, non-water
HETATM records (small molecules, ions, cofactors).

Notes:
- model=1: Uses first model in PDB (standard for X-ray structures)
Expand All @@ -71,11 +72,13 @@ def parse_asu_with_biotite(

protein_mask = bts.filter_amino_acids(atoms)
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
ligand_mask = ~protein_mask & ~water_mask

protein_atoms = atoms[protein_mask]
water_atoms = atoms[water_mask]
ligand_atoms = atoms[ligand_mask]
Comment on lines 73 to +79
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ligand_mask = ~protein_mask & ~water_mask will classify any non-amino-acid, non-water atoms as ligands, including non-protein polymers recorded as ATOM (e.g., DNA/RNA) and other non-hetero records. This contradicts the docstring claim of “non-water HETATM records” and can silently mix non-ligand atoms into the ligand set. Consider restricting ligands to atoms.hetero (and excluding HOH/WAT), or otherwise explicitly encoding the intended selection criteria.

Copilot uses AI. Check for mistakes.

return protein_atoms, water_atoms
return protein_atoms, water_atoms, ligand_atoms
Comment on lines 73 to +81
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In Biotite's PDB parser, does get_structure()preserve ATOM vs HETATM viaAtomArray.hetero, and is atoms.hetero & ~water_mask the correct way to isolate non-water HETATM ligands?

💡 Result:

Yes, Biotite's PDB parser via get_structure preserves the ATOM vs HETATM distinction using the AtomArray.hetero annotation, which is set to True for HETATM records and False for ATOM records. This is evidenced by the Atom constructor accepting a hetero parameter (result 5, result 14), and examples where hetero=True is explicitly set for added HETATM atoms, with the parser maintaining this in loaded structures (results 1,4,7,8). Additionally, Biotite provides struc.filter_solvent(atoms) to create a water_mask (results 18,19,22), typically identifying water molecules (e.g., res_name "HOH"). Thus, atoms.hetero & ~water_mask correctly isolates non-water HETATM atoms (hetero=True and not solvent), which are typically ligands or other hetero compounds.

Citations:


🏁 Script executed:

# First, check the actual code in src/dataset.py around lines 73-81
cat -n src/dataset.py | sed -n '48,85p'

Repository: diff-use/WaterFlow

Length of output: 1662


🏁 Script executed:

# Also check the function definition and docstring to understand the intended behavior
cat -n src/dataset.py | sed -n '48,100p'

Repository: diff-use/WaterFlow

Length of output: 2332


🏁 Script executed:

# Check if there's any documentation about what ligands should be
grep -n "ligand" src/dataset.py | head -20

Repository: diff-use/WaterFlow

Length of output: 1458


Use atoms.hetero flag to properly isolate HETATM ligands.

The docstring explicitly states that ligand_atoms should contain "non-protein, non-water HETATM records", but the current logic ~protein_mask & ~water_mask only uses negation and will also include ATOM records like nucleic acids or other polymers that are neither proteins nor water. This silently breaks the documented contract and changes node counts for mixed complexes.

Biotite's PDB parser preserves the ATOM vs HETATM distinction via the atoms.hetero flag, so the fix is straightforward:

Suggested fix
     protein_mask = bts.filter_amino_acids(atoms)
     water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
-    ligand_mask = ~protein_mask & ~water_mask
+    hetero_mask = atoms.hetero.astype(bool)
+    ligand_mask = hetero_mask & ~protein_mask & ~water_mask
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` around lines 73 - 81, The current ligand selection uses
ligand_mask = ~protein_mask & ~water_mask which includes any
non-protein/non-water ATOM records; instead restrict ligands to HETATM entries
by using the Biotite flag atoms.hetero: update ligand_mask to also require
atoms.hetero (e.g., ligand_mask = (~protein_mask & ~water_mask) & atoms.hetero)
so ligand_atoms contains only non-protein, non-water HETATM records; keep
protein_mask, water_mask, protein_atoms, water_atoms assignments unchanged and
return protein_atoms, water_atoms, ligand_atoms.



def get_crystal_contacts_pymol(
Expand Down Expand Up @@ -665,6 +668,7 @@ def __init__(
base_pdb_dir: str = "/sb/wankowicz_lab/data/srivasv/pdb_redo_data",
cutoff: float = 8.0,
include_mates: bool = True,
include_ligands: bool = False,
geometry_cache_name: str = "geometry",
preprocess: bool = True,
duplicate_single_sample: int = 1,
Expand All @@ -691,8 +695,12 @@ def __init__(
base_pdb_dir: Base directory containing PDB subdirectories
cutoff: Distance cutoff for PP edges and crystal contacts (Angstroms)
include_mates: If True, include symmetry mate atoms as protein nodes
include_ligands: If True, include non-water het atoms (ligands, ions,
cofactors) as protein nodes. They are appended after
protein (and mate) atoms with a boolean is_ligand mask.
geometry_cache_name: Base name for geometry cache directory. When
include_mates=True, "_mates" is appended automatically.
When include_ligands=True, "_lig" is appended.
Default is "geometry", resulting in "geometry/" or
"geometry_mates/" subdirectories.
preprocess: If True, run preprocessing on missing cached files
Expand Down Expand Up @@ -720,8 +728,10 @@ def __init__(
"""

self.cache_dir = Path(processed_dir)
# Directory-based separation: geometry/ vs geometry_mates/
# Directory-based separation: geometry/ vs geometry_mates/ vs geometry_lig/ etc.
cache_suffix = "_mates" if include_mates else ""
if include_ligands:
cache_suffix += "_lig"
self.geometry_dir = self.cache_dir / f"{geometry_cache_name}{cache_suffix}"
self.base_pdb_dir = Path(base_pdb_dir)
self.cutoff = cutoff
Expand All @@ -731,6 +741,7 @@ def __init__(
else:
self.embedding_dir = None
self.include_mates = include_mates
self.include_ligands = include_ligands
self.duplicate_single_sample = duplicate_single_sample

self.max_com_dist = max_com_dist
Expand Down Expand Up @@ -867,7 +878,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
"""
pdb_path = str(entry["pdb_path"])

protein_atoms, water_atoms = parse_asu_with_biotite(pdb_path)
protein_atoms, water_atoms, ligand_atoms = parse_asu_with_biotite(pdb_path)

# check inter-chain interactions for multi-chain proteins
chain_valid, chain_reason, _ = check_chain_interactions(
Expand Down Expand Up @@ -1046,6 +1057,26 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
final_protein_x = protein_x
final_protein_res_idx = protein_res_idx

# Append ligand atoms after protein (and mate) atoms when enabled.
# is_ligand mask marks which protein-type nodes are ligand atoms.
# Ligands always go last so num_asu_protein and mate counts are unaffected,
# preserving ESM/SLAE embedding alignment via _pad_atom_embeddings_for_mates.
if self.include_ligands and ligand_atoms:
ligand_pos = torch.tensor(ligand_atoms.coord, dtype=torch.float32) - center
ligand_elements = [str(e).upper() for e in ligand_atoms.element]
ligand_x = element_onehot(ligand_elements)
final_protein_pos = torch.cat([final_protein_pos, ligand_pos], dim=0)
final_protein_x = torch.cat([final_protein_x, ligand_x], dim=0)
# Ligand atoms get residue_index = -1 (sentinel; no residue embedding)
ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long)
Comment on lines +1070 to +1071
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning residue_index = -1 for ligand atoms will break any downstream code that uses residue_index for scatter/gather (e.g., residue pooling via torch_scatter), since those ops generally require non-negative indices. Either mask out ligand atoms wherever residue pooling/indexing is used, or map ligands to a valid non-negative index (and update num_residues accordingly) to keep the tensor safe for scatter-based pooling.

Suggested change
# Ligand atoms get residue_index = -1 (sentinel; no residue embedding)
ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long)
# Assign ligands to valid residue indices after existing residues
max_res_idx = (
final_protein_res_idx.max().item()
if final_protein_res_idx.numel() > 0
else -1
)
ligand_res_idx = torch.full(
(len(ligand_atoms),),
max_res_idx + 1,
dtype=torch.long,
)

Copilot uses AI. Check for mistakes.
final_protein_res_idx = torch.cat(
[final_protein_res_idx, ligand_res_idx], dim=0
)
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)
is_ligand[-len(ligand_atoms) :] = True
else:
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)
Comment on lines +1070 to +1078
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find every downstream read of residue_index and confirm negative values are
# masked before indexing, scattering, or pooling.
rg -n --type=py -C3 '\bresidue_index\b'

Repository: diff-use/WaterFlow

Length of output: 7577


🏁 Script executed:

# Check how _pool_by_residue is called and whether ligands are masked
rg -n -A 20 'def forward' src/gvp_encoder.py | head -100

Repository: diff-use/WaterFlow

Length of output: 1902


🏁 Script executed:

# Check if is_ligand is used anywhere to mask before pooling
rg -n 'is_ligand' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Look at the exact lines 1199-1224 in dataset.py mentioned in "Also applies to"
sed -n '1199,1224p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 1126


🏁 Script executed:

# Check PyTorch scatter behavior with negative indices - look for any defensive code
rg -n -B 5 -A 5 'scatter_mean\|scatter_add\|scatter_max' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check make_gvp_encoder_data function - does it exclude ligands?
rg -n -A 30 'def make_gvp_encoder_data' src/

Repository: diff-use/WaterFlow

Length of output: 1662


🏁 Script executed:

# Look at the actual _pool_by_residue implementation more carefully
sed -n '260,290p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1208


🏁 Script executed:

# Check if there's any validation or masking of negative indices in scatter operations
rg -n -B 10 -A 10 'def _pool_by_residue' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 862


🏁 Script executed:

# Test what happens with negative indices in scatter operations - check if there's error handling
python3 << 'EOF'
import torch
from torch_scatter import scatter_mean

# Test if scatter_mean handles -1 indices
try:
    atom_embed = torch.randn(5, 16)
    residue_index = torch.tensor([0, 1, -1, 1, -1], dtype=torch.long)
    result = scatter_mean(atom_embed, residue_index, dim=0, dim_size=2)
    print("scatter_mean with -1 indices succeeded")
    print(f"Result shape: {result.shape}")
    print(f"Result:\n{result}")
except Exception as e:
    print(f"scatter_mean with -1 indices failed: {e}")
EOF

Repository: diff-use/WaterFlow

Length of output: 179


🏁 Script executed:

# Check the complete make_gvp_encoder_data to see if it excludes ligands
sed -n '60,110p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1473


🏁 Script executed:

# Check the forward method to see when _pool_by_residue is called and if ligands are excluded
sed -n '304,360p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 2596


🏁 Script executed:

# Search for any filtering of ligands before pooling
rg -n 'is_ligand' src/gvp_encoder.py -B 5 -A 5

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check dataset.py to see where pool_residue is set and what assumptions it makes
rg -n 'pool_residue' src/

Repository: diff-use/WaterFlow

Length of output: 331


🏁 Script executed:

# Look for any scatter operation masking or negative index handling in the codebase
rg -n 'scatter_mean|scatter_add|scatter_max' src/ -B 3 -A 3

Repository: diff-use/WaterFlow

Length of output: 4617


🏁 Script executed:

# Check if residue_index is passed to make_gvp_encoder_data or the encoder
sed -n '60,100p' src/gvp_encoder.py | grep -n residue_index

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check the call to make_gvp_encoder_data to see if residue_index is preserved
rg -n 'make_gvp_encoder_data' src/ -B 2 -A 5

Repository: diff-use/WaterFlow

Length of output: 895


🏁 Script executed:

# Check the ProteinGVPEncoder forward method to see full implementation
sed -n '488,520p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1216


🏁 Script executed:

# Let's check the full flow - where pool_residue is actually called
rg -n 'pool_residue=True' src/ -B 3 -A 3

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Look more carefully at make_gvp_encoder_data - does it copy residue_index?
sed -n '60,105p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1394


🏁 Script executed:

# Check where ProteinGVPEncoder is instantiated - does it ever set pool_residue=True?
rg -n 'ProteinGVPEncoder\|pool_residue' src/ -B 2 -A 2 | head -60

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check if residue_index is ever added to enc_data after make_gvp_encoder_data
rg -n 'enc_data.residue_index\|enc_data.num_residues' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check the actual encoder instantiation in the GVPEncoder (wrapper) class
sed -n '430,480p' src/gvp_encoder.py

Repository: diff-use/WaterFlow

Length of output: 1600


🏁 Script executed:

# Verify pool_residue is never set to True anywhere
rg -n 'pool_residue.*=.*True\|pool_residue=True' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check what the default pool_residue value is and in what context
rg -n 'pool_residue' src/ -B 1 -A 1

Repository: diff-use/WaterFlow

Length of output: 930


🏁 Script executed:

# Now let's look at the lines around 1070 and 1199-1224 to understand the full context
sed -n '1060,1080p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 1367


🏁 Script executed:

# Check lines 1199-1224 as mentioned in "Also applies to"
sed -n '1190,1230p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 1726


🏁 Script executed:

# Check if residue_index is used anywhere else with negative value handling
rg -n '\-1.*residue|residue.*\-1' src/

Repository: diff-use/WaterFlow

Length of output: 164


🏁 Script executed:

# Check if there are any tests or code that actually enable pool_residue
git log --all --oneline -S 'pool_residue=True' 2>/dev/null || echo "No git log available"
rg -r 'pool_residue\s*=\s*True' src/ tests/ 2>/dev/null || echo "Not found"

Repository: diff-use/WaterFlow

Length of output: 148


🏁 Script executed:

# Check what happens with max() of tensor with -1 values
python3 << 'EOF'
import torch

# Simulate residue_index with -1 values for ligands
res_idx = torch.tensor([0, 1, 2, 1, 0, -1, -1])
max_val = res_idx.max().item()
num_residues = int(max_val + 1)
print(f"residue_index: {res_idx}")
print(f"max value: {max_val}")
print(f"num_residues: {num_residues}")

# Edge case: all ligands
res_idx_ligand_only = torch.tensor([-1, -1, -1])
max_val2 = res_idx_ligand_only.max().item()
num_residues2 = int(max_val2 + 1)
print(f"\nAll ligands case:")
print(f"residue_index: {res_idx_ligand_only}")
print(f"max value: {max_val2}")
print(f"num_residues: {num_residues2}")
EOF

Repository: diff-use/WaterFlow

Length of output: 179


🏁 Script executed:

# Check if residue_index is used anywhere for indexing operations besides scatter
rg -n 'residue_index\[' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Look for any masking of is_ligand before residue operations
rg -n 'is_ligand' src/ -A 3 -B 3 | grep -A 5 -B 5 residue

Repository: diff-use/WaterFlow

Length of output: 712


🏁 Script executed:

# Check all residue_index usages more comprehensively
rg -n 'residue_index' src/ tests/ --type py -C 2

Repository: diff-use/WaterFlow

Length of output: 6025


🏁 Script executed:

# Specifically check if residue_index is ever used as an index (array indexing)
rg -n '\[.*residue_index\]|\[residue_index' src/

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check if there are any comments or TODOs about handling -1 in residue_index
rg -n 'sentinel|-1.*residue|residue.*-1' src/ tests/

Repository: diff-use/WaterFlow

Length of output: 444


🏁 Script executed:

# Check the test at lines 2395-2398 that checks residue indices
sed -n '2390,2410p' src/dataset.py

Repository: diff-use/WaterFlow

Length of output: 44


🏁 Script executed:

# Check if there's any edge case test with ligands
rg -n 'ligand' tests/test_dataset.py | head -20

Repository: diff-use/WaterFlow

Length of output: 1444


🏁 Script executed:

# Final check: are there any other scatter operations that might use residue_index indirectly?
rg -n 'scatter_' src/ -B 2 -A 2

Repository: diff-use/WaterFlow

Length of output: 3766


Verify that _pool_by_residue masks out ligand atoms before scatter operations.

This change introduces -1 sentinel values in residue_index for ligand atoms, but the scatter operations in _pool_by_residue (lines 276–282 in gvp_encoder.py) cannot handle negative indices. Although pool_residue is currently disabled (pool_residue=False by default), if this feature is enabled in the future, the code will fail when calling scatter_mean, scatter_add, or scatter_max with a mixed residue_index containing -1 values.

Add a mask using is_ligand before the scatter operations:

mask = ~is_ligand  # or retrieve from data
atom_embed_masked = atom_embed[mask]
residue_index_masked = residue_index[mask]
res_embed = scatter_mean(atom_embed_masked, residue_index_masked, dim=0, dim_size=num_residues)

Also verify the same masking is applied at any other scatter/pooling operations keyed by residue_index if pool_residue becomes enabled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/dataset.py` around lines 1070 - 1078, _pool_by_residue will fail if
residue_index contains -1 for ligands because scatter operations (scatter_mean,
scatter_add, scatter_max) cannot accept negative indices; before any scatter
keyed by residue_index (in _pool_by_residue and any other pooling guarded by
pool_residue), filter out ligand atoms using the is_ligand mask (e.g., mask =
~is_ligand) and pass atom embeddings and residue_index masked arrays
(atom_embed[mask], residue_index[mask]) to scatter with dim_size=num_residues;
apply the same masking wherever residue_index is used for pooling to avoid
negative-index errors.


# Compute PP edges and features
if final_protein_pos.size(0) > 0:
pp_edge_index = radius_graph(final_protein_pos, r=self.cutoff, loop=False)
Expand All @@ -1071,6 +1102,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
"protein_pos": final_protein_pos,
"protein_x": final_protein_x,
"protein_res_idx": final_protein_res_idx,
"is_ligand": is_ligand,
"water_pos": water_pos,
"water_x": water_x,
# PP topology and features (precomputed)
Expand Down Expand Up @@ -1164,6 +1196,10 @@ def __getitem__(self, idx: int) -> HeteroData:
protein_pos = cached["protein_pos"]
protein_x = cached["protein_x"]
protein_res_idx = cached["protein_res_idx"]
# is_ligand may be absent in older caches (backward compat: all False)
is_ligand = cached.get(
"is_ligand", torch.zeros(protein_pos.size(0), dtype=torch.bool)
)
pp_edge_index = cached["pp_edge_index"]
pp_edge_unit_vectors = cached["pp_edge_unit_vectors"]
pp_edge_rbf = cached["pp_edge_rbf"]
Expand All @@ -1185,6 +1221,7 @@ def __getitem__(self, idx: int) -> HeteroData:
data["protein"].x = protein_x
data["protein"].pos = protein_pos
data["protein"].residue_index = protein_res_idx
data["protein"].is_ligand = is_ligand
data["protein"].num_nodes = protein_pos.size(0)
data["protein"].num_residues = num_residues
data["protein"].num_protein_residues = num_protein_residues
Expand Down
30 changes: 27 additions & 3 deletions src/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch
import torch.nn as nn

from src.constants import NODE_FEATURE_DIM


if TYPE_CHECKING:
from torch_geometric.data import HeteroData
Expand Down Expand Up @@ -177,6 +179,15 @@ def __init__(
self._embedding_dim: int | None = embedding_dim
self._embedding_key = embedding_key
self._encoder_type = encoder_type
# Learnable projection for ligand atoms (element one-hot -> embedding space).
# Ligands have no ESM/SLAE embeddings; this replaces zero-padding with a
# learned representation parameterized only by element type.
# Lazily initialized on first forward when embedding_dim becomes known.
self.ligand_embed: nn.Linear | None = (
nn.Linear(NODE_FEATURE_DIM, embedding_dim, bias=False)
if embedding_dim is not None
else None
)
Comment on lines +182 to +190
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In PyTorch, if an nn.Modulecreates a new submodule or Parameter insideforward() after the optimizer or DDP/FSDP wrapper has already been constructed, will that parameter be optimized and replicated?

💡 Result:

No, parameters or submodules created inside forward after the optimizer and DDP/FSDP wrapper have been constructed will not be optimized or replicated. PyTorch optimizers capture the parameters at construction time via model.parameters, which yields only parameters registered at that moment (those assigned as nn.Parameter attributes before wrapper/optimizer creation). New parameters registered dynamically in forward are not included in this iterator, so the optimizer does not track them and step skips them. DDP and FSDP similarly build their internal state (buckets, hooks, sharding) based on parameters present at construction; new parameters added later are not hooked for gradient synchronization or sharding/replication. Official docs confirm optimizers use param_groups from the initial parameters call, with state keyed by parameter IDs captured then. Forum discussions and GitHub issues (e.g., dynamic parametrizations in FSDP requiring manual ignored_params + DDP) show post-construction parameters need explicit handling (e.g., optimizer.add_param_group), but even then replication may fail due to missed hooks. Creating parameters in forward is discouraged; use init instead.

Citations:


🏁 Script executed:

# Check the file exists and inspect the relevant lines
if [ -f "src/encoder_base.py" ]; then
  echo "=== Lines 182-190 (declaration) ==="
  sed -n '182,190p' src/encoder_base.py
  echo ""
  echo "=== Lines 237-242 (instantiation) ==="
  sed -n '237,242p' src/encoder_base.py
  echo ""
  echo "=== Context around line 237 to understand forward() method ==="
  sed -n '220,250p' src/encoder_base.py
else
  echo "File not found"
fi

Repository: diff-use/WaterFlow

Length of output: 2363


Move ligand_embed initialization to __init__.

The module is created inside forward() after the optimizer has already captured model parameters. Since PyTorch optimizers iterate over model.parameters() at construction time, any submodules registered later in forward() are not included in param_groups and will not be updated. Similarly, DDP and FSDP will not replicate or synchronize these parameters. Move the nn.Linear initialization to __init__, or call optimizer.add_param_group() with the new parameters before training (though full DDP replication may still fail without explicit hooks).

Affects lines 182–190 (declaration) and 237–242 (forward instantiation).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/encoder_base.py` around lines 182 - 190, The ligand_embed Linear is being
created lazily in forward(), so its parameters are not registered when
optimizers/DDP/FSDP capture model parameters; move the nn.Linear initialization
into __init__ by creating self.ligand_embed there (using the given
embedding_dim) and remove the forward-time instantiation, or if embedding_dim
truly isn’t known at construction, add an explicit
initialize_ligand_embed(embedding_dim) method that constructs and registers
self.ligand_embed before the optimizer is created (call it in setup code),
ensuring parameters are registered with model.parameters(); as a last-resort
alternative, if you cannot construct before optimizer, call
optimizer.add_param_group(...) with the new parameters immediately after
creation (note this still may not work with DDP/FSDP), and update references to
ligand_embed in forward() to assume it already exists.


@property
def output_dims(self) -> tuple[int, int]:
Expand All @@ -203,13 +214,15 @@ def forward(
"""
Read cached embeddings and return (s, V, None).

On first call, infers embedding dimension from the data.
On first call, infers embedding dimension from the data. If ligand atoms
are present (data['protein'].is_ligand), their zero-padded embedding rows
are replaced with a learned projection from element one-hot features.

Args:
data: HeteroData with cached embeddings in data['protein']

Returns:
s: (N, embedding_dim) — raw embeddings
s: (N, embedding_dim) — embeddings (ligand rows via learned projection)
V: (N, 0, 3) — empty vector features
pp_edge_attr: None — cached embedding encoders don't process edges
"""
Expand All @@ -221,9 +234,20 @@ def forward(

embeddings = data["protein"][self._embedding_key]

# Infer dimension on first forward
# Infer dimension on first forward and lazily init ligand head
if self._embedding_dim is None:
self._embedding_dim = embeddings.size(-1)
self.ligand_embed = nn.Linear(
NODE_FEATURE_DIM, self._embedding_dim, bias=False
).to(embeddings.device)

# Replace zero-padded ligand rows with learned element projection
lig_mask = getattr(data["protein"], "is_ligand", None)
if lig_mask is not None and lig_mask.any():
embeddings = embeddings.clone()
embeddings[lig_mask] = self.ligand_embed(
data["protein"].x[lig_mask].to(embeddings.device)
)

V = embeddings.new_empty(embeddings.size(0), 0, 3)
return embeddings, V, None
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def pdb_1deu():
return _resolve_pdb_path("1deu")


@pytest.fixture
def pdb_4h0b():
"""4h0b - has non-water ligand HETATMs for ligand support tests."""
return _resolve_pdb_path("4h0b")


# ============== Shared encoder fixtures ==============


Expand Down
Loading