Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if not getattr(c, "supports_quantized_kv_cache", True):
continue
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)

Expand Down
7 changes: 7 additions & 0 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ def nbytes(self):
return self.keys.nbytes + self.values.nbytes


class MLACache(KVCache):
supports_quantized_kv_cache = False

def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("MLA cache quantization is not supported")


class RotatingKVCache(_BaseCache):
step = 256

Expand Down
7 changes: 7 additions & 0 deletions mlx_lm/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .activations import swiglu
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MLACache
from .mla import MultiLinear
from .pipeline import PipelineMixin
from .rope_utils import initialize_rope
Expand Down Expand Up @@ -360,6 +361,9 @@ def __call__(

return self.norm(h)

def make_cache(self):
return [MLACache() for _ in self.pipeline_layers]


class Model(nn.Module):
def __init__(self, config: ModelArgs):
Expand All @@ -377,6 +381,9 @@ def __call__(
out = self.model(inputs, cache)
return self.lm_head(out)

def make_cache(self):
return self.model.make_cache()

def sanitize(self, weights):
def dequant(weight, scale_inv):
dtype = mx.bfloat16
Expand Down
4 changes: 2 additions & 2 deletions mlx_lm/models/deepseek_v32.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .activations import swiglu
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import CacheList, KVCache
from .cache import CacheList, KVCache, MLACache
from .mla import MultiLinear
from .rope_utils import initialize_rope
from .switch_layers import SwitchGLU
Expand Down Expand Up @@ -651,4 +651,4 @@ def predicate(k):
return predicate

def make_cache(self):
return [CacheList(KVCache(), KVCache()) for _ in self.layers]
return [CacheList(MLACache(), KVCache()) for _ in self.layers]
6 changes: 6 additions & 0 deletions mlx_lm/models/kimi_k25.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def __call__(
out = self.model(inputs, cache)
return self.lm_head(out)

def make_cache(self):
return self.model.make_cache()


class Model(nn.Module):
def __init__(self, config: ModelArgs):
Expand All @@ -53,6 +56,9 @@ def __call__(
):
return self.language_model(inputs, cache)

def make_cache(self):
return self.language_model.make_cache()

def sanitize(self, weights):
weights = tree_unflatten(list(weights.items()))
weights.pop("vision_tower", None)
Expand Down
100 changes: 99 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@

from mlx_lm.models import rope_utils
from mlx_lm.models.base import create_causal_mask, scaled_dot_product_attention
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
from mlx_lm.models.cache import (
KVCache,
MLACache,
QuantizedKVCache,
RotatingKVCache,
make_prompt_cache,
)
from mlx_lm.models.gated_delta import (
gated_delta_kernel,
gated_delta_ops,
Expand Down Expand Up @@ -1422,6 +1428,98 @@ def test_deepseek_v3(self):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)

def test_mla_models_make_mla_caches(self):
from mlx_lm.models import deepseek_v3, deepseek_v32, kimi_k25

v3_args = deepseek_v3.ModelArgs(
model_type="deepseek_v3",
vocab_size=64,
hidden_size=64,
intermediate_size=128,
moe_intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=2,
n_routed_experts=None,
n_shared_experts=None,
kv_lora_rank=32,
q_lora_rank=32,
qk_rope_head_dim=32,
v_head_dim=32,
qk_nope_head_dim=32,
max_position_embeddings=64,
)
self.assertIsInstance(
make_prompt_cache(deepseek_v3.Model(v3_args))[0],
MLACache,
)
kimi = kimi_k25.Model(kimi_k25.ModelArgs(text_config=v3_args))
self.assertIsInstance(make_prompt_cache(kimi)[0], MLACache)

v32_args = deepseek_v32.ModelArgs(
model_type="deepseek_v32",
vocab_size=64,
hidden_size=64,
intermediate_size=128,
moe_intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=2,
n_routed_experts=None,
n_shared_experts=None,
kv_lora_rank=32,
q_lora_rank=32,
qk_rope_head_dim=32,
v_head_dim=32,
qk_nope_head_dim=32,
index_head_dim=32,
index_n_heads=2,
max_position_embeddings=64,
)
v32_cache = make_prompt_cache(deepseek_v32.Model(v32_args))
self.assertIsInstance(v32_cache[0][0], MLACache)
self.assertIsInstance(v32_cache[0][1], KVCache)

def test_deepseek_v3_kv_bits_skips_mla_cache(self):
from mlx_lm.generate import generate_step
from mlx_lm.models import deepseek_v3

args = deepseek_v3.ModelArgs(
model_type="deepseek_v3",
vocab_size=64,
hidden_size=64,
intermediate_size=128,
moe_intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=2,
n_routed_experts=None,
n_shared_experts=None,
kv_lora_rank=32,
q_lora_rank=32,
qk_rope_head_dim=32,
v_head_dim=32,
qk_nope_head_dim=32,
max_position_embeddings=64,
)
model = deepseek_v3.Model(args)
prompt_cache = make_prompt_cache(model)

self.assertIsInstance(prompt_cache[0], MLACache)
next(
generate_step(
mx.array([1, 2, 3]),
model,
max_tokens=1,
prompt_cache=prompt_cache,
prefill_step_size=2,
kv_bits=4,
kv_group_size=32,
)
)
self.assertIsInstance(prompt_cache[0], MLACache)
self.assertNotIsInstance(prompt_cache[0], QuantizedKVCache)

def test_gemma2(self):
from mlx_lm.models import gemma2

Expand Down