Skip to content

Commit 0bd8f2c

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Added debug_assert
The XLA GPU runtime does not yet handle device assertions well and will hang if the assert is triggered. However, the assertion output still appears in stderr, so I think having `debug_assert` is still useful. PiperOrigin-RevId: 740717697
1 parent 48001a2 commit 0bd8f2c

File tree

7 files changed

+47
-7
lines changed

7 files changed

+47
-7
lines changed

jax/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ py_library_providing_imports_info(
866866
"//jax/_src/lib",
867867
"//jaxlib/mlir:arithmetic_dialect",
868868
"//jaxlib/mlir:builtin_dialect",
869+
"//jaxlib/mlir:control_flow_dialect",
869870
"//jaxlib/mlir:func_dialect",
870871
"//jaxlib/mlir:gpu_dialect",
871872
"//jaxlib/mlir:ir",

jax/_src/lib/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ py_library_providing_imports_info(
5252
"//jaxlib/mlir:arithmetic_dialect",
5353
"//jaxlib/mlir:builtin_dialect",
5454
"//jaxlib/mlir:chlo_dialect",
55+
"//jaxlib/mlir:control_flow_dialect",
5556
"//jaxlib/mlir:func_dialect",
5657
"//jaxlib/mlir:ir",
5758
"//jaxlib/mlir:math_dialect",

jax/_src/lib/mlir/dialects/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,9 @@
5555

5656
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
5757
from jaxlib.mlir.dialects import stablehlo as hlo
58+
59+
from jax._src import lib
60+
if lib.version >= (0, 6, 1):
61+
from jaxlib.mlir.dialects import cf
62+
else:
63+
cf = None # type: ignore[no-redef]

jaxlib/mlir/_mlir_libs/register_jax_dialects.cc

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <nanobind/nanobind.h>
1919

2020
#include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep
21+
#include "mlir-c/Dialect/ControlFlow.h"
2122
#include "mlir-c/Dialect/Func.h" // IWYU pragma: keep
2223
#include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep
2324
#include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep
@@ -50,6 +51,7 @@ NB_MODULE(register_jax_dialects, m) {
5051
REGISTER_DIALECT(scf);
5152
REGISTER_DIALECT(vector);
5253
// For Mosaic GPU
54+
REGISTER_DIALECT(cf);
5355
REGISTER_DIALECT(gpu);
5456
REGISTER_DIALECT(nvgpu);
5557
REGISTER_DIALECT(nvvm);

jaxlib/mosaic/gpu/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ cc_library(
155155
"@llvm-project//mlir:ArithTransforms",
156156
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
157157
"@llvm-project//mlir:ComplexToLLVM",
158+
"@llvm-project//mlir:ControlFlowDialect",
158159
"@llvm-project//mlir:ControlFlowToLLVM",
159160
"@llvm-project//mlir:ConversionPasses",
160161
"@llvm-project//mlir:ExecutionEngine",

jaxlib/mosaic/gpu/custom_call.cc

+7-6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ limitations under the License.
6060
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
6161
#include "mlir/Dialect/Arith/IR/Arith.h"
6262
#include "mlir/Dialect/Arith/Transforms/Passes.h"
63+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
6364
#include "mlir/Dialect/Func/IR/FuncOps.h"
6465
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
6566
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -228,12 +229,12 @@ mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes,
228229

229230
void InitContext(mlir::MLIRContext* context) {
230231
mlir::DialectRegistry registry;
231-
registry.insert<mlir::arith::ArithDialect, mlir::func::FuncDialect,
232-
mlir::math::MathDialect, mlir::memref::MemRefDialect,
233-
mlir::scf::SCFDialect, mlir::vector::VectorDialect,
234-
mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect,
235-
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect,
236-
mosaic_gpu::MosaicGPUDialect>();
232+
registry.insert<mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
233+
mlir::func::FuncDialect, mlir::math::MathDialect,
234+
mlir::memref::MemRefDialect, mlir::scf::SCFDialect,
235+
mlir::vector::VectorDialect, mlir::gpu::GPUDialect,
236+
mlir::nvgpu::NVGPUDialect, mlir::NVVM::NVVMDialect,
237+
mlir::LLVM::LLVMDialect, mosaic_gpu::MosaicGPUDialect>();
237238
mlir::registerConvertNVVMToLLVMInterface(registry);
238239
mlir::registerConvertComplexToLLVMInterface(registry);
239240
mlir::registerConvertMemRefToLLVMInterface(registry);

tests/mosaic/gpu_test.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from jax._src.lib.mlir import ir
3333
from jax._src.lib.mlir import passmanager
3434
from jax._src.lib.mlir.dialects import arith
35+
from jax._src.lib.mlir.dialects import cf
3536
from jax._src.lib.mlir.dialects import scf
3637
from jax._src.lib.mlir.dialects import vector
3738
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
@@ -237,7 +238,6 @@ def capture_stdout(self):
237238
mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices()
238239

239240

240-
241241
class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest):
242242

243243
def setUp(self):
@@ -3320,6 +3320,34 @@ def test_parse_indices_oob(self, indices):
33203320
with self.assertRaisesRegex(IndexError, "out of bounds"):
33213321
utils.parse_indices(indices, (2, 3, 4))
33223322

3323+
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
3324+
def test_assert(self):
3325+
if cf is None:
3326+
self.skipTest("``cf`` is not available")
3327+
3328+
def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None:
3329+
del ctx, out # Unused.
3330+
# TODO(b/408271232): Use a False condition once the bug is fixed.
3331+
x = mgpu.FragmentedArray.load_strided(x_ref)
3332+
cond = x.reduce_sum(*scratch) != 42.0
3333+
cf.assert_(cond.registers.item(), "OOOPS")
3334+
3335+
f = mgpu.as_gpu_kernel(
3336+
kernel,
3337+
grid=(1, 1, 1),
3338+
block=(128, 1, 1),
3339+
in_shape=(jax.ShapeDtypeStruct((128,), jnp.float32),),
3340+
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
3341+
smem_scratch_shape=(jax.ShapeDtypeStruct((4,), jnp.float32),),
3342+
)
3343+
3344+
with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass:
3345+
f(jnp.ones((128,), jnp.float32))
3346+
3347+
# SASS doesn't seem to include the assertion message, so we are just
3348+
# checking that __assertfail appears in the symbol table for the kernel.
3349+
self.assertIn("__assertfail", sass())
3350+
33233351

33243352
class SerializationTest(absltest.TestCase):
33253353

0 commit comments

Comments
 (0)