diff --git a/config/default_config.yml b/config/default_config.yml index 679f58dd3..bf1f1f160 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -24,7 +24,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 2 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -42,12 +42,12 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 1 +forecast_policy: "fixed" forecast_att_dense_rate: 1.0 -fe_num_blocks: 0 +fe_num_blocks: 2 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -93,7 +93,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) @@ -160,3 +160,36 @@ train_log_freq: terminal: 10 metrics: 20 checkpoint: 250 + +# Parameters for logging/printing in the training loop +train_log: + # The period to log metrics (in number of batch steps) + log_interval: 20 + +# Forecast MLP type: "dense" (default) or "moe" +fe_mlp_type: "dense" # set to "moe" to enable MoE +ae_global_mlp_type: "dense" # set to "moe" to enable MoE +ffn_mlp_type: "dense" # set to "moe" to enable MoE in the feed-forward network of the decoder blocks +decoder_mlp_type: "dense" # set to "moe" to enable MoE in the decoder prediction MLP +moe_lambda: 0.02 # coefficient for the MoE load balancing loss + +# MoE-only params (ignored when fe_mlp_type != "moe") +fe_moe_num_experts: 2 +fe_moe_top_k: 1 +fe_moe_hidden_factor: 0.5 # = HF_dense / 4 + +# MoE-only params (ignored when ae_global_mlp_type != "moe") +ae_global_moe_num_experts: 4 +ae_global_moe_top_k: 2 +ae_global_moe_hidden_factor: 0.5 # = HF_dense / 4 + +# MoE-only params (ignored when ffn_mlp_type != "moe") +ffn_moe_num_experts: 2 +ffn_moe_top_k: 1 +ffn_moe_hidden_factor: 0.5 # = HF_dense / 4 + +# MoE-only params (ignored when decoder_mlp_type != "moe") +decoder_moe_num_experts: 2 +decoder_moe_top_k: 1 +decoder_moe_hidden_factor: 0.5 # = HF_dense / 4 +tr_mlp_hidden_factor: 2 diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index 061928f64..94dc73e9d 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -14,10 +14,12 @@ MultiCrossAttentionHeadVarlen, MultiSelfAttentionHeadVarlen, ) -from weathergen.model.layers import MLP +from weathergen.model.layers import MLP, MoEMLP from weathergen.model.norms import AdaLayerNormLayer from weathergen.utils.utils import get_dtype +import logging +logger = logging.getLogger(__name__) class SelfAttentionBlock(nn.Module): """ @@ -43,14 +45,32 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = MLP( - dim_in=dim, - dim_out=dim, - hidden_factor=4, - dropout_rate=0.1, - nonlin=approx_gelu, - with_residual=False, - ) + use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe") + ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4) + moe_kwargs = kwargs.get("moe_kwargs", {}) # e.g. num_experts, top_k, router_noisy_std + + if use_moe_ffn: + self.mlp = MoEMLP( + dim_in=dim, + dim_out=dim, + hidden_factor=ffn_hidden_factor, + dropout_rate=0.1, + nonlin=nn.GELU, # internal block constructs nonlin() + with_residual=False, + norm_type=kwargs["attention_kwargs"]["norm_type"], + dim_aux=(dim_aux if self.with_adanorm else None), + norm_eps=kwargs["attention_kwargs"]["norm_eps"], + **moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01 + ) + else: + self.mlp = MLP( + dim_in=dim, + dim_out=dim, + hidden_factor=4, + dropout_rate=0.1, + nonlin=approx_gelu, + with_residual=False, + ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) self.mlp_block = AdaLayerNormLayer(dim, dim_aux, self.mlp_fn, dropout_rate) @@ -104,7 +124,7 @@ def __init__( self.with_adanorm = with_adanorm self.with_self_attn = with_self_attn - self.with_mlp = with_self_attn + self.with_mlp = with_mlp if with_self_attn: self.mhsa = MultiSelfAttentionHeadVarlen( @@ -136,18 +156,37 @@ def __init__( if self.with_mlp: approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = MLP( - dim_in=dim_q, - dim_out=dim_q, - hidden_factor=4, - nonlin=approx_gelu, - with_residual=False, - ) + + use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe") + ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4) + moe_kwargs = kwargs.get("moe_kwargs", {}) + + if use_moe_ffn: + self.mlp = MoEMLP( + dim_in=dim_q, + dim_out=dim_q, + hidden_factor=ffn_hidden_factor, + dropout_rate=0.1, + nonlin=nn.GELU, # internal block constructs nonlin() + with_residual=False, + norm_type=kwargs["attention_kwargs"]["norm_type"], + dim_aux=(dim_aux if self.with_adanorm else None), + norm_eps=kwargs["attention_kwargs"]["norm_eps"], + **moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01 + ) + else: + self.mlp = MLP( + dim_in=dim_q, + dim_out=dim_q, + hidden_factor=4, + nonlin=approx_gelu, + with_residual=False, + ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) self.mlp_block = AdaLayerNormLayer(dim_q, dim_aux, self.mlp_fn, dropout_rate) else: - self.ln_mlp = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) + self.ln_mlp = nn.LayerNorm(eps=kwargs["attention_kwargs"]["norm_eps"]) self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x)) + x else: self.mlp_block = lambda x, _, **kwargs: x @@ -191,6 +230,7 @@ def __init__( tr_mlp_hidden_factor, tro_type, mlp_norm_eps=1e-6, + **kwargs, ): super().__init__() @@ -237,19 +277,46 @@ def __init__( ) # MLP Block - self.block.append( - MLP( - dim_in, - dim_out, - with_residual=True, - hidden_factor=self.tr_mlp_hidden_factor, - dropout_rate=0.1, # Assuming dropout_rate is 0.1 - norm_type=self.cf.norm_type, - dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), - norm_eps=self.cf.mlp_norm_eps, - ) + # Add MoE option + use_moe = getattr(self.cf, "decoder_mlp_type", "dense") == "moe" + logger.info( + "[MoE] Decoder head: type=%s%s", + "moe" if use_moe else "dense", + ("" if not use_moe else + f" (experts={getattr(self.cf,'moe_num_experts',None)}, top_k={getattr(self.cf,'moe_top_k',None)})"), ) + if use_moe: + self.block.append( + MoEMLP( + dim_in, + dim_out, + hidden_factor=self.tr_mlp_hidden_factor, + dropout_rate=0.1, + with_residual=True, # mirror dense + norm_type=self.cf.norm_type, + dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), + norm_eps=self.cf.mlp_norm_eps, + num_experts=getattr(self.cf, "moe_num_experts", 8), + top_k=getattr(self.cf, "moe_top_k", 2), + router_noisy_std=getattr(self.cf, "moe_router_noisy_std", 0.0), + use_checkpoint=getattr(self.cf, "moe_use_checkpoint", False), + ) + ) + else: + self.block.append( + MLP( + dim_in, + dim_out, + with_residual=True, + hidden_factor=self.tr_mlp_hidden_factor, + dropout_rate=0.1, # Assuming dropout_rate is 0.1 + norm_type=self.cf.norm_type, + dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), + norm_eps=self.cf.mlp_norm_eps, + ) + ) + def forward(self, latent, output, coords, latent_lens, output_lens): for layer in self.block: if isinstance(layer, MultiCrossAttentionHeadVarlen): diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a6..2a7a9721e 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -24,10 +24,12 @@ StreamEmbedLinear, StreamEmbedTransformer, ) -from weathergen.model.layers import MLP +from weathergen.model.layers import MLP, MoEMLP from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype +import logging +logger = logging.getLogger(__name__) class EmbeddingEngine: name: "EmbeddingEngine" @@ -249,17 +251,50 @@ def create(self) -> torch.nn.ModuleList: ) ) # MLP block - self.ae_global_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.ae_global_dropout_rate, - hidden_factor=self.cf.ae_global_mlp_hidden_factor, - norm_type=self.cf.norm_type, - norm_eps=self.cf.mlp_norm_eps, - ) + # Add MoE option + use_moe = getattr(self.cf, "ae_global_mlp_type", "dense") == "moe" + mlp_common_kwargs = dict( + dim_in=self.cf.ae_global_dim_embed, + dim_out=self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.ae_global_dropout_rate, + norm_type=self.cf.norm_type, + norm_eps=self.cf.mlp_norm_eps, ) + if use_moe: + self.ae_global_blocks.append( + MoEMLP( + **mlp_common_kwargs, + num_experts=getattr(self.cf, "ae_global_moe_num_experts", 2), + top_k=getattr(self.cf, "ae_global_moe_top_k", 1), + router_noisy_std=getattr(self.cf, "ae_global_moe_router_noisy_std", 0.0), + hidden_factor=getattr(self.cf, "ae_global_moe_hidden_factor", 2), + ) + ) + else: + self.ae_global_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.ae_global_dropout_rate, + hidden_factor=self.cf.ae_global_mlp_hidden_factor, + norm_type=self.cf.norm_type, + norm_eps=self.cf.mlp_norm_eps, + ) + ) + # Count MoE blocks + num_moe = sum(1 for m in self.ae_global_blocks if isinstance(m, MoEMLP)) + logger.info( + "[MoE] GlobalAssimilationEngine: %d MoEMLP blocks " + "(ae_global_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)", + num_moe, + getattr(self.cf, "ae_global_mlp_type", "dense"), + getattr(self.cf, "ae_global_moe_num_experts", None), + getattr(self.cf, "ae_global_moe_top_k", None), + getattr(self.cf, "ae_global_moe_hidden_factor", None), + ) + return self.ae_global_blocks @@ -318,27 +353,68 @@ def create(self) -> torch.nn.ModuleList: ) ) # Add MLP block - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=1, - norm_eps=self.cf.mlp_norm_eps, - ) + use_moe = getattr(self.cf, "fe_mlp_type", "dense") == "moe" + mlp_common_kwargs = dict( + dim_in=self.cf.ae_global_dim_embed, + dim_out=self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, ) - - def init_weights_final(m): - if isinstance(m, torch.nn.Linear): - torch.nn.init.normal_(m.weight, mean=0, std=0.001) - if m.bias is not None: - torch.nn.init.normal_(m.bias, mean=0, std=0.001) - - for block in self.fe_blocks: - block.apply(init_weights_final) - + # self.fe_blocks.append( + # MLP( + # self.cf.ae_global_dim_embed, + # self.cf.ae_global_dim_embed, + # with_residual=True, + # dropout_rate=self.cf.fe_dropout_rate, + # norm_type=self.cf.norm_type, + # dim_aux=1, + # norm_eps=self.cf.mlp_norm_eps, + # ) + # ) + if use_moe: + self.fe_blocks.append( + MoEMLP( + **mlp_common_kwargs, + num_experts=getattr(self.cf, "fe_moe_num_experts", 2), + top_k=getattr(self.cf, "fe_moe_top_k", 2), + router_noisy_std=getattr(self.cf, "fe_moe_router_noisy_std", 0.0), + hidden_factor=getattr(self.cf, "fe_moe_hidden_factor", 2), + ) + ) + else: + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, + ) + ) + # ------------------------------------------------------------------ + # def init_weights_final(m): + # if isinstance(m, torch.nn.Linear) and not getattr(m, "is_moe_router", False): + # torch.nn.init.normal_(m.weight, mean=0, std=0.001) + # if m.bias is not None: + # torch.nn.init.normal_(m.bias, mean=0, std=0.001) + + # for block in self.fe_blocks: + # block.apply(init_weights_final) + num_moe = sum(1 for m in self.fe_blocks if isinstance(m, MoEMLP)) + logger.info( + "[MoE] ForecastingEngine: %d MoEMLP blocks " + "(fe_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)", + num_moe, + getattr(self.cf, "fe_mlp_type", "dense"), + getattr(self.cf, "fe_moe_num_experts", None), + getattr(self.cf, "fe_moe_top_k", None), + getattr(self.cf, "fe_moe_hidden_factor", None), + ) return self.fe_blocks @@ -587,6 +663,14 @@ def __init__( with_adanorm=False, with_mlp=False, attention_kwargs=attention_kwargs, + ffn_mlp_type=getattr(self.cf, "decoder_ffn_mlp_type", "dense"), + ffn_hidden_factor=getattr(self.cf, "decoder_ffn_hidden_factor", 4), + moe_kwargs=dict( + num_experts=getattr(self.cf, "decoder_moe_num_experts", 2), + top_k=getattr(self.cf, "decoder_moe_top_k", 2), + router_noisy_std=getattr(self.cf, "decoder_moe_router_noisy_std", 0.0), + use_checkpoint=getattr(self.cf, "decoder_moe_use_checkpoint", False), + ) ) ) elif self.cf.decoder_type == "AdaLayerNormConditioning": @@ -642,6 +726,14 @@ def __init__( tr_mlp_hidden_factor=tr_mlp_hidden_factor, tro_type=tro_type, mlp_norm_eps=self.cf.mlp_norm_eps, + ffn_mlp_type=getattr(self.cf, "decoder_ffn_mlp_type", "dense"), + ffn_hidden_factor=getattr(self.cf, "decoder_ffn_hidden_factor", 4), + moe_kwargs=dict( + num_experts=getattr(self.cf, "decoder_moe_num_experts", 2), + top_k=getattr(self.cf, "decoder_moe_top_k", 2), + router_noisy_std=getattr(self.cf, "decoder_moe_router_noisy_std", 0.0), + use_checkpoint=getattr(self.cf, "decoder_moe_use_checkpoint", False), + ) ) ) else: diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..acd06b721 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from typing import Optional, Tuple, Dict, Any from weathergen.model.norms import AdaLayerNorm, RMSNorm @@ -93,3 +94,261 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x + +class _DenseBlock(nn.Module): + """A tiny FFN that mirrors the structure of the current MLP stack.""" + def __init__(self, dim_in, dim_hidden, dim_out, num_layers=2, + nonlin=nn.GELU, dropout_rate=0.0): + super().__init__() + layers = [nn.Linear(dim_in, dim_hidden), nonlin(), nn.Dropout(dropout_rate)] + for _ in range(num_layers - 2): + layers += [nn.Linear(dim_hidden, dim_hidden), nonlin(), nn.Dropout(dropout_rate)] + layers += [nn.Linear(dim_hidden, dim_out)] + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + + +class MoEMLP(nn.Module): + """ + Memory-friendly MoE MLP. + + Features + -------- + - Matches MLP call pattern: forward(*args) where args=(x, ...) and optional aux at the end + - Optional AdaLayerNorm pre-norm when dim_aux is provided + - Top-k routing with softmax over selected logits + - Streams experts and accumulates outputs (no large [E, ..., D] stacks) + - Optional auxiliary outputs (gate loss, route histogram) via `return_aux` + + Notes + ----- + - If `return_aux=False` (default), we still *compute* the aux loss (with grads) and stash it + on `self.last_aux` and `self.last_aux_loss` so you can read it after forward if desired. + - To actively use the load-balancing loss in training, either set `return_aux=True` and add it + to your loss, or read `self.last_aux['gate_loss']` from the module instance. + """ + def __init__( + self, + dim_in: int, + dim_out: int, + num_layers: int = 2, + hidden_factor: float = 2.0, + pre_layer_norm: bool = True, + dropout_rate: float = 0.0, + nonlin=nn.GELU, + with_residual: bool = False, + norm_type: str = "LayerNorm", + dim_aux: Optional[int] = None, + norm_eps: float = 1e-5, + name: Optional[str] = None, + # MoE + num_experts: int = 8, + top_k: int = 4, + router_noisy_std: float = 0.0, + # Memory + use_checkpoint: bool = False, + # API + return_aux: bool = False, + ): + super().__init__() + if name is not None: + self.name = name + + assert num_layers >= 2, "MoEMLP requires at least 2 layers" + assert 1 <= top_k <= num_experts, "top_k must be in [1, num_experts]" + + self.with_residual = with_residual + self.with_aux = dim_aux is not None + self.pre_layer_norm = pre_layer_norm + self.top_k = top_k + self.num_experts = num_experts + self.router_noisy_std = router_noisy_std + self.use_checkpoint = use_checkpoint + self.return_aux = return_aux + self.enable_gate_loss = True + + self.register_buffer("usage_buf", torch.zeros(num_experts), persistent=False) + dim_hidden = int(dim_in * hidden_factor) + + # Norm (match MLP behavior) + Norm = nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm + if pre_layer_norm: + self.norm = ( + Norm(dim_in, eps=norm_eps) + if dim_aux is None + else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) + ) + else: + self.norm = None + + # Router + self.router = nn.Linear(dim_in, num_experts) + # Recommended init: small std, zero bias + nn.init.normal_(self.router.weight, mean=0.0, std=1e-2) + nn.init.constant_(self.router.bias, 0.0) + + # Experts + self.experts = nn.ModuleList( + [ + _DenseBlock( + dim_in=dim_in, + dim_hidden=dim_hidden, + dim_out=dim_out, + num_layers=num_layers, + nonlin=nonlin, + dropout_rate=dropout_rate, + ) + for _ in range(num_experts) + ] + ) + + # Stashed aux for consumers that don't use return_aux + self.register_buffer("last_aux_loss", torch.zeros((), dtype=torch.float32)) + self.last_aux: Dict[str, torch.Tensor] = {} + + def _gate(self, x_norm: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Returns: + weights: [..., E] if top_k == E else [..., K] + top_idx: None if full softmax, else [..., K] int indices + """ + logits = self.router(x_norm) + if self.router_noisy_std > 0: + logits = logits + torch.randn_like(logits) * self.router_noisy_std + + if self.top_k == self.num_experts: + weights = torch.softmax(logits, dim=-1) + top_idx = None + else: + top_vals, top_idx = torch.topk(logits, k=self.top_k, dim=-1) + weights = torch.softmax(top_vals, dim=-1) + return weights, top_idx + + def _compute_load_balance_aux( + self, weights: torch.Tensor, top_idx: Optional[torch.Tensor], num_experts: int + ) -> torch.Tensor: + """ + Cross-entropy between observed expert usage and uniform 1/E target. + Works for both full-softmax and top-k. + """ + if top_idx is None: + # weights over E -> average across batch/time dims + probs = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E] + else: + # Aggregate usage from top-k selections + if weights.shape != top_idx.shape: + raise ValueError("Top-k weights and indices must share the same shape") + K = weights.shape[-1] + flat_w = weights.reshape(-1, K) # [N, K] + flat_i = top_idx.reshape(-1, K) # [N, K] + usage = torch.zeros(num_experts, device=weights.device, dtype=weights.dtype) + usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1)) + probs = usage / usage.sum().clamp_min(1e-6) # [E] + + E = num_experts + target = torch.full_like(probs, 1.0 / E) + aux = (probs * (probs.add(1e-6).log() - target.add(1e-6).log())).sum() + return aux + + def forward(self, *args): + """ + Args: + *args: expects x first; if AdaLN is enabled (dim_aux != None), the last arg is aux. + + Returns: + y or (y, aux_out) depending on `self.return_aux`. + aux_out = {"gate_loss": ..., "route_hist": ...} + """ + x = args[0] + x_in = x + aux_in = args[-1] if self.with_aux else None + + # Optional pre-norm + if self.norm is not None: + x = self.norm(x, aux_in) if self.with_aux else self.norm(x) + + # Routing + weights, top_idx = self._gate(x) # [..., E] or [..., K] + + # Build full weights when in top-k mode to stream experts + if top_idx is None: + w_full = weights # [..., E] + else: + E = self.num_experts + w_full = torch.zeros(*weights.shape[:-1], E, device=weights.device, dtype=weights.dtype) + w_full.scatter_(-1, top_idx, weights) + + # Accumulate outputs without stacking + out_dim = self.experts[0].net[-1].out_features # last Linear of _DenseBlock + y = x.new_zeros(*x.shape[:-1], out_dim) + + if self.use_checkpoint and self.training: + from torch.utils.checkpoint import checkpoint + + for e, expert in enumerate(self.experts): + w_e = w_full[..., e] # [...] + # Skip experts with (near) zero mass + if w_e.abs().max() <= 1e-12: + continue + y_e = expert(x) if not (self.use_checkpoint and self.training) else checkpoint(expert, x) + y = y + y_e * w_e.unsqueeze(-1) + + # Residual + if self.with_residual: + if y.shape[-1] == x_in.shape[-1]: + y = x_in + y + else: + assert y.shape[-1] % x_in.shape[-1] == 0 + y = y + x_in.repeat([*[1 for _ in y.shape[:-1]], y.shape[-1] // x_in.shape[-1]]) + + # # Aux outputs (WITH grads so router learns; also stash for external access) + # aux_out: Dict[str, Any] = {} + # gate_loss = self._compute_load_balance_aux(weights, top_idx, self.num_experts) + # aux_out["gate_loss"] = gate_loss + + # # utilization histogram (for logging) + # if top_idx is None: + # aux_out["route_hist"] = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E] + # else: + # K = weights.shape[-1] + # flat_w = weights.reshape(-1, K) + # flat_i = top_idx.reshape(-1, K) + # usage = torch.zeros(self.num_experts, device=weights.device, dtype=weights.dtype) + # usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1)) + # aux_out["route_hist"] = usage / usage.sum().clamp_min(1e-6) # [E] + + # # stash for consumers that don't use return_aux + # self.last_aux = aux_out + # self.last_aux_loss = gate_loss + + # return (y, aux_out) if self.return_aux else y + # --- Aux outputs (gate loss + route hist) --- + aux_out: Dict[str, Any] = {} + if self.enable_gate_loss: + gate_loss = self._compute_load_balance_aux(weights, top_idx, self.num_experts) + aux_out["gate_loss"] = gate_loss + + # utilization histogram (for logging / debugging only) + if top_idx is None: + aux_out["route_hist"] = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E] + else: + K = weights.shape[-1] + flat_w = weights.reshape(-1, K) + flat_i = top_idx.reshape(-1, K) + usage = self.usage_buf + usage = usage.to(weights.device, dtype=weights.dtype) + usage.zero_() + usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1)) + aux_out["route_hist"] = usage / usage.sum().clamp_min(1e-6) # [E] + else: + # no aux computation this step + pass + + # stash + self.last_aux = aux_out + if "gate_loss" in aux_out: + self.last_aux_loss = aux_out["gate_loss"] + + return (y, aux_out) if self.return_aux else y \ No newline at end of file diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8f26da14d..2c2396c84 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -40,7 +40,7 @@ MultiSelfAttentionHeadVarlen, ) from weathergen.model.ema import EMAModel -from weathergen.model.layers import MLP +from weathergen.model.layers import MLP, MoEMLP from weathergen.model.model import Model, ModelParams from weathergen.model.utils import freeze_weights from weathergen.train.loss_calculator import LossCalculator @@ -154,6 +154,14 @@ def inference(self, cf, devices, run_id_trained, epoch): self.validate(epoch=0) logger.info(f"Finished inference run with id: {cf.run_id}") + def _ensure_moe_modules_cached(self): + # Works with plain, DDP-wrapped, FSDP, or compiled models + from weathergen.model.layers import MoEMLP + m = self.model + if hasattr(m, "module"): # DDP + m = m.module + self.moe_modules = [x for x in m.modules() if isinstance(x, MoEMLP)] + def init_model_and_shard(self, cf, devices): sources_size = self.dataset.get_sources_size() targets_num_channels = self.dataset.get_targets_num_channels() @@ -197,6 +205,7 @@ def init_model_and_shard(self, cf, devices): MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, MultiSelfAttentionHeadVarlen, + MoEMLP, ) for module in model.ae_local_blocks.modules(): @@ -239,6 +248,7 @@ def init_model_and_shard(self, cf, devices): fully_shard(model) for tensor in itertools.chain(model.parameters(), model.buffers()): assert tensor.device == torch.device("meta") + return model, model_params def run(self, cf, devices, run_id_contd=None, epoch_contd=None): @@ -282,7 +292,7 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): ) self.model, self.model_params = self.init_model_and_shard(cf, devices) - + self._ensure_moe_modules_cached() if run_id_contd is None: self.model.to_empty(device="cuda") self.model.reset_parameters() @@ -560,7 +570,12 @@ def train(self, epoch): for bidx, batch in enumerate(dataset_iter): forecast_steps = batch[-1] batch = self.batch_to_device(batch) - + interval = max(1, int(getattr(self.cf, "moe_loss_interval", 1))) + collect = (self.cf.istep % interval) == 0 + for m in self.moe_modules: + # only set if MoEMLP implements the flag + if hasattr(m, "enable_gate_loss"): + m.enable_gate_loss = collect # evaluate model with torch.autocast( device_type=f"cuda:{cf.local_rank}", @@ -577,50 +592,126 @@ def train(self, epoch): if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in posteriors]) loss_values.loss += cf.latent_noise_kl_weight * kl.mean() + + # MoE gate loss + moe_lambda_base = float(getattr(self.cf, "moe_lambda", 0.02)) + if moe_lambda_base and collect: + # optional warmup + warm = int(getattr(self.cf, "moe_lambda_warmup_steps", 0)) + warm_mult = 1.0 if warm <= 0 else min(1.0, self.cf.istep / float(warm)) + + gate_loss = None + for m in self.moe_modules: + la = getattr(m, "last_aux", None) + if isinstance(la, dict) and ("gate_loss" in la): + gate_loss = la["gate_loss"] if gate_loss is None else (gate_loss + la["gate_loss"]) + + if gate_loss is not None: + # scale λ by interval so average gradient matches per-step application + effective_lambda = moe_lambda_base * interval * warm_mult + loss_values.loss = loss_values.loss + effective_lambda * gate_loss + + # moe_lambda = getattr(self.cf, "moe_lambda", 0.0) + # gate_loss = None + # if moe_lambda != 0.0: + # gate_loss = torch.zeros((), device=self.device) + # route_hists = [] + + # for m in self.model.modules(): + # if isinstance(m, MoEMLP) and hasattr(m, "last_aux"): + # la = m.last_aux + # if isinstance(la, dict): + # if "gate_loss" in la: + # gate_loss = gate_loss + la["gate_loss"] + # if "route_hist" in la: + # # route_hist: [E] + # route_hists.append(la["route_hist"].detach()) + + # loss_values.loss = loss_values.loss + moe_lambda * gate_loss + + # # Lightweight logging every metrics interval + # if (self.cf.istep % self.train_log_freq.metrics) == 0: + # # summarize routing (entropy and max-util) + # # summarize routing (entropy and max-util) without stacking, since E can differ per block + # if route_hists: + # entropies = [] + # max_utils = [] + # sizes = [] + # for rh in route_hists: + # p = rh.float() # [E], sums ~1 + # ent = (-(p * p.clamp_min(1e-6).log())).sum() # scalar + # entropies.append(ent.item()) + # max_utils.append(p.max().item()) + # sizes.append(p.numel()) + + # # averages across blocks + # entropy_mean = torch.tensor(entropies, device=self.device).mean().item() + # max_util_mean = torch.tensor(max_utils, device=self.device).mean().item() + + # # optional: quick distribution of expert counts across MoE modules + # # (kept tiny for logging) + # unique_E = sorted(set(sizes)) + # logger.info( + # "[MoE] step=%d | gate_loss=%.4e (λ=%.3g) | blocks=%d | route: entropy=%.3f, max_util=%.3f | E=%s", + # self.cf.istep, + # gate_loss.item(), + # moe_lambda, + # len(route_hists), + # entropy_mean, + # max_util_mean, + # unique_E, + # ) + # else: + # logger.info( + # "[MoE] step=%d | gate_loss=%.4e (λ=%.3g) | blocks=0 (no route_hist yet)", + # self.cf.istep, + # gate_loss.item(), + # moe_lambda, + # ) + + # backward pass + self.optimizer.zero_grad() + self.grad_scaler.scale(loss_values.loss).backward() + # loss_values.loss.backward() + + # gradient clipping + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip) + + # optimizer step + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + # self.optimizer.step() + + # update learning rate + self.lr_scheduler.step() + + # EMA update + if self.validate_with_ema: + self.ema_model.update( + self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, + self.world_size_original * self.cf.batch_size_per_gpu, + ) - # backward pass - self.optimizer.zero_grad() - self.grad_scaler.scale(loss_values.loss).backward() - # loss_values.loss.backward() - - # gradient clipping - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip) - - # optimizer step - self.grad_scaler.step(self.optimizer) - self.grad_scaler.update() - # self.optimizer.step() - - # update learning rate - self.lr_scheduler.step() - - # EMA update - if self.validate_with_ema: - self.ema_model.update( - self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, - self.world_size_original * self.cf.batch_size_per_gpu, - ) - - self.loss_unweighted_hist += [loss_values.losses_all] - self.loss_model_hist += [loss_values.loss.item()] - self.stdev_unweighted_hist += [loss_values.stddev_all] + self.loss_unweighted_hist += [loss_values.losses_all] + self.loss_model_hist += [loss_values.loss.item()] + self.stdev_unweighted_hist += [loss_values.stddev_all] - perf_gpu, perf_mem = self.get_perf() - self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() - self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() + perf_gpu, perf_mem = self.get_perf() + self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() + self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() - self._log_terminal(bidx, epoch, TRAIN) - if bidx % self.train_log_freq.metrics == 0: - self._log(TRAIN) + self._log_terminal(bidx, epoch, TRAIN) + if bidx % self.train_log_freq.metrics == 0: + self._log(TRAIN) - # save model checkpoint (with designation _latest) - if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: - self.save_model(-1) + # save model checkpoint (with designation _latest) + if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: + self.save_model(-1) - self.cf.istep += 1 + self.cf.istep += 1 - self.dataset.advance() + self.dataset.advance() def validate(self, epoch): cf = self.cf