fix(ssd-cache): preserve original bfloat16 dtype across quantized spill#612
Open
CBribiescas wants to merge 1 commit into
Open
fix(ssd-cache): preserve original bfloat16 dtype across quantized spill#612CBribiescas wants to merge 1 commit into
CBribiescas wants to merge 1 commit into
Conversation
Per @waybarrios's follow-up on PR waybarrios#605: the dequant-on-spill path casts bf16 → fp16 (to dodge numpy's PEP 3118 buffer-protocol mismatch) but didn't record the original dtype. `_mx_to_numpy_safe` only stamps a dtype hint when *it* did the upcast; fp16 IS numpy-supported so it returned `None`, and the snapshot carried no signal. The scheduler-side `_reconstruct_ssd_layers` then handed the model fp16 KV where it had computed bf16 — silent precision regression on every SSD reload of a quantized-cache model. Mirror the existing plain-cache pattern: stash an explicit dtype sentinel on the layer before the cast, and have `KVCacheSerializer.snapshot_layer` honor it with priority over the autodetect path. Sentinel name is namespaced (`_ssd_keys_original_dtype` / `_ssd_values_original_dtype`) to avoid colliding with any mlx-lm-side cache attribute. Also add `"QuantizedKVCache": "supported_via_dequant_on_spill"` to SERIALIZER_SUPPORT_MATRIX so the diagnostics table matches what enqueue_spill actually handles. Reload-side change unnecessary: scheduler.py:3019-3028 already reads `keys_original_dtype` from the deserialized dict and casts back via `mx.array(...).astype(mx.bfloat16)` — only the spill-side signal was missing. Four regression tests in TestLayerSerializer: - `test_support_matrix_includes_quantized_kv_cache` — guards the matrix entry against drift. - `test_snapshot_layer_honors_original_dtype_sentinel` — sentinel set, no autodetect signal: must record "bfloat16". - `test_snapshot_layer_sentinel_overrides_autodetect` — sentinel wins even when _mx_to_numpy_safe also produces a hint. - `test_snapshot_layer_no_sentinel_falls_back_to_autodetect` — plain fp16 layers still record no dtype hint. 63 ssd_cache tests pass; black clean.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to #605 per @waybarrios's review comment: the dequant-on-spill path casts bf16 → fp16 (to dodge numpy's PEP 3118 buffer-protocol mismatch) but didn't record the original dtype. Reload then hands the model fp16 KV where it had computed bf16 — silent precision regression on every SSD reload of a quantized-cache model.
Root cause
_mx_to_numpy_safeonly stamps a dtype hint when it did the upcast (bf16 → fp32 fallback). The quantized-spill path pre-casts bf16 → fp16 beforesnapshot_layeris called, and since fp16 IS numpy-supported,_mx_to_numpy_safereturnsdtype=None. The snapshot then carries no signal that the data originated as bf16, andscheduler._reconstruct_ssd_layershas nothing to cast back from.Fix
Mirror the existing plain-cache pattern: stash an explicit dtype sentinel on the layer before the cast, and have
KVCacheSerializer.snapshot_layerhonor it with priority over the autodetect path:Sentinel name is
_ssd_-prefixed to avoid colliding with any mlx-lm-side cache attribute.Also add
"QuantizedKVCache": "supported_via_dequant_on_spill"toSERIALIZER_SUPPORT_MATRIXso the diagnostics table matches whatenqueue_spillactually handles.Reload side
No change.
scheduler.py:3019-3028already readskeys_original_dtypefrom the deserialized dict and casts back viamx.array(...).astype(mx.bfloat16). Only the spill-side signal was missing.Tests
Four regression tests in
TestLayerSerializer:test_support_matrix_includes_quantized_kv_cache— guards the matrix entry against drift.test_snapshot_layer_honors_original_dtype_sentinel— sentinel set, no autodetect signal: must record"bfloat16".test_snapshot_layer_sentinel_overrides_autodetect— sentinel wins even when_mx_to_numpy_safealso produces a hint.test_snapshot_layer_no_sentinel_falls_back_to_autodetect— plain fp16 layers still record no dtype hint (no regression on the existing path).Verification
Refs #605 #563.