Skip to content

Commit

Permalink
for now, pass in the representative atom index for each residue in ma…
Browse files Browse the repository at this point in the history
…nually
  • Loading branch information
lucidrains committed May 20, 2024
1 parent 1618dff commit 7ef2462
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ msa = torch.randn(2, 7, seq_len, 64)
# required for training, but omitted on inference

atom_pos = torch.randn(2, atom_seq_len, 3)
residue_atom_indices = torch.randint(0, 27, (2, seq_len))

distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
plddt_labels = torch.randint(0, 50, (2, seq_len))
resolved_labels = torch.randint(0, 2, (2, seq_len))

# train

Expand All @@ -60,7 +66,12 @@ loss = alphafold3(
templates = template_feats,
template_mask = template_mask,
atom_pos = atom_pos,
distance_labels = distance_labels
residue_atom_indices = residue_atom_indices,
distance_labels = distance_labels,
pae_labels = pae_labels,
pde_labels = pde_labels,
plddt_labels = plddt_labels,
resolved_labels = resolved_labels
)

loss.backward()
Expand Down
7 changes: 5 additions & 2 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,7 @@ def forward(
templates: Float['b t n n dt'],
template_mask: Bool['b t'],
num_recycling_steps: int = 1,
residue_atom_indices: Int['b n'] | None = None,
num_sample_steps: int | None = None,
atom_pos: Float['b m 3'] | None = None,
distance_labels: Int['b n n'] | None = None,
Expand Down Expand Up @@ -2347,9 +2348,11 @@ def forward(
if should_call_confidence_head:
assert exists(atom_pos), 'diffusion module needs to have been called'

# fix this to accept representative atom indices for each residue
assert exists(residue_atom_indices)

pred_atom_pos = rearrange(denoised_atom_pos, 'b (n w) d -> b n w d', w = w)[..., 0, :]
windowed_denoised_atom_pos = rearrange(denoised_atom_pos, 'b (n w) c -> b n w c', w = w)

pred_atom_pos = einx.get_at('b n [w] c, b n -> b n c', windowed_denoised_atom_pos, residue_atom_indices)

logits = self.confidence_head(
single_repr = single,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.0.3"
version = "0.0.4"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
2 changes: 2 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def test_alphafold3():
msa = torch.randn(2, 7, seq_len, 64)

atom_pos = torch.randn(2, atom_seq_len, 3)
residue_atom_indices = torch.randint(0, 27, (2, seq_len))

distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
Expand All @@ -281,6 +282,7 @@ def test_alphafold3():
templates = template_feats,
template_mask = template_mask,
atom_pos = atom_pos,
residue_atom_indices = residue_atom_indices,
distance_labels = distance_labels,
pae_labels = pae_labels,
pde_labels = pde_labels,
Expand Down

0 comments on commit 7ef2462

Please sign in to comment.