Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use torchmetrics ppl logging #382

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Oct 30, 2024

Some experiments using torchmetrics for perplexity logging.

Example wandb chart for single-GPU training: https://wandb.ai/clara-discovery/bionemo-hf-pretraining/runs/fgn32r4p?nw=nwuserpstjohn

@@ -78,7 +78,7 @@ def bert_padding_collate_fn(
"text": padding_value,
"types": 0,
"attention_mask": False,
"labels": -1,
"labels": -100,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops 😄 . We should fix this, currently labels are padded with both -1 and -100

Comment on lines 286 to 323
def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
self.valid_ppl(outputs["token_logits"], batch["labels"])
self.log("valid_ppl_step", self.valid_ppl)
return outputs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does validation_step get called and how does this interact with megatron parallelism? Just wondering whether torchmetrics would correctly detect the parallelism settings and collect data from all ranks with this configuration

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this would be a good place for a coment about token_logits being [s,b,*] while other things are [b,s,*]

@pstjohn
Copy link
Collaborator Author

pstjohn commented Nov 12, 2024

Running with tp=2, I get the following error:

  File "/usr/local/lib/python3.10/dist-packages/bionemo/llm/lightning.py", line 319, in validation_step
    self.valid_ppl(outputs["token_logits"], batch["labels"])
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py", line 312, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py", line 381, in _forward_reduce_state_update
    self.update(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py", line 483, in wrapped_func
    update(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchmetrics/text/perplexity.py", line 82, in update
    total_log_probs, count = _perplexity_update(preds, target, self.ignore_index)
  File "/usr/local/lib/python3.10/dist-packages/torchmetrics/functional/text/perplexity.py", line 83, in _perplexity_update
    _check_shape_and_type_consistency(preds, target)
  File "/usr/local/lib/python3.10/dist-packages/torchmetrics/functional/text/perplexity.py", line 55, in _check_shape_and_type_consistency
    raise ValueError(
ValueError: Input tensors `preds` and `target` are expected to have equaling first two dimensions, [batch_size, seq_len], but got torch.Size([1024, 16]) and torch.Size([16, 1024]).

@pstjohn pstjohn force-pushed the pstjohn/main/torchmetrics-ppl-logging branch from 72097b9 to 238a30f Compare November 12, 2024 20:27
transposing labels
@pstjohn pstjohn force-pushed the pstjohn/main/torchmetrics-ppl-logging branch from 238a30f to 11169e0 Compare November 20, 2024 20:02
Comment on lines +312 to +314
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]
self.train_ppl(logits, batch["labels"])
self.log("train_ppl_step", self.train_ppl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe NeMo returns 0 by default when it isn't at the last PP stage but don't rely on my vague memory here. I don't know how it translates to other outputs such as token_logits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you need a device group for each parallel group, check out parallel_states.get_pipeline_model_parallel_group. You might have multiple PP group if PP and TP happen at the same time. Each TP under a PP group, or the other way around.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how VPP plays a role here.

Comment on lines +262 to +263
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100)
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to be configurable once the PR has matured.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure, I think we sould make an ESM2LightningModule with a lot of these args pre-defined where we can then attach metrics in an obvious way

Comment on lines +311 to +314
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]
self.train_ppl(logits, batch["labels"])
self.log("train_ppl_step", self.train_ppl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And similarly, we need a different mechanism for CP, which isn't implemented even in the current PerplexityLoggingCallback.
https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py#L429

Comment on lines +311 to +314
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]
self.train_ppl(logits, batch["labels"])
self.log("train_ppl_step", self.train_ppl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For TP, things are easier since there is already a TP aware method to get the loss.
https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py#L262

Comment on lines +311 to +314
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]
self.train_ppl(logits, batch["labels"])
self.log("train_ppl_step", self.train_ppl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to sync_dist here.

Comment on lines +339 to +345
def on_train_epoch_end(self):
"""Log perplexity at the end of the epoch."""
self.log("train_ppl_step", self.train_ppl)

def on_valid_epoch_end(self):
"""Log perplexity at the end of the epoch."""
self.log("valid_ppl_step", self.valid_ppl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to log once in step.

outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s]
self.valid_ppl(logits, batch["labels"])
self.log("valid_ppl_step", self.valid_ppl)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we will need sync_dist=True and on_epoch=True here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@pstjohn pstjohn Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sync_dist, sync_dist_group and reduce_fx flags from self.log(...) don’t affect the metric logging in any manner. The metric class contains its own distributed synchronization logic.

https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#logging-torchmetrics

So I think we do want on_epoch=True here (I'm not sure why my example run was so close without it...), but I think sync_dist is handled internally?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need sync_dist_group because you don't want to sync with 0/other weird value from non-last-pp-stage devices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe they are the same on single device because you had self.log("valid_ppl_step", self.valid_ppl) in on_valid_epoch_end?

Comment on lines +262 to +263
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100)
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any changes to the defaults we should make there? I used torchmetrics with the HF pretraining and it seemed to work well, so I don't think the oversampling is too big of an issue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sync_on_compute is on by default so we are good on that end. I wonder if we need to specify process_group to skip non-last-pp-stage devices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't have a clear idea on how to treat tp devices.

@sichu2023
Copy link
Collaborator

Related effort: #525

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants