@@ -1987,7 +1987,7 @@ def test_scalar_load_upcast(self, in_dtype):
1987
1987
if not jtu .if_cloud_tpu_at_least (2025 , 4 , 25 ):
1988
1988
self .skipTest ("Needs a newer libTPU" )
1989
1989
if in_dtype == jnp .int4 and not jtu .is_device_tpu_at_least (4 ):
1990
- self .skipTest ("Triggers an XLA bug" )
1990
+ self .skipTest ("Triggers an XLA bug" ) # TODO(b/413602952)
1991
1991
def kernel (x_ref , o_ref ):
1992
1992
o_ref [0 , 0 ] = x_ref [0 , 0 ].astype (o_ref .dtype )
1993
1993
x = jnp .asarray ([[- 1 ]], dtype = in_dtype )
@@ -1999,6 +1999,23 @@ def kernel(x_ref, o_ref):
1999
1999
)(x )
2000
2000
self .assertEqual (y , x .astype (jnp .int32 ))
2001
2001
2002
+ @parameterized .product (in_dtype = [jnp .int4 , jnp .int8 , jnp .int16 , jnp .int32 ])
2003
+ def test_scalar_indirect_load (self , in_dtype ):
2004
+ if not jtu .if_cloud_tpu_at_least (2025 , 4 , 27 ):
2005
+ self .skipTest ("Needs a newer libTPU" )
2006
+ def kernel (x_ref , o_ref ):
2007
+ o_ref [0 , 0 ] = x_ref [0 , x_ref [0 , 0 ].astype (jnp .int32 )].astype (o_ref .dtype )
2008
+ if in_dtype == jnp .int4 and not jtu .is_device_tpu_at_least (4 ):
2009
+ self .skipTest ("Triggers an XLA bug" ) # TODO(b/413602952)
2010
+ x = jnp .asarray ([[3 , 0 , 0 , 1 ]], dtype = in_dtype )
2011
+ y = pl .pallas_call (
2012
+ kernel ,
2013
+ in_specs = [pl .BlockSpec (memory_space = pltpu .SMEM )],
2014
+ out_specs = pl .BlockSpec (memory_space = pltpu .SMEM ),
2015
+ out_shape = jax .ShapeDtypeStruct ((1 , 1 ), jnp .int32 ),
2016
+ )(x )
2017
+ self .assertEqual (y , x [0 , x [0 , 0 ]].astype (jnp .int32 )[None , None ])
2018
+
2002
2019
def test_masked_store (self ):
2003
2020
shape = (16 , 256 )
2004
2021
mask_shape = (10 , 130 )
0 commit comments