Skip to content

fix(ssd-cache): unblock spill for native QuantizedKVCache + bfloat16#602

Closed
CBribiescas wants to merge 1 commit into
waybarrios:mainfrom
CBribiescas:fix/ssd-spill-dequant-and-bf16-cast
Closed

fix(ssd-cache): unblock spill for native QuantizedKVCache + bfloat16#602
CBribiescas wants to merge 1 commit into
waybarrios:mainfrom
CBribiescas:fix/ssd-spill-dequant-and-bf16-cast

Conversation

@CBribiescas

Copy link
Copy Markdown
Contributor

Note: just back from vacation — should be responsive on review feedback or follow-ups.

Summary

With --continuous-batching --kv-cache-quantization --ssd-cache-dir ... set,
every eviction in the running server silently fails to spill: the writer
thread crashes on np.array(layer.keys) with a misleading PEP 3118 error,
the writer loop catches the exception, and the SSD cache dir stays empty
forever. RAM evictions still happen, so KV gets discarded with no disk
fallback — defeating the whole point of the SSD cold tier.

The existing 4 SSD-cache tests pass plain KVCache mocks (MockKVCacheLayer),
which is why the bug was invisible to CI. Adding real-array roundtrip tests
exposes it immediately.

Two distinct issues fixed in SSDCacheTier.enqueue_spill (the single
chokepoint between in-process eviction and the async disk writer):

1. Quantized-layer detection was wrapper-only

Pre-patch:

if any(isinstance(layer, _QuantizedCacheWrapper) for layer in cache):
    cache = _dequantize_cache(cache)

But under --kv-cache-quantization, models can produce mlx-lm's native
QuantizedKVCache directly — whose .keys / .values are
(packed_uint32, scales_fp16, biases_fp16) tuples. Those slipped through
the gate, reached the writer thread, and crashed with
ValueError: setting an array element with a sequence ... inhomogeneous shape
when np.array(layer.keys) tried to stack three differently-shaped arrays.

Post-patch: detect any layer whose .keys is a tuple/list (covers both
wrapper and native), then dequantize manually using mlx-lm's
QuantizedKVCache(group_size, bits) parameters into a fresh KVCache-shaped
duck-type for the serializer.

2. bfloat16 is unrepresentable to numpy via PEP 3118

mx.dequantize(...) returns the original cache's dtype — for many models
(Qwen-3-Coder, gpt-oss-120B, etc.) that's bfloat16. Numpy's buffer protocol
sees bf16 as format=B (uint8) item_size=2 but its dtype-B table says
item_size=1, raising:

RuntimeError: Item size 2 for PEP 3118 buffer format string B does not
match the dtype B item size 1.

at np.array(layer.keys).

Post-patch: cast bf16 → float16 on the caller thread before queueing.
KV values sit well within fp16's ±65504 range, fp16 actually has more
mantissa bits than bf16 (10 vs 7), and byte size is identical. Then
mx.eval forces host-side materialization so the writer thread does only
a buffer copy (no GPU stream context — that's a separate no Stream(gpu,N) in current thread crash class, already handled by keeping dequant on the
caller).

3. Test cleanup (incidental)

4 pre-existing SSD test failures had nothing to do with this patch — they
created cache entries with 10-token sequences but the default
min_prefix_tokens=128 silently rejected them in store(), leaving the
cache empty. Tests then asserted len(cache) == 3 against 0.
Fixed by adding min_prefix_tokens=1 to the 6 MemoryCacheConfig(...)
calls in test_ssd_cache.py.

Type of change

  • Bug fix

Surface touched

  • KV cache / prefix cache
  • Tests only (for the min_prefix_tokens test fix)

Test plan

```bash
pytest tests/test_ssd_cache.py -q
```

Three new tests in TestQuantizedSpillRoundtrip exercise the real
spill→reload roundtrip end-to-end (no mocks), covering each bug surface:

  • test_wrapper_quantized_layer_round_trips_QuantizedCacheWrapper regression
  • test_native_quantized_kv_cache_round_trips — mlx-lm native QuantizedKVCache
  • test_bfloat16_layer_round_trips_via_fp16_cast — asserts persisted dtype is fp16

Test results

```
tests/test_ssd_cache.py ............................................... (61 passed)
============================== 61 passed in 2.71s ==============================
```

Before this PR: 54 passed / 4 failed on tests/test_ssd_cache.py
(the 4 broken-by-min_prefix_tokens tests were pre-existing). After: 61 passed.

Production validation

Validated against live Qwen-3-Coder-30B-A3B-Instruct-MLX-4bit under
--continuous-batching --ssd-cache-dir ... --kv-cache-quantization:

Before patch: every eviction logged:
```
ERROR:vllm_mlx.ssd_cache:[ssd_cache] failed to write entry (4894 tokens)
Traceback (most recent call last):
File ".../vllm_mlx/ssd_cache.py", line 650, in _writer_loop
File ".../vllm_mlx/ssd_cache.py", line 719, in _write_entry
File ".../vllm_mlx/ssd_cache.py", line 465, in serialize_layer
keys_np = np.array(layer.keys)
RuntimeError: Item size 2 for PEP 3118 buffer format string B does not match the dtype B item size 1.
```
SSD cache dir stayed at 0 B despite 60+ evictions.

After patch:

  • [ssd_cache] dequantized 48 layers before spill (1705 tokens) (info log on each spill)
  • 7+ entries spilled, ~600 MB safetensors landing on disk
  • Cache HIT on a prefix-extending follow-up query: cached=1705 remaining=19
    (only 19 new tokens needed prefill; 95% reuse)
  • Zero failed-write errors, zero crashes, zero quarantined entries

Performance impact

N/A — eviction path only, off the request hot path. The dequantize is exactly
what an SSD-promoted entry would do anyway on reload; this just moves the
cost from "first reload" to "spill time" while ensuring the spill can land
at all. The fp16 cast adds one astype per layer (48 layers for Qwen-30B,
~36 for gpt-oss-120B, ~16 GB/s GPU mem bandwidth — sub-millisecond).

Output parity

The fp16 cast is lossy relative to the original bf16 — values outside
±65504 become inf. In practice KV values are normalized and never approach
that range; tested against Qwen-3-Coder with normalized prompts and saw no
divergence in re-served output vs cold prefill. Future improvement: store
the bf16 buffer as a uint16 view + dtype metadata, which is lossless but
needs a matching deserialize-side change (out of scope for this fix).

Risk notes

  • Patch is local to SSDCacheTier.enqueue_spill; no scheduler / engine / sampler interaction.
  • Three new info-level log lines per spill: dequantized N layers before spill. Quiet at production levels but noisy at debug. (Trim to debug if reviewers prefer.)
  • The fp16 cast is lossy in the theoretical worst case but lossless in practice for KV-normalized values — see Output Parity above.
  • Pre-patch, --kv-cache-quantization + --ssd-cache-dir was effectively a footgun: it appeared to enable a cold tier but silently dropped every eviction. This PR makes the flags do what they advertise.

When `--continuous-batching --kv-cache-quantization --ssd-cache-dir ...` is
set, every eviction in the running server silently failed to spill: the writer
thread crashed on `np.array(layer.keys)` with a misleading PEP 3118 error,
the writer loop caught it, and the SSD cache dir stayed empty forever. RAM
evictions still happened, so KV got discarded with no disk fallback. The
existing 4 SSD-cache tests in this file pass plain `KVCache` mocks and never
exercised the quantized path, so the bug was invisible to CI.

Two distinct issues fixed in `SSDCacheTier.enqueue_spill` (single chokepoint
between the in-process eviction and the async disk writer):

1. **Quantized-layer detection gate was wrapper-only.**
   The previous gate matched only our `_QuantizedCacheWrapper` (set when our
   `_quantize_cache` wraps a plain `KVCache`). But under `--kv-cache-quantization`
   models can ALSO produce mlx-lm's native `QuantizedKVCache` directly, whose
   `.keys` / `.values` are `(packed, scales, biases)` tuples. Those slipped
   through the gate, reached the writer thread, and crashed with
   `ValueError: setting an array element with a sequence ... inhomogeneous shape`
   when `np.array(layer.keys)` tried to stack the three arrays. The new gate
   detects any layer whose `.keys` is a tuple/list (covers both wrapper and
   native) and dequantizes in place — using `mlx-lm.models.cache.QuantizedKVCache`'s
   shape — into a fresh `KVCache`-shaped duck-type for the serializer.

2. **bfloat16 is unrepresentable to numpy via PEP 3118.**
   `mx.dequantize(...)` returns the original cache's dtype, which for many
   models (Qwen-3-Coder, gpt-oss-120B, etc.) is bfloat16. Numpy's buffer
   protocol reads bf16 as `format=B (uint8) item_size=2` but its dtype-B
   table says item_size=1, raising
   `RuntimeError: Item size 2 for PEP 3118 buffer format string B does not
   match the dtype B item size 1.` at `np.array(layer.keys)`.
   Fix: cast bf16 → float16 on the caller thread before queueing. KV values
   sit well within fp16's ±65504 range, fp16 actually has *more* mantissa
   bits than bf16 (10 vs 7), and byte size is identical. After cast,
   `mx.eval` forces host-side materialization so the writer thread does only
   a buffer copy (no GPU stream context needed — that's a separate class of
   crash, "no Stream(gpu,N) in current thread", already handled by keeping
   dequant on the caller).

Also fixes 4 pre-existing test failures in `test_ssd_cache.py` that were
unrelated to the patch — the test entries used 10-token sequences but the
default `min_prefix_tokens=128` silently rejected them in `store()`, leaving
the cache empty so eviction tests asserted `len(cache) == 3` against `0`.
Added `min_prefix_tokens=1` to the 6 `MemoryCacheConfig(...)` calls in the
test file. With the patches applied, 61/61 SSD tests pass (was 54/58 before).

3 new tests added in `TestQuantizedSpillRoundtrip` exercise the actual
spill→reload roundtrip end-to-end (no mocks) for:
  - `_QuantizedCacheWrapper` (fork wrapper, regression)
  - mlx-lm native `QuantizedKVCache` (the new path)
  - bfloat16 + fp16 cast (asserts persisted dtype)

Validated against live prod (Qwen-3-Coder-30B at 4-bit) — pre-patch every
eviction logged `[ssd_cache] failed to write entry`; post-patch 7+ entries
spilled with ~600 MB safetensors landing on disk and a cache hit on a
prefix-extending follow-up query reused 1705/1724 prompt tokens (only 19
new tokens needed prefill).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@CBribiescas

Copy link
Copy Markdown
Contributor Author

Closing as a duplicate of #563, which I opened myself on May 22 and overlooked when filing this. #563 takes a cleaner architectural approach (snapshot_layer() on producer thread) and uses an fp32 fallback that round-trips the original bf16 dtype losslessly via manifest metadata; this PR's fp16 cast is lossy and the producer-thread fix here is redundant with #563. Please review #563 instead.

@CBribiescas CBribiescas closed this Jun 9, 2026
@CBribiescas CBribiescas deleted the fix/ssd-spill-dequant-and-bf16-cast branch June 9, 2026 10:04
waybarrios pushed a commit that referenced this pull request Jun 11, 2026
#605)

* fix(ssd-cache): spill native QuantizedKVCache + handle bfloat16 buffer

When --kv-cache-quantization is enabled, KV layers are QuantizedKVCache whose
keys/values are tuples of (packed, scales, biases); the spill serializer didn't
recognize them. Also: numpy's PEP 3118 buffer protocol can't read bfloat16, so
cast before serialize. Recreates closed #602 on top of latest main.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* fix lint: black-format ssd_cache.py

Single-line wrap on the `mx.dequantize(*layer.keys, ...)` call in the
native QuantizedKVCache spill path. CI lint was failing on the original
b51d8f4 push because the line exceeded black's 88-char limit.

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant