Skip to content

Commit 9bd5290

Browse files
committed
Simplify training
1 parent 0b2d50c commit 9bd5290

File tree

9 files changed

+58
-333
lines changed

9 files changed

+58
-333
lines changed

README.md

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,19 @@ python download_data.py -d era5 -s 1987-01-01 -e 2024-05-31
4646

4747
Mediterranean reanalysis
4848
```
49-
python prepare_states.py -d data/mediterranean/raw/reanalysis -o data/mediterranean/samples/train -n 6 -p rea_data -s 1987-01-01 -e 2021-11-30
50-
python prepare_states.py -d data/mediterranean/raw/reanalysis -o data/mediterranean/samples/val -n 6 -p rea_data -s 2021-11-01 -e 2021-12-31
51-
python prepare_states.py -d data/mediterranean/raw/reanalysis -o data/mediterranean/samples/test -n 16 -p rea_data -s 2022-01-01 -e 2022-07-31
49+
python prepare_states.py -d data/mediterranean/raw/reanalysis -o data/mediterranean/samples/train -n 6 -p rea_data -s 1987-01-01 -e 2021-12-31
5250
```
5351

5452
Mediterranean analysis
5553
```
56-
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/train -n 6 -p ana_data -s 2021-11-01 -e 2024-03-31
57-
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/val -n 6 -p ana_data -s 2024-04-01 -e 2024-05-31
54+
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/train -n 6 -p ana_data -s 2022-01-01 -e 2024-04-30
55+
python prepare_states.py -d data/mediterranean/raw/analysis -o data/mediterranean/samples/val -n 6 -p ana_data -s 2024-05-01 -e 2024-06-30
5856
```
5957

6058
ERA5
6159
```
62-
python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/samples/train -n 6 -p forcing -s 1987-01-01 -e 2024-03-31
63-
python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/samples/val -n 6 -p forcing -s 2021-11-01 -e 2021-12-31
64-
python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/samples/val -n 6 -p forcing -s 2024-04-01 -e 2024-05-31
65-
python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/samples/test -n 16 -p forcing -s 2022-01-01 -e 2022-07-31
60+
python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/samples/train -n 6 -p forcing -s 1987-01-01 -e 2024-04-30
61+
python prepare_states.py -d data/mediterranean/raw/era5 -o data/mediterranean/samples/val -n 6 -p forcing -s 2024-05-01 -e 2024-06-30
6662
```
6763

6864
Forecast data
@@ -111,13 +107,13 @@ wandb off
111107
SeaCast was trained on 4 nodes with 8 GPUs each:
112108
```
113109
python train_model.py \
114-
--data_subset reanalysis \
115110
--epochs 200 \
116111
--n_workers 4 \
117112
--batch_size 1 \
118113
--step_length 1 \
119114
--ar_steps 4 \
120115
--lr 0.001 \
116+
--optimizer momo_adam \
121117
--scheduler cosine \
122118
--finetune_start 0.6 \
123119
--model hi_lam \
@@ -126,16 +122,15 @@ python train_model.py \
126122
--hidden_dim 128 \
127123
--n_nodes 4
128124
```
129-
For finetuing update arguments `--data_subset analysis`, `--epochs 50` and `--lr 0.0001`.
130125

131126
For a full list of possible training options, check `python train_model.py --help`.
132127

133128
## Evaluation
134129

135-
SeaCast was evaluated on 1 GPU using `--eval test`, and by choosing the correct data subset + loading the appropriate model:
130+
SeaCast was evaluated on 1 GPU using `--eval test`:
136131
```
137132
python train_model.py \
138-
--data_subset analysis \
133+
--forcing_prefix aifs_forcing \
139134
--n_workers 4 \
140135
--batch_size 1 \
141136
--step_length 1 \
@@ -188,7 +183,8 @@ data
188183
│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py)
189184
│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py)
190185
│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py)
191-
│ ├── forcing_stats.pt - Mean and std.-dev. of forcing (create_parameter_weights.py)
186+
│ ├── forcing_mean.pt - Means of atmospheric forcing (create_parameter_weights.py)
187+
│ ├── forcing_std.pt - Std.-dev. of atmospheric forcing (create_parameter_weights.py)
192188
│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py)
193189
├── baltic
194190
├── ...

create_parameter_weights.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,29 @@ def main():
9090
) # (N_batch, d_features,)
9191

9292
# Atmospheric forcing at 1st windowed position
93-
forcing_batch = forcing_batch[:, :, :, :6]
94-
forcing_means.append(torch.mean(forcing_batch)) # (,)
95-
forcing_squares.append(torch.mean(forcing_batch**2)) # (,)
93+
forcing_batch = forcing_batch[
94+
:, :, :, :4
95+
] # (N_batch, N_t-2, N_grid, d_atm)
96+
forcing_means.append(
97+
torch.mean(forcing_batch, dim=(1, 2))
98+
) # (N_batch, d_atm)
99+
forcing_squares.append(
100+
torch.mean(forcing_batch**2, dim=(1, 2))
101+
) # (N_batch, d_atm)
96102

97103
mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features)
98104
second_moment = torch.mean(torch.cat(squares, dim=0), dim=0)
99105
std = torch.sqrt(second_moment - mean**2) # (d_features)
100106

101-
forcing_mean = torch.mean(torch.stack(forcing_means)) # (,)
102-
forcing_second_moment = torch.mean(torch.stack(forcing_squares)) # (,)
103-
forcing_std = torch.sqrt(forcing_second_moment - forcing_mean**2) # (,)
104-
forcing_stats = torch.stack((forcing_mean, forcing_std))
107+
forcing_mean = torch.mean(torch.cat(forcing_means, dim=0), dim=0) # (d_atm)
108+
forcing_second_moment = torch.mean(torch.cat(forcing_squares, dim=0), dim=0)
109+
forcing_std = torch.sqrt(forcing_second_moment - forcing_mean**2) # (d_atm)
105110

106111
print("Saving mean, std.-dev...")
107112
torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
108113
torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))
109-
torch.save(forcing_stats, os.path.join(static_dir_path, "forcing_stats.pt"))
114+
torch.save(forcing_mean, os.path.join(static_dir_path, "forcing_mean.pt"))
115+
torch.save(forcing_std, os.path.join(static_dir_path, "forcing_std.pt"))
110116

111117
# Compute mean and std.-dev. of one-step differences across the dataset
112118
print("Computing mean and std.-dev. for one-step differences...")

0 commit comments

Comments
 (0)