Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_blocks: 2
ae_global_num_heads: 32
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
Expand All @@ -42,12 +42,12 @@ pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_offset : 1
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_steps: 1
forecast_policy: "fixed"
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_blocks: 2
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
Expand Down Expand Up @@ -93,7 +93,7 @@ ema_halflife_in_thousands: 1e-3

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
training_mode: "forecast"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
Expand Down Expand Up @@ -160,3 +160,36 @@ train_log_freq:
terminal: 10
metrics: 20
checkpoint: 250

# Parameters for logging/printing in the training loop
train_log:
# The period to log metrics (in number of batch steps)
log_interval: 20

# Forecast MLP type: "dense" (default) or "moe"
fe_mlp_type: "dense" # set to "moe" to enable MoE
ae_global_mlp_type: "dense" # set to "moe" to enable MoE
ffn_mlp_type: "dense" # set to "moe" to enable MoE in the feed-forward network of the decoder blocks
decoder_mlp_type: "dense" # set to "moe" to enable MoE in the decoder prediction MLP
moe_lambda: 0.02 # coefficient for the MoE load balancing loss

# MoE-only params (ignored when fe_mlp_type != "moe")
fe_moe_num_experts: 2
fe_moe_top_k: 1
fe_moe_hidden_factor: 0.5 # = HF_dense / 4

# MoE-only params (ignored when ae_global_mlp_type != "moe")
ae_global_moe_num_experts: 4
ae_global_moe_top_k: 2
ae_global_moe_hidden_factor: 0.5 # = HF_dense / 4

# MoE-only params (ignored when ffn_mlp_type != "moe")
ffn_moe_num_experts: 2
ffn_moe_top_k: 1
ffn_moe_hidden_factor: 0.5 # = HF_dense / 4

# MoE-only params (ignored when decoder_mlp_type != "moe")
decoder_moe_num_experts: 2
decoder_moe_top_k: 1
decoder_moe_hidden_factor: 0.5 # = HF_dense / 4
tr_mlp_hidden_factor: 2
125 changes: 96 additions & 29 deletions src/weathergen/model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
MultiCrossAttentionHeadVarlen,
MultiSelfAttentionHeadVarlen,
)
from weathergen.model.layers import MLP
from weathergen.model.layers import MLP, MoEMLP
from weathergen.model.norms import AdaLayerNormLayer
from weathergen.utils.utils import get_dtype

import logging
logger = logging.getLogger(__name__)

class SelfAttentionBlock(nn.Module):
"""
Expand All @@ -43,14 +45,32 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs
self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x

approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = MLP(
dim_in=dim,
dim_out=dim,
hidden_factor=4,
dropout_rate=0.1,
nonlin=approx_gelu,
with_residual=False,
)
use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe")
ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4)
moe_kwargs = kwargs.get("moe_kwargs", {}) # e.g. num_experts, top_k, router_noisy_std

if use_moe_ffn:
self.mlp = MoEMLP(
dim_in=dim,
dim_out=dim,
hidden_factor=ffn_hidden_factor,
dropout_rate=0.1,
nonlin=nn.GELU, # internal block constructs nonlin()
with_residual=False,
norm_type=kwargs["attention_kwargs"]["norm_type"],
dim_aux=(dim_aux if self.with_adanorm else None),
norm_eps=kwargs["attention_kwargs"]["norm_eps"],
**moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01
)
else:
self.mlp = MLP(
dim_in=dim,
dim_out=dim,
hidden_factor=4,
dropout_rate=0.1,
nonlin=approx_gelu,
with_residual=False,
)
if self.with_adanorm:
self.mlp_fn = lambda x, **kwargs: self.mlp(x)
self.mlp_block = AdaLayerNormLayer(dim, dim_aux, self.mlp_fn, dropout_rate)
Expand Down Expand Up @@ -104,7 +124,7 @@ def __init__(

self.with_adanorm = with_adanorm
self.with_self_attn = with_self_attn
self.with_mlp = with_self_attn
self.with_mlp = with_mlp

if with_self_attn:
self.mhsa = MultiSelfAttentionHeadVarlen(
Expand Down Expand Up @@ -136,18 +156,37 @@ def __init__(

if self.with_mlp:
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = MLP(
dim_in=dim_q,
dim_out=dim_q,
hidden_factor=4,
nonlin=approx_gelu,
with_residual=False,
)

use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe")
ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4)
moe_kwargs = kwargs.get("moe_kwargs", {})

if use_moe_ffn:
self.mlp = MoEMLP(
dim_in=dim_q,
dim_out=dim_q,
hidden_factor=ffn_hidden_factor,
dropout_rate=0.1,
nonlin=nn.GELU, # internal block constructs nonlin()
with_residual=False,
norm_type=kwargs["attention_kwargs"]["norm_type"],
dim_aux=(dim_aux if self.with_adanorm else None),
norm_eps=kwargs["attention_kwargs"]["norm_eps"],
**moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01
)
else:
self.mlp = MLP(
dim_in=dim_q,
dim_out=dim_q,
hidden_factor=4,
nonlin=approx_gelu,
with_residual=False,
)
if self.with_adanorm:
self.mlp_fn = lambda x, **kwargs: self.mlp(x)
self.mlp_block = AdaLayerNormLayer(dim_q, dim_aux, self.mlp_fn, dropout_rate)
else:
self.ln_mlp = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"])
self.ln_mlp = nn.LayerNorm(eps=kwargs["attention_kwargs"]["norm_eps"])
self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x)) + x
else:
self.mlp_block = lambda x, _, **kwargs: x
Expand Down Expand Up @@ -191,6 +230,7 @@ def __init__(
tr_mlp_hidden_factor,
tro_type,
mlp_norm_eps=1e-6,
**kwargs,
):
super().__init__()

Expand Down Expand Up @@ -237,19 +277,46 @@ def __init__(
)

# MLP Block
self.block.append(
MLP(
dim_in,
dim_out,
with_residual=True,
hidden_factor=self.tr_mlp_hidden_factor,
dropout_rate=0.1, # Assuming dropout_rate is 0.1
norm_type=self.cf.norm_type,
dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None),
norm_eps=self.cf.mlp_norm_eps,
)
# Add MoE option
use_moe = getattr(self.cf, "decoder_mlp_type", "dense") == "moe"
logger.info(
"[MoE] Decoder head: type=%s%s",
"moe" if use_moe else "dense",
("" if not use_moe else
f" (experts={getattr(self.cf,'moe_num_experts',None)}, top_k={getattr(self.cf,'moe_top_k',None)})"),
)

if use_moe:
self.block.append(
MoEMLP(
dim_in,
dim_out,
hidden_factor=self.tr_mlp_hidden_factor,
dropout_rate=0.1,
with_residual=True, # mirror dense
norm_type=self.cf.norm_type,
dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None),
norm_eps=self.cf.mlp_norm_eps,
num_experts=getattr(self.cf, "moe_num_experts", 8),
top_k=getattr(self.cf, "moe_top_k", 2),
router_noisy_std=getattr(self.cf, "moe_router_noisy_std", 0.0),
use_checkpoint=getattr(self.cf, "moe_use_checkpoint", False),
)
)
else:
self.block.append(
MLP(
dim_in,
dim_out,
with_residual=True,
hidden_factor=self.tr_mlp_hidden_factor,
dropout_rate=0.1, # Assuming dropout_rate is 0.1
norm_type=self.cf.norm_type,
dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None),
norm_eps=self.cf.mlp_norm_eps,
)
)

def forward(self, latent, output, coords, latent_lens, output_lens):
for layer in self.block:
if isinstance(layer, MultiCrossAttentionHeadVarlen):
Expand Down
Loading