Skip to content

Commit

Permalink
Support NVFlare sequence-level classification fine-tuning (#664)
Browse files Browse the repository at this point in the history
### Description
This PR adds support for sequence-level classification fine-tuning using
ESM2. And refactors the already existing APIs for token-level
classification and sequence-level regression.

### Type of changes
<!-- Mark the relevant option with an [x] -->

- [ ]  Bug fix (non-breaking change which fixes an issue)
- [x]  New feature (non-breaking change which adds functionality)
- [x]  Refactor
- [x]  Documentation update
- [ ]  Other (please describe):

### CI Pipeline Configuration
Configure CI behavior by applying the relevant labels:

-
[SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci)
- Skip all continuous integration tests
-
[INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests)
- Execute notebook validation tests in pytest
-
[INCLUDE_SLOW_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_slow_tests)
- Execute tests labelled as slow in pytest for extensive testing


> [!NOTE]
> By default, the notebooks validation tests are skipped unless
explicitly enabled.

### Usage
<!--- How does a user interact with the changed code -->
```python
TODO: Add code snippet
```

### Pre-submit Checklist
<!--- Ensure all items are completed before submitting -->

 - [x] I have tested these changes locally
 - [x] I have updated the documentation accordingly
 - [x] I have added/updated tests as needed
 - [x] All existing tests pass successfully

---------

Signed-off-by: Farhad Ramezanghorbani <[email protected]>
  • Loading branch information
farhadrgh authored Feb 3, 2025
1 parent b338825 commit a55bea5
Show file tree
Hide file tree
Showing 12 changed files with 540 additions and 206 deletions.
196 changes: 151 additions & 45 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


import os
from typing import Sequence
from typing import Literal, Sequence

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(
self,
sequences: pd.Series,
labels: pd.Series | None = None,
task_type: Literal["classification", "regression", None] = None,
tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
seed: int = np.random.SeedSequence().entropy, # type: ignore
):
Expand All @@ -55,13 +56,15 @@ def __init__(
Args:
sequences (pd.Series): A pandas Series containing protein sequences.
labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
task_type (str, optional): Fine-tuning task type. Defaults to None.
tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
generated.
"""
self.sequences = sequences
self.labels = labels
self.task_type = task_type

self.seed = seed
self._len = len(self.sequences)
Expand All @@ -71,13 +74,15 @@ def __init__(
def from_csv(
cls,
csv_path: str | os.PathLike,
task_type: Literal["classification", "regression", None] = None,
tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
ignore_labels: bool = False,
):
"""Class method to create a ProteinDataset instance from a CSV file.
Args:
csv_path: path to CSV file containing sequences and optionally labels column.
task_type (str, optional): Fine-tuning task type. Defaults to None.
tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
ignore_labels (bool): ignore labels column if exist (to avoid reading labels during inference)
"""
Expand All @@ -92,7 +97,7 @@ def from_csv(
if not ignore_labels:
labels = df["labels"]

return cls(sequences, labels, tokenizer)
return cls(sequences, labels=labels, task_type=task_type, tokenizer=tokenizer)

def __len__(self) -> int:
"""The size of the dataset."""
Expand Down Expand Up @@ -148,35 +153,48 @@ class InMemorySingleValueDataset(InMemoryProteinDataset):
def __init__(
self,
sequences: pd.Series,
labels: pd.Series | None = None,
labels: pd.Series,
task_type: Literal["classification", "regression"] = "regression",
tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
seed: int = np.random.SeedSequence().entropy, # type: ignore
):
"""Initializes a dataset for single-value regression fine-tuning.
"""Initializes a dataset for single-value fine-tuning.
This is an in-memory dataset that does not apply masking to the sequence. But keeps track of <mask> in the
dataset sequences provided.
Args:
sequences (pd.Series): A pandas Series containing protein sequences.
labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
task_type (str): Fine-tuning task type. Defaults to regression.
tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
generated.
"""
super().__init__(sequences, labels, tokenizer, seed)
super().__init__(sequences, labels, task_type, tokenizer, seed)

self.task_type = task_type
if self.task_type == "classification":
label_tokenizer = Label2IDTokenizer()
self.label_tokenizer = label_tokenizer.build_vocab(self.labels.values.reshape(-1, 1))

def transform_label(self, label: float) -> Tensor:
def transform_label(self, label: float | str) -> Tensor:
"""Transform the regression label.
Args:
label: regression value
label: single regression/classification value
Returns:
tokenized label
"""
return torch.tensor([label], dtype=torch.float)
if self.task_type == "regression":
return torch.tensor([label], dtype=torch.float)
elif self.task_type == "classification":
tokenized_label = torch.tensor(self.label_tokenizer.text_to_ids([label]))
return tokenized_label
else:
raise ValueError(f"{self.task_type} task type is not supported with {self.__class__.__name__}")


class InMemoryPerTokenValueDataset(InMemoryProteinDataset):
Expand All @@ -185,7 +203,8 @@ class InMemoryPerTokenValueDataset(InMemoryProteinDataset):
def __init__(
self,
sequences: pd.Series,
labels: pd.Series | None = None,
labels: pd.Series,
task_type: Literal["classification", "regression"] = "classification",
tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
seed: int = np.random.SeedSequence().entropy, # type: ignore
):
Expand All @@ -197,35 +216,36 @@ def __init__(
Args:
sequences (pd.Series): A pandas Series containing protein sequences.
labels (pd.Series, optional): A pandas Series containing labels. Defaults to None.
task_type (str): Fine-tuning task type. Defaults to classification. Regression per-token values are not supported.
tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
generated.
"""
super().__init__(sequences, labels, tokenizer, seed)
super().__init__(sequences, labels, task_type, tokenizer, seed)

self.task_type = task_type
if not task_type == "classification":
raise ValueError(f"{task_type} task type is not supported with {self.__class__.__name__}")

label_tokenizer = Label2IDTokenizer()
self.label_tokenizer = label_tokenizer.build_vocab("CHE")
self.label_tokenizer = label_tokenizer.build_vocab(self.labels.values)
self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX

def transform_label(self, label: str) -> Tensor:
"""Transform the sequence label by tokenizing them.
This method tokenizes the secondary structure token sequences.
This method tokenizes a sequence of labels into a tensor of tokens and adds CLS/EOS tokens.
Args:
label: secondary structure token sequences to be transformed
label: label sequence to be transformed
Returns:
tokenized label
"""
label_ids = torch.tensor(self.label_tokenizer.text_to_ids(label))

# # for multi-label classification with BCEWithLogitsLoss
# tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
# cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)
tokenized_labels = torch.tensor(self.label_tokenizer.text_to_ids(label))

# for multi-class (mutually exclusive) classification with CrossEntropyLoss
tokenized_labels = label_ids
cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)

# add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
Expand Down
132 changes: 132 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict, Sequence, Tuple

import torch
from megatron.core import parallel_state
from torch import Tensor

from bionemo.llm.model.loss import BERTMLMLossWithReduction, PerTokenLossDict, SameSizeLossDict


__all__: Sequence[str] = (
"RegressorLossReduction",
"ClassifierLossReduction",
)


class RegressorLossReduction(BERTMLMLossWithReduction):
"""A class for calculating the MSE loss of regression output.
This class used for calculating the loss, and for logging the reduced loss across micro batches.
"""

def forward(
self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
"""Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
Args:
batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
forward_out: the output of the forward method inside classification head.
Returns:
A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
backpropagation and the ReductionT will be passed to the reduce method
(which currently only works for logging.).
"""
regression_output = forward_out["regression_output"]
targets = batch["labels"].to(dtype=regression_output.dtype) # [b, 1]

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
loss = torch.nn.functional.mse_loss(regression_output, targets)
else:
raise NotImplementedError("Context Parallel support is not implemented for this loss")

return loss, {"avg": loss}

def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
"""Works across micro-batches. (data on single gpu).
Note: This currently only works for logging and this loss will not be used for backpropagation.
Args:
losses_reduced_per_micro_batch: a list of the outputs of forward
Returns:
A tensor that is the mean of the losses. (used for logging).
"""
losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
return losses.mean()


class ClassifierLossReduction(BERTMLMLossWithReduction):
"""A class for calculating the cross entropy loss of classification output.
This class used for calculating the loss, and for logging the reduced loss across micro batches.
"""

def forward(
self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
"""Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
Args:
batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
forward_out: the output of the forward method inside classification head.
Returns:
A tuple where the loss tensor will be used for backpropagation and the dict will be passed to
the reduce method, which currently only works for logging.
"""
targets = batch["labels"].squeeze() # [b] or [b, s] for sequence-level or token-level classification

classification_output = forward_out["classification_output"] # [b, num_class] or [b, s, num_class]
# [b, s, num_class] -> [b, num_class, s] to satisfy toke-level input dims for cross_entropy loss
if classification_output.dim() == 3:
classification_output = classification_output.permute(0, 2, 1)

loss_mask = batch["loss_mask"] # [b, s]

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
losses = torch.nn.functional.cross_entropy(classification_output, targets, reduction="none")
# token-level losses may contain NaNs at masked locations. We use masked_select to filter out these NaNs
if classification_output.dim() == 3:
masked_loss = torch.masked_select(losses, loss_mask)
loss = masked_loss.sum() / loss_mask.sum()
else:
loss = losses.mean() # sequence-level single value classification
else:
raise NotImplementedError("Context Parallel support is not implemented for this loss")

return loss, {"avg": loss}

def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
"""Works across micro-batches. (data on single gpu).
Note: This currently only works for logging and this loss will not be used for backpropagation.
Args:
losses_reduced_per_micro_batch: a list of the outputs of forward
Returns:
A tensor that is the mean of the losses. (used for logging).
"""
losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
return losses.mean()
Loading

0 comments on commit a55bea5

Please sign in to comment.