Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions config/config_era5_split_causal.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
streams_directory: "./config/streams/era5_split/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 128
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 0.2
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_freeze_model: False
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

loss_fcts:
-
- "mse"
- 1.0
loss_fcts_val:
-
- "mse"
- 1.0

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.5
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination"
masking_strategy: "causal"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"probabilities": [0.34, 0.33, 0.33],
"hl_mask": 3, "mode": "per_cell",
"same_strategy_per_batch": false
}

num_epochs: 48
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "linear"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"

start_date: 197901010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
len_hrs: 12
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# Parameters for logging/printing in the training loop
train_log:
# The period to log metrics (in number of batch steps)
log_interval: 20
154 changes: 154 additions & 0 deletions config/config_era5_split_forecast.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
streams_directory: "./config/streams/era5_split/"

embed_orientation: "channels"
embed_local_coords: True
embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 128
ae_local_num_blocks: 2
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 0.2
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 1
forecast_delta_hrs: 0
forecast_steps: 2
forecast_policy: "fixed"
forecast_freeze_model: False
forecast_att_dense_rate: 1.0
fe_num_blocks: 8
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True

healpix_level: 5

with_mixed_precision: True
with_flash_attention: True
compile_model: False
with_fsdp: True
attention_dtype: bf16
mlp_norm_eps: 1e-5
norm_eps: 1e-4

latent_noise_kl_weight: 0.0 # 1e-5
latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

loss_fcts:
-
- "mse"
- 1.0
loss_fcts_val:
-
- "mse"
- 1.0

batch_size_per_gpu: 1
batch_size_validation_per_gpu: 1

# a regex that needs to fully match the name of the modules you want to freeze
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
# encoders and decoders that exist per stream have the stream name attached at the end
freeze_modules: ""

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "forecast"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.8
# sample the masking rate (with normal distribution centered at masking_rate)
# note that a sampled masking rate leads to varying requirements
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream)
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination"
masking_strategy: "random"
# masking_strategy_config is a dictionary of additional parameters for the masking strategy
# required for "healpix" and "channel" masking strategies
# "healpix": requires healpix mask level to be specified with `hl_mask`
# "channel": requires "mode" to be specified, "per_cell" or "global",
masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"probabilities": [0.34, 0.33, 0.33],
"hl_mask": 3, "mode": "per_cell",
"same_strategy_per_batch": false
}

num_epochs: 48
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "linear"
lr_policy_cooldown: "linear"

grad_clip: 1.0
weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"

start_date: 197901010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
len_hrs: 6
step_hrs: 6
input_window_steps: 1

val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]

istep: 0
run_history: []

desc: ""
data_loader_rng_seed: ???
run_id: ???

# Parameters for logging/printing in the training loop
train_log:
# The period to log metrics (in number of batch steps)
log_interval: 20
36 changes: 36 additions & 0 deletions config/streams/era5_split/era5_t.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# (C) Copyright 2024 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

ERA5_t :
type : anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
source : ['2t', 't_']
target : ['2t', 't_']
loss_weight : 1.
masking_rate_none : 0.05
token_size : 8
tokenize_spacetime : True
max_num_targets: -1
embed :
net : transformer
num_tokens : 1
num_heads : 8
dim_embed : 256
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 256
target_readout :
type : 'obs_value' # token or obs_value
num_layers : 2
num_heads : 4
# sampling_rate : 0.2
pred_head :
ens_size : 1
num_layers : 1
36 changes: 36 additions & 0 deletions config/streams/era5_split/era5_uv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# (C) Copyright 2024 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

ERA5_uv :
type : anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
source : ['10u', '10v', 'u_', 'v_']
target : ['10u', '10v', 'u_', 'v_']
loss_weight : 1.
masking_rate_none : 0.05
token_size : 8
tokenize_spacetime : True
max_num_targets: -1
embed :
net : transformer
num_tokens : 1
num_heads : 8
dim_embed : 256
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 256
target_readout :
type : 'obs_value' # token or obs_value
num_layers : 2
num_heads : 4
# sampling_rate : 0.2
pred_head :
ens_size : 1
num_layers : 1
Loading