Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/weathergen/model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
super(StreamEmbedTransformer, self).__init__()

self.name = f"StreamEmbedder_{stream_name}"

self.mode = mode
self.num_tokens = num_tokens
self.token_size = token_size
self.num_channels = num_channels
Expand Down Expand Up @@ -113,9 +113,8 @@ def __init__(
)

else:
assert False
raise ValueError(f"Unknown unembed mode: {unembed_mode}")

self.forward = self.forward_channels

elif mode == "columns":
assert embed_size_centroids == 0
Expand All @@ -130,7 +129,6 @@ def __init__(
self.num_tokens * ((self.dim_out - embed_size_centroids) // token_size),
)
self.ln_final = norm(dim_out, eps=1e-6)
self.forward = self.forward_columns

# TODO: factorization when sqrt is not int
dim1 = int(np.sqrt(dim_out))
Expand All @@ -140,7 +138,7 @@ def __init__(
self.unembed2 = torch.nn.Linear(self.token_size, dim1)

else:
assert False
raise ValueError(f"Unknown mode: {mode}")

self.dropout_final = torch.nn.Dropout(0.1)
self.embed_centroids = torch.nn.Linear(5, embed_size_centroids)
Expand All @@ -164,7 +162,7 @@ def forward_channels(self, x_in, centroids):
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
assert False
raise ValueError(f"Unknown unembed mode: {self.unembed_mode}")

# append centroids
if self.embed_size_centroids > 0:
Expand Down Expand Up @@ -195,6 +193,13 @@ def forward_columns(self, x_in, centroids):

return out.to(torch.float16)

def forward(self, x_in, centroids):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice clean up

if self.mode == "channels":
return self.forward_channels(x_in, centroids)
elif self.mode == "columns":
return self.forward_columns(x_in, centroids)
else:
raise ValueError(f"Unknown mode {self.mode}")

class StreamEmbedLinear(torch.nn.Module):
def __init__(self, dim_in, dim_out, stream_name="stream_embed"):
Expand Down
123 changes: 82 additions & 41 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from weathergen.utils.utils import get_dtype


class EmbeddingEngine:
class EmbeddingEngine(torch.nn.Module):
name: "EmbeddingEngine"

def __init__(self, cf: Config, sources_size) -> None:
Expand All @@ -39,16 +39,11 @@ def __init__(self, cf: Config, sources_size) -> None:
:param cf: Configuration object containing parameters for the engine.
:param sources_size: List of source sizes for each stream.
"""
super(EmbeddingEngine, self).__init__()
self.cf = cf
self.sources_size = sources_size # KCT:iss130, what is this?
self.embeds = torch.nn.ModuleList()

def create(self) -> torch.nn.ModuleList:
"""
Creates and returns the module list (embeds).

:return: torch.nn.ModuleList containing the embedding layers.
"""
for i, si in enumerate(self.cf.streams):
stream_name = si.get("name", i)

Expand Down Expand Up @@ -84,10 +79,53 @@ def create(self) -> torch.nn.ModuleList:
)
else:
raise ValueError("Unsupported embedding network type")
return self.embeds


def forward(self, streams_data, pe_embed, dtype, device):
source_tokens_lens = torch.stack(
[
torch.stack(
[
s.source_tokens_lens if len(s.source_tokens_lens) > 0 else torch.tensor([])
for s in stl_b
]
)
for stl_b in streams_data
]
)
offsets_base = source_tokens_lens.sum(1).sum(0).cumsum(0)

class LocalAssimilationEngine:
tokens_all = torch.empty(
(int(offsets_base[-1]), self.cf.ae_local_dim_embed), dtype=dtype, device=device
)

for _, sb in enumerate(streams_data):
for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)):
if not s.source_empty():
idxs = s.source_idxs_embed.to(device)
idxs_pe = s.source_idxs_embed_pe.to(device)

# create full scatter index
# (there's no broadcasting which is likely highly inefficient)
idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed))
x_embed = embed(s.source_tokens_cells, s.source_centroids).flatten(0, 1)
# there's undocumented limitation in flash_attn that will make embed fail if
# #tokens is too large; code below is a work around
# x_embed = torch.cat(
# [
# embed(s_c, c_c).flatten(0, 1)
# for s_c, c_c in zip(
# torch.split(s.source_tokens_cells, 49152),
# torch.split(s.source_centroids, 49152),
# )
# ]
# )

# scatter write to reorder from per stream to per cell ordering
tokens_all.scatter_(0, idxs, x_embed + pe_embed[idxs_pe])
return tokens_all


class LocalAssimilationEngine(torch.nn.Module):
name: "LocalAssimilationEngine"

def __init__(self, cf: Config) -> None:
Expand All @@ -96,15 +134,10 @@ def __init__(self, cf: Config) -> None:

:param cf: Configuration object containing parameters for the engine.
"""
super(LocalAssimilationEngine, self).__init__()
self.cf = cf
self.ae_local_blocks = torch.nn.ModuleList()

def create(self) -> torch.nn.ModuleList:
"""
Creates and returns the module list (ae_local_blocks).

:return: torch.nn.ModuleList containing the local assimilation blocks.
"""
for _ in range(self.cf.ae_local_num_blocks):
self.ae_local_blocks.append(
MultiSelfAttentionHeadVarlen(
Expand All @@ -128,10 +161,14 @@ def create(self) -> torch.nn.ModuleList:
norm_eps=self.cf.mlp_norm_eps,
)
)
return self.ae_local_blocks

def forward(self, tokens_c, cell_lens_c, use_reentrant):
for block in self.ae_local_blocks:
tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=use_reentrant)
return tokens_c


class Local2GlobalAssimilationEngine:
class Local2GlobalAssimilationEngine(torch.nn.Module):
name: "Local2GlobalAssimilationEngine"

def __init__(self, cf: Config) -> None:
Expand All @@ -140,15 +177,10 @@ def __init__(self, cf: Config) -> None:

:param cf: Configuration object containing parameters for the engine.
"""
super(Local2GlobalAssimilationEngine, self).__init__()
self.cf = cf
self.ae_adapter = torch.nn.ModuleList()

def create(self) -> torch.nn.ModuleList:
"""
Creates and returns the module list (ae_adapter).

:return: torch.nn.ModuleList containing the local-to-global assimilation adapter blocks.
"""
self.ae_adapter.append(
MultiCrossAttentionHeadVarlenSlicedQ(
self.cf.ae_global_dim_embed,
Expand Down Expand Up @@ -191,10 +223,21 @@ def create(self) -> torch.nn.ModuleList:
attention_dtype=get_dtype(self.cf.attention_dtype),
)
)
return self.ae_adapter

def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant):
for block in self.ae_adapter:
tokens_global_c = checkpoint(
block,
tokens_global_c,
tokens_c,
q_cells_lens_c,
cell_lens_c,
use_reentrant=use_reentrant,
)
return tokens_global_c


class GlobalAssimilationEngine:
class GlobalAssimilationEngine(torch.nn.Module):
name: "GlobalAssimilationEngine"

def __init__(self, cf: Config, num_healpix_cells: int) -> None:
Expand All @@ -204,17 +247,12 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
:param cf: Configuration object containing parameters for the engine.
:param num_healpix_cells: Number of healpix cells used for local queries.
"""
super(GlobalAssimilationEngine, self).__init__()
self.cf = cf
self.num_healpix_cells = num_healpix_cells

self.ae_global_blocks = torch.nn.ModuleList()

def create(self) -> torch.nn.ModuleList:
"""
Creates and returns the module list (ae_global_blocks).

:return: torch.nn.ModuleList containing the global assimilation blocks.
"""
global_rate = int(1 / self.cf.ae_global_att_dense_rate)
for i in range(self.cf.ae_global_num_blocks):
## Alternate between local and global attention
Expand Down Expand Up @@ -260,10 +298,14 @@ def create(self) -> torch.nn.ModuleList:
norm_eps=self.cf.mlp_norm_eps,
)
)
return self.ae_global_blocks

def forward(self, tokens, use_reentrant):
for block in self.ae_global_blocks:
tokens = checkpoint(block, tokens, use_reentrant=use_reentrant)
return tokens


class ForecastingEngine:
class ForecastingEngine(torch.nn.Module):
name: "ForecastingEngine"

def __init__(self, cf: Config, num_healpix_cells: int) -> None:
Expand All @@ -273,16 +315,11 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
:param cf: Configuration object containing parameters for the engine.
:param num_healpix_cells: Number of healpix cells used for local queries.
"""
super(ForecastingEngine, self).__init__()
self.cf = cf
self.num_healpix_cells = num_healpix_cells
self.fe_blocks = torch.nn.ModuleList()

def create(self) -> torch.nn.ModuleList:
"""
Creates and returns the module list (fe_blocks).

:return: torch.nn.ModuleList containing the forecasting blocks.
"""
global_rate = int(1 / self.cf.forecast_att_dense_rate)
if self.cf.forecast_policy is not None:
for i in range(self.cf.fe_num_blocks):
Expand Down Expand Up @@ -339,7 +376,11 @@ def init_weights_final(m):
for block in self.fe_blocks:
block.apply(init_weights_final)

return self.fe_blocks
def forward(self, tokens, use_reentrant):
for it, block in enumerate(self.fe_blocks):
aux_info = torch.tensor([it], dtype=torch.float32, device="cuda")
tokens = checkpoint(block, tokens, aux_info, use_reentrant=use_reentrant)
return tokens


class EnsPredictionHead(torch.nn.Module):
Expand Down
Loading