Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
2ed8115
initial compatibility changes for upgrading multimer
jnwei Jan 12, 2024
201eafd
np type update in openfold.np.relax
jnwei Jan 16, 2024
e71c1b1
initial compatibility changes for upgrading multimer
jnwei Jan 12, 2024
91776cd
np type update in openfold.np.relax
jnwei Jan 16, 2024
427a6ee
update deprecated jax.numpy.DeviceArray to jax.Array
jnwei Jan 23, 2024
9c94078
Merge branch 'multimer-pytorch-update' of https://github.com/aqlabora…
jnwei Jan 23, 2024
e813bb5
Additional fix for multimer deepspeed test
christinaflo Jan 23, 2024
df4dfac
first pass changes to run with pl 2.1
jnwei Jan 24, 2024
456103d
initial compatibility changes for upgrading multimer
jnwei Jan 12, 2024
ff36800
first pass changes to run with pl 2.1
jnwei Jan 24, 2024
f0fc7d9
merging changes from main
jnwei Feb 19, 2024
49ab053
Merge pull request #407 from jnwei/pl_upgrades
jnwei Feb 19, 2024
5f5a79a
initial compatibility changes for upgrading multimer
jnwei Jan 12, 2024
6dc34d7
first pass changes to run with pl 2.1
jnwei Jan 24, 2024
36cd9eb
Merge branch 'pl_upgrades' of https://github.com/jnwei/openfold into …
jnwei Mar 19, 2024
a317ad2
superimposition fix from Aymen
jnwei Mar 21, 2024
cfd2e71
seed workers fix and validation_epoch_end extra argument
jnwei Mar 25, 2024
0c3435c
add metric logging to progress bar.
jnwei Mar 25, 2024
5ff5177
more logging changes
jnwei Mar 27, 2024
8626358
add paren to save_hyperparameters
jnwei Apr 2, 2024
577219c
Removes OF copy of zero_to_fp32.py favoring deepspeed.util version
jnwei Apr 2, 2024
523adaf
adds reload_dataloaders_every_n_epochs flag
jnwei Apr 11, 2024
80e6341
change message for test_model.py compare
jnwei Apr 17, 2024
1ae833b
Updates low_precision check to use current precision settings.
jnwei Apr 19, 2024
ea142a0
fixes deepspeed function definition.
jnwei Apr 19, 2024
5ccb7de
updates Dockerfile
jnwei Apr 19, 2024
3cab807
fix mkl version to 2024.0.0
jnwei Apr 19, 2024
866477a
Update gpg keys for Docker build
jnwei Apr 20, 2024
addb80a
changes to Dockerfile ane pin mkl to 2024
jnwei Apr 20, 2024
793eb96
adjust pytorch version number
jnwei Apr 20, 2024
ad34fc3
updates Bio.PDBData call and environment.yml
jnwei Apr 22, 2024
ed5261f
Split cuda install commands in Dockerfile
jnwei Apr 22, 2024
0b11ced
change mamba version
jnwei Apr 22, 2024
1d22373
upgrading hmmer hhsuite and kalign2 packages
jnwei Apr 22, 2024
cf0cc8b
small edit to Dockerfile
jnwei Apr 22, 2024
12eb81b
Reset miniforge version to 23.3.1-1
jnwei Apr 22, 2024
4ee9943
Remove nvcc compute capability 37 which caused kernel build issues
jnwei Apr 23, 2024
76fb7ce
remove test print statements
jnwei May 6, 2024
a51b08c
initial compatibility changes for upgrading multimer
jnwei Jan 12, 2024
7de0ab0
first pass changes to run with pl 2.1
jnwei Jan 24, 2024
4f2f069
Additional fix for multimer deepspeed test
christinaflo Jan 23, 2024
53cdb24
add metric logging to progress bar.
jnwei Mar 25, 2024
3eed6cb
more logging changes
jnwei Mar 27, 2024
19c8158
adds reload_dataloaders_every_n_epochs flag
jnwei Apr 11, 2024
cdf6039
change message for test_model.py compare
jnwei Apr 17, 2024
0a8ae6a
Updates low_precision check to use current precision settings.
jnwei Apr 19, 2024
ee502a2
updates Dockerfile
jnwei Apr 19, 2024
4631b54
Update gpg keys for Docker build
jnwei Apr 20, 2024
1b7f8f4
changes to Dockerfile ane pin mkl to 2024
jnwei Apr 20, 2024
8d816e3
adjust pytorch version number
jnwei Apr 20, 2024
571ae26
updates Bio.PDBData call and environment.yml
jnwei Apr 22, 2024
6b9c61c
Split cuda install commands in Dockerfile
jnwei Apr 22, 2024
435ec2f
change mamba version
jnwei Apr 22, 2024
2b2f7b7
upgrading hmmer hhsuite and kalign2 packages
jnwei Apr 22, 2024
52d4bb8
small edit to Dockerfile
jnwei Apr 22, 2024
4172f34
Reset miniforge version to 23.3.1-1
jnwei Apr 22, 2024
e3e09c4
Remove nvcc compute capability 37 which caused kernel build issues
jnwei Apr 23, 2024
c715b13
Merge remote-tracking branch 'refs/remotes/jnwei/pl_upgrades' into pl…
jnwei May 6, 2024
f10f662
pins mkl version to 2022 to avoid conda environment conflict
jnwei May 6, 2024
ed69f06
make space for docker CI
jnwei May 6, 2024
12eae13
Update docker-image.yml
jnwei May 11, 2024
0eaf08a
make sure padded asym_id won't affect permutation steps
dingquanyu Feb 15, 2024
54ec5c4
fixed bugs in unittests for multi-chain permutation. now working on e…
dingquanyu Feb 15, 2024
b542701
remove unnecessary lines
dingquanyu Feb 15, 2024
2669287
restore to the verison on main
dingquanyu Feb 15, 2024
9f964fe
added typing hints and fixed some comments
dingquanyu Feb 16, 2024
939fd0a
make sure no padded features are going to be selected as anchors
dingquanyu Feb 20, 2024
9597368
fixed typing errors; added more comments
dingquanyu Mar 21, 2024
515b082
added comments
dingquanyu Mar 21, 2024
eb262d2
update comments;fixed typos
dingquanyu May 10, 2024
ad414ec
Update tests and comments
dingquanyu May 10, 2024
dc2da1f
fixed typing error of anchor_gt_residue
dingquanyu May 10, 2024
5221ed4
Update test_permutation.py
jnwei May 11, 2024
26f8761
Initial commit for sphinx documentation.
jnwei Mar 20, 2024
4873c02
Rough draft dump of docs and readthedocs build
jnwei May 8, 2024
0b724be
fix typo in readthedocs.yaml
jnwei May 8, 2024
d4a14b6
replace doc environment pip dependencies with conda builds
jnwei May 8, 2024
9f1e0a8
cleanup makefiles and original readme
jnwei May 8, 2024
9c98e57
updates to Inference.md
jnwei May 8, 2024
44e5733
Add addtional inference pages
jnwei May 8, 2024
6a52cc4
add convert v1 weights instructions
jnwei May 8, 2024
d64dffd
Adds FAQ section
jnwei May 8, 2024
f1175ab
creates link to FAQ in documentation
jnwei May 8, 2024
6261c95
small edits to main page
jnwei May 8, 2024
55c1e0e
minor language edits
jnwei May 8, 2024
a744abe
Adds mkl version to environment.yml
jnwei Apr 30, 2024
89c756d
make space for docker CI
jnwei May 6, 2024
78644cd
Shorten README.md main page.
jnwei May 9, 2024
e338f20
adds mmseqs2 to environment.yml for clustering
jnwei May 9, 2024
0b30bb8
Update training OpenFold docs with correct paths.
jnwei May 10, 2024
cc565fd
Adds example directory
jnwei May 10, 2024
3ed09c6
Update docker-image.yml
jnwei May 13, 2024
e5ce219
Update docker-image.yml
jnwei May 13, 2024
a8c61c6
Merge branch 'main' into pl_upgrades
jnwei May 13, 2024
1647ec9
fix typo in environment.yml
jnwei May 13, 2024
dc93d33
remove mpipy from pip install requirements
jnwei May 13, 2024
c07075c
in scripts/utils.py account for case where no conda environment is sp…
jnwei May 13, 2024
3bec3e9
Merge pull request #438 from jnwei/pl_upgrades
jnwei May 13, 2024
6bdbd48
Fix pl_upgrades enwironment (numpy, cuda, gcc)
vaclavhanzl Oct 22, 2024
23cf2f6
Merge pull request #496 from vaclavhanzl/vh-fix-pl_upgrades-env
jnwei Nov 7, 2024
100a309
Maintainance to pl_upgrades
jnwei Apr 23, 2025
ab4a245
Merge branch 'main' into pl_upgrades
jnwei Apr 23, 2025
0c2d455
fix environment to support tests
jnwei Apr 23, 2025
9caf30a
Change casting for deepspeed compare model test to fp32
jnwei Apr 24, 2025
7e06ed9
support openmm>8 and fix tolerance units in amber minimization
jnwei Apr 24, 2025
cb899a5
Merge pull request #2 from aqlaboratory/pl_upgrades
jnwei Apr 24, 2025
da37880
Allow numpy>2 and support compute capability >9
jnwei Apr 24, 2025
4312aec
Add link to issue for deepspeed_evo_attention test.
jnwei Apr 25, 2025
0672517
Update installation docs to build CUDA12 version
jnwei Apr 25, 2025
16af434
fix inference documentation
jnwei Apr 25, 2025
fe10216
update version number.
jnwei Apr 25, 2025
a5433c3
Update config.py
jnwei Apr 25, 2025
620a54f
Update amber_minimize.py
jnwei Apr 25, 2025
1ffd197
Merge pull request #533 from jnwei/pl_upgrades
jnwei Apr 25, 2025
50a2e75
Update OpenFold notebook to updated pytorch2 commit
jnwei Apr 26, 2025
c587b06
add Open In Colab banner to notebook.
jnwei Apr 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04

# metainformation
LABEL org.opencontainers.image.version = "1.0.0"
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz"
LABEL org.opencontainers.image.version = "2.0.0"
LABEL org.opencontainers.image.authors = "OpenFold Team"
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04"

RUN apt-get update && apt-get install -y wget

RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb
RUN dpkg -i cuda-keyring_1.0-1_all.deb

RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git

RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
Expand Down
2 changes: 1 addition & 1 deletion docs/source/Inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ python3 run_pretrained_openfold.py \
$TEMPLATE_MMCIF_DIR
--output_dir $OUTPUT_DIR \
--config_preset model_1_ptm \
--uniref90_database_path $BASE_DATA_DIR/uniref90 \
--uniref90_database_path $BASE_DATA_DIR/uniref90/uniref90.fasta \
--mgnify_database_path $BASE_DATA_DIR/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path $BASE_DATA_DIR/pdb70 \
--uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
Expand Down
10 changes: 2 additions & 8 deletions docs/source/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ In this guide, we will OpenFold and its dependencies.

**Pre-requisites**

This package is currently supported for CUDA 11 and Pytorch 1.12. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml). To install OpenFold for CUDA 12, please refer to the [Environment specific modifications](#Environment-specific-modifications) section.
This package is currently supported for CUDA 12 and Pytorch 2. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml).

At this time, only Linux systems are supported.

Expand Down Expand Up @@ -53,12 +53,6 @@ Certain tests perform equivalence comparisons with the AlphaFold implementation.

## Environment specific modifications

### CUDA 12
To use OpenFold on CUDA 12 environment rather than a CUDA 11 environment.
In step 1, use the branch [`pl_upgrades`](https://github.com/aqlaboratory/openfold/tree/pl_upgrades) rather than the main branch, i.e. replace the command in step 1 with `git clone -b pl_upgrades https://github.com/aqlaboratory/openfold.git`
and follow the rest of the steps of [Installation Guide](#Installation)


### MPI
To use OpenFold with MPI support, you will need to add the package [`mpi4py`](https://pypi.org/project/mpi4py/). This can be done with pip in your OpenFold environment, e.g. `$ pip install mpi4py`.

Expand All @@ -71,4 +65,4 @@ If you don't have access to `aws` on your system, you can use a different downlo

### Docker setup

A [`Dockerfile`] is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container).
A [`Dockerfile`](https://github.com/aqlaboratory/openfold/blob/main/Dockerfile) is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container).
32 changes: 17 additions & 15 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,38 @@ channels:
- conda-forge
- bioconda
- pytorch
- nvidia
dependencies:
- python=3.9
- libgcc=7.2
- cuda
- gcc=12.4
- python=3.10
- setuptools=59.5.0
- pip
- openmm=7.7
- openmm
- pdbfixer
- pytorch-lightning
- biopython
- numpy
- pandas
- PyYAML==5.4.1
- PyYAML
- requests
- scipy==1.7
- tqdm==4.62.2
- typing-extensions==4.0
- scipy
- tqdm
- typing-extensions
- wandb
- modelcif==0.7
- awscli
- ml-collections
- aria2
- mkl=2024.0
- mkl
- git
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
- bioconda::mmseqs2
- pytorch::pytorch=1.12.*
- bioconda::hmmer
- bioconda::hhsuite
- bioconda::kalign2
- pytorch::pytorch=2.5
- pytorch::pytorch-cuda=12.4
- pip:
- deepspeed==0.12.4
- deepspeed==0.14.5
- dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
- flash-attn
23 changes: 16 additions & 7 deletions notebooks/OpenFold.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/OpenFold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -107,11 +117,11 @@
"\n",
"python_version = f\"{version_info.major}.{version_info.minor}\"\n",
"\n",
"\n",
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh\")\n",
"os.system(\"bash Miniforge3-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.environ[\"PATH\"] = \"/usr/local/bin:\" + os.environ[\"PATH\"]\n",
"os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.83\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=8.2.0 python={python_version} pdbfixer biopython=1.83\")\n",
"os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n",
"try:\n",
Expand All @@ -127,7 +137,7 @@
"\n",
" %shell mkdir -p /content/openfold/openfold/resources\n",
"\n",
" commit = \"3bec3e9b2d1e8bdb83887899102eff7d42dc2ba9\"\n",
" commit = \"1ffd197489aa5f35a5fbce1f00d7dd49bce1bd2f\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
"\n",
" os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
Expand Down Expand Up @@ -893,8 +903,7 @@
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"toc_visible": true
"gpuType": "T4"
},
"kernelspec": {
"display_name": "Python 3",
Expand Down
2 changes: 1 addition & 1 deletion openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def model_config(
},
"relax": {
"max_iterations": 0, # no max
"tolerance": 2.39,
"tolerance": 10.0,
"stiffness": 10.0,
"max_outer_iterations": 20,
"exclude_residues": [],
Expand Down
6 changes: 3 additions & 3 deletions openfold/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func

import torch
import torch.nn as nn
Expand Down Expand Up @@ -808,10 +808,10 @@ def _flash_attn(q, k, v, kv_mask):
# [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1)

kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad, _, kv_cu_seqlens, kv_max_s, _ = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])

out = flash_attn_unpadded_kvpacked_func(
out = flash_attn_varlen_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
Expand Down
3 changes: 2 additions & 1 deletion openfold/np/relax/amber_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from openmm.app.internal.pdbstructure import PdbStructure

ENERGY = unit.kilocalories_per_mole
FORCE = unit.kilojoules_per_mole / unit.nanometer
LENGTH = unit.angstroms


Expand Down Expand Up @@ -439,7 +440,7 @@ def _run_one_iteration(
exclude_residues = exclude_residues or []

# Assign physical dimensions.
tolerance = tolerance * ENERGY
tolerance = tolerance * FORCE
stiffness = stiffness * ENERGY / (LENGTH ** 2)

start = time.perf_counter()
Expand Down
8 changes: 4 additions & 4 deletions openfold/utils/superimposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def _superimpose_np(reference, coords):


def _superimpose_single(reference, coords):
reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)


def superimpose(reference, coords, mask):
Expand Down
2 changes: 1 addition & 1 deletion scripts/install_third_party_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.
python setup.py install

echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1
git clone https://github.com/NVIDIA/cutlass --branch v3.6.0 --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass

# This setting is used to fix a worker assignment issue during data loading
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
]

extra_cuda_flags = [
'-std=c++14',
'-std=c++17',
'-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
Expand All @@ -52,9 +52,9 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor

compute_capabilities = set([
(3, 7), # K80, e.g.
(5, 2), # Titan X
(6, 1), # GeForce 1000-series
(9, 0), # Hopper
])

compute_capabilities.add((7, 0))
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_cuda_bare_metal_version(cuda_dir):

setup(
name='openfold',
version='2.0.0',
version='2.2.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='OpenFold Team',
author_email='jennifer.wei@omsf.io',
Expand All @@ -130,7 +130,7 @@ def get_cuda_bare_metal_version(cuda_dir):
classifiers=[
'License :: OSI Approved :: Apache Software License',
'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.9,'
'Programming Language :: Python :: 3.10,'
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
)
6 changes: 3 additions & 3 deletions tests/test_deepspeed_evo_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def test_compare_model(self):
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
# print(batch["target_feat"].shape)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
Expand All @@ -316,8 +315,9 @@ def test_compare_model(self):
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16
# https://github.com/aqlaboratory/openfold/issues/532
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model(batch)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,4 @@ def run_alphafold(batch):
out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0)

self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 1e-3)
Loading
Loading