-
Notifications
You must be signed in to change notification settings - Fork 27
Implement per channel logginig again #440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Implement per channel logginig again #440
Conversation
Apply the review
…ub.com/kacpnowak/WeatherGenerator2 into kacpnowak/develop/per-channel-logginig
Closes #282 |
Fixes #282 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kacpnowak I have a few comments. It is tricky code and I am a bit limited for reviewing capacity in the coming days. Is it working as intended? It would be great if someone else tried it out as well. @clessig , if this current implementation is good for you, then I think someone else should have a look and try it too. Any thoughts?
Returns: | ||
int: current rank | ||
""" | ||
if not dist.is_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_is_distributed_initialized
return dist.get_world_size() | ||
|
||
|
||
def get_rank() -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update is_root
. that function is following the best practices for pytorch. Maybe Seb Hoffman also has something to say about that part too.
src/weathergen/utils/distributed.py
Outdated
return dist.get_rank() | ||
|
||
|
||
def all_gather(data: Tensor) -> list[Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should make very explicit that this implementation does not allow gradient propagation (or does it? I would assume it breaks the tape tracking of the tensors but stranger things have happened in pytorch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a great quesiton. I couldn't find anything explicit on this topic, but in my understanding it's not allowing gradients to flow. The gradient flag is preserved but once it's reconstructed it's detached from the autograd's graph
src/weathergen/train/trainer.py
Outdated
# Make list of losses into a tensor. This is individual tensor per rank | ||
real_loss = torch.tensor(self.loss_model_hist, device=self.devices[0]) | ||
# Gather all tensors from all ranks into a list and stack them into one tensor again | ||
real_loss = torch.cat(all_gather(real_loss)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surprised it works as expected
I will try the code later today. But yes, would be great if others could test as well. Anyone from JSC?
________________________________
From: Timothy Hunter ***@***.***>
Sent: Monday, July 7, 2025 10:16:18 AM
To: ecmwf/WeatherGenerator ***@***.***>
Cc: Christian Lessig ***@***.***>; Mention ***@***.***>
Subject: Re: [ecmwf/WeatherGenerator] Implement per channel logginig again (PR #440)
@tjhunter commented on this pull request.
@kacpnowak<https://github.com/kacpnowak> I have a few comments. It is tricky code and I am a bit limited for reviewing capacity in the coming days. Is it working as intended? It would be great if someone else tried it out as well. @clessig<https://github.com/clessig> , if this current implementation is good for you, then I think someone else should have a look and try it too. Any thoughts?
________________________________
In src/weathergen/utils/distributed.py<#440 (comment)>:
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ """
+ Get current rank number
+
+ Returns:
+ int: current rank
+ """
+ if not dist.is_available():
_is_distributed_initialized
________________________________
In src/weathergen/utils/distributed.py<#440 (comment)>:
+
+def get_world_size() -> int:
+ """
+ Get MPI world size
+
+ Returns:
+ int: world size
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
please update is_root. that function is following the best practices for pytorch. Maybe Seb Hoffman also has something to say about that part too.
________________________________
In src/weathergen/utils/distributed.py<#440 (comment)>:
+
+def get_rank() -> int:
+ """
+ Get current rank number
+
+ Returns:
+ int: current rank
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def all_gather(data: Tensor) -> list[Tensor]:
you should make very explicit that this implementation does not allow gradient propagation (or does it? I would assume it breaks the tape tracking of the tensors but stranger things have happened in pytorch).
________________________________
In src/weathergen/train/trainer.py<#440 (comment)>:
+ Aggregates across ranks loss and standard deviation data for logging.
+
+ Returns:
+ real_loss (torch.Tensor): The scalar loss used for backpropagation.
+ losses_all (dict[str, torch.Tensor]): Dictionary mapping each stream name to its
+ per-channel loss tensor.
+ stddev_all (dict[str, torch.Tensor]): Dictionary mapping each stream name to its
+ per-channel standard deviation tensor.
+ """
+ losses_all: dict[str, Tensor] = {}
+ stddev_all: dict[str, Tensor] = {}
+
+ # Make list of losses into a tensor. This is individual tensor per rank
+ real_loss = torch.tensor(self.loss_model_hist, device=self.devices[0])
+ # Gather all tensors from all ranks into a list and stack them into one tensor again
+ real_loss = torch.cat(all_gather(real_loss))
I am surprised it works as expected
—
Reply to this email directly, view it on GitHub<#440 (review)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AHCHOHXDDEEYMRO52OHTIZL3HIUFFAVCNFSM6AAAAACAXI6N7OVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDSOJSGY2TKMJZGY>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Testing on a single node on Leonardo: Training without error for 3h already. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kacpnowak : I just tried the code and wanted to plot the loss values. I get:
Traceback (most recent call last):
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/bin/plot_train", line 10, in <module>
sys.exit(plot_train())
^^^^^^^^^^^^
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/plot_training.py", line 671, in plot_train
runs_data = [TrainLogger.read(run_id, model_path=model_base_dir) for run_id in runs_ids]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/train_logger.py", line 249, in read
log_train_df = read_metrics(cf, run_id, "train", cols1, result_dir_base)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/train_logger.py", line 371, in read_metrics
df = clean_df(df, cols)
^^^^^^^^^^^^^^^^^^
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/src/weathergen/utils/train_logger.py", line 395, in clean_df
df = df.select(columns)
^^^^^^^^^^^^^^^^^^
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/lib/python3.12/site-packages/polars/dataframe/frame.py", line 9632, in select
return self.lazy().select(*exprs, **named_exprs).collect(_eager=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/lib/python3.12/site-packages/polars/_utils/deprecation.py", line 88, in wrapper
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/h2resw01/hpcperm/nacl/WeatherGenerator/.venv/lib/python3.12/site-packages/polars/lazyframe/frame.py", line 2188, in collect
return wrap_df(ldf.collect(engine, callback))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
polars.exceptions.ColumnNotFoundError: loss_avg_0_mean
Resolved plan until failure:
---> FAILED HERE RESOLVING 'sink' <---
DF ["stream.NPPATMS.loss_mse.loss_obsvaluerawbt3", "stream.SurfaceCombined.loss_mse.loss_obsvaluet2m0", "stream.SurfaceCombined.loss_mse.loss_avg", "weathergen.time", ...]; PROJECT */114 COLUMNS
nacl@ac6-318:WeatherGenerator$ uv run plot_train
Can you implement this patch please:
--- a/src/weathergen/utils/train_logger.py
+++ b/src/weathergen/utils/train_logger.py
@@ -199,7 +199,7 @@ class TrainLogger:
# define cols for training
cols_train = ["dtime", "samples", "mse", "lr"]
- cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_0_mean", "learning_rate"]
+ cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"]
plot_training.py
and train_logger
need to be adapted to allow one to select the columns that one would like to plot (loss_avg_mean
is a good default but now I also want to plot q850 etc). Please open a PR on this.
I still need to test with mlflow.
Thanks for finding out this bug. I've patched it |
…g mean - Switched from all_gather to this function in trainer to robustly average - Some code cleanup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks good and is working but can we replace the all_gather
with one that doesn't scratch on the bytes. This would do the job:
def all_gather_vdim(tensor: torch.Tensor, group=None) -> list[torch.Tensor]:
"""Gather tensors with different number of dimensions."""
world_size = dist.get_world_size(group=group)
# Gather shapes first
shapes = all_gather_vlen(
torch.as_tensor(tensor.shape, device=tensor.device), group=group
)
# Gather data
inputs = [tensor] * world_size
outputs = [
torch.empty(*_shape, dtype=tensor.dtype, device=tensor.device)
for _shape in shapes
]
dist.all_to_all(outputs, inputs, group=group)
return outputs
…el-logginig/fix_comms Simpler, more robust communication using standard torch primitives
Description
Implements logging losses per channel, now taking into account that ranks can have different number of samples.
Type of Change
Issue Number
Closes #282
Code Compatibility
Code Performance and Testing
uv run train
and (if necessary)uv run evaluate
on a least one GPU node and it works$WEATHER_GENERATOR_PRIVATE
directoryDependencies
Documentation
Additional Notes