|
32 | 32 | from jax._src.lib.mlir import ir
|
33 | 33 | from jax._src.lib.mlir import passmanager
|
34 | 34 | from jax._src.lib.mlir.dialects import arith
|
| 35 | +from jax._src.lib.mlir.dialects import cf |
35 | 36 | from jax._src.lib.mlir.dialects import scf
|
36 | 37 | from jax._src.lib.mlir.dialects import vector
|
37 | 38 | from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
|
@@ -237,7 +238,6 @@ def capture_stdout(self):
|
237 | 238 | mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices()
|
238 | 239 |
|
239 | 240 |
|
240 |
| - |
241 | 241 | class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest):
|
242 | 242 |
|
243 | 243 | def setUp(self):
|
@@ -3320,6 +3320,34 @@ def test_parse_indices_oob(self, indices):
|
3320 | 3320 | with self.assertRaisesRegex(IndexError, "out of bounds"):
|
3321 | 3321 | utils.parse_indices(indices, (2, 3, 4))
|
3322 | 3322 |
|
| 3323 | + @jtu.thread_unsafe_test() # Modifies ``os.environ``. |
| 3324 | + def test_assert(self): |
| 3325 | + if cf is None: |
| 3326 | + self.skipTest("``cf`` is not available") |
| 3327 | + |
| 3328 | + def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None: |
| 3329 | + del ctx, out # Unused. |
| 3330 | + # TODO(b/408271232): Use a False condition once the bug is fixed. |
| 3331 | + x = mgpu.FragmentedArray.load_strided(x_ref) |
| 3332 | + cond = x.reduce_sum(*scratch) != 42.0 |
| 3333 | + cf.assert_(cond.registers.item(), "OOOPS") |
| 3334 | + |
| 3335 | + f = mgpu.as_gpu_kernel( |
| 3336 | + kernel, |
| 3337 | + grid=(1, 1, 1), |
| 3338 | + block=(128, 1, 1), |
| 3339 | + in_shape=(jax.ShapeDtypeStruct((128,), jnp.float32),), |
| 3340 | + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), |
| 3341 | + smem_scratch_shape=(jax.ShapeDtypeStruct((4,), jnp.float32),), |
| 3342 | + ) |
| 3343 | + |
| 3344 | + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: |
| 3345 | + f(jnp.ones((128,), jnp.float32)) |
| 3346 | + |
| 3347 | + # SASS doesn't seem to include the assertion message, so we are just |
| 3348 | + # checking that __assertfail appears in the symbol table for the kernel. |
| 3349 | + self.assertIn("__assertfail", sass()) |
| 3350 | + |
3323 | 3351 |
|
3324 | 3352 | class SerializationTest(absltest.TestCase):
|
3325 | 3353 |
|
|
0 commit comments