-
Notifications
You must be signed in to change notification settings - Fork 27
Loss class refactoring #533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Apply the review
…ub.com/kacpnowak/WeatherGenerator2 into kacpnowak/develop/per-channel-logginig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great progress, thanks! Two more points:
- Please add doc strings to all functions
- loss_module is not a great name; "module" is an overloaded term in CS and the class is not a module in most definitions.
LossComputer
is not great but descriptive. Open for other suggestions.
src/weathergen/train/loss_module.py
Outdated
i_batch = 0 # TODO: Iterate over batch dimension here in future | ||
for i_strm, strm in enumerate(self.cf.streams): | ||
targets = streams_data[i_batch][i_strm].target_tokens[self.cf.forecast_offset:] | ||
#assert len(targets) == self.cf.forecast_steps + 1, "Length of targets does not match number of forecast_steps." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this assertion works well for different configurations under training_mode: "forecast"
, it crashes with
training_mode: "masking"
forecast_offset : 0
forecast_steps: 0
since len(targets)=1
and forecast_offset+forecast_steps=0
. This seems incompatible with autoencoder training. I have put the assertion into an if training_mode == "masking"
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: After testing more configurations, this assertion seems not valid in general. It only seems valid for training_mode: "forecast"
and forecast_policy: "fixed"
. Furthermore, it should then be
if self.cf.training_mode == "forecast" and self.cf.forecast_policy == "fixed":
assert len(targets) == self.forecast_steps
But that seems too constraining.
src/weathergen/train/loss_module.py
Outdated
targets = streams_data[i_batch][i_strm].target_tokens[self.cf.forecast_offset:] | ||
#assert len(targets) == self.cf.forecast_steps + 1, "Length of targets does not match number of forecast_steps." | ||
|
||
for fstep, target in enumerate(targets): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add more comments to the individual lines, e.g. why 108.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
l114-122: overwriting in VAL case is not very clean: better to have
if TRAIN :
...
elif VAL:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
l. 125: we should document what the shape of pred at this point is: (ensemble, target_points, target_channels)
src/weathergen/train/loss_module.py
Outdated
ctr_chs += 1 | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the code would be simpler if mask is just identity when tok_spacetime = False and this could be handled completely in _construct_masks
.
|
||
### Original logging preparation | ||
# TODO: remove dependency from old trainer.compute_loss() function | ||
_, _, _, logging_items = trainer.compute_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be high priority since we cannot merge with the old dependency.
pred: The prediction tensor, potentially with an ensemble dimension. | ||
mask: A boolean mask tensor, indicating which elements to consider for loss computation. | ||
i_ch: The index of the channel for which to compute the loss. | ||
loss_fct: The specific loss function to apply. It is expected to accept |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the loss function definition is very rigid here, for instance any latent loss cannot work in this way, indeed many regularisation losses might not work with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we introduce a dead end here or if we can work on this generalization in a new PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean the KL divergence in the ELBO, in case of a VAE? Since it is coming from an intermediate stage in the model, I would say that we can add it later as an extra parameter.
Thinking about it, we have roughly 3 losses depending on the stage in the model:
- initial (stage 0): regularization terms (no need for samples)
- mid-way: variational terms in the ELBO
- end-to-end: expected empirical risk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regularisation terms can very much need samples frequently, e.g. z-loss as needed for stability in multi-modal transformers as here https://arxiv.org/pdf/2405.09818, or representation regularisation like a dispersion loss here https://arxiv.org/abs/2506.09027
# If no valid data under the mask, return 0 to avoid errors and not contribute to loss | ||
return 0 | ||
|
||
def compute_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my suggestion would be to rename this to compute_input_space_loss
and people to register custom functions that all get executed in a wrapper compute_loss
function.
each register function can come with a list of arguments to recover out of an argument dict kwargs
? or we extend the loss function through inheritance for instance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't src/weathergen/train/loss.py
such a register of custom functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not a fan of registration because it includes state to think about. Could it be described fully in the input of the class constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there should be an easy to compute a loss on the latents given we are doing representation learning and in particular SSL, this will come quite soon I suspect, we can leave it for a future PR I guess, but this is quite restrictice is my point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a light review, but thanks for this hard work. Make sure it does not expand in scope, it is already heavy.
@@ -40,7 +40,7 @@ pred_mlp_adaln: True | |||
|
|||
# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then | |||
# one is training an auto-encoder | |||
forecast_offset : 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to check: are we ok with changing the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we have to revert to the default configs before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comes from #440. We need to clean up before it's merged. Same below.
@@ -9,7 +9,7 @@ | |||
|
|||
FESOM : | |||
type : fesom | |||
filenames : ['fesom_ifs_awi'] | |||
filenames : ['test4.zarr'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needed?
@@ -11,6 +11,8 @@ | |||
import numpy as np | |||
import torch | |||
|
|||
stat_loss_fcts = ["stats", "kernel_crps"] # Names of loss functions that need std computed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean by std computed
?
loss: Tensor | ||
# Dictionaries containing detailed loss values and standard deviation statistics for each | ||
# stream, channel, and loss function. | ||
losses_all: dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict[str, Tensor]
? Be more precise.
pred: The prediction tensor, potentially with an ensemble dimension. | ||
mask: A boolean mask tensor, indicating which elements to consider for loss computation. | ||
i_ch: The index of the channel for which to compute the loss. | ||
loss_fct: The specific loss function to apply. It is expected to accept |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean the KL divergence in the ELBO, in case of a VAE? Since it is coming from an intermediate stage in the model, I would say that we can add it later as an extra parameter.
Thinking about it, we have roughly 3 losses depending on the stage in the model:
- initial (stage 0): regularization terms (no need for samples)
- mid-way: variational terms in the ELBO
- end-to-end: expected empirical risk
# If no valid data under the mask, return 0 to avoid errors and not contribute to loss | ||
return 0 | ||
|
||
def compute_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not a fan of registration because it includes state to think about. Could it be described fully in the input of the class constructor?
Returns: | ||
int: world size | ||
""" | ||
if not dist.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a _is_distributed_initialized
below.
return dist.get_world_size() | ||
|
||
|
||
def get_rank() -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not passing processgroups? see is_root
below. I still don't fully grasp why we need them.
@sophie-xhonneux or Sebastian Hoffmann would probably know better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_rank is afaik often used for things you only want to do on one gpu, e.g. logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is again #440. Wrong place to discuss.
We need this function because if distributed
is not initialized when we run interactively, all the calls to distributed will fail.
@@ -388,6 +410,13 @@ def _key_loss(st_name: str, lf_name: str) -> str: | |||
return f"stream.{st_name}.loss_{lf_name}.loss_avg" | |||
|
|||
|
|||
def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str: | |||
st_name = _clean_name(st_name) | |||
lf_name = _clean_name(lf_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not clean the loss or channel names, they should be standard enough.
Thanks @tjhunter. I think some of the additions come from @kacpnowak's PR - so probably doesn't make fully sense to review everything before his PR is merged. |
This PR is just refactoring. So we should merge without add functionality for latent losses. It's already too big, if anything, and designing something now without actually using it will just lead to misalignment between what we think is needed and how we actually use it. There will be a subsequent PR for the weighting functions but I would still keep it in physical (output) space. Once this is merged, we can add the latent loss. |
Tests are currently failing in masking mode because of issue #553 |
Description
In this PR we refactor the
trainer.compute_loss()
function into a standaloneLossModule
class.Note: This PR works off of Kacper's branch kacpnowak:kacpnowak/develop/per-channel-logginig
Type of Change
Issue Number
Closes #568
Code Compatibility
Code Performance and Testing
uv run train
and (if necessary)uv run evaluate
on a least one GPU node and it works$WEATHER_GENERATOR_PRIVATE
directoryDependencies
Documentation
Additional Notes