-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use torchmetrics ppl logging #382
base: main
Are you sure you want to change the base?
Conversation
@@ -78,7 +78,7 @@ def bert_padding_collate_fn( | |||
"text": padding_value, | |||
"types": 0, | |||
"attention_mask": False, | |||
"labels": -1, | |||
"labels": -100, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops 😄 . We should fix this, currently labels are padded with both -1 and -100
def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor: | ||
"""In mcore the loss-function is part of the forward-pass when labels are provided.""" | ||
return self.forward_step(batch) | ||
outputs = self.forward_step(batch) | ||
self.valid_ppl(outputs["token_logits"], batch["labels"]) | ||
self.log("valid_ppl_step", self.valid_ppl) | ||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does validation_step
get called and how does this interact with megatron parallelism? Just wondering whether torchmetrics would correctly detect the parallelism settings and collect data from all ranks with this configuration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this would be a good place for a coment about token_logits
being [s,b,*]
while other things are [b,s,*]
Running with
|
72097b9
to
238a30f
Compare
transposing labels
238a30f
to
11169e0
Compare
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s] | ||
self.train_ppl(logits, batch["labels"]) | ||
self.log("train_ppl_step", self.train_ppl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For PP, we need to check whether this is at the last PP stage.
https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py#L410
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe NeMo returns 0
by default when it isn't at the last PP stage but don't rely on my vague memory here. I don't know how it translates to other outputs such as token_logits
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you need a device group for each parallel group, check out parallel_states.get_pipeline_model_parallel_group
. You might have multiple PP group if PP and TP happen at the same time. Each TP under a PP group, or the other way around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure how VPP plays a role here.
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100) | ||
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great to be configurable once the PR has matured.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure, I think we sould make an ESM2LightningModule
with a lot of these args pre-defined where we can then attach metrics in an obvious way
outputs = self.forward_step(batch) | ||
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s] | ||
self.train_ppl(logits, batch["labels"]) | ||
self.log("train_ppl_step", self.train_ppl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And similarly, we need a different mechanism for CP, which isn't implemented even in the current PerplexityLoggingCallback.
https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py#L429
outputs = self.forward_step(batch) | ||
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s] | ||
self.train_ppl(logits, batch["labels"]) | ||
self.log("train_ppl_step", self.train_ppl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For TP, things are easier since there is already a TP aware method to get the loss.
https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py#L262
outputs = self.forward_step(batch) | ||
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s] | ||
self.train_ppl(logits, batch["labels"]) | ||
self.log("train_ppl_step", self.train_ppl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to sync_dist
here.
def on_train_epoch_end(self): | ||
"""Log perplexity at the end of the epoch.""" | ||
self.log("train_ppl_step", self.train_ppl) | ||
|
||
def on_valid_epoch_end(self): | ||
"""Log perplexity at the end of the epoch.""" | ||
self.log("valid_ppl_step", self.valid_ppl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need to log once in step.
outputs = self.forward_step(batch) | ||
logits = outputs["token_logits"].transpose(0, 1) # [s, b] -> [b, s] | ||
self.valid_ppl(logits, batch["labels"]) | ||
self.log("valid_ppl_step", self.valid_ppl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we will need sync_dist=True
and on_epoch=True
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is that? It seems to give similar perplexity values to the callback:
https://wandb.ai/clara-discovery/bionemo-hf-pretraining/reports/val_ppl-valid_ppl_step-24-11-20-18-12-31---VmlldzoxMDI2NjQ0MA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sync_dist, sync_dist_group and reduce_fx flags from self.log(...) don’t affect the metric logging in any manner. The metric class contains its own distributed synchronization logic.
https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#logging-torchmetrics
So I think we do want on_epoch=True
here (I'm not sure why my example run was so close without it...), but I think sync_dist is handled internally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need sync_dist_group
because you don't want to sync with 0/other weird value from non-last-pp-stage devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe they are the same on single device because you had self.log("valid_ppl_step", self.valid_ppl)
in on_valid_epoch_end
?
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100) | ||
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any changes to the defaults we should make there? I used torchmetrics with the HF pretraining and it seemed to work well, so I don't think the oversampling is too big of an issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sync_on_compute
is on by default so we are good on that end. I wonder if we need to specify process_group
to skip non-last-pp-stage devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't have a clear idea on how to treat tp devices.
Related effort: #525 |
Some experiments using torchmetrics for perplexity logging.
Example wandb chart for single-GPU training: https://wandb.ai/clara-discovery/bionemo-hf-pretraining/runs/fgn32r4p?nw=nwuserpstjohn