diff --git a/mlx_lm/models/cohere2_moe.py b/mlx_lm/models/cohere2_moe.py new file mode 100644 index 000000000..3850c6247 --- /dev/null +++ b/mlx_lm/models/cohere2_moe.py @@ -0,0 +1,350 @@ +# Copyright © 2026 Apple Inc. + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import KVCache, RotatingKVCache +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int = 1024 + head_dim: int = 128 + num_hidden_layers: int = 36 + intermediate_size: int = 1024 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + rope_theta: float = 50000.0 + vocab_size: int = 256000 + layer_norm_eps: float = 1e-05 + logit_scale: float = 0.0625 + attention_bias: bool = False + layer_norm_bias: bool = False + sliding_window: int = 4096 + sliding_window_pattern: int = 4 + num_experts: int = 128 + num_experts_per_tok: int = 8 + norm_topk_prob: bool = True + num_shared_experts: Optional[int] = None + moe_num_shared_experts: int = 4 + moe_gate_act: str = "sigmoid" + expert_selection_fn: Optional[str] = None + shared_expert_combination_strategy: str = "average" + rms_norm_eps: Optional[float] = None + first_k_dense_replace: int = 0 + prefix_dense_intermediate_size: Optional[int] = None + prefix_dense_sliding_window_pattern: int = 1 + layer_types: Optional[List[str]] = None + + def __post_init__(self): + if self.num_shared_experts is not None: + self.moe_num_shared_experts = self.num_shared_experts + if self.expert_selection_fn is not None: + self.moe_gate_act = self.expert_selection_fn + if self.prefix_dense_intermediate_size is None: + self.prefix_dense_intermediate_size = self.intermediate_size + + +def is_prefix_dense_layer(args: ModelArgs, layer_idx: int): + return layer_idx < args.first_k_dense_replace + + +def is_sliding_layer(args: ModelArgs, layer_idx: int): + if is_prefix_dense_layer(args, layer_idx): + return False + if args.layer_types is not None: + return args.layer_types[layer_idx] == "sliding_attention" + return (layer_idx + 1) % args.sliding_window_pattern != 0 + + +def norm_layer(args: ModelArgs): + if args.rms_norm_eps is not None: + return nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + return nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.args = args + self.layer_idx = layer_idx + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim + self.scale = head_dim**-0.5 + + attetion_bias = args.attention_bias + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) + + self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) + + self.use_sliding_window = is_sliding_layer(args, layer_idx) + self.force_rope = ( + is_prefix_dense_layer(args, layer_idx) + and args.prefix_dense_sliding_window_pattern == 1 + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + # Cohere2Moe applies RoPE to sliding layers and optionally to prefix + # dense full-attention layers. + if self.use_sliding_window or self.force_rope: + if cache is None: + queries = self.rope(queries) + keys = self.rope(keys) + else: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + sdpa_type = mx.float32 if queries.dtype == mx.float16 else queries.dtype + output = scaled_dot_product_attention( + queries.astype(sdpa_type), + keys, + values, + cache=cache, + scale=self.scale, + mask=mask, + ).astype(queries.dtype) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x): + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class CohereMoeSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.intermediate_size + + self.num_experts = num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + self.norm_topk_prob = args.norm_topk_prob + + self.gate = nn.Linear(dim, num_experts, bias=False) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + if getattr(args, "moe_num_shared_experts", 0) > 0: + shared_intermediate_size = ( + args.intermediate_size * args.moe_num_shared_experts + ) + self.shared_experts = MLP( + args.hidden_size, shared_intermediate_size, + ) + self.shared_expert_combination_strategy = \ + args.shared_expert_combination_strategy + assert self.shared_expert_combination_strategy in [ + "average", "sum" + ], "shared_expert_combination_strategy " + "must be one of ['average', 'sum']" + else: + self.shared_experts = None + self.shared_expert_combination_strategy = None + + if args.moe_gate_act == "softmax": + self.gate_act = nn.Softmax() + elif args.moe_gate_act == "sigmoid": + self.gate_act = nn.Sigmoid() + else: + raise ValueError(f"{args.moe_gate_act} is not supported.") + + def __call__( + self, + x: mx.array, + ): + gates = self.gate(x) + gates = self.gate_act(gates.astype(mx.float32)) + + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) + if self.norm_topk_prob: + scores = scores / mx.maximum(scores.sum(axis=-1, keepdims=True), 1e-12) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) + + if self.shared_experts is not None: + if self.shared_expert_combination_strategy == "average": + y = (y + self.shared_experts(x)) / 2 + else: + y = y + self.shared_experts(x) + + return y + + +class CohereMoEDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + self.n_heads = args.num_attention_heads + + self.self_attn = Attention(args, layer_idx) + self.mlp = ( + MLP(args.hidden_size, args.prefix_dense_intermediate_size) + if is_prefix_dense_layer(args, layer_idx) + else CohereMoeSparseMoeBlock(args) + ) + self.input_layernorm = norm_layer(args) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + + h = self.input_layernorm(x) + attn_h = self.self_attn(h, mask, cache) + ff_h = self.mlp(h) + + return attn_h + ff_h + x + + +class CohereModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.window_size = args.sliding_window + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + CohereMoEDecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = norm_layer(args) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + mask = create_attention_mask( + h, + c, + window_size=( + self.window_size if layer.self_attn.use_sliding_window else None + ), + ) + + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = CohereModel(args) + self.args = args + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + out = self.model.embed_tokens.as_linear(out) + out = out * self.model.args.logit_scale + return out + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if is_sliding_layer(self.args, i): + caches.append( + RotatingKVCache(max_size=self.args.sliding_window, keep=0) + ) + else: + caches.append(KVCache()) + return caches + + def sanitize(self, weights): + for l in range(self.args.num_hidden_layers): + if is_prefix_dense_layer(self.args, l): + continue + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + + for key in list(weights.keys()): + if "rotary_emb.inv_freq" in key: + weights.pop(key) + elif key.endswith(".bias"): + if ".mlp." in key: + weights.pop(key) + elif ".self_attn." in key and not self.args.attention_bias: + weights.pop(key) + elif "layernorm" in key.lower() and not self.args.layer_norm_bias: + weights.pop(key) + + return weights + + @property + def quant_predicate(self): + def predicate(path, module): + if ".self_attn." in path: + return False + if ".mlp.gate" in path and "gate_proj" not in path: + return False + return True + + return predicate + + @property + def layers(self): + return self.model.layers diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index ef3d266b9..a3d03579f 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -172,6 +172,169 @@ def _transform_awq_weights( return new_weights, mlx_quantization +def _e4m3_decode_table() -> mx.array: + """Return a 256-entry ``float32`` LUT mapping every E4M3FN byte to its value. + + OCP ``E4M3FN``: 1 sign / 4 exponent (bias 7) / 3 mantissa bits, no infinities, and + ``(exp=15, mant=7)`` reserved for NaN (max finite magnitude is therefore 448). This + matches the byte convention MLX uses for ``nvfp4`` scales (verified empirically). + """ + table = [] + for b in range(256): + s = (b >> 7) & 1 + e = (b >> 3) & 0xF + m = b & 0x7 + if e == 0: + val = (m / 8.0) * 2.0**-6 # subnormal + elif e == 15 and m == 7: + val = float("nan") + else: + val = (1.0 + m / 8.0) * 2.0 ** (e - 7) + table.append(-val if s else val) + return mx.array(table, dtype=mx.float32) + + +# Built once; reused by every NVFP4 fold. +_E4M3_DECODE_LUT = _e4m3_decode_table() + + +def _f32_to_e4m3(x: mx.array) -> mx.array: + """Encode non-negative ``float32`` values to ``E4M3FN`` bytes (round-to-nearest-even). + + Pure-MLX bit manipulation (MLX exposes no float8 dtype). Saturates to 448 on + overflow and flushes to the subnormal grid / zero on underflow. Inputs are assumed + ``>= 0`` (NVFP4 group scales are magnitudes), so the sign bit is always 0. + """ + x = mx.maximum(x.astype(mx.float32), 0.0) + bits = x.view(mx.uint32) + fexp = (bits >> 23) & 0xFF # fp32 exponent, bias 127 + fman = bits & 0x7FFFFF # fp32 mantissa, 23 bits + + # Normal path: target E4M3 biased exponent e = (fexp - 127) + 7. + e = fexp.astype(mx.int32) - 120 + drop = 20 # 23 -> 3 mantissa bits + round_bit = (fman >> (drop - 1)) & 1 + sticky = (fman & ((1 << (drop - 1)) - 1)) != 0 + mant3 = fman >> drop + roundup = round_bit & (sticky.astype(mx.uint32) | (mant3 & 1)) + mant3 = mant3 + roundup + carry = mant3 >> 3 # mantissa overflowed past 7 -> bump exponent + mant3 = mant3 & 0x7 + e = e + carry.astype(mx.int32) + # Saturate: e > 15, or the NaN slot (e == 15, mant == 7), clamps to 448 (e15 m6). + over = (e > 15) | ((e == 15) & (mant3 == 7)) + e = mx.where(over, mx.array(15, mx.int32), e) + mant3 = mx.where(over, mx.array(6, mx.uint32), mant3) + normal_byte = (e.astype(mx.uint32) << 3) | mant3 + normal_valid = e >= 1 + + # Subnormal path: value = m * 2^-9, so m = round(x * 512) (RNE). m == 8 lands + # exactly on the smallest normal (0x08 = e1 m0 = 2^-6), so it is encoded correctly. + sub = x * 512.0 + sub_floor = mx.floor(sub) + frac = sub - sub_floor + sf = sub_floor.astype(mx.uint32) + up = (frac > 0.5) | ((frac == 0.5) & ((sf & 1) == 1)) + sub_byte = sf + up.astype(mx.uint32) + + byte = mx.where(normal_valid, normal_byte, sub_byte) + return byte.astype(mx.uint8) + + +def _transform_compressed_tensors_nvfp4_weights( + weights: Dict[str, mx.array], + quantization_config: Dict[str, Any], +) -> Dict[str, mx.array]: + """Fold compressed-tensors NVFP4 weights into MLX-native ``nvfp4`` weights. + + A ``nvfp4-pack-quantized`` checkpoint stores, per quantized Linear: + + - ``

.weight_packed`` ``uint8`` ``[out, in // 2]`` (2x E2M1 per byte) + - ``

.weight_scale`` ``uint8`` ``[out, in // 16]`` (E4M3 per group of 16, + loaded by ``mx.load`` as raw bytes -- the same byte layout MLX uses for nvfp4 scales) + - ``

.weight_global_scale`` ``float32`` ``[1]`` (per-tensor; the real weight is + ``fp4 * weight_scale / weight_global_scale``) + + MLX ``nvfp4`` ``QuantizedLinear`` expects ``

.weight`` (``uint32``) plus + ``

.scales`` (``uint8`` E4M3) and is single-level: the per-tensor global scale is + not representable (and is rejected on the Metal backend). Both decodes are linear in + the FP4 codes, so the global scale can be folded directly into the per-group E4M3 + scales: ``scale_mlx = E4M3(weight_scale / global_scale)``. We keep the original packed + E2M1 codes bit-exact (only the per-group scales are re-encoded once), avoiding the + weight dequantize/re-quantize round-trip entirely. + """ + packed_suffix = ".weight_packed" + + new_weights = {} + for key in list(weights.keys()): + if key.endswith(packed_suffix): + prefix = key[: -len(packed_suffix)] + packed = weights[key] + scale = weights[f"{prefix}.weight_scale"] + global_scale = weights[f"{prefix}.weight_global_scale"].astype(mx.float32) + + # weight_packed is uint8 [out, in//2]; reinterpret as uint32 [out, in//8] + # to match MLX's nvfp4 weight layout (bit-identical, no data movement). + new_weights[f"{prefix}.weight"] = packed.view(mx.uint32) + + # Fold the per-tensor global scale into the per-group E4M3 scales: + # decode E4M3 -> divide by global scale -> re-encode E4M3. The FP4 codes are + # untouched, so this only re-rounds the (much smaller) scale tensor once. + decoded = _E4M3_DECODE_LUT[scale.astype(mx.uint32)] + new_weights[f"{prefix}.scales"] = _f32_to_e4m3(decoded / global_scale) + elif key.endswith(".weight_scale") or key.endswith(".weight_global_scale"): + # Consumed alongside their ``.weight_packed``. + continue + else: + new_weights[key] = weights[key] + + return new_weights + + +def _transform_compressed_tensors_int4_weights( + weights: Dict[str, mx.array], + quantization_config: Dict[str, Any], +) -> Dict[str, mx.array]: + """Remap compressed-tensors INT4 ``pack-quantized`` weights to MLX affine weights. + + A symmetric int4 ``pack-quantized`` checkpoint stores, per quantized Linear: + + - ``

.weight_packed`` ``int32`` ``[out, in // 8]`` (8x int4 per word, LSB-first) + - ``

.weight_scale`` (``bf16``/``float``) ``[out, in // group_size]`` + - ``

.weight_shape`` ``int64`` ``[2]`` (unused by MLX) + + MLX affine ``QuantizedLinear`` uses the same int4 packing and dequantizes as + ``w * scale + bias``. Symmetric int4 stores values in ``[0, 15]`` representing + ``[-8, 7]``, i.e. ``value = packed - 8``, so we set ``bias = -8 * scale``. The packed + ``int32`` is bit-identical to MLX's ``uint32`` layout (reinterpreted via ``view``). + This is the model-agnostic version of the int4 remap in ``deepseek_v3.sanitize``. + """ + weights_cfg = ( + quantization_config.get("config_groups", {}) + .get("group_0", {}) + .get("weights", {}) + ) + group_size = weights_cfg.get("group_size", 32) + bits = weights_cfg.get("num_bits", 4) + packed_suffix = ".weight_packed" + + new_weights = {} + for key in list(weights.keys()): + if key.endswith(packed_suffix): + prefix = key[: -len(packed_suffix)] + scale = weights[f"{prefix}.weight_scale"] + new_weights[f"{prefix}.weight"] = weights[key].view(mx.uint32) + new_weights[f"{prefix}.scales"] = scale + new_weights[f"{prefix}.biases"] = -(2 ** (bits - 1)) * scale + elif key.endswith(".weight_scale") or key.endswith(".weight_shape"): + # Consumed alongside their ``.weight_packed`` (shape is unused by MLX). + continue + else: + new_weights[key] = weights[key] + + return new_weights, {"group_size": group_size, "bits": bits, "mode": "affine"} + + def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. @@ -342,6 +505,31 @@ def load_model( model = model_class(model_args) + # Transform compressed-tensors weights into MLX-native form before sanitize() so that + # per-expert (MoE) tensors are renamed to .weight/.scales(/.biases) and then stacked + # correctly. Skip if a model already remapped them itself (no .weight_packed left). + if ( + (qc := config.get("quantization_config")) + and isinstance(qc, dict) + and qc.get("quant_method") == "compressed-tensors" + and any(k.endswith(".weight_packed") for k in weights) + ): + ct_format = qc.get("format") + if ct_format == "nvfp4-pack-quantized": + weights = _transform_compressed_tensors_nvfp4_weights(weights, qc) + quantization = {"group_size": 16, "bits": 4, "mode": "nvfp4"} + elif ct_format == "pack-quantized": + weights, quantization = _transform_compressed_tensors_int4_weights( + weights, qc + ) + else: + raise ValueError( + f"Unsupported compressed-tensors format: {ct_format!r}. " + "Supported: 'nvfp4-pack-quantized', 'pack-quantized'." + ) + config["quantization"] = quantization + config["quantization_config"] = quantization + if hasattr(model, "sanitize"): weights = model.sanitize(weights) diff --git a/tests/test_models.py b/tests/test_models.py index 6e1fcd96e..75d74311b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1129,6 +1129,32 @@ def test_cohere(self): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_cohere2_moe(self): + from mlx_lm.models import cohere2_moe + + args = cohere2_moe.ModelArgs( + model_type="cohere2_moe", + hidden_size=64, + head_dim=16, + num_hidden_layers=4, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=2, + vocab_size=1000, + sliding_window=4, + sliding_window_pattern=3, + num_experts=4, + num_experts_per_tok=2, + moe_num_shared_experts=0, + first_k_dense_replace=1, + prefix_dense_intermediate_size=96, + rms_norm_eps=1e-5, + ) + model = cohere2_moe.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_dbrx(self): from mlx_lm.models import dbrx