diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index da916413..7545dce1 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -48,7 +48,7 @@ def __init__( super(StreamEmbedTransformer, self).__init__() self.name = f"StreamEmbedder_{stream_name}" - + self.mode = mode self.num_tokens = num_tokens self.token_size = token_size self.num_channels = num_channels @@ -113,9 +113,8 @@ def __init__( ) else: - assert False + raise ValueError(f"Unknown unembed mode: {unembed_mode}") - self.forward = self.forward_channels elif mode == "columns": assert embed_size_centroids == 0 @@ -130,7 +129,6 @@ def __init__( self.num_tokens * ((self.dim_out - embed_size_centroids) // token_size), ) self.ln_final = norm(dim_out, eps=1e-6) - self.forward = self.forward_columns # TODO: factorization when sqrt is not int dim1 = int(np.sqrt(dim_out)) @@ -140,7 +138,7 @@ def __init__( self.unembed2 = torch.nn.Linear(self.token_size, dim1) else: - assert False + raise ValueError(f"Unknown mode: {mode}") self.dropout_final = torch.nn.Dropout(0.1) self.embed_centroids = torch.nn.Linear(5, embed_size_centroids) @@ -164,7 +162,7 @@ def forward_channels(self, x_in, centroids): ] out = torch.stack(out, dim=1).flatten(-2, -1) else: - assert False + raise ValueError(f"Unknown unembed mode: {self.unembed_mode}") # append centroids if self.embed_size_centroids > 0: @@ -195,6 +193,13 @@ def forward_columns(self, x_in, centroids): return out.to(torch.float16) + def forward(self, x_in, centroids): + if self.mode == "channels": + return self.forward_channels(x_in, centroids) + elif self.mode == "columns": + return self.forward_columns(x_in, centroids) + else: + raise ValueError(f"Unknown mode {self.mode}") class StreamEmbedLinear(torch.nn.Module): def __init__(self, dim_in, dim_out, stream_name="stream_embed"): diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a..1f2b3b2b 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -29,7 +29,7 @@ from weathergen.utils.utils import get_dtype -class EmbeddingEngine: +class EmbeddingEngine(torch.nn.Module): name: "EmbeddingEngine" def __init__(self, cf: Config, sources_size) -> None: @@ -39,16 +39,11 @@ def __init__(self, cf: Config, sources_size) -> None: :param cf: Configuration object containing parameters for the engine. :param sources_size: List of source sizes for each stream. """ + super(EmbeddingEngine, self).__init__() self.cf = cf self.sources_size = sources_size # KCT:iss130, what is this? self.embeds = torch.nn.ModuleList() - def create(self) -> torch.nn.ModuleList: - """ - Creates and returns the module list (embeds). - - :return: torch.nn.ModuleList containing the embedding layers. - """ for i, si in enumerate(self.cf.streams): stream_name = si.get("name", i) @@ -84,10 +79,53 @@ def create(self) -> torch.nn.ModuleList: ) else: raise ValueError("Unsupported embedding network type") - return self.embeds - + + def forward(self, streams_data, pe_embed, dtype, device): + source_tokens_lens = torch.stack( + [ + torch.stack( + [ + s.source_tokens_lens if len(s.source_tokens_lens) > 0 else torch.tensor([]) + for s in stl_b + ] + ) + for stl_b in streams_data + ] + ) + offsets_base = source_tokens_lens.sum(1).sum(0).cumsum(0) -class LocalAssimilationEngine: + tokens_all = torch.empty( + (int(offsets_base[-1]), self.cf.ae_local_dim_embed), dtype=dtype, device=device + ) + + for _, sb in enumerate(streams_data): + for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): + if not s.source_empty(): + idxs = s.source_idxs_embed.to(device) + idxs_pe = s.source_idxs_embed_pe.to(device) + + # create full scatter index + # (there's no broadcasting which is likely highly inefficient) + idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) + x_embed = embed(s.source_tokens_cells, s.source_centroids).flatten(0, 1) + # there's undocumented limitation in flash_attn that will make embed fail if + # #tokens is too large; code below is a work around + # x_embed = torch.cat( + # [ + # embed(s_c, c_c).flatten(0, 1) + # for s_c, c_c in zip( + # torch.split(s.source_tokens_cells, 49152), + # torch.split(s.source_centroids, 49152), + # ) + # ] + # ) + + # scatter write to reorder from per stream to per cell ordering + tokens_all.scatter_(0, idxs, x_embed + pe_embed[idxs_pe]) + return tokens_all + + +class LocalAssimilationEngine(torch.nn.Module): name: "LocalAssimilationEngine" def __init__(self, cf: Config) -> None: @@ -96,15 +134,10 @@ def __init__(self, cf: Config) -> None: :param cf: Configuration object containing parameters for the engine. """ + super(LocalAssimilationEngine, self).__init__() self.cf = cf self.ae_local_blocks = torch.nn.ModuleList() - def create(self) -> torch.nn.ModuleList: - """ - Creates and returns the module list (ae_local_blocks). - - :return: torch.nn.ModuleList containing the local assimilation blocks. - """ for _ in range(self.cf.ae_local_num_blocks): self.ae_local_blocks.append( MultiSelfAttentionHeadVarlen( @@ -128,10 +161,14 @@ def create(self) -> torch.nn.ModuleList: norm_eps=self.cf.mlp_norm_eps, ) ) - return self.ae_local_blocks + + def forward(self, tokens_c, cell_lens_c, use_reentrant): + for block in self.ae_local_blocks: + tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=use_reentrant) + return tokens_c -class Local2GlobalAssimilationEngine: +class Local2GlobalAssimilationEngine(torch.nn.Module): name: "Local2GlobalAssimilationEngine" def __init__(self, cf: Config) -> None: @@ -140,15 +177,10 @@ def __init__(self, cf: Config) -> None: :param cf: Configuration object containing parameters for the engine. """ + super(Local2GlobalAssimilationEngine, self).__init__() self.cf = cf self.ae_adapter = torch.nn.ModuleList() - def create(self) -> torch.nn.ModuleList: - """ - Creates and returns the module list (ae_adapter). - - :return: torch.nn.ModuleList containing the local-to-global assimilation adapter blocks. - """ self.ae_adapter.append( MultiCrossAttentionHeadVarlenSlicedQ( self.cf.ae_global_dim_embed, @@ -191,10 +223,21 @@ def create(self) -> torch.nn.ModuleList: attention_dtype=get_dtype(self.cf.attention_dtype), ) ) - return self.ae_adapter + + def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant): + for block in self.ae_adapter: + tokens_global_c = checkpoint( + block, + tokens_global_c, + tokens_c, + q_cells_lens_c, + cell_lens_c, + use_reentrant=use_reentrant, + ) + return tokens_global_c -class GlobalAssimilationEngine: +class GlobalAssimilationEngine(torch.nn.Module): name: "GlobalAssimilationEngine" def __init__(self, cf: Config, num_healpix_cells: int) -> None: @@ -204,17 +247,12 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: :param cf: Configuration object containing parameters for the engine. :param num_healpix_cells: Number of healpix cells used for local queries. """ + super(GlobalAssimilationEngine, self).__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells self.ae_global_blocks = torch.nn.ModuleList() - def create(self) -> torch.nn.ModuleList: - """ - Creates and returns the module list (ae_global_blocks). - - :return: torch.nn.ModuleList containing the global assimilation blocks. - """ global_rate = int(1 / self.cf.ae_global_att_dense_rate) for i in range(self.cf.ae_global_num_blocks): ## Alternate between local and global attention @@ -260,10 +298,14 @@ def create(self) -> torch.nn.ModuleList: norm_eps=self.cf.mlp_norm_eps, ) ) - return self.ae_global_blocks + + def forward(self, tokens, use_reentrant): + for block in self.ae_global_blocks: + tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) + return tokens -class ForecastingEngine: +class ForecastingEngine(torch.nn.Module): name: "ForecastingEngine" def __init__(self, cf: Config, num_healpix_cells: int) -> None: @@ -273,16 +315,11 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: :param cf: Configuration object containing parameters for the engine. :param num_healpix_cells: Number of healpix cells used for local queries. """ + super(ForecastingEngine, self).__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() - def create(self) -> torch.nn.ModuleList: - """ - Creates and returns the module list (fe_blocks). - - :return: torch.nn.ModuleList containing the forecasting blocks. - """ global_rate = int(1 / self.cf.forecast_att_dense_rate) if self.cf.forecast_policy is not None: for i in range(self.cf.fe_num_blocks): @@ -339,7 +376,11 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - return self.fe_blocks + def forward(self, tokens, use_reentrant): + for it, block in enumerate(self.fe_blocks): + aux_info = torch.tensor([it], dtype=torch.float32, device="cuda") + tokens = checkpoint(block, tokens, aux_info, use_reentrant=use_reentrant) + return tokens class EnsPredictionHead(torch.nn.Module): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f049caeb..e38a63db 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -260,11 +260,11 @@ def create(self) -> "Model": cf = self.cf # separate embedding networks for differnt observation types - self.embeds = EmbeddingEngine(cf, self.sources_size).create() + self.embed_engine = EmbeddingEngine(cf, self.sources_size) ############## # local assimilation engine - self.ae_local_blocks = LocalAssimilationEngine(cf).create() + self.ae_local_engine = LocalAssimilationEngine(cf) if cf.latent_noise_kl_weight > 0.0: self.interpolate_latents = LatentInterpolator( @@ -276,7 +276,7 @@ def create(self) -> "Model": ############## # local -> global assimilation engine adapter - self.ae_adapter = Local2GlobalAssimilationEngine(cf).create() + self.ae_local_global_engine = Local2GlobalAssimilationEngine(cf) ############## # learnable queries @@ -308,7 +308,7 @@ def create(self) -> "Model": ############## # global assimilation engine - self.ae_global_blocks = GlobalAssimilationEngine(cf, self.num_healpix_cells).create() + self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) ############### # forecasting engine @@ -321,7 +321,7 @@ def create(self) -> "Model": "Empty forecast engine (fe_num_blocks = 0), but forecast_steps[i] > 0 for some i" ) - self.fe_blocks = ForecastingEngine(cf, self.num_healpix_cells).create() + self.forecast_engine = ForecastingEngine(cf, self.num_healpix_cells) ############### # embed coordinates yielding one query token for each target token @@ -466,15 +466,15 @@ def print_num_parameters(self) -> None: """Print number of parameters for entire model and each module used to build the model""" cf = self.cf - num_params_embed = [get_num_parameters(embed) for embed in self.embeds] + num_params_embed = [get_num_parameters(embed) for embed in self.embed_engine.embeds] num_params_total = get_num_parameters(self) - num_params_ae_local = get_num_parameters(self.ae_local_blocks) - num_params_ae_global = get_num_parameters(self.ae_global_blocks) + num_params_ae_local = get_num_parameters(self.ae_local_engine.ae_local_blocks) + num_params_ae_global = get_num_parameters(self.ae_global_engine.ae_global_blocks) num_params_q_cells = np.prod(self.q_cells.shape) if self.q_cells.requires_grad else 0 - num_params_ae_adapater = get_num_parameters(self.ae_adapter) + num_params_ae_adapater = get_num_parameters(self.ae_local_global_engine.ae_adapter) - num_params_fe = get_num_parameters(self.fe_blocks) + num_params_fe = get_num_parameters(self.forecast_engine.fe_blocks) num_params_pred_adapter = [get_num_parameters(kv) for kv in self.pred_adapter_kv] num_params_embed_tcs = [get_num_parameters(etc) for etc in self.embed_target_coords] @@ -509,6 +509,45 @@ def print_num_parameters(self) -> None: ] print("-----------------") + ######################################### + def rename_old_state_dict(self, params: dict) -> dict: + """Checks if model from checkpoint is from the old model version and if so renames + the parameters accordingly to the new model version. + + Args: + params : Dictionary with (old) model parameters from checkpoint + Returns: + new_params : Dictionary with (renamed) model parameters + """ + params_cleanup = { + "embeds": "embed_engine.embeds", # EmbeddingEngine + "ae_local_blocks": "ae_local_engine.ae_local_blocks", # LocalAssimilationEngine + "ae_adapter": "ae_local_global_engine.ae_adapter", # Local2GlobalAssimilationEngine + "ae_global_blocks": "ae_global_engine.ae_global_blocks", # GlobalAssimilationEngine + "fe_blocks":"forecast_engine.fe_blocks" # ForecastingEngine + } + + new_params = {} + + for k, v in params.items(): + new_k = k + prefix = "" + + # Strip "module." (prefix for DataParallel or DistributedDataParallel) + if new_k.startswith("module."): + prefix = "module." + new_k = new_k[len(prefix):] + + first_w, rest = new_k.split(".", 1) if "." in new_k else (new_k, "") + # Only check first word (root level modules) to avoid false matches. + if first_w in params_cleanup: + new_k = params_cleanup[first_w] + "." + rest + + new_k = prefix + new_k + new_params[new_k] = v + + return new_params + ######################################### def load(self, run_id: str, epoch: str = -1) -> None: """Loads model state from checkpoint and checks for missing and unused keys. @@ -524,9 +563,14 @@ def load(self, run_id: str, epoch: str = -1) -> None: params = torch.load( path_run / filename, map_location=torch.device("cpu"), weights_only=True ) + + # Ensure backward compatibility with old model checkpoints + params = self.rename_old_state_dict(params) + params_renamed = {} for k in params.keys(): params_renamed[k.replace("module.", "")] = params[k] + mkeys, ukeys = self.load_state_dict(params_renamed, strict=False) # mkeys, ukeys = self.load_state_dict( params, strict=False) @@ -621,48 +665,9 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: Tokens for local assimilation """ - source_tokens_lens = torch.stack( - [ - torch.stack( - [ - s.source_tokens_lens if len(s.source_tokens_lens) > 0 else torch.tensor([]) - for s in stl_b - ] - ) - for stl_b in streams_data - ] - ) - offsets_base = source_tokens_lens.sum(1).sum(0).cumsum(0) device = next(self.parameters()).device - tokens_all = torch.empty( - (int(offsets_base[-1]), self.cf.ae_local_dim_embed), dtype=self.dtype, device=device - ) - - for _, sb in enumerate(streams_data): - for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): - if not s.source_empty(): - idxs = s.source_idxs_embed.to(device) - idxs_pe = s.source_idxs_embed_pe.to(device) - - # create full scatter index - # (there's no broadcasting which is likely highly inefficient) - idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) - x_embed = embed(s.source_tokens_cells, s.source_centroids).flatten(0, 1) - # there's undocumented limitation in flash_attn that will make embed fail if - # #tokens is too large; code below is a work around - # x_embed = torch.cat( - # [ - # embed(s_c, c_c).flatten(0, 1) - # for s_c, c_c in zip( - # torch.split(s.source_tokens_cells, 49152), - # torch.split(s.source_centroids, 49152), - # ) - # ] - # ) - - # scatter write to reorder from per stream to per cell ordering - tokens_all.scatter_(0, idxs, x_embed + model_params.pe_embed[idxs_pe]) - + tokens_all = self.embed_engine(streams_data, model_params.pe_embed, self.dtype, device) + return tokens_all ######################################### @@ -742,9 +747,9 @@ def assimilate_local( if l0 == l1 or tokens_c.shape[0] == 0: tokens_global_all += [tokens_global_c] continue - - for block in self.ae_local_blocks: - tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=False) + + # local assimilation model + tokens_c = self.ae_local_engine(tokens_c, cell_lens_c, use_reentrant=False) if self.cf.latent_noise_kl_weight > 0.0: tokens_c, posteriors_c = self.interpolate_latents.interpolate_with_noise( @@ -754,15 +759,7 @@ def assimilate_local( else: tokens_c, posteriors = tokens_c, 0.0 - for block in self.ae_adapter: - tokens_global_c = checkpoint( - block, - tokens_global_c, - tokens_c, - q_cells_lens_c, - cell_lens_c, - use_reentrant=False, - ) + tokens_global_c = self.ae_local_global_engine(tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant=False) tokens_global_all += [tokens_global_c] @@ -787,8 +784,7 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> """ # global assimilation engine and adapter - for block in self.ae_global_blocks: - tokens = checkpoint(block, tokens, use_reentrant=False) + tokens = self.ae_global_engine(tokens, use_reentrant=False) return tokens @@ -806,7 +802,7 @@ def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) ValueError: For unexpected arguments in checkpoint method """ - for block in self.fe_blocks: + for block in self.forecast_engine.fe_blocks: aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 55d30c7f..ae14c368 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -200,19 +200,19 @@ def init_model_and_shard(self, cf, devices): MultiSelfAttentionHeadVarlen, ) - for module in model.ae_local_blocks.modules(): + for module in model.ae_local_engine.ae_local_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in model.ae_adapter.modules(): + for module in model.ae_local_global_engine.ae_adapter.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in model.ae_global_blocks.modules(): + for module in model.ae_global_engine.ae_global_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in model.fe_blocks.modules(): + for module in model.forecast_engine.fe_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) @@ -246,7 +246,7 @@ def init_model_and_shard(self, cf, devices): # functions in the embedding engine as forward functions. Thus, yielding a crash # because the input tensors are not converted to DTensors. This seems to primarily # occur during validation. - for embed in model.embeds: + for embed in model.embed_engine.embeds: torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_channels") torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_columns") @@ -749,6 +749,9 @@ def load_model(self, run_id: str, epoch=-1): path_run / filename, map_location=torch.device("cpu"), mmap=True, weights_only=True ) + # Ensure backward compatibility with old model checkpoints + params = self.model.rename_old_state_dict(params) + model_state_dict = self.model.state_dict() params = { k: v