Skip to content

Commit

Permalink
use torchmetrics ppl logging
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Oct 30, 2024
1 parent 41ef909 commit 72097b9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
3 changes: 1 addition & 2 deletions scripts/protein/esm2/esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size
Expand Down Expand Up @@ -163,7 +162,7 @@ def main(
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
# PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
]
Expand Down
2 changes: 1 addition & 1 deletion sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def bert_padding_collate_fn(
"text": padding_value,
"types": 0,
"attention_mask": False,
"labels": -1,
"labels": -100,
"loss_mask": False,
"is_random": 0,
}
Expand Down
21 changes: 19 additions & 2 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytorch_lightning as pl
import torch.distributed
import torchmetrics.text
from megatron.core import parallel_state
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from nemo.lightning import io as nlio
Expand Down Expand Up @@ -228,6 +229,8 @@ def __init__(
self._data_step = data_step
self._forward_step = forward_step
self.model_transform = model_transform
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100)
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100)

def configure_model(self) -> None:
"""Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.
Expand Down Expand Up @@ -275,11 +278,17 @@ def forward_step(self, batch) -> Tensor:

def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
self.train_ppl(outputs["token_logits"], batch["labels"])
self.log("train_ppl_step", self.train_ppl)
return outputs

def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
self.valid_ppl(outputs["token_logits"], batch["labels"])
self.log("valid_ppl_step", self.valid_ppl)
return outputs

def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""Alias for forward_step."""
Expand All @@ -295,6 +304,14 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102
def test_loss_reduction(self) -> MegatronLossType: # noqa: D102
return self.loss_reduction_class(validation_step=True)

def on_train_epoch_end(self):
"""Log perplexity at the end of the epoch."""
self.log("train_ppl_step", self.train_ppl)

def on_valid_epoch_end(self):
"""Log perplexity at the end of the epoch."""
self.log("valid_ppl_step", self.valid_ppl)


def default_megatron_optimizer() -> MegatronOptimizerModule:
"""Default distributed optimizer uses Adam with a 1e-4 learning rate."""
Expand Down

0 comments on commit 72097b9

Please sign in to comment.