Skip to content

Conversation

Ahmed-Ali
Copy link

@Ahmed-Ali Ahmed-Ali commented Aug 20, 2025

Summary

This PR fixes a runtime crash in the MLX Metal kernel when processing bfloat16 tensors, where the incompatibility between bfloat16 and float types causes Metal compilation failures at runtime.

Problem

The Metal kernel in outlines_core/kernels/mlx.py was directly using -INFINITY (a float literal) in the kernel source. This causes a runtime crash when the input tensor has dtype bfloat16, as Metal doesn't support implicit conversion between bfloat16 and float types.

Solution

A minor change following the same approach applied in the LLGuidance package by providing a compatible tensor for the -INF.

Changes

  • Modified the Metal kernel to accept neg_inf as an additional input parameter
  • Updated kernel invocation to pass a properly typed negative infinity value
  • Maintained backward compatibility and performance characteristics

Testing

All tests pass:

  • Applied all the tests mentioned in the readme fully and all passes.
  • Also ran the benchmark tests and nothing stood out.

Type of Change

  • Bug fix for a runtime crash in the happy flow

To the best of my knowledge, this is a minimal, focused fix that resolves the runtime crash without affecting the kernel's logic or performance characteristics.

@Ahmed-Ali Ahmed-Ali marked this pull request as ready for review August 20, 2025 03:46
@unaidedelf8777
Copy link
Contributor

I would suggest instead using a template in the metal kernel, that way there is less FFI pass through and the float("-inf") is not constructed every run like it is in this current version. in the kernel you can add:

/// per IEEE 754 this is equivalent to -inf.
T neg_inf = -(T(1.0) / T(0.0));

and change

out[batch * inp_shape[1] + elem] = bit ? inp[batch * inp_shape[1] + elem] : neg_inf[0];
/// to
out[batch * inp_shape[1] + elem] = bit ? inp[batch * inp_shape[1] + elem] : neg_inf;

and when dispatching the kernel you will need to add the template arg:

@mx.compile
def _apply_token_bitmask_kernel(data: mx.array, mask: mx.array) -> mx.array:
    return _KERNEL(
        inputs=[data, mask],
        template=[("T", data.dtype)], # this
        grid=(data.shape[1], data.shape[0], 1),
        threadgroup=(256, 1, 1),
        output_shapes=[data.shape],
        output_dtypes=[data.dtype],
    )[0]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants