-
Notifications
You must be signed in to change notification settings - Fork 27
introducing pereceiver-io style module for embedding layer #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
introducing pereceiver-io style module for embedding layer #438
Conversation
…stream replaced cross-attentionheadvarlen with cross-attentionhead
ecmwf#265) * Adding base class for data readers, with anemoi data reader as example. Also implements (trivial) required changes in MultiStreamDataSampler. * Ruffed * Fixed various issues. Training with loss starting to converge. Much more testing needed. * Fixing bug where cf.shuffle was not passed to data loader. * Cleaned up problems that occured with arbitrary length window and offsets. Also introduced base class for time stepped datasets and moved function to get dataset indices there. Other smaller code improvements and documentation. * Changes due to changed interface in TimeWindowHandler * Fixed typo * Fixed documentation * Renaming * Adapted obs data reader * [100] Refactor of data readers into clessig's branch (ecmwf#306) * changes * changes * configs * changes * changes * changes * changes * better interface * comments * changes * Fixed formatting of comment * Reenabled obs data reader * Switched to use of TRAIN and VAL stages flags. * Fixed problem with target and source channels that can differ between train and val * Clean up and also added check_reader_data * Fixed problem that occured when dataset has no overlap with training / validation range. Also some restructuring to handle this properly. * Added check for proper time inclusion in interval to check_data_reader. Code adaptations for handling of empty dataset. * Fixed handling of subsampling_rate / frequency for anemoi dataset. * Cleaned up special case handling and removed old comments * Added missing handling of special case where t_win is inbetween two timesteps * Renamed FESOM data reader and adapted to base class etc * Fixed incorrect handling of shuffling * Re-enabled FESOM data reader * Added warning when dataset does not overlap time window of data loader * Fixed missing check * Adding fixed FESOM data reader. Needs to be verified * Removing spurious files * small changes * style * restored defaults * Removing log messages that are too verbose * Fixing sub-optimal solution with placement of warning * Removing logging that breaks metric plotting (and seesm not sensible) * Reenabling performance logging * removing incorrect formatting * Removed unused variable * Fixed bug in evaluation with missing usage of new stages * Removing stream files that need to be considered in more detail before going to develop * Removing debug dependency * Restoring defaults * Introduced stream_info to base class. Cleanup. --------- Co-authored-by: Timothy Hunter <[email protected]>
@@ -52,6 +77,29 @@ def __init__( | |||
self.dim_embed = dim_embed | |||
self.dim_out = dim_out | |||
self.num_blocks = num_blocks | |||
if cross_attn_params is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a couple reasons it would be better to have an explicit flag, use Perceiver in EmbeddingNetwork, than trying to check if the config exists. Looking back at logs it is easier to always check the same flag than to verify something is not present
num_channels = self.num_queries | ||
self.cross_attn_num_blocks = cross_attn_params.get("num_blocks", 1) if cross_attn_params else 1 | ||
self.cross_attn_num_heads = cross_attn_params.get("num_heads", 1) if cross_attn_params else 1 | ||
self.queries = torch.nn.Parameter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty sure this code wasn't formatted
self.cross_attn_num_blocks = cross_attn_params.get("num_blocks", 1) if cross_attn_params else 1 | ||
self.cross_attn_num_heads = cross_attn_params.get("num_heads", 1) if cross_attn_params else 1 | ||
self.queries = torch.nn.Parameter( | ||
torch.randn(1, self.num_queries, self.dim_embed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these parameters should be initialised to me much smaller than randn, please use an explicit init strategy! So for gaussian the standard deviation should be something like 0.01 to 0.0005, but really there should be an init_weights function that explicitly initialises all the weights of the function.
self.queries.normal_(mean=0.0, std=1.0 / np.sqrt(self.dim_embed)) | ||
self.perceiver_io = PerceiverBlock(self.dim_embed, self.cross_attn_num_heads) | ||
else: | ||
self.perceiver_io = torch.nn.Identity() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we also set self.selu = torch.nn.Identity()
here? else there is a difference between num_queries = 0 and no cross attention parameters which is odd
x = peh(self.selu(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))) | ||
|
||
|
||
if self.num_queries > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please annotate with some batch dimension etc stuff, right now this is very hard to review/follow
Description
For channel-wise attention on ERA5, we currently use only 96 channels. As more channels are incorporated in the future, the computational burden will increase significantly. To address this, we introduce a Perceiver IO module before the channel-wise attention stage.
Type of Change
Issue Number
Open #225
Code Compatibility
Code Performance and Testing
uv run train
and (if necessary)uv run evaluate
on a least one GPU node and it works$WEATHER_GENERATOR_PRIVATE
directoryDependencies
Documentation
Additional Notes