diff --git a/Dockerfile b/Dockerfile index 835300319..fcaeb56b5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,20 @@ -FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04 +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 # metainformation -LABEL org.opencontainers.image.version = "1.0.0" -LABEL org.opencontainers.image.authors = "Gustaf Ahdritz" +LABEL org.opencontainers.image.version = "2.0.0" +LABEL org.opencontainers.image.authors = "OpenFold Team" LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold" LABEL org.opencontainers.image.licenses = "Apache License 2.0" -LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04" +LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04" + +RUN apt-get update && apt-get install -y wget RUN apt-key del 7fa2af80 -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb +RUN dpkg -i cuda-keyring_1.0-1_all.deb + +RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git -RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git RUN wget -P /tmp \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ diff --git a/docs/source/Inference.md b/docs/source/Inference.md index 8c179bca9..1e40f59ff 100644 --- a/docs/source/Inference.md +++ b/docs/source/Inference.md @@ -62,7 +62,7 @@ python3 run_pretrained_openfold.py \ $TEMPLATE_MMCIF_DIR --output_dir $OUTPUT_DIR \ --config_preset model_1_ptm \ - --uniref90_database_path $BASE_DATA_DIR/uniref90 \ + --uniref90_database_path $BASE_DATA_DIR/uniref90/uniref90.fasta \ --mgnify_database_path $BASE_DATA_DIR/mgnify/mgy_clusters_2018_12.fa \ --pdb70_database_path $BASE_DATA_DIR/pdb70 \ --uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ diff --git a/docs/source/Installation.md b/docs/source/Installation.md index f2f90202d..6b9599e83 100644 --- a/docs/source/Installation.md +++ b/docs/source/Installation.md @@ -4,7 +4,7 @@ In this guide, we will OpenFold and its dependencies. **Pre-requisites** -This package is currently supported for CUDA 11 and Pytorch 1.12. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml). To install OpenFold for CUDA 12, please refer to the [Environment specific modifications](#Environment-specific-modifications) section. +This package is currently supported for CUDA 12 and Pytorch 2. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml). At this time, only Linux systems are supported. @@ -53,12 +53,6 @@ Certain tests perform equivalence comparisons with the AlphaFold implementation. ## Environment specific modifications -### CUDA 12 -To use OpenFold on CUDA 12 environment rather than a CUDA 11 environment. - In step 1, use the branch [`pl_upgrades`](https://github.com/aqlaboratory/openfold/tree/pl_upgrades) rather than the main branch, i.e. replace the command in step 1 with `git clone -b pl_upgrades https://github.com/aqlaboratory/openfold.git` - and follow the rest of the steps of [Installation Guide](#Installation) - - ### MPI To use OpenFold with MPI support, you will need to add the package [`mpi4py`](https://pypi.org/project/mpi4py/). This can be done with pip in your OpenFold environment, e.g. `$ pip install mpi4py`. @@ -71,4 +65,4 @@ If you don't have access to `aws` on your system, you can use a different downlo ### Docker setup -A [`Dockerfile`] is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container). +A [`Dockerfile`](https://github.com/aqlaboratory/openfold/blob/main/Dockerfile) is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container). diff --git a/environment.yml b/environment.yml index c5cf4104c..448959007 100644 --- a/environment.yml +++ b/environment.yml @@ -3,36 +3,38 @@ channels: - conda-forge - bioconda - pytorch + - nvidia dependencies: - - python=3.9 - - libgcc=7.2 + - cuda + - gcc=12.4 + - python=3.10 - setuptools=59.5.0 - pip - - openmm=7.7 + - openmm - pdbfixer - pytorch-lightning - biopython - numpy - pandas - - PyYAML==5.4.1 + - PyYAML - requests - - scipy==1.7 - - tqdm==4.62.2 - - typing-extensions==4.0 + - scipy + - tqdm + - typing-extensions - wandb - modelcif==0.7 - awscli - ml-collections - aria2 - - mkl=2024.0 + - mkl - git - - bioconda::hmmer==3.3.2 - - bioconda::hhsuite==3.3.0 - - bioconda::kalign2==2.04 - - bioconda::mmseqs2 - - pytorch::pytorch=1.12.* + - bioconda::hmmer + - bioconda::hhsuite + - bioconda::kalign2 + - pytorch::pytorch=2.5 + - pytorch::pytorch-cuda=12.4 - pip: - - deepspeed==0.12.4 + - deepspeed==0.14.5 - dm-tree==0.1.6 - git+https://github.com/NVIDIA/dllogger.git - - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 + - flash-attn diff --git a/notebooks/OpenFold.ipynb b/notebooks/OpenFold.ipynb index de5d4539c..dfdaa1022 100644 --- a/notebooks/OpenFold.ipynb +++ b/notebooks/OpenFold.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, { "cell_type": "markdown", "metadata": { @@ -107,11 +117,11 @@ "\n", "python_version = f\"{version_info.major}.{version_info.minor}\"\n", "\n", - "\n", - "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", - "os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", + "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh\")\n", + "os.system(\"bash Miniforge3-Linux-x86_64.sh -bfp /usr/local\")\n", + "os.environ[\"PATH\"] = \"/usr/local/bin:\" + os.environ[\"PATH\"]\n", "os.system(\"mamba config --set auto_update_conda false\")\n", - "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.83\")\n", + "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=8.2.0 python={python_version} pdbfixer biopython=1.83\")\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n", "\n", "try:\n", @@ -127,7 +137,7 @@ "\n", " %shell mkdir -p /content/openfold/openfold/resources\n", "\n", - " commit = \"3bec3e9b2d1e8bdb83887899102eff7d42dc2ba9\"\n", + " commit = \"1ffd197489aa5f35a5fbce1f00d7dd49bce1bd2f\"\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", "\n", " os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n", @@ -893,8 +903,7 @@ "metadata": { "colab": { "provenance": [], - "gpuType": "T4", - "toc_visible": true + "gpuType": "T4" }, "kernelspec": { "display_name": "Python 3", diff --git a/openfold/config.py b/openfold/config.py index 7bf30e391..a738b9f07 100644 --- a/openfold/config.py +++ b/openfold/config.py @@ -660,7 +660,7 @@ def model_config( }, "relax": { "max_iterations": 0, # no max - "tolerance": 2.39, + "tolerance": 10.0, "stiffness": 10.0, "max_outer_iterations": 20, "exclude_residues": [], diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index ea38cb34a..c35472539 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -28,7 +28,7 @@ fa_is_installed = importlib.util.find_spec("flash_attn") is not None if fa_is_installed: from flash_attn.bert_padding import unpad_input - from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func import torch import torch.nn as nn @@ -808,10 +808,10 @@ def _flash_attn(q, k, v, kv_mask): # [B_flat, N, 2 * H * C] kv = kv.reshape(*kv.shape[:-3], -1) - kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) + kv_unpad, _, kv_cu_seqlens, kv_max_s, _ = unpad_input(kv, kv_mask) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) - out = flash_attn_unpadded_kvpacked_func( + out = flash_attn_varlen_kvpacked_func( q, kv_unpad, q_cu_seqlens, diff --git a/openfold/np/relax/amber_minimize.py b/openfold/np/relax/amber_minimize.py index 02816bb81..43d9337e4 100644 --- a/openfold/np/relax/amber_minimize.py +++ b/openfold/np/relax/amber_minimize.py @@ -34,6 +34,7 @@ from openmm.app.internal.pdbstructure import PdbStructure ENERGY = unit.kilocalories_per_mole +FORCE = unit.kilojoules_per_mole / unit.nanometer LENGTH = unit.angstroms @@ -439,7 +440,7 @@ def _run_one_iteration( exclude_residues = exclude_residues or [] # Assign physical dimensions. - tolerance = tolerance * ENERGY + tolerance = tolerance * FORCE stiffness = stiffness * ENERGY / (LENGTH ** 2) start = time.perf_counter() diff --git a/openfold/utils/superimposition.py b/openfold/utils/superimposition.py index 9fe794fff..d1dca2718 100644 --- a/openfold/utils/superimposition.py +++ b/openfold/utils/superimposition.py @@ -35,10 +35,10 @@ def _superimpose_np(reference, coords): def _superimpose_single(reference, coords): - reference_np = reference.detach().to(torch.float).cpu().numpy() - coords_np = coords.detach().to(torch.float).cpu().numpy() - superimposed, rmsd = _superimpose_np(reference_np, coords_np) - return coords.new_tensor(superimposed), coords.new_tensor(rmsd) + reference_np = reference.detach().to(torch.float).cpu().numpy() + coords_np = coords.detach().to(torch.float).cpu().numpy() + superimposed, rmsd = _superimpose_np(reference_np, coords_np) + return coords.new_tensor(superimposed), coords.new_tensor(rmsd) def superimpose(reference, coords, mask): diff --git a/scripts/install_third_party_dependencies.sh b/scripts/install_third_party_dependencies.sh index fe2a6a0ba..e9d91002a 100755 --- a/scripts/install_third_party_dependencies.sh +++ b/scripts/install_third_party_dependencies.sh @@ -14,7 +14,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats. python setup.py install echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel" -git clone https://github.com/NVIDIA/cutlass --depth 1 +git clone https://github.com/NVIDIA/cutlass --branch v3.6.0 --depth 1 conda env config vars set CUTLASS_PATH=$PWD/cutlass # This setting is used to fix a worker assignment issue during data loading diff --git a/setup.py b/setup.py index bec986254..3750d9fe9 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ ] extra_cuda_flags = [ - '-std=c++14', + '-std=c++17', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', @@ -52,9 +52,9 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_major, bare_metal_minor compute_capabilities = set([ - (3, 7), # K80, e.g. (5, 2), # Titan X (6, 1), # GeForce 1000-series + (9, 0), # Hopper ]) compute_capabilities.add((7, 0)) @@ -113,7 +113,7 @@ def get_cuda_bare_metal_version(cuda_dir): setup( name='openfold', - version='2.0.0', + version='2.2.0', description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', author='OpenFold Team', author_email='jennifer.wei@omsf.io', @@ -130,7 +130,7 @@ def get_cuda_bare_metal_version(cuda_dir): classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.9,' + 'Programming Language :: Python :: 3.10,' 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], ) diff --git a/tests/test_deepspeed_evo_attention.py b/tests/test_deepspeed_evo_attention.py index 5474f98f8..a65a76317 100644 --- a/tests/test_deepspeed_evo_attention.py +++ b/tests/test_deepspeed_evo_attention.py @@ -306,7 +306,6 @@ def test_compare_model(self): batch["residx_atom37_to_atom14"] = batch[ "residx_atom37_to_atom14" ].long() - # print(batch["target_feat"].shape) batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32) batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch.update( @@ -316,8 +315,9 @@ def test_compare_model(self): # Move the recycling dimension to the end move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) batch = tensor_tree_map(move_dim, batch) - with torch.no_grad(): - with torch.cuda.amp.autocast(dtype=torch.bfloat16): + # Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16 + # https://github.com/aqlaboratory/openfold/issues/532 + with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32): model = compare_utils.get_global_pretrained_openfold() model.globals.use_deepspeed_evo_attention = False out_repro = model(batch) diff --git a/tests/test_model.py b/tests/test_model.py index 3d19f14ed..ecf5af13f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -202,4 +202,4 @@ def run_alphafold(batch): out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro.squeeze(0) - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 1e-3) diff --git a/train_openfold.py b/train_openfold.py index 168a4b43f..c55de9db3 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -21,7 +21,6 @@ from openfold.model.model import AlphaFold from openfold.model.torchscript import script_preset_ from openfold.np import residue_constants -from openfold.utils.argparse_utils import remove_arguments from openfold.utils.callbacks import ( EarlyStoppingVerbose, ) @@ -55,7 +54,7 @@ def __init__(self, config): self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) - + self.cached_weights = None self.last_lr_step = -1 self.save_hyperparameters() @@ -73,7 +72,7 @@ def _log(self, loss_breakdown, batch, outputs, train=True): on_step=train, on_epoch=(not train), logger=True, sync_dist=False, ) - if(train): + if (train): self.log( f"{phase}/{loss_name}_epoch", indiv_loss, @@ -82,12 +81,12 @@ def _log(self, loss_breakdown, batch, outputs, train=True): with torch.no_grad(): other_metrics = self._compute_validation_metrics( - batch, + batch, outputs, superimposition_metrics=(not train) ) - for k,v in other_metrics.items(): + for k, v in other_metrics.items(): self.log( f"{phase}/{k}", torch.mean(v), @@ -96,7 +95,7 @@ def _log(self, loss_breakdown, batch, outputs, train=True): ) def training_step(self, batch, batch_idx): - if(self.ema.device != batch["aatype"].device): + if (self.ema.device != batch["aatype"].device): self.ema.to(batch["aatype"].device) ground_truth = batch.pop('gt_features', None) @@ -127,12 +126,13 @@ def on_before_zero_grad(self, *args, **kwargs): def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights - if(self.cached_weights is None): + if (self.cached_weights is None): # model.state_dict() contains references to model weights rather - # than copies. Therefore, we need to clone them before calling + # than copies. Therefore, we need to clone them before calling # load_state_dict(). - clone_param = lambda t: t.detach().clone() - self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + def clone_param(t): return t.detach().clone() + self.cached_weights = tensor_tree_map( + clone_param, self.model.state_dict()) self.model.load_state_dict(self.ema.state_dict()["params"]) ground_truth = batch.pop('gt_features', None) @@ -160,17 +160,17 @@ def on_validation_epoch_end(self): self.model.load_state_dict(self.cached_weights) self.cached_weights = None - def _compute_validation_metrics(self, - batch, - outputs, - superimposition_metrics=False - ): + def _compute_validation_metrics(self, + batch, + outputs, + superimposition_metrics=False + ): metrics = {} - + gt_coords = batch["all_atom_positions"] pred_coords = outputs["final_atom_positions"] all_atom_mask = batch["all_atom_mask"] - + # This is super janky for superimposition. Fix later gt_coords_masked = gt_coords * all_atom_mask[..., None] pred_coords_masked = pred_coords * all_atom_mask[..., None] @@ -178,7 +178,7 @@ def _compute_validation_metrics(self, gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :] pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :] all_atom_mask_ca = all_atom_mask[..., ca_pos] - + lddt_ca_score = lddt_ca( pred_coords, gt_coords, @@ -186,18 +186,18 @@ def _compute_validation_metrics(self, eps=self.config.globals.eps, per_residue=False, ) - + metrics["lddt_ca"] = lddt_ca_score - + drmsd_ca_score = drmsd( pred_coords_masked_ca, gt_coords_masked_ca, - mask=all_atom_mask_ca, # still required here to compute n + mask=all_atom_mask_ca, # still required here to compute n ) - + metrics["drmsd_ca"] = drmsd_ca_score - - if(superimposition_metrics): + + if (superimposition_metrics): superimposed_pred, alignment_rmsd = superimpose( gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, ) @@ -211,7 +211,7 @@ def _compute_validation_metrics(self, metrics["alignment_rmsd"] = alignment_rmsd metrics["gdt_ts"] = gdt_ts_score metrics["gdt_ha"] = gdt_ha_score - + return metrics def configure_optimizers(self, @@ -220,8 +220,8 @@ def configure_optimizers(self, ) -> torch.optim.Adam: # Ignored as long as a DeepSpeed optimizer is configured optimizer = torch.optim.Adam( - self.model.parameters(), - lr=learning_rate, + self.model.parameters(), + lr=learning_rate, eps=eps ) @@ -246,8 +246,9 @@ def configure_optimizers(self, def on_load_checkpoint(self, checkpoint): ema = checkpoint["ema"] - if(not self.model.template_config.enabled): - ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} + if (not self.model.template_config.enabled): + ema["params"] = {k: v for k, + v in ema["params"].items() if not "template" in k} self.ema.load_state_dict(ema) def on_save_checkpoint(self, checkpoint): @@ -258,13 +259,13 @@ def resume_last_lr_step(self, lr_step): def load_from_jax(self, jax_path): model_basename = os.path.splitext( - os.path.basename( - os.path.normpath(jax_path) - ) + os.path.basename( + os.path.normpath(jax_path) + ) )[0] model_version = "_".join(model_basename.split("_")[1:]) import_jax_weights_( - self.model, jax_path, version=model_version + self.model, jax_path, version=model_version ) def get_model_state_dict_from_ds_checkpoint(checkpoint_dir): @@ -331,30 +332,31 @@ def main(args): if args.resume_from_jax_params: model_module.load_from_jax(args.resume_from_jax_params) - logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") - + logging.info( + f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") + # TorchScript components of the model - if(args.script_modules): + if (args.script_modules): script_preset_(model_module) if "multimer" in args.config_preset: data_module = OpenFoldMultimerDataModule( - config=config.data, - batch_seed=args.seed, - **vars(args) - ) + config=config.data, + batch_seed=args.seed, + **vars(args) + ) else: data_module = OpenFoldDataModule( - config=config.data, + config=config.data, batch_seed=args.seed, **vars(args) ) data_module.prepare_data() data_module.setup() - + callbacks = [] - if(args.checkpoint_every_epoch): + if (args.checkpoint_every_epoch): mc = ModelCheckpoint( every_n_epochs=1, auto_insert_metric_name=False, @@ -362,7 +364,7 @@ def main(args): ) callbacks.append(mc) - if(args.early_stopping): + if (args.early_stopping): es = EarlyStoppingVerbose( monitor="val/lddt_ca", min_delta=args.min_delta, @@ -374,7 +376,7 @@ def main(args): ) callbacks.append(es) - if(args.log_performance): + if (args.log_performance): global_batch_size = args.num_nodes * args.gpus perf = PerformanceLoggingCallback( log_file=os.path.join(args.output_dir, "performance_log.json"), @@ -382,7 +384,7 @@ def main(args): ) callbacks.append(perf) - if(args.log_lr): + if (args.log_lr): lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) @@ -448,7 +450,7 @@ def main(args): ckpt_path = args.resume_from_ckpt trainer.fit( - model_module, + model_module, datamodule=data_module, ckpt_path=ckpt_path, ) @@ -680,22 +682,22 @@ def bool_type(bool_str: str): trainer_group.add_argument( "--reload_dataloaders_every_n_epochs", type=int, default=1, ) - - trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1, - help="Accumulate gradients over k batches before next optimizer step.") + trainer_group.add_argument( + "--accumulate_grad_batches", type=int, default=1, + help="Accumulate gradients over k batches before next optimizer step.") args = parser.parse_args() - if(args.seed is None and - ((args.gpus is not None and args.gpus > 1) or + if (args.seed is None and + ((args.gpus is not None and args.gpus > 1) or (args.num_nodes is not None and args.num_nodes > 1))): raise ValueError("For distributed training, --seed must be specified") - if(str(args.precision) == "16" and args.deepspeed_config_path is not None): + if (str(args.precision) == "16" and args.deepspeed_config_path is not None): raise ValueError("DeepSpeed and FP16 training are not compatible") - if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): - raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") - + if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): + raise ValueError( + "Choose between loading pretrained Jax-weights and a checkpoint-path") main(args)