Skip to content

Loss class refactoring #533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 50 commits into
base: develop
Choose a base branch
from
Draft

Conversation

Jubeku
Copy link
Contributor

@Jubeku Jubeku commented Jul 16, 2025

Description

In this PR we refactor the trainer.compute_loss() function into a standalone LossModule class.
Note: This PR works off of Kacper's branch kacpnowak:kacpnowak/develop/per-channel-logginig

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

Issue Number

Closes #568

Code Compatibility

  • I have performed a self-review of my code

Code Performance and Testing

  • I ran the uv run train and (if necessary) uv run evaluate on a least one GPU node and it works
  • If the new feature introduces modifications at the config level, I have made sure to have notified the other software developers through Mattermost and updated the paths in the $WEATHER_GENERATOR_PRIVATE directory

Dependencies

  • I have ensured that the code is still pip-installable after the changes and runs
  • I have tested that new dependencies themselves are pip-installable.
  • I have not introduced new dependencies in the inference portion of the pipeline

Documentation

  • My code follows the style guidelines of this project
  • I have updated the documentation and docstrings to reflect the changes
  • I have added comments to my code, particularly in hard-to-understand areas

Additional Notes

kacpnowak and others added 30 commits May 30, 2025 14:48
@Jubeku Jubeku added the enhancement New feature or request label Jul 16, 2025
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress, thanks! Two more points:

  • Please add doc strings to all functions
  • loss_module is not a great name; "module" is an overloaded term in CS and the class is not a module in most definitions. LossComputer is not great but descriptive. Open for other suggestions.

i_batch = 0 # TODO: Iterate over batch dimension here in future
for i_strm, strm in enumerate(self.cf.streams):
targets = streams_data[i_batch][i_strm].target_tokens[self.cf.forecast_offset:]
#assert len(targets) == self.cf.forecast_steps + 1, "Length of targets does not match number of forecast_steps."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out?

Copy link
Contributor

@MatKbauer MatKbauer Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this assertion works well for different configurations under training_mode: "forecast", it crashes with

training_mode: "masking"
forecast_offset : 0
forecast_steps: 0

since len(targets)=1 and forecast_offset+forecast_steps=0. This seems incompatible with autoencoder training. I have put the assertion into an if training_mode == "masking".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: After testing more configurations, this assertion seems not valid in general. It only seems valid for training_mode: "forecast" and forecast_policy: "fixed". Furthermore, it should then be

if self.cf.training_mode == "forecast" and self.cf.forecast_policy == "fixed":
    assert len(targets) == self.forecast_steps

But that seems too constraining.

targets = streams_data[i_batch][i_strm].target_tokens[self.cf.forecast_offset:]
#assert len(targets) == self.cf.forecast_steps + 1, "Length of targets does not match number of forecast_steps."

for fstep, target in enumerate(targets):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more comments to the individual lines, e.g. why 108.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l114-122: overwriting in VAL case is not very clean: better to have

if TRAIN :
  ...
elif VAL:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l. 125: we should document what the shape of pred at this point is: (ensemble, target_points, target_channels)

ctr_chs += 1
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code would be simpler if mask is just identity when tok_spacetime = False and this could be handled completely in _construct_masks.


### Original logging preparation
# TODO: remove dependency from old trainer.compute_loss() function
_, _, _, logging_items = trainer.compute_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be high priority since we cannot merge with the old dependency.

pred: The prediction tensor, potentially with an ensemble dimension.
mask: A boolean mask tensor, indicating which elements to consider for loss computation.
i_ch: The index of the channel for which to compute the loss.
loss_fct: The specific loss function to apply. It is expected to accept
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the loss function definition is very rigid here, for instance any latent loss cannot work in this way, indeed many regularisation losses might not work with this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we introduce a dead end here or if we can work on this generalization in a new PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean the KL divergence in the ELBO, in case of a VAE? Since it is coming from an intermediate stage in the model, I would say that we can add it later as an extra parameter.

Thinking about it, we have roughly 3 losses depending on the stage in the model:

  • initial (stage 0): regularization terms (no need for samples)
  • mid-way: variational terms in the ELBO
  • end-to-end: expected empirical risk

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regularisation terms can very much need samples frequently, e.g. z-loss as needed for stability in multi-modal transformers as here https://arxiv.org/pdf/2405.09818, or representation regularisation like a dispersion loss here https://arxiv.org/abs/2506.09027

# If no valid data under the mask, return 0 to avoid errors and not contribute to loss
return 0

def compute_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suggestion would be to rename this to compute_input_space_loss and people to register custom functions that all get executed in a wrapper compute_loss function.

each register function can come with a list of arguments to recover out of an argument dict kwargs? or we extend the loss function through inheritance for instance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't src/weathergen/train/loss.py such a register of custom functions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of registration because it includes state to think about. Could it be described fully in the input of the class constructor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be an easy to compute a loss on the latents given we are doing representation learning and in particular SSL, this will come quite soon I suspect, we can leave it for a future PR I guess, but this is quite restrictice is my point

Copy link
Collaborator

@tjhunter tjhunter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a light review, but thanks for this hard work. Make sure it does not expand in scope, it is already heavy.

@@ -40,7 +40,7 @@ pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to check: are we ok with changing the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we have to revert to the default configs before merging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes from #440. We need to clean up before it's merged. Same below.

@@ -9,7 +9,7 @@

FESOM :
type : fesom
filenames : ['fesom_ifs_awi']
filenames : ['test4.zarr']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed?

@@ -11,6 +11,8 @@
import numpy as np
import torch

stat_loss_fcts = ["stats", "kernel_crps"] # Names of loss functions that need std computed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by std computed?

loss: Tensor
# Dictionaries containing detailed loss values and standard deviation statistics for each
# stream, channel, and loss function.
losses_all: dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict[str, Tensor] ? Be more precise.

pred: The prediction tensor, potentially with an ensemble dimension.
mask: A boolean mask tensor, indicating which elements to consider for loss computation.
i_ch: The index of the channel for which to compute the loss.
loss_fct: The specific loss function to apply. It is expected to accept
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean the KL divergence in the ELBO, in case of a VAE? Since it is coming from an intermediate stage in the model, I would say that we can add it later as an extra parameter.

Thinking about it, we have roughly 3 losses depending on the stage in the model:

  • initial (stage 0): regularization terms (no need for samples)
  • mid-way: variational terms in the ELBO
  • end-to-end: expected empirical risk

# If no valid data under the mask, return 0 to avoid errors and not contribute to loss
return 0

def compute_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of registration because it includes state to think about. Could it be described fully in the input of the class constructor?

Returns:
int: world size
"""
if not dist.is_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a _is_distributed_initialized below.

return dist.get_world_size()


def get_rank() -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not passing processgroups? see is_root below. I still don't fully grasp why we need them.

@sophie-xhonneux or Sebastian Hoffmann would probably know better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_rank is afaik often used for things you only want to do on one gpu, e.g. logging

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is again #440. Wrong place to discuss.

We need this function because if distributed is not initialized when we run interactively, all the calls to distributed will fail.

@@ -388,6 +410,13 @@ def _key_loss(st_name: str, lf_name: str) -> str:
return f"stream.{st_name}.loss_{lf_name}.loss_avg"


def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str:
st_name = _clean_name(st_name)
lf_name = _clean_name(lf_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not clean the loss or channel names, they should be standard enough.

@Jubeku
Copy link
Contributor Author

Jubeku commented Jul 23, 2025

Thanks @tjhunter. I think some of the additions come from @kacpnowak's PR - so probably doesn't make fully sense to review everything before his PR is merged.

@clessig
Copy link
Collaborator

clessig commented Jul 24, 2025

This PR is just refactoring. So we should merge without add functionality for latent losses. It's already too big, if anything, and designing something now without actually using it will just lead to misalignment between what we think is needed and how we actually use it.

There will be a subsequent PR for the weighting functions but I would still keep it in physical (output) space. Once this is merged, we can add the latent loss.

@Jubeku
Copy link
Contributor Author

Jubeku commented Jul 24, 2025

Tests are currently failing in masking mode because of issue #553

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

Loss function refactoring
6 participants