Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mojo_opset/core/operators/moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
up_weight_dtype: Union[torch.dtype, str] = torch.int8,
down_quant_group_size: int = -1,
down_weight_dtype: Union[torch.dtype, str] = torch.int8,
process_group: Optional[dist.ProcessGroup] = None,
**kwargs,
):
super().__init__()
Expand All @@ -105,6 +107,7 @@ def __init__(
self.up_weight_dtype = up_weight_dtype
self.down_quant_group_size = down_quant_group_size
self.down_weight_dtype = down_weight_dtype
self.process_group = process_group

self.gating = MojoMoEGating._registry.get(self._backend)(
hidden_size=self.hidden_size,
Expand All @@ -123,6 +126,7 @@ def __init__(
up_weight_dtype=up_weight_dtype,
down_quant_group_size=down_quant_group_size,
down_weight_dtype=down_weight_dtype,
top_k=self.top_k,
**kwargs,
)
self.combine = MojoMoECombine._registry.get(self._backend)(multiply_by_gates=True, **kwargs)
Expand Down
18 changes: 18 additions & 0 deletions mojo_opset/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,27 @@
from .operators.attention import MojoPagedPrefillSWAWithKVDequant
from .operators.attention import MojoPrefillMLA
from .operators.attention import MojoPrefillNSA
from .operators.attention_gate import MojoFusedAttnGateConcat
from .operators.attention_gate import MojoFusedAttnOutputGate
from .operators.attention import MojoPagedPrefillSageGQA
from .operators.compute_with_comm import MojoFusedAGScaleQuant
from .operators.deepep import MojoDeepEPCombine
from .operators.deepep import MojoDeepEPDispatch
from .operators.gemm import MojoQuantBatchGemmReduceSum
from .operators.indexer import MojoIndexer
from .operators.indexer import MojoLightningIndexer
from .operators.kv_cache import MojoGatherRopeStore
from .operators.kv_cache import MojoPagedAttentionStoreKvCache
from .operators.kv_cache import MojoPagedCacheDequant
from .operators.kv_cache import MojoStorePagedMLAKVCache
from .operators.moe import MojoFusedSwiGLUMoEScaleDynamicQuantize
from .operators.moe import MojoMoEInitRoutingDynamicQuant
from .operators.normalization import MojoChannelRMSNorm
from .operators.normalization import MojoGroupLayerNorm
from .operators.normalization import MojoQKInplaceRMSNorm
from .operators.position_embedding import MojoGridRoPE
from .operators.position_embedding import MojoRelativeEmbedding
from .operators.position_embedding import MojoRotaryEmbedding
from .operators.store_lowrank import MojoStoreLowrank

__all__ = [
Expand All @@ -50,15 +59,24 @@
"MojoPagedDecodeGQAWithKVDequant",
"MojoPagedPrefillSWAWithKVDequant",
"MojoPagedDecodeSWAWithKVDequant",
"MojoFusedAttnGateConcat",
"MojoFusedAttnOutputGate",
"MojoFusedAGScaleQuant",
"MojoPagedPrefillSageGQA",
"MojoGatherRopeStore",
"MojoPagedAttentionStoreKvCache",
"MojoPagedCacheDequant",
"MojoStorePagedMLAKVCache",
"MojoMoEInitRoutingDynamicQuant",
"MojoFusedSwiGLUMoEScaleDynamicQuantize",
"MojoGroupLayerNorm",
"MojoChannelRMSNorm",
"MojoQKInplaceRMSNorm",
"MojoRelativeEmbedding",
"MojoGridRoPE",
"MojoRotaryEmbedding",
"MojoStoreLowrank",
"MojoIndexer",
"MojoDeepEPDispatch",
"MojoDeepEPCombine",
]
14 changes: 14 additions & 0 deletions mojo_opset/experimental/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,31 @@
from .attention import MojoPagedPrefillSWAWithKVDequant
from .attention import MojoPrefillMLA
from .attention import MojoPrefillNSA
from .attention_gate import MojoFusedAttnGateConcat
from .attention_gate import MojoFusedAttnOutputGate
from .compute_with_comm import MojoFusedAGScaleQuant
from .gemm import MojoQuantBatchGemmReduceSum
from .indexer import MojoIndexer
from .indexer import MojoLightningIndexer
from .kv_cache import MojoGatherRopeStore
from .kv_cache import MojoPagedAttentionStoreKvCache
from .kv_cache import MojoPagedCacheDequant
from .kv_cache import MojoStorePagedMLAKVCache
from .moe import MojoFusedSwiGLUMoEScaleDynamicQuantize
from .moe import MojoMoEInitRoutingDynamicQuant
from .normalization import MojoChannelRMSNorm
from .normalization import MojoGroupLayerNorm
from .normalization import MojoQKInplaceRMSNorm
from .position_embedding import MojoGridRoPE
from .position_embedding import MojoRelativeEmbedding
from .position_embedding import MojoRotaryEmbedding
from .store_lowrank import MojoStoreLowrank

__all__ = [
"MojoRotateActivation",
"MojoFusedAttnGateConcat",
"MojoFusedAttnOutputGate",
"MojoFusedAGScaleQuant",
"MojoIndexer",
"MojoLightningIndexer",
"MojoPrefillMLA",
Expand All @@ -41,13 +50,18 @@
"MojoPagedDecodeGQAWithKVDequant",
"MojoPagedPrefillSWAWithKVDequant",
"MojoPagedDecodeSWAWithKVDequant",
"MojoGatherRopeStore",
"MojoPagedAttentionStoreKvCache",
"MojoPagedCacheDequant",
"MojoStorePagedMLAKVCache",
"MojoMoEInitRoutingDynamicQuant",
"MojoFusedSwiGLUMoEScaleDynamicQuantize",
"MojoGroupLayerNorm",
"MojoChannelRMSNorm",
"MojoQKInplaceRMSNorm",
"MojoRelativeEmbedding",
"MojoGridRoPE",
"MojoRotaryEmbedding",
"MojoStoreLowrank",
"MojoQuantBatchGemmReduceSum",
]
63 changes: 63 additions & 0 deletions mojo_opset/experimental/operators/attention_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,66 @@ def extra_repr(self) -> str:
f"head_dim={self.head_dim}, "
f"bias={self.full_gate_bias is not None}"
)


class MojoFusedAttnGateConcat(MojoOperator):
"""Apply full/SWA attention gates and concatenate the gated outputs."""

def __init__(self):
super().__init__()

def forward(
self,
full_attn_out: torch.Tensor,
full_attn_gate_score: torch.Tensor,
swa_attn_out: torch.Tensor,
swa_attn_gate_score: torch.Tensor,
) -> torch.Tensor:
"""
Args:
full_attn_out: Full attention output with shape
``[total_seq, full_head_num, head_dim]``.
full_attn_gate_score: Gate score for full attention with shape
``[total_seq, full_head_num]``.
swa_attn_out: SWA attention output with shape
``[total_seq, swa_head_num, head_dim]``.
swa_attn_gate_score: Gate score for SWA attention with shape
``[total_seq, swa_head_num]``.

Returns:
Gated concatenated attention output with shape
``[total_seq, full_head_num + swa_head_num, head_dim]`` and the
same dtype as ``full_attn_out``.
"""
if full_attn_gate_score is None:
raise ValueError("full_attn_gate_score is required.")
if swa_attn_gate_score is None:
raise ValueError("swa_attn_gate_score is required.")
if full_attn_out.dim() != 3:
raise ValueError(f"full_attn_out must be 3D, got {tuple(full_attn_out.shape)}.")
if swa_attn_out.dim() != 3:
raise ValueError(f"swa_attn_out must be 3D, got {tuple(swa_attn_out.shape)}.")
if full_attn_gate_score.dim() != 2:
raise ValueError(f"full_attn_gate_score must be 2D, got {tuple(full_attn_gate_score.shape)}.")
if swa_attn_gate_score.dim() != 2:
raise ValueError(f"swa_attn_gate_score must be 2D, got {tuple(swa_attn_gate_score.shape)}.")

total_seq, full_head_num, head_dim = full_attn_out.shape
swa_total_seq, swa_head_num, swa_head_dim = swa_attn_out.shape
if swa_total_seq != total_seq or swa_head_dim != head_dim:
raise ValueError(
"full_attn_out and swa_attn_out must have matching total_seq and head_dim, "
f"got {tuple(full_attn_out.shape)} and {tuple(swa_attn_out.shape)}."
)
if full_attn_gate_score.shape != (total_seq, full_head_num):
raise ValueError(
f"full_attn_gate_score must have shape [{total_seq}, {full_head_num}], got {tuple(full_attn_gate_score.shape)}."
)
if swa_attn_gate_score.shape != (total_seq, swa_head_num):
raise ValueError(
f"swa_attn_gate_score must have shape [{total_seq}, {swa_head_num}], got {tuple(swa_attn_gate_score.shape)}."
)

full_out = full_attn_out.float() * torch.sigmoid(full_attn_gate_score.float()).unsqueeze(-1)
swa_out = swa_attn_out.float() * torch.sigmoid(swa_attn_gate_score.float()).unsqueeze(-1)
return torch.cat((full_out, swa_out), dim=1).to(full_attn_out.dtype)
113 changes: 113 additions & 0 deletions mojo_opset/experimental/operators/compute_with_comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.distributed_c10d import _get_default_group

from mojo_opset.core.operator import MojoOperator


def _is_dist_initialized() -> bool:
return dist.is_available() and dist.is_initialized()


class MojoFusedAGScaleQuant(MojoOperator):
def __init__(
self,
*,
team_size: int = 1,
quant_mode: str = "per_token",
norm_mode: str = "none",
eps: float = 1e-5,
max_tokens: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
comm_context=None,
**kwargs,
):
"""
Fused AllGather-scale exchange + optional RMSNorm + per-token int8 quantization.

Args:
team_size (int): Communication team size.
quant_mode (str): Quantization mode. Only ``"per_token"`` is supported.
norm_mode (str): Normalization mode. Supports ``"none"`` and ``"rmsnorm"``.
eps (float): Epsilon for RMSNorm.
max_tokens (Optional[int]): Maximum token count expected by backend
implementations that initialize communication buffers in ``__init__``.
process_group (Optional[ProcessGroup]): Distributed group for the torch reference.
``None`` means the default group.
comm_context: Optional runtime/context object for backend implementations.
"""
super().__init__(**kwargs)
if quant_mode not in ["per_token"]:
raise NotImplementedError(f"quant_mode {quant_mode} not supported")
if norm_mode not in ["none", "rmsnorm"]:
raise NotImplementedError(f"norm_mode {norm_mode} not supported")
if team_size < 1:
raise ValueError(f"team_size must be positive, but got {team_size}")
if max_tokens is not None and max_tokens < 1:
raise ValueError(f"max_tokens must be positive, but got {max_tokens}")

self.team_size = team_size
self.quant_mode = quant_mode
self.norm_mode = norm_mode
self.eps = eps
self.max_tokens = max_tokens
self.process_group = process_group
self.comm_context = comm_context

def _team_max_scale(self, scale: torch.Tensor) -> torch.Tensor:
if self.team_size == 1 or not _is_dist_initialized():
return scale

process_group = self.process_group or _get_default_group()
world_size = dist.get_world_size(group=process_group)
if world_size == 1:
return scale
if world_size != self.team_size:
raise ValueError(f"process group world size must match team_size={self.team_size}, but got {world_size}")

gathered = [torch.empty_like(scale) for _ in range(world_size)]
dist.all_gather(gathered, scale.contiguous(), group=process_group)
return torch.stack(gathered, dim=0).amax(dim=0)

def forward(
self,
input: torch.Tensor,
quant_scale: torch.Tensor,
norm_weight: Optional[torch.Tensor] = None,
):
if input.dim() not in [3, 4]:
raise ValueError(f"input must be 3-D or 4-D, but got dim={input.dim()}")

head_num = input.shape[-2]
head_dim = input.shape[-1]
hidden_size = head_num * head_dim
if quant_scale.numel() != hidden_size:
raise ValueError(f"quant_scale numel must be {hidden_size}, but got {quant_scale.numel()}")
if self.norm_mode == "rmsnorm" and norm_weight is not None and norm_weight.numel() != head_dim:
raise ValueError(f"norm_weight numel must be {head_dim}, but got {norm_weight.numel()}")

input_fp = input.float()
if self.norm_mode == "rmsnorm":
weight = norm_weight.float() if norm_weight is not None else None
input_fp = F.rms_norm(input_fp, (head_dim,), weight=weight, eps=self.eps)

rows = input_fp.numel() // hidden_size
if self.max_tokens is not None and rows > self.max_tokens:
raise ValueError(f"input token count {rows} exceeds max_tokens={self.max_tokens}")
scaled = input_fp.reshape(rows, hidden_size) * quant_scale.float().reshape(1, hidden_size)
scale = scaled.abs().amax(dim=-1).clamp(min=1e-12) / 127
scale = self._team_max_scale(scale)
quantized = torch.clamp(torch.round(scaled / scale.unsqueeze(-1)), -128, 127).to(torch.int8)

return quantized, scale

def extra_repr(self) -> str:
return (
f"{self.team_size=}, {self.quant_mode=}, {self.norm_mode=}, {self.eps=}, {self.max_tokens=}"
).replace("self.", "")


__all__ = ["MojoFusedAGScaleQuant"]
Loading