Skip to content

Commit 2c85a25

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Allow indexing refs with narrow integers
PiperOrigin-RevId: 751362838
1 parent d8238ec commit 2c85a25

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

tests/pallas/tpu_pallas_test.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1987,7 +1987,7 @@ def test_scalar_load_upcast(self, in_dtype):
19871987
if not jtu.if_cloud_tpu_at_least(2025, 4, 25):
19881988
self.skipTest("Needs a newer libTPU")
19891989
if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4):
1990-
self.skipTest("Triggers an XLA bug")
1990+
self.skipTest("Triggers an XLA bug") # TODO(b/413602952)
19911991
def kernel(x_ref, o_ref):
19921992
o_ref[0, 0] = x_ref[0, 0].astype(o_ref.dtype)
19931993
x = jnp.asarray([[-1]], dtype=in_dtype)
@@ -1999,6 +1999,23 @@ def kernel(x_ref, o_ref):
19991999
)(x)
20002000
self.assertEqual(y, x.astype(jnp.int32))
20012001

2002+
@parameterized.product(in_dtype=[jnp.int4, jnp.int8, jnp.int16, jnp.int32])
2003+
def test_scalar_indirect_load(self, in_dtype):
2004+
if not jtu.if_cloud_tpu_at_least(2025, 4, 27):
2005+
self.skipTest("Needs a newer libTPU")
2006+
def kernel(x_ref, o_ref):
2007+
o_ref[0, 0] = x_ref[0, x_ref[0, 0].astype(jnp.int32)].astype(o_ref.dtype)
2008+
if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4):
2009+
self.skipTest("Triggers an XLA bug") # TODO(b/413602952)
2010+
x = jnp.asarray([[3, 0, 0, 1]], dtype=in_dtype)
2011+
y = pl.pallas_call(
2012+
kernel,
2013+
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
2014+
out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),
2015+
out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32),
2016+
)(x)
2017+
self.assertEqual(y, x[0, x[0, 0]].astype(jnp.int32)[None, None])
2018+
20022019
def test_masked_store(self):
20032020
shape = (16, 256)
20042021
mask_shape = (10, 130)

0 commit comments

Comments
 (0)