-
Notifications
You must be signed in to change notification settings - Fork 0
Adding ligands to dataset processing #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
38c2e0b
b8568a5
7cf6de1
1529943
bebf742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -47,16 +47,17 @@ def element_onehot(symbols: list[str]) -> Tensor: | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def parse_asu_with_biotite( | ||||||||||||||||||||||||||||
| path: str, | ||||||||||||||||||||||||||||
| ) -> tuple[bts.AtomArray, bts.AtomArray]: | ||||||||||||||||||||||||||||
| ) -> tuple[bts.AtomArray, bts.AtomArray, bts.AtomArray]: | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
Comment on lines
48
to
51
|
||||||||||||||||||||||||||||
| Parse PDB file and extract protein and water atoms. | ||||||||||||||||||||||||||||
| Parse PDB file and extract protein, water, and ligand atoms. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||
| path: Path to PDB file | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||
| Tuple of (protein_atoms, water_atoms) as biotite AtomArrays. | ||||||||||||||||||||||||||||
| Hydrogen atoms are excluded. | ||||||||||||||||||||||||||||
| Tuple of (protein_atoms, water_atoms, ligand_atoms) as biotite AtomArrays. | ||||||||||||||||||||||||||||
| Hydrogen atoms are excluded. ligand_atoms contains all non-protein, non-water | ||||||||||||||||||||||||||||
| HETATM records (small molecules, ions, cofactors). | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Notes: | ||||||||||||||||||||||||||||
| - model=1: Uses first model in PDB (standard for X-ray structures) | ||||||||||||||||||||||||||||
|
|
@@ -71,11 +72,13 @@ def parse_asu_with_biotite( | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| protein_mask = bts.filter_amino_acids(atoms) | ||||||||||||||||||||||||||||
| water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT") | ||||||||||||||||||||||||||||
| ligand_mask = ~protein_mask & ~water_mask | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| protein_atoms = atoms[protein_mask] | ||||||||||||||||||||||||||||
| water_atoms = atoms[water_mask] | ||||||||||||||||||||||||||||
| ligand_atoms = atoms[ligand_mask] | ||||||||||||||||||||||||||||
|
Comment on lines
73
to
+79
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return protein_atoms, water_atoms | ||||||||||||||||||||||||||||
| return protein_atoms, water_atoms, ligand_atoms | ||||||||||||||||||||||||||||
|
Comment on lines
73
to
+81
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: Yes, Biotite's PDB parser via get_structure preserves the ATOM vs HETATM distinction using the AtomArray.hetero annotation, which is set to True for HETATM records and False for ATOM records. This is evidenced by the Atom constructor accepting a hetero parameter (result 5, result 14), and examples where hetero=True is explicitly set for added HETATM atoms, with the parser maintaining this in loaded structures (results 1,4,7,8). Additionally, Biotite provides struc.filter_solvent(atoms) to create a water_mask (results 18,19,22), typically identifying water molecules (e.g., res_name "HOH"). Thus, atoms.hetero & ~water_mask correctly isolates non-water HETATM atoms (hetero=True and not solvent), which are typically ligands or other hetero compounds. Citations:
🏁 Script executed: # First, check the actual code in src/dataset.py around lines 73-81
cat -n src/dataset.py | sed -n '48,85p'Repository: diff-use/WaterFlow Length of output: 1662 🏁 Script executed: # Also check the function definition and docstring to understand the intended behavior
cat -n src/dataset.py | sed -n '48,100p'Repository: diff-use/WaterFlow Length of output: 2332 🏁 Script executed: # Check if there's any documentation about what ligands should be
grep -n "ligand" src/dataset.py | head -20Repository: diff-use/WaterFlow Length of output: 1458 Use The docstring explicitly states that Biotite's PDB parser preserves the ATOM vs HETATM distinction via the Suggested fix protein_mask = bts.filter_amino_acids(atoms)
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
- ligand_mask = ~protein_mask & ~water_mask
+ hetero_mask = atoms.hetero.astype(bool)
+ ligand_mask = hetero_mask & ~protein_mask & ~water_mask🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def get_crystal_contacts_pymol( | ||||||||||||||||||||||||||||
|
|
@@ -665,6 +668,7 @@ def __init__( | |||||||||||||||||||||||||||
| base_pdb_dir: str = "/sb/wankowicz_lab/data/srivasv/pdb_redo_data", | ||||||||||||||||||||||||||||
| cutoff: float = 8.0, | ||||||||||||||||||||||||||||
| include_mates: bool = True, | ||||||||||||||||||||||||||||
| include_ligands: bool = False, | ||||||||||||||||||||||||||||
| geometry_cache_name: str = "geometry", | ||||||||||||||||||||||||||||
| preprocess: bool = True, | ||||||||||||||||||||||||||||
| duplicate_single_sample: int = 1, | ||||||||||||||||||||||||||||
|
|
@@ -691,8 +695,12 @@ def __init__( | |||||||||||||||||||||||||||
| base_pdb_dir: Base directory containing PDB subdirectories | ||||||||||||||||||||||||||||
| cutoff: Distance cutoff for PP edges and crystal contacts (Angstroms) | ||||||||||||||||||||||||||||
| include_mates: If True, include symmetry mate atoms as protein nodes | ||||||||||||||||||||||||||||
| include_ligands: If True, include non-water het atoms (ligands, ions, | ||||||||||||||||||||||||||||
| cofactors) as protein nodes. They are appended after | ||||||||||||||||||||||||||||
| protein (and mate) atoms with a boolean is_ligand mask. | ||||||||||||||||||||||||||||
| geometry_cache_name: Base name for geometry cache directory. When | ||||||||||||||||||||||||||||
| include_mates=True, "_mates" is appended automatically. | ||||||||||||||||||||||||||||
| When include_ligands=True, "_lig" is appended. | ||||||||||||||||||||||||||||
| Default is "geometry", resulting in "geometry/" or | ||||||||||||||||||||||||||||
| "geometry_mates/" subdirectories. | ||||||||||||||||||||||||||||
| preprocess: If True, run preprocessing on missing cached files | ||||||||||||||||||||||||||||
|
|
@@ -720,8 +728,10 @@ def __init__( | |||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| self.cache_dir = Path(processed_dir) | ||||||||||||||||||||||||||||
| # Directory-based separation: geometry/ vs geometry_mates/ | ||||||||||||||||||||||||||||
| # Directory-based separation: geometry/ vs geometry_mates/ vs geometry_lig/ etc. | ||||||||||||||||||||||||||||
| cache_suffix = "_mates" if include_mates else "" | ||||||||||||||||||||||||||||
| if include_ligands: | ||||||||||||||||||||||||||||
| cache_suffix += "_lig" | ||||||||||||||||||||||||||||
| self.geometry_dir = self.cache_dir / f"{geometry_cache_name}{cache_suffix}" | ||||||||||||||||||||||||||||
| self.base_pdb_dir = Path(base_pdb_dir) | ||||||||||||||||||||||||||||
| self.cutoff = cutoff | ||||||||||||||||||||||||||||
|
|
@@ -731,6 +741,7 @@ def __init__( | |||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| self.embedding_dir = None | ||||||||||||||||||||||||||||
| self.include_mates = include_mates | ||||||||||||||||||||||||||||
| self.include_ligands = include_ligands | ||||||||||||||||||||||||||||
| self.duplicate_single_sample = duplicate_single_sample | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| self.max_com_dist = max_com_dist | ||||||||||||||||||||||||||||
|
|
@@ -867,7 +878,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| pdb_path = str(entry["pdb_path"]) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| protein_atoms, water_atoms = parse_asu_with_biotite(pdb_path) | ||||||||||||||||||||||||||||
| protein_atoms, water_atoms, ligand_atoms = parse_asu_with_biotite(pdb_path) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # check inter-chain interactions for multi-chain proteins | ||||||||||||||||||||||||||||
| chain_valid, chain_reason, _ = check_chain_interactions( | ||||||||||||||||||||||||||||
|
|
@@ -1046,6 +1057,26 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |||||||||||||||||||||||||||
| final_protein_x = protein_x | ||||||||||||||||||||||||||||
| final_protein_res_idx = protein_res_idx | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Append ligand atoms after protein (and mate) atoms when enabled. | ||||||||||||||||||||||||||||
| # is_ligand mask marks which protein-type nodes are ligand atoms. | ||||||||||||||||||||||||||||
| # Ligands always go last so num_asu_protein and mate counts are unaffected, | ||||||||||||||||||||||||||||
| # preserving ESM/SLAE embedding alignment via _pad_atom_embeddings_for_mates. | ||||||||||||||||||||||||||||
| if self.include_ligands and ligand_atoms: | ||||||||||||||||||||||||||||
| ligand_pos = torch.tensor(ligand_atoms.coord, dtype=torch.float32) - center | ||||||||||||||||||||||||||||
| ligand_elements = [str(e).upper() for e in ligand_atoms.element] | ||||||||||||||||||||||||||||
| ligand_x = element_onehot(ligand_elements) | ||||||||||||||||||||||||||||
| final_protein_pos = torch.cat([final_protein_pos, ligand_pos], dim=0) | ||||||||||||||||||||||||||||
| final_protein_x = torch.cat([final_protein_x, ligand_x], dim=0) | ||||||||||||||||||||||||||||
| # Ligand atoms get residue_index = -1 (sentinel; no residue embedding) | ||||||||||||||||||||||||||||
| ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long) | ||||||||||||||||||||||||||||
|
Comment on lines
+1070
to
+1071
|
||||||||||||||||||||||||||||
| # Ligand atoms get residue_index = -1 (sentinel; no residue embedding) | |
| ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long) | |
| # Assign ligands to valid residue indices after existing residues | |
| max_res_idx = ( | |
| final_protein_res_idx.max().item() | |
| if final_protein_res_idx.numel() > 0 | |
| else -1 | |
| ) | |
| ligand_res_idx = torch.full( | |
| (len(ligand_atoms),), | |
| max_res_idx + 1, | |
| dtype=torch.long, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find every downstream read of residue_index and confirm negative values are
# masked before indexing, scattering, or pooling.
rg -n --type=py -C3 '\bresidue_index\b'Repository: diff-use/WaterFlow
Length of output: 7577
🏁 Script executed:
# Check how _pool_by_residue is called and whether ligands are masked
rg -n -A 20 'def forward' src/gvp_encoder.py | head -100Repository: diff-use/WaterFlow
Length of output: 1902
🏁 Script executed:
# Check if is_ligand is used anywhere to mask before pooling
rg -n 'is_ligand' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Look at the exact lines 1199-1224 in dataset.py mentioned in "Also applies to"
sed -n '1199,1224p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 1126
🏁 Script executed:
# Check PyTorch scatter behavior with negative indices - look for any defensive code
rg -n -B 5 -A 5 'scatter_mean\|scatter_add\|scatter_max' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check make_gvp_encoder_data function - does it exclude ligands?
rg -n -A 30 'def make_gvp_encoder_data' src/Repository: diff-use/WaterFlow
Length of output: 1662
🏁 Script executed:
# Look at the actual _pool_by_residue implementation more carefully
sed -n '260,290p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1208
🏁 Script executed:
# Check if there's any validation or masking of negative indices in scatter operations
rg -n -B 10 -A 10 'def _pool_by_residue' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 862
🏁 Script executed:
# Test what happens with negative indices in scatter operations - check if there's error handling
python3 << 'EOF'
import torch
from torch_scatter import scatter_mean
# Test if scatter_mean handles -1 indices
try:
atom_embed = torch.randn(5, 16)
residue_index = torch.tensor([0, 1, -1, 1, -1], dtype=torch.long)
result = scatter_mean(atom_embed, residue_index, dim=0, dim_size=2)
print("scatter_mean with -1 indices succeeded")
print(f"Result shape: {result.shape}")
print(f"Result:\n{result}")
except Exception as e:
print(f"scatter_mean with -1 indices failed: {e}")
EOFRepository: diff-use/WaterFlow
Length of output: 179
🏁 Script executed:
# Check the complete make_gvp_encoder_data to see if it excludes ligands
sed -n '60,110p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1473
🏁 Script executed:
# Check the forward method to see when _pool_by_residue is called and if ligands are excluded
sed -n '304,360p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 2596
🏁 Script executed:
# Search for any filtering of ligands before pooling
rg -n 'is_ligand' src/gvp_encoder.py -B 5 -A 5Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check dataset.py to see where pool_residue is set and what assumptions it makes
rg -n 'pool_residue' src/Repository: diff-use/WaterFlow
Length of output: 331
🏁 Script executed:
# Look for any scatter operation masking or negative index handling in the codebase
rg -n 'scatter_mean|scatter_add|scatter_max' src/ -B 3 -A 3Repository: diff-use/WaterFlow
Length of output: 4617
🏁 Script executed:
# Check if residue_index is passed to make_gvp_encoder_data or the encoder
sed -n '60,100p' src/gvp_encoder.py | grep -n residue_indexRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check the call to make_gvp_encoder_data to see if residue_index is preserved
rg -n 'make_gvp_encoder_data' src/ -B 2 -A 5Repository: diff-use/WaterFlow
Length of output: 895
🏁 Script executed:
# Check the ProteinGVPEncoder forward method to see full implementation
sed -n '488,520p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1216
🏁 Script executed:
# Let's check the full flow - where pool_residue is actually called
rg -n 'pool_residue=True' src/ -B 3 -A 3Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Look more carefully at make_gvp_encoder_data - does it copy residue_index?
sed -n '60,105p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1394
🏁 Script executed:
# Check where ProteinGVPEncoder is instantiated - does it ever set pool_residue=True?
rg -n 'ProteinGVPEncoder\|pool_residue' src/ -B 2 -A 2 | head -60Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check if residue_index is ever added to enc_data after make_gvp_encoder_data
rg -n 'enc_data.residue_index\|enc_data.num_residues' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check the actual encoder instantiation in the GVPEncoder (wrapper) class
sed -n '430,480p' src/gvp_encoder.pyRepository: diff-use/WaterFlow
Length of output: 1600
🏁 Script executed:
# Verify pool_residue is never set to True anywhere
rg -n 'pool_residue.*=.*True\|pool_residue=True' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check what the default pool_residue value is and in what context
rg -n 'pool_residue' src/ -B 1 -A 1Repository: diff-use/WaterFlow
Length of output: 930
🏁 Script executed:
# Now let's look at the lines around 1070 and 1199-1224 to understand the full context
sed -n '1060,1080p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 1367
🏁 Script executed:
# Check lines 1199-1224 as mentioned in "Also applies to"
sed -n '1190,1230p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 1726
🏁 Script executed:
# Check if residue_index is used anywhere else with negative value handling
rg -n '\-1.*residue|residue.*\-1' src/Repository: diff-use/WaterFlow
Length of output: 164
🏁 Script executed:
# Check if there are any tests or code that actually enable pool_residue
git log --all --oneline -S 'pool_residue=True' 2>/dev/null || echo "No git log available"
rg -r 'pool_residue\s*=\s*True' src/ tests/ 2>/dev/null || echo "Not found"Repository: diff-use/WaterFlow
Length of output: 148
🏁 Script executed:
# Check what happens with max() of tensor with -1 values
python3 << 'EOF'
import torch
# Simulate residue_index with -1 values for ligands
res_idx = torch.tensor([0, 1, 2, 1, 0, -1, -1])
max_val = res_idx.max().item()
num_residues = int(max_val + 1)
print(f"residue_index: {res_idx}")
print(f"max value: {max_val}")
print(f"num_residues: {num_residues}")
# Edge case: all ligands
res_idx_ligand_only = torch.tensor([-1, -1, -1])
max_val2 = res_idx_ligand_only.max().item()
num_residues2 = int(max_val2 + 1)
print(f"\nAll ligands case:")
print(f"residue_index: {res_idx_ligand_only}")
print(f"max value: {max_val2}")
print(f"num_residues: {num_residues2}")
EOFRepository: diff-use/WaterFlow
Length of output: 179
🏁 Script executed:
# Check if residue_index is used anywhere for indexing operations besides scatter
rg -n 'residue_index\[' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Look for any masking of is_ligand before residue operations
rg -n 'is_ligand' src/ -A 3 -B 3 | grep -A 5 -B 5 residueRepository: diff-use/WaterFlow
Length of output: 712
🏁 Script executed:
# Check all residue_index usages more comprehensively
rg -n 'residue_index' src/ tests/ --type py -C 2Repository: diff-use/WaterFlow
Length of output: 6025
🏁 Script executed:
# Specifically check if residue_index is ever used as an index (array indexing)
rg -n '\[.*residue_index\]|\[residue_index' src/Repository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check if there are any comments or TODOs about handling -1 in residue_index
rg -n 'sentinel|-1.*residue|residue.*-1' src/ tests/Repository: diff-use/WaterFlow
Length of output: 444
🏁 Script executed:
# Check the test at lines 2395-2398 that checks residue indices
sed -n '2390,2410p' src/dataset.pyRepository: diff-use/WaterFlow
Length of output: 44
🏁 Script executed:
# Check if there's any edge case test with ligands
rg -n 'ligand' tests/test_dataset.py | head -20Repository: diff-use/WaterFlow
Length of output: 1444
🏁 Script executed:
# Final check: are there any other scatter operations that might use residue_index indirectly?
rg -n 'scatter_' src/ -B 2 -A 2Repository: diff-use/WaterFlow
Length of output: 3766
Verify that _pool_by_residue masks out ligand atoms before scatter operations.
This change introduces -1 sentinel values in residue_index for ligand atoms, but the scatter operations in _pool_by_residue (lines 276–282 in gvp_encoder.py) cannot handle negative indices. Although pool_residue is currently disabled (pool_residue=False by default), if this feature is enabled in the future, the code will fail when calling scatter_mean, scatter_add, or scatter_max with a mixed residue_index containing -1 values.
Add a mask using is_ligand before the scatter operations:
mask = ~is_ligand # or retrieve from data
atom_embed_masked = atom_embed[mask]
residue_index_masked = residue_index[mask]
res_embed = scatter_mean(atom_embed_masked, residue_index_masked, dim=0, dim_size=num_residues)Also verify the same masking is applied at any other scatter/pooling operations keyed by residue_index if pool_residue becomes enabled.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/dataset.py` around lines 1070 - 1078, _pool_by_residue will fail if
residue_index contains -1 for ligands because scatter operations (scatter_mean,
scatter_add, scatter_max) cannot accept negative indices; before any scatter
keyed by residue_index (in _pool_by_residue and any other pooling guarded by
pool_residue), filter out ligand atoms using the is_ligand mask (e.g., mask =
~is_ligand) and pass atom embeddings and residue_index masked arrays
(atom_embed[mask], residue_index[mask]) to scatter with dim_size=num_residues;
apply the same masking wherever residue_index is used for pooling to avoid
negative-index errors.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from src.constants import NODE_FEATURE_DIM | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from torch_geometric.data import HeteroData | ||
|
|
@@ -177,6 +179,15 @@ def __init__( | |
| self._embedding_dim: int | None = embedding_dim | ||
| self._embedding_key = embedding_key | ||
| self._encoder_type = encoder_type | ||
| # Learnable projection for ligand atoms (element one-hot -> embedding space). | ||
| # Ligands have no ESM/SLAE embeddings; this replaces zero-padding with a | ||
| # learned representation parameterized only by element type. | ||
| # Lazily initialized on first forward when embedding_dim becomes known. | ||
| self.ligand_embed: nn.Linear | None = ( | ||
| nn.Linear(NODE_FEATURE_DIM, embedding_dim, bias=False) | ||
| if embedding_dim is not None | ||
| else None | ||
| ) | ||
|
Comment on lines
+182
to
+190
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: No, parameters or submodules created inside forward after the optimizer and DDP/FSDP wrapper have been constructed will not be optimized or replicated. PyTorch optimizers capture the parameters at construction time via model.parameters, which yields only parameters registered at that moment (those assigned as nn.Parameter attributes before wrapper/optimizer creation). New parameters registered dynamically in forward are not included in this iterator, so the optimizer does not track them and step skips them. DDP and FSDP similarly build their internal state (buckets, hooks, sharding) based on parameters present at construction; new parameters added later are not hooked for gradient synchronization or sharding/replication. Official docs confirm optimizers use param_groups from the initial parameters call, with state keyed by parameter IDs captured then. Forum discussions and GitHub issues (e.g., dynamic parametrizations in FSDP requiring manual ignored_params + DDP) show post-construction parameters need explicit handling (e.g., optimizer.add_param_group), but even then replication may fail due to missed hooks. Creating parameters in forward is discouraged; use init instead. Citations:
🏁 Script executed: # Check the file exists and inspect the relevant lines
if [ -f "src/encoder_base.py" ]; then
echo "=== Lines 182-190 (declaration) ==="
sed -n '182,190p' src/encoder_base.py
echo ""
echo "=== Lines 237-242 (instantiation) ==="
sed -n '237,242p' src/encoder_base.py
echo ""
echo "=== Context around line 237 to understand forward() method ==="
sed -n '220,250p' src/encoder_base.py
else
echo "File not found"
fiRepository: diff-use/WaterFlow Length of output: 2363 Move The module is created inside Affects lines 182–190 (declaration) and 237–242 (forward instantiation). 🤖 Prompt for AI Agents |
||
|
|
||
| @property | ||
| def output_dims(self) -> tuple[int, int]: | ||
|
|
@@ -203,13 +214,15 @@ def forward( | |
| """ | ||
| Read cached embeddings and return (s, V, None). | ||
|
|
||
| On first call, infers embedding dimension from the data. | ||
| On first call, infers embedding dimension from the data. If ligand atoms | ||
| are present (data['protein'].is_ligand), their zero-padded embedding rows | ||
| are replaced with a learned projection from element one-hot features. | ||
|
|
||
| Args: | ||
| data: HeteroData with cached embeddings in data['protein'] | ||
|
|
||
| Returns: | ||
| s: (N, embedding_dim) — raw embeddings | ||
| s: (N, embedding_dim) — embeddings (ligand rows via learned projection) | ||
| V: (N, 0, 3) — empty vector features | ||
| pp_edge_attr: None — cached embedding encoders don't process edges | ||
| """ | ||
|
|
@@ -221,9 +234,20 @@ def forward( | |
|
|
||
| embeddings = data["protein"][self._embedding_key] | ||
|
|
||
| # Infer dimension on first forward | ||
| # Infer dimension on first forward and lazily init ligand head | ||
| if self._embedding_dim is None: | ||
| self._embedding_dim = embeddings.size(-1) | ||
| self.ligand_embed = nn.Linear( | ||
| NODE_FEATURE_DIM, self._embedding_dim, bias=False | ||
| ).to(embeddings.device) | ||
|
|
||
| # Replace zero-padded ligand rows with learned element projection | ||
| lig_mask = getattr(data["protein"], "is_ligand", None) | ||
| if lig_mask is not None and lig_mask.any(): | ||
| embeddings = embeddings.clone() | ||
| embeddings[lig_mask] = self.ligand_embed( | ||
| data["protein"].x[lig_mask].to(embeddings.device) | ||
| ) | ||
|
|
||
| V = embeddings.new_empty(embeddings.size(0), 0, 3) | ||
| return embeddings, V, None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rdkitis added as a required runtime dependency, but it isn’t referenced anywhere in the repo (src/scripts/tests). Since RDKit significantly increases install size and can complicate builds on some platforms, consider removing it until it’s actually needed, or making it an optional extra for ligand-specific workflows.