fix: resolve Metal kernel runtime crash with bfloat16 dtype #228
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes a runtime crash in the MLX Metal kernel when processing
bfloat16
tensors, where the incompatibility betweenbfloat16
andfloat
types causes Metal compilation failures at runtime.Problem
The Metal kernel in
outlines_core/kernels/mlx.py
was directly using-INFINITY
(a float literal) in the kernel source. This causes a runtime crash when the input tensor has dtypebfloat16
, as Metal doesn't support implicit conversion betweenbfloat16
andfloat
types.Solution
A minor change following the same approach applied in the LLGuidance package by providing a compatible tensor for the -INF.
Changes
neg_inf
as an additional input parameterTesting
✅ All tests pass:
Type of Change
To the best of my knowledge, this is a minimal, focused fix that resolves the runtime crash without affecting the kernel's logic or performance characteristics.