Skip to content

fix(ssd-cache): spill native QuantizedKVCache + handle bfloat16 buffer#605

Merged
waybarrios merged 2 commits into
waybarrios:mainfrom
CBribiescas:pr-602-ssd-quantized-spill
Jun 11, 2026
Merged

fix(ssd-cache): spill native QuantizedKVCache + handle bfloat16 buffer#605
waybarrios merged 2 commits into
waybarrios:mainfrom
CBribiescas:pr-602-ssd-quantized-spill

Conversation

@CBribiescas

Copy link
Copy Markdown
Contributor

Recreates #602 (branch was deleted) on top of latest main.

Problem

With --kv-cache-quantization enabled, KV layers are QuantizedKVCache whose keys/values are tuples of (packed, scales, biases). The SSD-spill serializer didn't recognize this layout, so spilling a quantized KV cache failed. Separately, numpy's PEP 3118 buffer protocol can't read bfloat16, which broke serialization of bf16 tensors.

Fix

  • Detect quantized layers (_is_quantized_layer) and route them through a serializer that handles the (packed, scales, biases) tuple.
  • Cast bfloat16 to a numpy-readable dtype before serialize.
  • Test coverage in tests/test_ssd_cache.py.

Supersedes the closed #602; rebased clean on current main.

CBribiescas and others added 2 commits June 10, 2026 12:13
When --kv-cache-quantization is enabled, KV layers are QuantizedKVCache whose
keys/values are tuples of (packed, scales, biases); the spill serializer didn't
recognize them. Also: numpy's PEP 3118 buffer protocol can't read bfloat16, so
cast before serialize. Recreates closed waybarrios#602 on top of latest main.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Single-line wrap on the `mx.dequantize(*layer.keys, ...)` call in the
native QuantizedKVCache spill path. CI lint was failing on the original
b51d8f4 push because the line exceeded black's 88-char limit.
@waybarrios waybarrios merged commit b67edee into waybarrios:main Jun 11, 2026
9 checks passed
@waybarrios

Copy link
Copy Markdown
Owner

Merged, thanks. One follow-up worth a small PR: the dequant path casts bf16 to fp16 before spilling but doesn't record the original dtype, so restore hands back fp16 where the model computed bf16. The plain-cache path already solves this with metadata:

# existing pattern in _mx_to_numpy_safe
snapshot["keys_original_dtype"] = "bfloat16"  # restored on reload

Mirroring that for the dequantized layers would keep the round-trip dtype-faithful. While at it, SERIALIZER_SUPPORT_MATRIX could gain a "QuantizedKVCache": "supported_via_dequant_on_spill" entry so the diagnostics table matches reality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants