Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Amgen's AMPLIFY Port #442

Draft
wants to merge 116 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
9755977
adding bionemo-scclip to sub-packages
ynashed Sep 11, 2024
843ce40
Initial model change
ynashed Sep 30, 2024
9fe70d6
Add bionemo-amplify sub-package and its requirements
ynashed Oct 3, 2024
e21c692
Added AMPLIFY tokenizer
ynashed Oct 3, 2024
2e9e299
Added AMPLIFY dataset and dataloader
ynashed Oct 4, 2024
8b41f03
Finalizing amplify model and train script
ynashed Oct 4, 2024
82928a4
Add bionemo-esm2 to local requirements
ynashed Oct 4, 2024
2d0181e
Update biobert_spec_option to use esm2_bert_layer_local_spec
ynashed Oct 4, 2024
87b259f
Fix syntax error
ynashed Oct 4, 2024
fa8daad
Update lr_scheduler import in amplify_pretrain.py
ynashed Oct 4, 2024
794a40b
Update amplify_pretrain.py to use hf_dataset_name instead of individu…
ynashed Oct 4, 2024
c1a03d2
Fixing errors related to hf_dataset_name
ynashed Oct 5, 2024
94b9649
Update optimizer and lr_scheduler in amplify_pretrain.py
ynashed Oct 5, 2024
d8e4d1a
Bugfixes
ynashed Oct 5, 2024
552a952
Merge pull request #1 from NVIDIA/main
ynashed Oct 8, 2024
0cbc475
adding bionemo-scclip to sub-packages
ynashed Sep 11, 2024
a1e323f
Initial model change
ynashed Sep 30, 2024
bd7041c
Add bionemo-amplify sub-package and its requirements
ynashed Oct 3, 2024
897601e
Added AMPLIFY tokenizer
ynashed Oct 3, 2024
1dc780c
Added AMPLIFY dataset and dataloader
ynashed Oct 4, 2024
88ebe21
Finalizing amplify model and train script
ynashed Oct 4, 2024
ee6e91b
Add bionemo-esm2 to local requirements
ynashed Oct 4, 2024
a7f8ff4
Update biobert_spec_option to use esm2_bert_layer_local_spec
ynashed Oct 4, 2024
1a1c4e8
Fix syntax error
ynashed Oct 4, 2024
3c9ffc7
Update lr_scheduler import in amplify_pretrain.py
ynashed Oct 4, 2024
52809c7
Update amplify_pretrain.py to use hf_dataset_name instead of individu…
ynashed Oct 4, 2024
9a4aec8
Fixing errors related to hf_dataset_name
ynashed Oct 5, 2024
d6eada9
Update optimizer and lr_scheduler in amplify_pretrain.py
ynashed Oct 5, 2024
0901c5d
Bugfixes
ynashed Oct 5, 2024
d261500
Update tach.toml to include bionemo.amplify module and its dependencies
ynashed Oct 8, 2024
579f16a
Updated bionemo-amplify to sync with upstream esm2 changes
ynashed Oct 10, 2024
0c9de08
Ignore test_experiment directory in git
ynashed Oct 10, 2024
abfa9da
solving merge conflicts
ynashed Oct 10, 2024
3820690
Update BioNeMoAMPLIFYTokenizer to use EsmTokenizer
ynashed Oct 10, 2024
059cc8c
Update BioNeMoAMPLIFYTokenizer to use chandar-lab/AMPLIFY_350M
ynashed Oct 10, 2024
b8fd917
Update BioNeMoAMPLIFYTokenizer to fix serialization issue
ynashed Oct 10, 2024
80947b0
Fix range for random_tokens in AMPLIFYMaskedResidueDataset
ynashed Oct 10, 2024
f8eac7c
Refactor index variable in AMPLIFYMaskedResidueDataset's __getitem__ …
ynashed Oct 10, 2024
121ff5f
Adding AMPLIFY specific config parameters
ynashed Oct 14, 2024
c758656
Amplify doesn't inherit from ESM2Model anymore
ynashed Oct 17, 2024
1e36af1
removed extra layernorm, added gradient clipping, configure cosine lr…
ynashed Oct 19, 2024
273cf6a
reducing attention block ffn_hidden_size to match the paper
ynashed Oct 22, 2024
26b12a0
Dataset resampling with MultiEpochDatasetResampler
ynashed Oct 22, 2024
21083c0
Merge branch 'NVIDIA:main' into v2-main
ynashed Oct 22, 2024
e47def5
cast np.int64 to int in dataset __getitem__
ynashed Oct 22, 2024
522e028
Merge branch 'v2-main' into ynashed/v2-main/amplify
ynashed Oct 22, 2024
eb2d658
Update amplify to match latest esm2 code
ynashed Oct 22, 2024
b822461
Revert to PRNGResampleDataset
ynashed Oct 22, 2024
7f40015
optimize multi_epoch_dataset for constant memory and space usage
pstjohn Oct 23, 2024
e09db4e
Merge pull request #2 from NVIDIA/pstjohn/main/optimize-multi-epoch-d…
ynashed Oct 23, 2024
70284e0
Trying out upstream MultiEpochDatasetResampler optimization
ynashed Oct 23, 2024
89166e4
Added dataset_subset to AMPLIFYMaskedResidueDataset
ynashed Oct 23, 2024
f699da2
Changed defaults to 120M Model. Added final_step to lr scheduler (cos…
ynashed Nov 2, 2024
97b4bf6
enabled bf16 in the optimizer
ynashed Nov 4, 2024
77040d5
LR scheduler warmup starts from min_lr 0 by default
ynashed Nov 4, 2024
af44848
Fixing cosine lr scheduler to match HF implementation
ynashed Nov 4, 2024
c52d5ba
Turning off CosineAnnealingScheduler constan_steps
ynashed Nov 5, 2024
2305b17
RandomMaskStrategy defaults to AMINO_ACIDS_ONLY
ynashed Nov 6, 2024
429bbdc
Make sure <mask> token exists in masked sequence
ynashed Nov 6, 2024
c5ef9d4
Forgot self. (doh)
ynashed Nov 6, 2024
c65f86e
revert the masking check
ynashed Nov 7, 2024
2f1b149
Turning off weight decay
ynashed Nov 7, 2024
16d0a30
[WIP] fix tests
ynashed Nov 8, 2024
3260d37
Merge remote-tracking branch 'upstream/main'
ynashed Nov 8, 2024
ece8de4
Merge branch 'main' into ynashed/v2-main/amplify
ynashed Nov 8, 2024
d64740e
esm2 updates
ynashed Nov 8, 2024
82de2ab
Fixes after testing
ynashed Nov 8, 2024
cdba113
adding log-every-n-steps argument
ynashed Nov 8, 2024
6c52266
Merge remote-tracking branch 'upstream/main'
ynashed Nov 12, 2024
9b7d5ff
Merge branch 'main' into ynashed/v2-main/amplify
ynashed Nov 12, 2024
efc53eb
added nsys profiling arguments
ynashed Nov 13, 2024
e803606
removed slowdown in dataset getitem
ynashed Nov 13, 2024
3f649ab
Trying differemt optimizer and model configs
ynashed Nov 14, 2024
40566ec
roll back config changes
ynashed Nov 14, 2024
2a2cbac
Removed abandoned bionemo-scclip
ynashed Nov 15, 2024
2f051f8
Patching Megatron-LM to include pytorch optimizers as default
ynashed Dec 10, 2024
d10ab06
Adding run:ai submit scriots [WIP]
ynashed Dec 10, 2024
3e7f959
Added Mehatron optimizer patch file
ynashed Dec 10, 2024
50025eb
changes to match https://github.com/NVIDIA/NeMo/pull/11252
ynashed Dec 12, 2024
4280418
Adding WANDB_API_KEY run:ai secret
ynashed Dec 12, 2024
5d03f8d
Switch back to pytorch_lightning.callbacks import
ynashed Dec 12, 2024
3339571
trying to get rid if loss spikes
ynashed Dec 13, 2024
faed442
pip installing megatron-lm just for good measure
ynashed Dec 13, 2024
85e1d78
faster weight decay
ynashed Dec 13, 2024
f98f02d
Changing core attention to default
ynashed Dec 13, 2024
c161652
Trying training in fp32
ynashed Dec 13, 2024
192e121
OOM, trying fp32-mixed
ynashed Dec 13, 2024
5d37fea
OOM, fp16-mixed
ynashed Dec 13, 2024
32aa173
lower initial loss_scaling
ynashed Dec 14, 2024
5927369
fp16-mixed precision with constant loss scaling
ynashed Dec 14, 2024
ce2298c
loss_scale passed to the right class
ynashed Dec 14, 2024
77d3239
MegatronMixedPrecision class argument name fix
ynashed Dec 14, 2024
432acf0
reverting back to bf16-mixed
ynashed Dec 14, 2024
572511b
Increasing adam_eps to try counter grad norm explosion
ynashed Dec 14, 2024
7dcd4e1
reverting adam_eps to 1e-8
ynashed Dec 14, 2024
d7e1cc0
Turning off fusions, turning on attention_softmax_in_fp32
ynashed Dec 14, 2024
a02b4ee
Use esm2 LM_Head and layernorm instead of RMSNorm
ynashed Dec 16, 2024
98a62d7
Trying esm2_bert_layer_with_transformer_engine_spec
ynashed Dec 16, 2024
a027e2b
Allowing AmplifyConfig to accept esm2 bert spec
ynashed Dec 16, 2024
e7db7d3
Removing LM_head, using esm2_bert_layer_with_transformer_engine_spec,…
ynashed Dec 16, 2024
0f7ef3f
kubectl script changes to add random-mask-strategy argument
ynashed Dec 23, 2024
d60e76d
Merge remote-tracking branch 'upstream/main' into ynashed/v2-main/amp…
ynashed Dec 23, 2024
62caa27
Merge branch 'ynashed/amplify/runai' into ynashed/v2-main/amplify
ynashed Dec 24, 2024
b1836eb
Updates after sync with upstream
ynashed Dec 25, 2024
169e5bb
updating the pytorch_lightning import
ynashed Dec 30, 2024
ec10b36
Merge remote-tracking branch 'upstream/main'
ynashed Dec 30, 2024
b7e35ac
Merge branch 'main' into ynashed/v2-main/amplify
ynashed Dec 30, 2024
d0b837d
Merge remote-tracking branch 'upstream/main' into ynashed/v2-main/amp…
ynashed Dec 30, 2024
f6e0f37
bump NeMo version to match upstream
ynashed Dec 30, 2024
8f15b80
Add handling for special token IDs in BioNeMoAMPLIFYTokenizer
ynashed Dec 31, 2024
0107bc6
Add DDP configuration options for gradient reduction and parameter ga…
ynashed Dec 31, 2024
1a3ffa9
Adding train script entrypoint
ynashed Dec 31, 2024
7e8cddf
Conforming with upstream changes
ynashed Dec 31, 2024
5942afa
Removing runai scripts from git
ynashed Dec 31, 2024
cbf3535
AMPLIFYConfig::__post_init__ was called twice leading to decreased mo…
ynashed Jan 1, 2025
7fe9c46
Disable distributed optimizer in training script
ynashed Jan 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ coverage.xml
Thumbs.db

.python_history
test_experiment/*
1 change: 1 addition & 0 deletions sub-packages/bionemo-amplify/LICENSE
13 changes: 13 additions & 0 deletions sub-packages/bionemo-amplify/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# bionemo-amplify


### Setup
To install, execute the following:
```bash
pip install -e .
```

To run unit tests, execute:
```bash
pytest -v .
```
1 change: 1 addition & 0 deletions sub-packages/bionemo-amplify/VERSION
38 changes: 38 additions & 0 deletions sub-packages/bionemo-amplify/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[build-system]
requires = ["setuptools>=64", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "bionemo-amplify"
readme = "README.md"
description = "BioNeMo AMPLIFY"
authors = [{ name = "BioNeMo Team", email = "[email protected]" }]
requires-python = ">=3.10"
license = { file = "LICENSE" }
dynamic = ["version"]
dependencies = [
# bionemo sub-packages
'bionemo-core',
'bionemo-esm2',
'bionemo-llm',
# external
]

[project.scripts]
train_amplify = "bionemo.amplify.scripts.train_amplify:train_amplify_entrypoint"

# Make sure that the tokenizer files are included along with the python files during installation.
[tool.setuptools.package-data]
"bionemo.amplify" = ["data/tokenizer/*.json", "data/tokenizer/*.txt"]

[tool.setuptools.packages.find]
where = ["src"]
include = ["bionemo.*"]
namespaces = true
exclude = ["test*."]

[tool.setuptools.dynamic]
version = { file = "VERSION" }

[tool.uv]
cache-keys = [{ git = true }]
14 changes: 14 additions & 0 deletions sub-packages/bionemo-amplify/src/bionemo/amplify/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25 changes: 25 additions & 0 deletions sub-packages/bionemo-amplify/src/bionemo/amplify/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Sequence

from bionemo.amplify.model.model import AMPLIFYConfig, AMPLIFYModel


__all__: Sequence[str] = (
"AMPLIFYConfig",
"AMPLIFYModel",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
203 changes: 203 additions & 0 deletions sub-packages/bionemo-amplify/src/bionemo/amplify/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
from typing import Literal

from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from nemo.lightning.data import WrappedDataLoader
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging

from bionemo.core.data.multi_epoch_dataset import MultiEpochDatasetResampler
from bionemo.amplify.data import dataset, tokenizer
from bionemo.llm.data import collate
from bionemo.llm.data.datamodule import MegatronDataModule
from bionemo.llm.utils.datamodule_utils import infer_num_samples


Mode = Literal["train", "validation", "test"]


class AMPLIFYDataModule(MegatronDataModule):
"""LightningDataModule wrapper of `AMPLIFYDataset`."""
def __init__(
self,
hf_dataset_name: str = "chandar-lab/UR100P",
seed: int | None = 42,
min_seq_length: int | None = None,
max_seq_length: int = 512,
micro_batch_size: int = 512,
global_batch_size: int = 4096,
num_workers: int = 10, # TODO(@jomitchell) can this be automatically set?
persistent_workers: bool = True,
pin_memory: bool = True,
rampup_batch_size: list[int] | None = None,
mask_prob: float = 0.15,
mask_token_prob: float = 0.8,
mask_random_prob: float = 0.1,
random_mask_strategy: dataset.RandomMaskStrategy = dataset.RandomMaskStrategy.AMINO_ACIDS_ONLY,
tokenizer: tokenizer.BioNeMoAMPLIFYTokenizer = tokenizer.get_tokenizer(),
dataloader_type: Literal["single", "cyclic"] = "single",
) -> None:
"""Initialize the AMPLIFYDataModule.

Args:
hf_dataset_name: The name of the HuggingFace dataset. Defaults to "chandar-lab/UR100P".
seed: Input random seed. If None, initializes randomly. Defaults to 42.
min_seq_length: Whether to pad sequences to a minimum length. If None, no extra padding is added. Defaults
to None.
max_seq_length: The maximum context length for the AMPLIFY transformer. Defaults to 512.
micro_batch_size: Passed to MegatronDataSampler. Defaults to 512.
global_batch_size: Passed to MegatronDataSampler. Defaults to 4096.
num_workers: The number of workers for the pytorch Dataloaders. Defaults to 10.
persistent_workers: Whether to keep the workers alive between epochs. Defaults to True.
pin_memory: Whether to pin GPU memory in the pytorch Dataloaders. Defaults to True.
rampup_batch_size: Passed to MegatronDataSampler. Defaults to None.
mask_prob: The overall chance of masking a token and having it appear in the loss fn. Defaults to 0.15.
mask_token_prob: Percentage of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
mask_random_prob: Percentage of masked tokens assigned to a random amino acid. Defaults to 0.1.
random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.AMINO_ACIDS_ONLY.
tokenizer: The AMPLIFY tokenizer. Defaults to the one returned by `tokenizer.get_tokenizer()`.
dataloader_type: The type of dataloader to use. Defaults to "single".
"""
super().__init__()
self._hf_dataset_name = hf_dataset_name
self._seed = seed
self._min_seq_length = min_seq_length
self._max_seq_length = max_seq_length
self._mask_prob = mask_prob
self._mask_token_prob = mask_token_prob
self._mask_random_prob = mask_random_prob
self._random_mask_strategy = random_mask_strategy
self._tokenizer = tokenizer

self._micro_batch_size = micro_batch_size
self._num_workers = num_workers
self._persistent_workers = persistent_workers
self._pin_memory = pin_memory

self.data_sampler = MegatronDataSampler(
seq_len=max_seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
dataloader_type=dataloader_type, # `MegatronPretrainingRandomSampler` from "cyclic" is failing.
rampup_batch_size=rampup_batch_size,
)

@property
def tokenizer(self) -> tokenizer.BioNeMoAMPLIFYTokenizer:
"""Returns the tokenizer."""
return self._tokenizer

def setup(self, stage: str = "") -> None:
"""Setup the AMPLIFYDataModule.

Args:
stage: Unused.

Raises:
RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set.
"""
del stage # Unused.

if not hasattr(self, "trainer") or self.trainer is None:
raise RuntimeError("Setup should be completed when trainer and config are attached.")

if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1:
logging.warning(
"Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used "
"in each. Instead set max_epochs to 1 and increase the number of max_steps."
)

max_train_steps = self.trainer.max_steps
if max_train_steps <= 0:
raise RuntimeError("Please specify trainer.max_steps")

# Create training dataset
num_train_samples = int(
max_train_steps * self.data_sampler.global_batch_size
) # training data requires upsampling (multiply by max_train_steps) on single MegatronPretrainingRandomSampler
_train_ds = dataset.AMPLIFYMaskedResidueDataset(hf_dataset_name=self._hf_dataset_name,
dataset_subset=None,
split="train",
seed=self._seed,
max_seq_length=self._max_seq_length,
mask_prob=self._mask_prob,
mask_token_prob=self._mask_token_prob,
mask_random_prob=self._mask_random_prob,
random_mask_strategy=self._random_mask_strategy,
tokenizer=self._tokenizer)
self._train_ds = MultiEpochDatasetResampler(_train_ds, num_samples=num_train_samples, shuffle=True, seed=self._seed)

# Create validation dataset
_valid_ds = dataset.AMPLIFYMaskedResidueDataset(hf_dataset_name=self._hf_dataset_name,
dataset_subset="UniProt",
split="test",
seed=self._seed,
max_seq_length=self._max_seq_length,
mask_prob=self._mask_prob,
mask_token_prob=self._mask_token_prob,
mask_random_prob=self._mask_random_prob,
random_mask_strategy=self._random_mask_strategy,
tokenizer=self._tokenizer)
num_val_samples = infer_num_samples(limit_batches=self.trainer.limit_val_batches,
num_samples_in_dataset=len(_valid_ds),
global_batch_size=self.data_sampler.global_batch_size,
stage="val")
self._valid_ds = MultiEpochDatasetResampler(_valid_ds, num_samples=num_val_samples, shuffle=False, seed=self._seed)

assert (
hasattr(self, "trainer") and self.trainer is not None
), "Setup should be completed when trainer and config are attached."

def _create_dataloader(self, dataset, mode: Mode, **kwargs) -> WrappedDataLoader:
"""Create dataloader for train, validation, and test stages.

Args:
dataset: The dataset to create the dataloader for.
mode: Stage of training, which is used to determined if consumed_samples in MegatronPretrainingSampler should be initialized to 0 (validation/test), or be set to the previous value from state_dict in case of checkpoint resumption (train).
**kwargs: Additional arguments to pass to the dataloader.
"""
self.update_init_global_step()
assert self._tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id."

return WrappedDataLoader(
mode=mode,
dataset=dataset,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers,
collate_fn=functools.partial(
collate.bert_padding_collate_fn,
padding_value=self._tokenizer.pad_token_id,
min_length=self._min_seq_length,
max_length=self._max_seq_length,
),
**kwargs,
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Returns the dataloader for training data."""
return self._create_dataloader(self._train_ds, mode="train")

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Returns the dataloader for validation data."""
return self._create_dataloader(self._valid_ds, mode="validation")

def test_dataloader(self) -> EVAL_DATALOADERS:
"""Raises a not implemented error."""
raise NotImplementedError("No test dataset provided for AMPLIFY")
Loading
Loading