Skip to content

[pallas/mosaic] emulate fp32<->ui32 casts in terms of fp32<->si32 casts #28257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,6 +2143,31 @@ def _convert_helper(x, *, to_dtype):
return x.astype(to_dtype)
raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")


def fp32_to_ui32(x):
"""Implement fp32 -> ui32 conversion using the signed integer casting path."""
assert x.dtype == jnp.float32
umax = np.uint32(2 ** 31)
fmax = np.float32(umax)
return jnp.where(
x < fmax,
jnp.where(x <= 0, 0, x.astype('int32').astype('uint32')),
(x - fmax).astype('int32').astype('uint32') + umax,
)


def ui32_to_fp32(x):
"""Implement ui32 -> fp32 conversion using the signed integer casting path."""
assert x.dtype == jnp.uint32
umax = np.uint32(2 ** 31)
fmax = np.float32(umax)
return jnp.where(
x < umax,
x.astype('int32').astype('float32'),
(x - umax).astype('int32').astype('float32') + fmax,
)


def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
Expand All @@ -2166,6 +2191,7 @@ def _convert_element_type_lowering_rule(
floating = jnp.floating
integer = jnp.integer
signed = jnp.signedinteger
unsigned = jnp.unsignedinteger
both_32bit = old_dtype.itemsize == 4 and new_dtype.itemsize == 4
if _from(floating) and _to(floating):
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
Expand All @@ -2185,8 +2211,14 @@ def _convert_element_type_lowering_rule(
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
elif _from(floating) and _to(signed):
return arith.fptosi(out_type, x)
elif _from(floating) and _to(unsigned) and both_32bit:
# return arith.fptoui(out_type, x) # no llo lowering path.
return lower_fun(fp32_to_ui32, multiple_results=False)(ctx, x)
elif _from(signed) and _to(floating) and both_32bit:
return arith.sitofp(out_type, x)
elif _from(unsigned) and _to(floating) and both_32bit:
# return arith.uitofp(out_type, x) # no llo lowering path.
return lower_fun(ui32_to_fp32, multiple_results=False)(ctx, x)
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:
return arith.extui(out_type, x)
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
Expand Down
18 changes: 18 additions & 0 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,24 @@ def body(x_ref, o_ref):
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0)

def test_tpu_f32_to_u32(self):
def body(x_ref, o_ref):
# Test cast from float32 -> uint32
o_ref[...] = lax.convert_element_type(x_ref[...], jnp.uint32)
out = jax.ShapeDtypeStruct((8, 128), jnp.uint32)
x = np.arange(0, 2 ** 32, 2 ** 22).reshape(8, 128).astype('float32')
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(result, x.astype(jnp.uint32))

def test_tpu_u32_to_f32(self):
def body(x_ref, o_ref):
# Test cast from uint32 -> float32
o_ref[...] = lax.convert_element_type(x_ref[...], jnp.float32)
out = jax.ShapeDtypeStruct((8, 128), jnp.float32)
x = np.arange(0, 2 ** 32, 2 ** 22, dtype='uint32').reshape(8, 128)
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(result, x.astype(jnp.float32))

def test_tpu_signed_int_upcast(self):
if not jtu.is_device_tpu_at_least(version=5):
self.skipTest("TPUv5+ needed for integer matmuls")
Expand Down
Loading