Skip to content

introducing pereceiver-io style module for embedding layer #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: develop
Choose a base branch
from

Conversation

csjfwang
Copy link
Contributor

@csjfwang csjfwang commented Jul 3, 2025

Description

For channel-wise attention on ERA5, we currently use only 96 channels. As more channels are incorporated in the future, the computational burden will increase significantly. To address this, we introduce a Perceiver IO module before the channel-wise attention stage.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

Issue Number

Open #225

Code Compatibility

  • I have performed a self-review of my code

Code Performance and Testing

  • I ran the uv run train and (if necessary) uv run evaluate on a least one GPU node and it works
  • If the new feature introduces modifications at the config level, I have made sure to have notified the other software developers through Mattermost and updated the paths in the $WEATHER_GENERATOR_PRIVATE directory

Dependencies

  • I have ensured that the code is still pip-installable after the changes and runs
  • I have tested that new dependencies themselves are pip-installable.
  • I have not introduced new dependencies in the inference portion of the pipeline

Documentation

  • My code follows the style guidelines of this project
  • I have updated the documentation and docstrings to reflect the changes
  • I have added comments to my code, particularly in hard-to-understand areas

Additional Notes

wang85 and others added 12 commits June 12, 2025 16:56
…stream

replaced cross-attentionheadvarlen with cross-attentionhead
ecmwf#265)

* Adding base class for data readers, with anemoi data reader as example. Also implements (trivial) required changes in MultiStreamDataSampler.

* Ruffed

* Fixed various issues. Training with loss starting to converge. Much more testing needed.

* Fixing bug where cf.shuffle was not passed to data loader.

* Cleaned up problems that occured with arbitrary length window and offsets. Also introduced base class for time stepped datasets and moved function to get dataset indices there. Other smaller code improvements and documentation.

* Changes due to changed interface in TimeWindowHandler

* Fixed typo

* Fixed documentation

* Renaming

* Adapted obs data reader

* [100] Refactor of data readers into clessig's branch (ecmwf#306)

* changes

* changes

* configs

* changes

* changes

* changes

* changes

* better interface

* comments

* changes

* Fixed formatting of comment

* Reenabled obs data reader

* Switched to use of TRAIN and VAL stages flags.

* Fixed problem with target and source channels that can differ between train and val

* Clean up and also added check_reader_data

* Fixed problem that occured when dataset has no overlap with training / validation range. Also some restructuring to handle this properly.

* Added check for proper time inclusion in interval to check_data_reader. Code adaptations for handling of empty dataset.

* Fixed handling of subsampling_rate / frequency for anemoi dataset.

* Cleaned up special case handling and removed old comments

* Added missing handling of special case where t_win is inbetween two timesteps

* Renamed FESOM data reader and adapted to base class etc

* Fixed incorrect handling of shuffling

* Re-enabled FESOM data reader

* Added warning when dataset does not overlap time window of data loader

* Fixed missing check

* Adding fixed FESOM data reader. Needs to be verified

* Removing spurious files

* small changes

* style

* restored defaults

* Removing log messages that are too verbose

* Fixing sub-optimal solution with placement of warning

* Removing logging that breaks metric plotting (and seesm not sensible)

* Reenabling performance logging

* removing incorrect formatting

* Removed unused variable

* Fixed bug in evaluation with missing usage of new stages

* Removing stream files that need to be considered in more detail before going to develop

* Removing debug dependency

* Restoring defaults

* Introduced stream_info to base class. Cleanup.

---------

Co-authored-by: Timothy Hunter <[email protected]>
@tjhunter tjhunter marked this pull request as draft July 3, 2025 09:55
@clessig clessig requested a review from sophie-xhonneux July 7, 2025 14:42
@@ -52,6 +77,29 @@ def __init__(
self.dim_embed = dim_embed
self.dim_out = dim_out
self.num_blocks = num_blocks
if cross_attn_params is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a couple reasons it would be better to have an explicit flag, use Perceiver in EmbeddingNetwork, than trying to check if the config exists. Looking back at logs it is easier to always check the same flag than to verify something is not present

num_channels = self.num_queries
self.cross_attn_num_blocks = cross_attn_params.get("num_blocks", 1) if cross_attn_params else 1
self.cross_attn_num_heads = cross_attn_params.get("num_heads", 1) if cross_attn_params else 1
self.queries = torch.nn.Parameter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure this code wasn't formatted

self.cross_attn_num_blocks = cross_attn_params.get("num_blocks", 1) if cross_attn_params else 1
self.cross_attn_num_heads = cross_attn_params.get("num_heads", 1) if cross_attn_params else 1
self.queries = torch.nn.Parameter(
torch.randn(1, self.num_queries, self.dim_embed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these parameters should be initialised to me much smaller than randn, please use an explicit init strategy! So for gaussian the standard deviation should be something like 0.01 to 0.0005, but really there should be an init_weights function that explicitly initialises all the weights of the function.

self.queries.normal_(mean=0.0, std=1.0 / np.sqrt(self.dim_embed))
self.perceiver_io = PerceiverBlock(self.dim_embed, self.cross_attn_num_heads)
else:
self.perceiver_io = torch.nn.Identity()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we also set self.selu = torch.nn.Identity() here? else there is a difference between num_queries = 0 and no cross attention parameters which is odd

x = peh(self.selu(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False)))


if self.num_queries > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please annotate with some batch dimension etc stuff, right now this is very hard to review/follow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants