Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions mlx_lm/models/mellum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# Copyright © 2026 Apple Inc.

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import mlx.core as mx
import mlx.nn as nn

from .activations import swiglu
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
from .rope_utils import initialize_rope
from .switch_layers import SwitchGLU


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_experts: int
num_experts_per_tok: int
moe_intermediate_size: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
head_dim: int
tie_word_embeddings: bool
max_position_embeddings: int
norm_topk_prob: bool
sliding_window: int
layer_types: List[str]
rope_parameters: Dict[str, Any] = field(default_factory=dict)


def _rope_for(layer_type: str, args: ModelArgs):
params = args.rope_parameters[layer_type]
base = params["rope_theta"]
rope_type = params.get("rope_type", "default")
if rope_type in ("default", "linear"):
return initialize_rope(args.head_dim, base=base, traditional=False)
scaling_config = dict(params)
scaling_config["type"] = rope_type
return initialize_rope(
args.head_dim,
base=base,
traditional=False,
scaling_config=scaling_config,
max_position_embeddings=args.max_position_embeddings,
)


class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.head_dim
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)

self.rope = _rope_for(args.layer_types[layer_idx], args)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(
0, 2, 1, 3
)
keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(
0, 2, 1, 3
)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)

output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class MellumSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
self.norm_topk_prob = args.norm_topk_prob

self.gate = nn.Linear(dim, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(dim, args.moe_intermediate_size, self.num_experts)

def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
gates = mx.softmax(gates, axis=-1, precise=True)

k = self.top_k
inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:]
scores = mx.take_along_axis(gates, inds, axis=-1)
if self.norm_topk_prob:
scores /= mx.sum(scores, axis=-1, keepdims=True)

y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y


class MellumDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = Attention(args, layer_idx)
self.mlp = MellumSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r


class MellumModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
MellumDecoderLayer(args=args, layer_idx=i)
for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

self._first_full = next(
i for i, t in enumerate(args.layer_types) if t == "full_attention"
)
self._first_sliding = next(
(i for i, t in enumerate(args.layer_types) if t == "sliding_attention"),
None,
)

def __call__(
self,
inputs: mx.array,
cache=None,
input_embeddings: Optional[mx.array] = None,
) -> mx.array:
if input_embeddings is not None:
h = input_embeddings
else:
h = self.embed_tokens(inputs)

if cache is None:
cache = [None] * len(self.layers)

full_mask = create_attention_mask(h, cache[self._first_full])
if self._first_sliding is not None:
sliding_mask = create_attention_mask(
h, cache[self._first_sliding], window_size=self.args.sliding_window
)
else:
sliding_mask = None

for layer, c, t in zip(self.layers, cache, self.args.layer_types):
mask = full_mask if t == "full_attention" else sliding_mask
h = layer(h, mask, c)

return self.norm(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = MellumModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
input_embeddings: Optional[mx.array] = None,
) -> mx.array:
out = self.model(inputs, cache, input_embeddings)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out

def sanitize(self, weights):
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
if f"{prefix}.mlp.experts.0.{n}.weight" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.weight")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.weight"] = mx.stack(to_join)
return weights

@property
def quant_predicate(self):
def predicate(path, _):
if path.endswith("mlp.gate"):
return {"group_size": 64, "bits": 8}
return True

return predicate

@property
def layers(self):
return self.model.layers

def make_cache(self):
caches = []
for t in self.args.layer_types:
if t == "full_attention":
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.args.sliding_window))
return caches
Loading