Skip to content

Commit

Permalink
use torchmetrics ppl logging
Browse files Browse the repository at this point in the history
transposing labels
  • Loading branch information
pstjohn committed Nov 20, 2024
1 parent d35fe41 commit 11169e0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def bert_padding_collate_fn(
"text": padding_value,
"types": 0,
"attention_mask": False,
"labels": -1,
"labels": -100,
"loss_mask": False,
"is_random": 0,
}
Expand Down
23 changes: 21 additions & 2 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytorch_lightning as pl
import torch.distributed
import torchmetrics.text
from megatron.core import parallel_state
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from nemo.lightning import io as nlio
Expand Down Expand Up @@ -258,6 +259,8 @@ def __init__(
self._data_step = data_step
self._forward_step = forward_step
self.model_transform = model_transform
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100)
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100)

def configure_model(self) -> None:
"""Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.
Expand Down Expand Up @@ -305,11 +308,19 @@ def forward_step(self, batch) -> Tensor:

def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]
self.train_ppl(logits, batch["labels"])
self.log("train_ppl_step", self.train_ppl)
return outputs

def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]
self.valid_ppl(logits, batch["labels"])
self.log("valid_ppl_step", self.valid_ppl)
return outputs

def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""Alias for forward_step."""
Expand All @@ -325,6 +336,14 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102
def test_loss_reduction(self) -> MegatronLossType: # noqa: D102
return self.loss_reduction_class(validation_step=True)

def on_train_epoch_end(self):
"""Log perplexity at the end of the epoch."""
self.log("train_ppl_step", self.train_ppl)

def on_valid_epoch_end(self):
"""Log perplexity at the end of the epoch."""
self.log("valid_ppl_step", self.valid_ppl)


def default_megatron_optimizer() -> MegatronOptimizerModule:
"""Default distributed optimizer uses Adam with a 1e-4 learning rate."""
Expand Down

0 comments on commit 11169e0

Please sign in to comment.