diff --git a/config/default_config.yml b/config/default_config.yml index 56a7c3e25..c7459c5aa 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -66,6 +66,7 @@ latent_noise_gamma: 2.0 latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True +encode_targets_latent: False loss_fcts: - diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index c41916f9d..a996806e7 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -31,6 +31,7 @@ from weathergen.datasets.utils import ( compute_idxs_predict, compute_offsets_scatter_embed, + compute_offsets_scatter_embed_target_source_like, compute_source_cell_lens, ) from weathergen.utils.logger import logger @@ -348,6 +349,7 @@ def __iter__(self): rdata.datetimes, (time_win1.start, time_win1.end), ds, + "source_normalizer", ) stream_data.add_source(rdata_wrapped, ss_lens, ss_cells, ss_centroids) @@ -367,6 +369,7 @@ def __iter__(self): if rdata.is_empty(): stream_data.add_empty_target(fstep) + stream_data.add_empty_target_source_like(fstep) else: (tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target( stream_info, @@ -379,7 +382,32 @@ def __iter__(self): ds, ) + target_raw_source_like = torch.from_numpy( + np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) + ) + ( + tt_cells_source_like, + tt_lens_source_like, + tt_centroids_source_like, + ) = self.tokenizer.batchify_source( + stream_info, + torch.from_numpy(rdata.coords), + torch.from_numpy(rdata.geoinfos), + torch.from_numpy(rdata.data), + rdata.datetimes, + (time_win2.start, time_win2.end), + ds, + "target_normalizer", + ) + stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t) + stream_data.add_target_source_like( + fstep, + target_raw_source_like, + tt_lens_source_like, + tt_cells_source_like, + tt_centroids_source_like, + ) # merge inputs for sources and targets for current stream stream_data.merge_inputs() @@ -398,6 +426,7 @@ def __iter__(self): # compute offsets for scatter computation after embedding batch = compute_offsets_scatter_embed(batch) + batch = compute_offsets_scatter_embed_target_source_like(batch) # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index a5f12327e..e096a230b 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -66,6 +66,17 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i self.source_idxs_embed = torch.tensor([]) self.source_idxs_embed_pe = torch.tensor([]) + # below are for targets which are tokenized like sources + self.target_source_like_raw = [[] for _ in range(forecast_steps + 1)] + self.target_source_like_tokens_lens = [[] for _ in range(forecast_steps + 1)] + self.target_source_like_tokens_cells = [[] for _ in range(forecast_steps + 1)] + self.target_source_like_centroids = [[] for _ in range(forecast_steps + 1)] + + self.target_source_like_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)] + self.target_source_like_idxs_embed_pe = [ + torch.tensor([]) for _ in range(forecast_steps + 1) + ] + def to_device(self, device="cuda") -> None: """ Move data to GPU @@ -91,6 +102,26 @@ def to_device(self, device="cuda") -> None: self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True) self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True) + self.target_source_like_raw = [ + t.to(device, non_blocking=True) for t in self.target_source_like_raw + ] + self.target_source_like_tokens_lens = [ + t.to(device, non_blocking=True) for t in self.target_source_like_tokens_lens + ] + self.target_source_like_tokens_cells = [ + t.to(device, non_blocking=True) for t in self.target_source_like_tokens_cells + ] + self.target_source_like_centroids = [ + t.to(device, non_blocking=True) for t in self.target_source_like_centroids + ] + + self.target_source_like_idxs_embed = [ + t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed + ] + self.target_source_like_idxs_embed_pe = [ + t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed_pe + ] + return self def add_empty_source(self, source: IOReaderData) -> None: @@ -111,6 +142,24 @@ def add_empty_source(self, source: IOReaderData) -> None: self.source_tokens_cells += [torch.tensor([])] self.source_centroids += [torch.tensor([])] + def add_empty_target_source_like(self, fstep: int) -> None: + """ + Add an empty target for an input encoded like source. + Parameters + ---------- + None + Returns + ------- + None + """ + + self.target_source_like_raw[fstep] += [torch.tensor([])] + self.target_source_like_tokens_lens[fstep] += [ + torch.zeros([self.nhc_source], dtype=torch.int32) + ] + self.target_source_like_tokens_cells[fstep] += [torch.tensor([])] + self.target_source_like_centroids[fstep] += [torch.tensor([])] + def add_empty_target(self, fstep: int) -> None: """ Add an empty target for an input. @@ -159,6 +208,34 @@ def add_source( self.source_tokens_cells += [ss_cells] self.source_centroids += [ss_centroids] + def add_target_source_like( + self, + fstep: int, + tt_raw: torch.tensor, + tt_lens: torch.tensor, + tt_cells: list, + tt_centroids: list, + ) -> None: + """ + Add data for source for one input. + Parameters + ---------- + ss_raw : torch.tensor( number of data points in time window , number of channels ) + ss_lens : torch.tensor( number of healpix cells ) + ss_cells : list( number of healpix cells ) + [ torch.tensor( tokens per cell, token size, number of channels) ] + ss_centroids : list(number of healpix cells ) + [ torch.tensor( for source , 5) ] + Returns + ------- + None + """ + + self.target_source_like_raw[fstep] += [tt_raw] + self.target_source_like_tokens_lens[fstep] += [tt_lens] + self.target_source_like_tokens_cells[fstep] += [tt_cells] + self.target_source_like_centroids[fstep] += [tt_centroids] + def add_target( self, fstep: int, @@ -318,6 +395,39 @@ def merge_inputs(self) -> None: self.source_tokens_cells = torch.tensor([]) self.source_centroids = torch.tensor([]) + # collect all source like tokens in current stream and add to + # batch sample list when non-empty + for fstep in range(len(self.target_source_like_tokens_cells)): + if ( + torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum() + > 0 + ): + self.target_source_like_raw[fstep] = torch.cat(self.target_source_like_raw[fstep]) + + # collect by merging entries per cells, preserving cell structure + self.target_source_like_tokens_cells[fstep] = self._merge_cells( + self.target_source_like_tokens_cells[fstep], self.nhc_source + ) + self.target_source_like_centroids[fstep] = self._merge_cells( + self.target_source_like_centroids[fstep], self.nhc_source + ) + # lens can be stacked and summed + self.target_source_like_tokens_lens[fstep] = torch.stack( + self.target_source_like_tokens_lens[fstep] + ).sum(0) + + # remove NaNs + idx = torch.isnan(self.target_source_like_tokens_cells[fstep]) + self.target_source_like_tokens_cells[fstep][idx] = self.mask_value + idx = torch.isnan(self.target_source_like_centroids[fstep]) + self.target_source_like_centroids[fstep][idx] = self.mask_value + + else: + self.target_source_like_raw[fstep] = torch.tensor([]) + self.target_source_like_tokens_lens[fstep] = torch.zeros([self.nhc_source]) + self.target_source_like_tokens_cells[fstep] = torch.tensor([]) + self.target_source_like_centroids[fstep] = torch.tensor([]) + # targets for fstep in range(len(self.target_coords)): # collect all targets in current stream and add to batch sample list when non-empty diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py index 3b17fddb2..be1d4103f 100644 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ b/src/weathergen/datasets/tokenizer_forecast.py @@ -41,13 +41,20 @@ def batchify_source( source: np.array, times: np.array, time_win: tuple, - normalizer, # dataset + normalizer, # dataset, + use_normalizer: str, # "source_normalizer" or "target_normalizer" ): init_loggers() token_size = stream_info["token_size"] is_diagnostic = stream_info.get("diagnostic", False) tokenize_spacetime = stream_info.get("tokenize_spacetime", False) + channel_normalizer = ( + normalizer.normalize_source_channels + if use_normalizer == "source_normalizer" + else normalizer.normalize_target_channels + ) + tokenize_window = partial( tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, time_win=time_win, @@ -56,7 +63,7 @@ def batchify_source( hpy_verts_rots=self.hpy_verts_rots_source[-1], n_coords=normalizer.normalize_coords, n_geoinfos=normalizer.normalize_geoinfos, - n_data=normalizer.normalize_source_channels, + n_data=channel_normalizer, enc_time=encode_times_source, ) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 3da8508f6..cab667d9e 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -48,12 +48,19 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset + use_normalizer: str, # "source_normalizer" or "target_normalizer" ): init_loggers() token_size = stream_info["token_size"] is_diagnostic = stream_info.get("diagnostic", False) tokenize_spacetime = stream_info.get("tokenize_spacetime", False) + channel_normalizer = ( + normalizer.normalize_source_channels + if use_normalizer == "source_normalizer" + else normalizer.normalize_target_channels + ) + tokenize_window = partial( tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, time_win=time_win, @@ -62,7 +69,7 @@ def batchify_source( hpy_verts_rots=self.hpy_verts_rots_source[-1], n_coords=normalizer.normalize_coords, n_geoinfos=normalizer.normalize_geoinfos, - n_data=normalizer.normalize_source_channels, + n_data=channel_normalizer, enc_time=encode_times_source, ) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index b5d2279b8..1f9760c49 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -677,6 +677,84 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: return batch +def compute_offsets_scatter_embed_target_source_like(batch: StreamData) -> StreamData: + """ + Compute auxiliary information for scatter operation that changes from stream-centric to + cell-centric computations + + Parameters + ---------- + batch : str + batch of stream data information for which offsets have to be computed + + Returns + ------- + StreamData + stream data with offsets added as members + """ + + # collect source_tokens_lens for all stream datas + target_source_like_tokens_lens = torch.stack( + [ + torch.stack( + [ + torch.stack( + [ + s.target_source_like_tokens_lens[fstep] + if len(s.target_source_like_tokens_lens[fstep]) > 0 + else torch.tensor([]) + for fstep in range(len(s.target_source_like_tokens_lens)) + ] + ) + for s in stl_b + ] + ) + for stl_b in batch + ] + ) + + # precompute index sets for scatter operation after embed + offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) + # shift the offsets for each fstep by one to the right, add a zero to the + # beginning as the first token starts at 0 + zeros_col = torch.zeros( + (offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device + ) + offsets = torch.cat([zeros_col, offsets_base[:, :-1]], dim=1) + offsets_pe = torch.zeros_like(offsets) + + for ib, sb in enumerate(batch): + for itype, s in enumerate(sb): + for fstep in range(offsets.shape[0]): + if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty + s.target_source_like_idxs_embed[fstep] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int64) + for offset, token_len in zip( + offsets[fstep], + target_source_like_tokens_lens[ib, itype, fstep], + strict=False, + ) + ] + ) + s.target_source_like_idxs_embed_pe[fstep] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int32) + for offset, token_len in zip( + offsets_pe[fstep], + target_source_like_tokens_lens[ib][itype][fstep], + strict=False, + ) + ] + ) + + # advance offsets + offsets[fstep] += target_source_like_tokens_lens[ib][itype][fstep] + offsets_pe[fstep] += target_source_like_tokens_lens[ib][itype][fstep] + + return batch + + def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list: """ Compute auxiliary information for prediction diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 29bae5806..1bb2d6830 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -40,7 +40,7 @@ def __init__(self, cf: Config, sources_size) -> None: :param sources_size: List of source sizes for each stream. """ self.cf = cf - self.sources_size = sources_size # KCT:iss130, what is this? + self.sources_size = sources_size self.embeds = torch.nn.ModuleList() def create(self) -> torch.nn.ModuleList: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 136d96149..5f068e9d9 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -520,6 +520,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca # roll-out in latent space preds_all = [] + tokens_all = [tokens] for fstep in range(forecast_offset, forecast_offset + forecast_steps): # prediction preds_all += [ @@ -533,6 +534,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ] tokens = self.forecast(model_params, tokens) + tokens_all += [tokens] # prediction for final step preds_all += [ @@ -545,7 +547,31 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - return preds_all, posteriors + # now encode the targets into the latent space for all fsteps + if self.cf.get("encode_targets_latent", False): + with torch.no_grad(): + tokens_targets = [] + tokens_targets_source_like = self.embed_cells_targets_source_like( + model_params, streams_data + ) + for fstep in range(len(tokens_targets_source_like)): + if tokens_targets_source_like[fstep].sum() == 0: + # if the input is empty, return an empty tensor + tokens_targets.append(torch.tensor([]).detach()) + else: + tokens_target, _ = self.assimilate_local( + model_params, tokens_targets_source_like[fstep], source_cell_lens + ) + tokens_target = self.assimilate_global(model_params, tokens_target) + tokens_target_det = tokens_target.detach() # explicitly detach as well + tokens_targets.append(tokens_target_det) + + return_dict = {"preds_all": preds_all, "posteriors": posteriors} + if self.cf.get("encode_targets_latent", False): + return_dict["tokens_all"] = tokens_all + return_dict["tokens_targets"] = tokens_targets + + return return_dict ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: @@ -600,6 +626,76 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: return tokens_all + def embed_cells_targets_source_like( + self, model_params: ModelParams, streams_data + ) -> torch.Tensor: + """Embeds target data similar to source tokens for each fstep and stream separately and + rearranges it to cell-wise order + Args: + model_params : Query and embedding parameters + streams_data : Used to initialize first tokens for pre-processing + Returns: + Tokens for local assimilation + """ + with torch.no_grad(): + target_source_like_tokens_lens = torch.stack( + [ + torch.stack( + [ + torch.stack( + [ + s.target_source_like_tokens_lens[fstep] + if len(s.target_source_like_tokens_lens[fstep]) > 0 + else torch.tensor([]) + for fstep in range(len(s.target_source_like_tokens_lens)) + ] + ) + for s in stl_b + ] + ) + for stl_b in streams_data + ] + ) + offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) + num_fsteps = target_source_like_tokens_lens.shape[2] + tokens_all = [] + for fstep in range(num_fsteps): + tokens_all.append( + torch.empty( + (int(offsets_base[fstep][-1]), self.cf.ae_local_dim_embed), + dtype=self.dtype, + device="cuda", + ) + ) + + tokens_all_scattered = [] + for _, sb in enumerate(streams_data): + for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): + for fstep in range(num_fsteps): + if s.target_source_like_tokens_lens[fstep].sum() != 0: + idxs = s.target_source_like_idxs_embed[fstep] + idxs_pe = s.target_source_like_idxs_embed_pe[fstep] + + # create full scatter index + # (there's no broadcasting which is likely highly inefficient) + idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) + + x_embed = embed( + s.target_source_like_tokens_cells[fstep], + s.target_source_like_centroids[fstep], + ).flatten(0, 1) + + # scatter write to reorder from per stream to per cell ordering + tokens_all_fstep = tokens_all[fstep] + tokens_all_fstep.scatter_( + 0, idxs, x_embed + model_params.pe_embed[idxs_pe] + ) + tokens_all_scattered.append(tokens_all_fstep) + else: + tokens_all_scattered.append(torch.tensor([])) + + return tokens_all_scattered + ######################################### def assimilate_local( self, model_params: ModelParams, tokens: torch.Tensor, cell_lens: torch.Tensor diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f3eb8850e..6047d5162 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -501,9 +501,12 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, posteriors = self.ddp_model( + model_output = self.ddp_model( self.model_params, batch, cf.forecast_offset, forecast_steps ) + preds = model_output["preds_all"] + posteriors = model_output["posteriors"] + loss_values = self.loss_calculator.compute_loss( preds=preds, streams_data=batch[0], @@ -569,9 +572,10 @@ def validate(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, _ = self.ddp_model( + model_output = self.ddp_model( self.model_params, batch, cf.forecast_offset, forecast_steps ) + preds = model_output["preds_all"] # compute loss and log output if bidx < cf.log_validation: