diff --git a/mlx_lm/models/mellum.py b/mlx_lm/models/mellum.py new file mode 100644 index 000000000..c80736aa7 --- /dev/null +++ b/mlx_lm/models/mellum.py @@ -0,0 +1,264 @@ +# Copyright © 2026 Apple Inc. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn + +from .activations import swiglu +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import KVCache, RotatingKVCache +from .rope_utils import initialize_rope +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_experts: int + num_experts_per_tok: int + moe_intermediate_size: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int + head_dim: int + tie_word_embeddings: bool + max_position_embeddings: int + norm_topk_prob: bool + sliding_window: int + layer_types: List[str] + rope_parameters: Dict[str, Any] = field(default_factory=dict) + + +def _rope_for(layer_type: str, args: ModelArgs): + params = args.rope_parameters[layer_type] + base = params["rope_theta"] + rope_type = params.get("rope_type", "default") + if rope_type in ("default", "linear"): + return initialize_rope(args.head_dim, base=base, traditional=False) + scaling_config = dict(params) + scaling_config["type"] = rope_type + return initialize_rope( + args.head_dim, + base=base, + traditional=False, + scaling_config=scaling_config, + max_position_embeddings=args.max_position_embeddings, + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + head_dim = args.head_dim + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps) + + self.rope = _rope_for(args.layer_types[layer_idx], args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose( + 0, 2, 1, 3 + ) + keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MellumSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + self.num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + self.norm_topk_prob = args.norm_topk_prob + + self.gate = nn.Linear(dim, self.num_experts, bias=False) + self.switch_mlp = SwitchGLU(dim, args.moe_intermediate_size, self.num_experts) + + def __call__(self, x: mx.array) -> mx.array: + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:] + scores = mx.take_along_axis(gates, inds, axis=-1) + if self.norm_topk_prob: + scores /= mx.sum(scores, axis=-1, keepdims=True) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + return y + + +class MellumDecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.self_attn = Attention(args, layer_idx) + self.mlp = MellumSparseMoeBlock(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + return h + r + + +class MellumModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + MellumDecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + self._first_full = next( + i for i, t in enumerate(args.layer_types) if t == "full_attention" + ) + self._first_sliding = next( + (i for i, t in enumerate(args.layer_types) if t == "sliding_attention"), + None, + ) + + def __call__( + self, + inputs: mx.array, + cache=None, + input_embeddings: Optional[mx.array] = None, + ) -> mx.array: + if input_embeddings is not None: + h = input_embeddings + else: + h = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + full_mask = create_attention_mask(h, cache[self._first_full]) + if self._first_sliding is not None: + sliding_mask = create_attention_mask( + h, cache[self._first_sliding], window_size=self.args.sliding_window + ) + else: + sliding_mask = None + + for layer, c, t in zip(self.layers, cache, self.args.layer_types): + mask = full_mask if t == "full_attention" else sliding_mask + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = MellumModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + input_embeddings: Optional[mx.array] = None, + ) -> mx.array: + out = self.model(inputs, cache, input_embeddings) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + if f"{prefix}.mlp.experts.0.{n}.weight" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack(to_join) + return weights + + @property + def quant_predicate(self): + def predicate(path, _): + if path.endswith("mlp.gate"): + return {"group_size": 64, "bits": 8} + return True + + return predicate + + @property + def layers(self): + return self.model.layers + + def make_cache(self): + caches = [] + for t in self.args.layer_types: + if t == "full_attention": + caches.append(KVCache()) + else: + caches.append(RotatingKVCache(max_size=self.args.sliding_window)) + return caches