Skip to content

Implement per channel logginig again #440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: develop
Choose a base branch
from

Conversation

kacpnowak
Copy link
Contributor

@kacpnowak kacpnowak commented Jul 3, 2025

Description

Implements logging losses per channel, now taking into account that ranks can have different number of samples.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

Issue Number

Closes #282

Code Compatibility

  • I have performed a self-review of my code

Code Performance and Testing

  • I ran the uv run train and (if necessary) uv run evaluate on a least one GPU node and it works
  • If the new feature introduces modifications at the config level, I have made sure to have notified the other software developers through Mattermost and updated the paths in the $WEATHER_GENERATOR_PRIVATE directory

Dependencies

  • I have ensured that the code is still pip-installable after the changes and runs
  • I have tested that new dependencies themselves are pip-installable.
  • I have not introduced new dependencies in the inference portion of the pipeline

Documentation

  • My code follows the style guidelines of this project
  • I have updated the documentation and docstrings to reflect the changes
  • I have added comments to my code, particularly in hard-to-understand areas

Additional Notes

kacpnowak and others added 29 commits May 30, 2025 14:48
@kacpnowak kacpnowak marked this pull request as draft July 3, 2025 16:30
@kacpnowak
Copy link
Contributor Author

Closes #282

@kacpnowak kacpnowak marked this pull request as ready for review July 4, 2025 11:03
@kacpnowak
Copy link
Contributor Author

Fixes #282

Copy link
Collaborator

@tjhunter tjhunter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kacpnowak I have a few comments. It is tricky code and I am a bit limited for reviewing capacity in the coming days. Is it working as intended? It would be great if someone else tried it out as well. @clessig , if this current implementation is good for you, then I think someone else should have a look and try it too. Any thoughts?

Returns:
int: current rank
"""
if not dist.is_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_is_distributed_initialized

return dist.get_world_size()


def get_rank() -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update is_root. that function is following the best practices for pytorch. Maybe Seb Hoffman also has something to say about that part too.

return dist.get_rank()


def all_gather(data: Tensor) -> list[Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should make very explicit that this implementation does not allow gradient propagation (or does it? I would assume it breaks the tape tracking of the tensors but stranger things have happened in pytorch).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a great quesiton. I couldn't find anything explicit on this topic, but in my understanding it's not allowing gradients to flow. The gradient flag is preserved but once it's reconstructed it's detached from the autograd's graph

# Make list of losses into a tensor. This is individual tensor per rank
real_loss = torch.tensor(self.loss_model_hist, device=self.devices[0])
# Gather all tensors from all ranks into a list and stack them into one tensor again
real_loss = torch.cat(all_gather(real_loss))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised it works as expected

@clessig
Copy link
Collaborator

clessig commented Jul 7, 2025 via email

@tjhunter tjhunter requested a review from MatKbauer July 7, 2025 12:52
@Jubeku
Copy link
Contributor

Jubeku commented Jul 11, 2025

Testing on a single node on Leonardo: Training without error for 3h already.

Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kacpnowak : I just tried the code and wanted to plot the loss values. I get:

Traceback (most recent call last):
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/bin/plot_train", line 10, in <module>
    sys.exit(plot_train())
             ^^^^^^^^^^^^
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/plot_training.py", line 671, in plot_train
    runs_data = [TrainLogger.read(run_id, model_path=model_base_dir) for run_id in runs_ids]
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/train_logger.py", line 249, in read
    log_train_df = read_metrics(cf, run_id, "train", cols1, result_dir_base)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/train_logger.py", line 371, in read_metrics
    df = clean_df(df, cols)
         ^^^^^^^^^^^^^^^^^^
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/train_logger.py", line 395, in clean_df
    df = df.select(columns)
         ^^^^^^^^^^^^^^^^^^
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/lib/python3.12/site-packages/polars/dataframe/frame.py", line 9632, in select
    return self.lazy().select(*exprs, **named_exprs).collect(_eager=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/lib/python3.12/site-packages/polars/_utils/deprecation.py", line 88, in wrapper
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/lib/python3.12/site-packages/polars/lazyframe/frame.py", line 2188, in collect
    return wrap_df(ldf.collect(engine, callback))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
polars.exceptions.ColumnNotFoundError: loss_avg_0_mean

Resolved plan until failure:

        ---> FAILED HERE RESOLVING 'sink' <---
DF ["stream.NPPATMS.loss_mse.loss_obsvaluerawbt3", "stream.SurfaceCombined.loss_mse.loss_obsvaluet2m0", "stream.SurfaceCombined.loss_mse.loss_avg", "weathergen.time", ...]; PROJECT */114 COLUMNS
nacl@ac6-318:WeatherGenerator$ uv run plot_train

Can you implement this patch please:

--- a/src/weathergen/utils/train_logger.py
+++ b/src/weathergen/utils/train_logger.py
@@ -199,7 +199,7 @@ class TrainLogger:
 
         # define cols for training
         cols_train = ["dtime", "samples", "mse", "lr"]
-        cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_0_mean", "learning_rate"]
+        cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"]

plot_training.py and train_logger need to be adapted to allow one to select the columns that one would like to plot (loss_avg_mean is a good default but now I also want to plot q850 etc). Please open a PR on this.

I still need to test with mlflow.

@kacpnowak
Copy link
Contributor Author

Thanks for finding out this bug. I've patched it

kacpnowak and others added 3 commits July 16, 2025 18:17
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good and is working but can we replace the all_gather with one that doesn't scratch on the bytes. This would do the job:

def all_gather_vdim(tensor: torch.Tensor, group=None) -> list[torch.Tensor]:
    """Gather tensors with different number of dimensions."""
    world_size = dist.get_world_size(group=group)
    # Gather shapes first
    shapes = all_gather_vlen(
        torch.as_tensor(tensor.shape, device=tensor.device), group=group
    )
    # Gather data
    inputs = [tensor] * world_size
    outputs = [
        torch.empty(*_shape, dtype=tensor.dtype, device=tensor.device)
        for _shape in shapes
    ]
    dist.all_to_all(outputs, inputs, group=group)
    return outputs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement logging loss per-chanel
4 participants