diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..91ab3f622 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -300,6 +300,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ if kv_bits is None: return for e, c in enumerate(prompt_cache): + if not getattr(c, "supports_quantized_kv_cache", True): + continue if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start: prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..8d6971480 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -407,6 +407,13 @@ def nbytes(self): return self.keys.nbytes + self.values.nbytes +class MLACache(KVCache): + supports_quantized_kv_cache = False + + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + raise NotImplementedError("MLA cache quantization is not supported") + + class RotatingKVCache(_BaseCache): step = 256 diff --git a/mlx_lm/models/deepseek_v3.py b/mlx_lm/models/deepseek_v3.py index a2766e59e..1d39a8f36 100644 --- a/mlx_lm/models/deepseek_v3.py +++ b/mlx_lm/models/deepseek_v3.py @@ -11,6 +11,7 @@ from .activations import swiglu from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import MLACache from .mla import MultiLinear from .pipeline import PipelineMixin from .rope_utils import initialize_rope @@ -360,6 +361,9 @@ def __call__( return self.norm(h) + def make_cache(self): + return [MLACache() for _ in self.pipeline_layers] + class Model(nn.Module): def __init__(self, config: ModelArgs): @@ -377,6 +381,9 @@ def __call__( out = self.model(inputs, cache) return self.lm_head(out) + def make_cache(self): + return self.model.make_cache() + def sanitize(self, weights): def dequant(weight, scale_inv): dtype = mx.bfloat16 diff --git a/mlx_lm/models/deepseek_v32.py b/mlx_lm/models/deepseek_v32.py index 7c97682e7..7a6dc7897 100644 --- a/mlx_lm/models/deepseek_v32.py +++ b/mlx_lm/models/deepseek_v32.py @@ -10,7 +10,7 @@ from .activations import swiglu from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from .cache import CacheList, KVCache +from .cache import CacheList, KVCache, MLACache from .mla import MultiLinear from .rope_utils import initialize_rope from .switch_layers import SwitchGLU @@ -651,4 +651,4 @@ def predicate(k): return predicate def make_cache(self): - return [CacheList(KVCache(), KVCache()) for _ in self.layers] + return [CacheList(MLACache(), KVCache()) for _ in self.layers] diff --git a/mlx_lm/models/kimi_k25.py b/mlx_lm/models/kimi_k25.py index 089830380..b573fb190 100644 --- a/mlx_lm/models/kimi_k25.py +++ b/mlx_lm/models/kimi_k25.py @@ -38,6 +38,9 @@ def __call__( out = self.model(inputs, cache) return self.lm_head(out) + def make_cache(self): + return self.model.make_cache() + class Model(nn.Module): def __init__(self, config: ModelArgs): @@ -53,6 +56,9 @@ def __call__( ): return self.language_model(inputs, cache) + def make_cache(self): + return self.language_model.make_cache() + def sanitize(self, weights): weights = tree_unflatten(list(weights.items())) weights.pop("vision_tower", None) diff --git a/tests/test_models.py b/tests/test_models.py index 6e1fcd96e..0469ad0ab 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,13 @@ from mlx_lm.models import rope_utils from mlx_lm.models.base import create_causal_mask, scaled_dot_product_attention -from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache +from mlx_lm.models.cache import ( + KVCache, + MLACache, + QuantizedKVCache, + RotatingKVCache, + make_prompt_cache, +) from mlx_lm.models.gated_delta import ( gated_delta_kernel, gated_delta_ops, @@ -1422,6 +1428,98 @@ def test_deepseek_v3(self): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_mla_models_make_mla_caches(self): + from mlx_lm.models import deepseek_v3, deepseek_v32, kimi_k25 + + v3_args = deepseek_v3.ModelArgs( + model_type="deepseek_v3", + vocab_size=64, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + n_routed_experts=None, + n_shared_experts=None, + kv_lora_rank=32, + q_lora_rank=32, + qk_rope_head_dim=32, + v_head_dim=32, + qk_nope_head_dim=32, + max_position_embeddings=64, + ) + self.assertIsInstance( + make_prompt_cache(deepseek_v3.Model(v3_args))[0], + MLACache, + ) + kimi = kimi_k25.Model(kimi_k25.ModelArgs(text_config=v3_args)) + self.assertIsInstance(make_prompt_cache(kimi)[0], MLACache) + + v32_args = deepseek_v32.ModelArgs( + model_type="deepseek_v32", + vocab_size=64, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + n_routed_experts=None, + n_shared_experts=None, + kv_lora_rank=32, + q_lora_rank=32, + qk_rope_head_dim=32, + v_head_dim=32, + qk_nope_head_dim=32, + index_head_dim=32, + index_n_heads=2, + max_position_embeddings=64, + ) + v32_cache = make_prompt_cache(deepseek_v32.Model(v32_args)) + self.assertIsInstance(v32_cache[0][0], MLACache) + self.assertIsInstance(v32_cache[0][1], KVCache) + + def test_deepseek_v3_kv_bits_skips_mla_cache(self): + from mlx_lm.generate import generate_step + from mlx_lm.models import deepseek_v3 + + args = deepseek_v3.ModelArgs( + model_type="deepseek_v3", + vocab_size=64, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=16, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + n_routed_experts=None, + n_shared_experts=None, + kv_lora_rank=32, + q_lora_rank=32, + qk_rope_head_dim=32, + v_head_dim=32, + qk_nope_head_dim=32, + max_position_embeddings=64, + ) + model = deepseek_v3.Model(args) + prompt_cache = make_prompt_cache(model) + + self.assertIsInstance(prompt_cache[0], MLACache) + next( + generate_step( + mx.array([1, 2, 3]), + model, + max_tokens=1, + prompt_cache=prompt_cache, + prefill_step_size=2, + kv_bits=4, + kv_group_size=32, + ) + ) + self.assertIsInstance(prompt_cache[0], MLACache) + self.assertNotIsInstance(prompt_cache[0], QuantizedKVCache) + def test_gemma2(self): from mlx_lm.models import gemma2