Skip to content

fix(ssd-cache): preserve original bfloat16 dtype across quantized spill#612

Open
CBribiescas wants to merge 1 commit into
waybarrios:mainfrom
CBribiescas:fix/ssd-quantized-dtype-roundtrip
Open

fix(ssd-cache): preserve original bfloat16 dtype across quantized spill#612
CBribiescas wants to merge 1 commit into
waybarrios:mainfrom
CBribiescas:fix/ssd-quantized-dtype-roundtrip

Conversation

@CBribiescas

Copy link
Copy Markdown
Contributor

Summary

Follow-up to #605 per @waybarrios's review comment: the dequant-on-spill path casts bf16 → fp16 (to dodge numpy's PEP 3118 buffer-protocol mismatch) but didn't record the original dtype. Reload then hands the model fp16 KV where it had computed bf16 — silent precision regression on every SSD reload of a quantized-cache model.

Root cause

_mx_to_numpy_safe only stamps a dtype hint when it did the upcast (bf16 → fp32 fallback). The quantized-spill path pre-casts bf16 → fp16 before snapshot_layer is called, and since fp16 IS numpy-supported, _mx_to_numpy_safe returns dtype=None. The snapshot then carries no signal that the data originated as bf16, and scheduler._reconstruct_ssd_layers has nothing to cast back from.

Fix

Mirror the existing plain-cache pattern: stash an explicit dtype sentinel on the layer before the cast, and have KVCacheSerializer.snapshot_layer honor it with priority over the autodetect path:

# enqueue_spill (dequant branch):
if str(getattr(k, "dtype", "")).endswith("bfloat16"):
    layer._ssd_keys_original_dtype = "bfloat16"
    layer._ssd_values_original_dtype = "bfloat16"
    layer.keys = k.astype(mx.float16)
    layer.values = v.astype(mx.float16)

# KVCacheSerializer.snapshot_layer:
keys_orig_dtype = (
    getattr(layer, "_ssd_keys_original_dtype", None) or keys_orig_dtype
)

Sentinel name is _ssd_-prefixed to avoid colliding with any mlx-lm-side cache attribute.

Also add "QuantizedKVCache": "supported_via_dequant_on_spill" to SERIALIZER_SUPPORT_MATRIX so the diagnostics table matches what enqueue_spill actually handles.

Reload side

No change. scheduler.py:3019-3028 already reads keys_original_dtype from the deserialized dict and casts back via mx.array(...).astype(mx.bfloat16). Only the spill-side signal was missing.

Tests

Four regression tests in TestLayerSerializer:

  • test_support_matrix_includes_quantized_kv_cache — guards the matrix entry against drift.
  • test_snapshot_layer_honors_original_dtype_sentinel — sentinel set, no autodetect signal: must record "bfloat16".
  • test_snapshot_layer_sentinel_overrides_autodetect — sentinel wins even when _mx_to_numpy_safe also produces a hint.
  • test_snapshot_layer_no_sentinel_falls_back_to_autodetect — plain fp16 layers still record no dtype hint (no regression on the existing path).

Verification

pytest tests/test_ssd_cache.py
63 passed in 3.09s

black --check vllm_mlx/ssd_cache.py tests/test_ssd_cache.py
2 files would be left unchanged

Refs #605 #563.

Per @waybarrios's follow-up on PR waybarrios#605: the dequant-on-spill path casts
bf16 → fp16 (to dodge numpy's PEP 3118 buffer-protocol mismatch) but
didn't record the original dtype. `_mx_to_numpy_safe` only stamps a
dtype hint when *it* did the upcast; fp16 IS numpy-supported so it
returned `None`, and the snapshot carried no signal. The
scheduler-side `_reconstruct_ssd_layers` then handed the model fp16
KV where it had computed bf16 — silent precision regression on every
SSD reload of a quantized-cache model.

Mirror the existing plain-cache pattern: stash an explicit dtype
sentinel on the layer before the cast, and have
`KVCacheSerializer.snapshot_layer` honor it with priority over the
autodetect path. Sentinel name is namespaced (`_ssd_keys_original_dtype`
/ `_ssd_values_original_dtype`) to avoid colliding with any
mlx-lm-side cache attribute.

Also add `"QuantizedKVCache": "supported_via_dequant_on_spill"` to
SERIALIZER_SUPPORT_MATRIX so the diagnostics table matches what
enqueue_spill actually handles.

Reload-side change unnecessary: scheduler.py:3019-3028 already reads
`keys_original_dtype` from the deserialized dict and casts back via
`mx.array(...).astype(mx.bfloat16)` — only the spill-side signal was
missing.

Four regression tests in TestLayerSerializer:
- `test_support_matrix_includes_quantized_kv_cache` — guards the
  matrix entry against drift.
- `test_snapshot_layer_honors_original_dtype_sentinel` — sentinel set,
  no autodetect signal: must record "bfloat16".
- `test_snapshot_layer_sentinel_overrides_autodetect` — sentinel wins
  even when _mx_to_numpy_safe also produces a hint.
- `test_snapshot_layer_no_sentinel_falls_back_to_autodetect` — plain
  fp16 layers still record no dtype hint.

63 ssd_cache tests pass; black clean.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant