Skip to content

Commit cf86ec6

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Added support for using cf.assert in Mosaic GPU kernels
The XLA GPU runtime does not yet handle device assertions well and will hang if the assert is triggered. However, the assertion output still appears in stderr, so I think having `cf.assert` support is still useful. PiperOrigin-RevId: 740717697
1 parent 7d395c2 commit cf86ec6

File tree

11 files changed

+76
-19
lines changed

11 files changed

+76
-19
lines changed

jax/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ py_library_providing_imports_info(
866866
"//jax/_src/lib",
867867
"//jaxlib/mlir:arithmetic_dialect",
868868
"//jaxlib/mlir:builtin_dialect",
869+
"//jaxlib/mlir:control_flow_dialect",
869870
"//jaxlib/mlir:func_dialect",
870871
"//jaxlib/mlir:gpu_dialect",
871872
"//jaxlib/mlir:ir",

jax/_src/lib/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ py_library_providing_imports_info(
5252
"//jaxlib/mlir:arithmetic_dialect",
5353
"//jaxlib/mlir:builtin_dialect",
5454
"//jaxlib/mlir:chlo_dialect",
55+
"//jaxlib/mlir:control_flow_dialect",
5556
"//jaxlib/mlir:func_dialect",
5657
"//jaxlib/mlir:ir",
5758
"//jaxlib/mlir:math_dialect",

jax/_src/lib/mlir/dialects/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,9 @@
5555

5656
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
5757
from jaxlib.mlir.dialects import stablehlo as hlo
58+
59+
from jax._src import lib
60+
if lib.version >= (0, 6, 1):
61+
from jaxlib.mlir.dialects import cf
62+
else:
63+
cf = None # type: ignore[no-redef]

jaxlib/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ py_library_providing_imports_info(
8282
"//jaxlib/mlir:arithmetic_dialect",
8383
"//jaxlib/mlir:builtin_dialect",
8484
"//jaxlib/mlir:chlo_dialect",
85+
"//jaxlib/mlir:control_flow_dialect",
8586
"//jaxlib/mlir:func_dialect",
8687
"//jaxlib/mlir:gpu_dialect",
8788
"//jaxlib/mlir:ir",

jaxlib/mlir/BUILD.bazel

+24-11
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ symlink_inputs(
6565
name = "func_dialect",
6666
rule = py_library,
6767
symlinked_inputs = {"srcs": {
68-
"dialects": ["@llvm-project//mlir/python:FuncPyFiles"],
68+
"dialects": ["@llvm-project//mlir/python:FuncPyFiles"],
6969
}},
7070
deps = [
7171
":core",
@@ -78,7 +78,7 @@ symlink_inputs(
7878
name = "vector_dialect",
7979
rule = py_library,
8080
symlinked_inputs = {"srcs": {
81-
"dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"],
81+
"dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"],
8282
}},
8383
deps = [
8484
":core",
@@ -91,7 +91,7 @@ symlink_inputs(
9191
name = "math_dialect",
9292
rule = py_library,
9393
symlinked_inputs = {"srcs": {
94-
"dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"],
94+
"dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"],
9595
}},
9696
deps = [
9797
":core",
@@ -104,7 +104,7 @@ symlink_inputs(
104104
name = "arithmetic_dialect",
105105
rule = py_library,
106106
symlinked_inputs = {"srcs": {
107-
"dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"],
107+
"dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"],
108108
}},
109109
deps = [
110110
":core",
@@ -117,7 +117,20 @@ symlink_inputs(
117117
name = "memref_dialect",
118118
rule = py_library,
119119
symlinked_inputs = {"srcs": {
120-
"dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"],
120+
"dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"],
121+
}},
122+
deps = [
123+
":core",
124+
":ir",
125+
":mlir",
126+
],
127+
)
128+
129+
symlink_inputs(
130+
name = "control_flow_dialect",
131+
rule = py_library,
132+
symlinked_inputs = {"srcs": {
133+
"dialects": ["@llvm-project//mlir/python:ControlFlowOpsPyFiles"],
121134
}},
122135
deps = [
123136
":core",
@@ -130,7 +143,7 @@ symlink_inputs(
130143
name = "scf_dialect",
131144
rule = py_library,
132145
symlinked_inputs = {"srcs": {
133-
"dialects": ["@llvm-project//mlir/python:SCFPyFiles"],
146+
"dialects": ["@llvm-project//mlir/python:SCFPyFiles"],
134147
}},
135148
deps = [
136149
":core",
@@ -143,7 +156,7 @@ symlink_inputs(
143156
name = "builtin_dialect",
144157
rule = py_library,
145158
symlinked_inputs = {"srcs": {
146-
"dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"],
159+
"dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"],
147160
}},
148161
deps = [
149162
":core",
@@ -157,7 +170,7 @@ symlink_inputs(
157170
name = "chlo_dialect",
158171
rule = py_library,
159172
symlinked_inputs = {"srcs": {
160-
"dialects": ["@stablehlo//:chlo_ops_py_files"],
173+
"dialects": ["@stablehlo//:chlo_ops_py_files"],
161174
}},
162175
deps = [
163176
":core",
@@ -171,7 +184,7 @@ symlink_inputs(
171184
name = "sparse_tensor_dialect",
172185
rule = py_library,
173186
symlinked_inputs = {"srcs": {
174-
"dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"],
187+
"dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"],
175188
}},
176189
deps = [
177190
":core",
@@ -186,7 +199,7 @@ symlink_inputs(
186199
name = "mhlo_dialect",
187200
rule = py_library,
188201
symlinked_inputs = {"srcs": {
189-
"dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"],
202+
"dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"],
190203
}},
191204
deps = [
192205
":core",
@@ -228,7 +241,7 @@ symlink_inputs(
228241
name = "stablehlo_dialect",
229242
rule = py_library,
230243
symlinked_inputs = {"srcs": {
231-
"dialects": ["@stablehlo//:stablehlo_ops_py_files"],
244+
"dialects": ["@stablehlo//:stablehlo_ops_py_files"],
232245
}},
233246
deps = [
234247
":core",

jaxlib/mlir/_mlir_libs/BUILD.bazel

+2-1
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ nanobind_pywrap_extension(
208208
deps = [
209209
"//jaxlib/mosaic/gpu:mlir_capi",
210210
"@llvm-project//mlir:CAPIArith",
211+
"@llvm-project//mlir:CAPICF",
211212
"@llvm-project//mlir:CAPIGPU",
212213
"@llvm-project//mlir:CAPIIR",
213214
"@llvm-project//mlir:CAPILLVM",
@@ -297,4 +298,4 @@ nanobind_pywrap_extension(
297298
"@nanobind",
298299
"@stablehlo//:stablehlo_capi",
299300
],
300-
)
301+
)

jaxlib/mlir/_mlir_libs/register_jax_dialects.cc

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <nanobind/nanobind.h>
1919

2020
#include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep
21+
#include "mlir-c/Dialect/ControlFlow.h"
2122
#include "mlir-c/Dialect/Func.h" // IWYU pragma: keep
2223
#include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep
2324
#include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep
@@ -50,6 +51,7 @@ NB_MODULE(register_jax_dialects, m) {
5051
REGISTER_DIALECT(scf);
5152
REGISTER_DIALECT(vector);
5253
// For Mosaic GPU
54+
REGISTER_DIALECT(cf);
5355
REGISTER_DIALECT(gpu);
5456
REGISTER_DIALECT(nvgpu);
5557
REGISTER_DIALECT(nvvm);

jaxlib/mosaic/gpu/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ cc_library(
155155
"@llvm-project//mlir:ArithTransforms",
156156
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
157157
"@llvm-project//mlir:ComplexToLLVM",
158+
"@llvm-project//mlir:ControlFlowDialect",
158159
"@llvm-project//mlir:ControlFlowToLLVM",
159160
"@llvm-project//mlir:ConversionPasses",
160161
"@llvm-project//mlir:ExecutionEngine",

jaxlib/mosaic/gpu/custom_call.cc

+7-6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ limitations under the License.
6060
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
6161
#include "mlir/Dialect/Arith/IR/Arith.h"
6262
#include "mlir/Dialect/Arith/Transforms/Passes.h"
63+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
6364
#include "mlir/Dialect/Func/IR/FuncOps.h"
6465
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
6566
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -228,12 +229,12 @@ mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes,
228229

229230
void InitContext(mlir::MLIRContext* context) {
230231
mlir::DialectRegistry registry;
231-
registry.insert<mlir::arith::ArithDialect, mlir::func::FuncDialect,
232-
mlir::math::MathDialect, mlir::memref::MemRefDialect,
233-
mlir::scf::SCFDialect, mlir::vector::VectorDialect,
234-
mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect,
235-
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect,
236-
mosaic_gpu::MosaicGPUDialect>();
232+
registry.insert<mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
233+
mlir::func::FuncDialect, mlir::math::MathDialect,
234+
mlir::memref::MemRefDialect, mlir::scf::SCFDialect,
235+
mlir::vector::VectorDialect, mlir::gpu::GPUDialect,
236+
mlir::nvgpu::NVGPUDialect, mlir::NVVM::NVVMDialect,
237+
mlir::LLVM::LLVMDialect, mosaic_gpu::MosaicGPUDialect>();
237238
mlir::registerConvertNVVMToLLVMInterface(registry);
238239
mlir::registerConvertComplexToLLVMInterface(registry);
239240
mlir::registerConvertMemRefToLLVMInterface(registry);

jaxlib/tools/build_wheel.py

+2
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources):
272272
f"{source_file_prefix}jaxlib/mlir/dialects/_arith_enum_gen.py",
273273
f"{source_file_prefix}jaxlib/mlir/dialects/_arith_ops_gen.py",
274274
f"{source_file_prefix}jaxlib/mlir/dialects/_builtin_ops_gen.py",
275+
f"{source_file_prefix}jaxlib/mlir/dialects/_cf_ops_gen.py",
275276
f"{source_file_prefix}jaxlib/mlir/dialects/_chlo_ops_gen.py",
276277
f"{source_file_prefix}jaxlib/mlir/dialects/_func_ops_gen.py",
277278
f"{source_file_prefix}jaxlib/mlir/dialects/_math_ops_gen.py",
@@ -296,6 +297,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources):
296297
f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_ops_gen.py",
297298
f"{source_file_prefix}jaxlib/mlir/dialects/arith.py",
298299
f"{source_file_prefix}jaxlib/mlir/dialects/builtin.py",
300+
f"{source_file_prefix}jaxlib/mlir/dialects/cf.py",
299301
f"{source_file_prefix}jaxlib/mlir/dialects/chlo.py",
300302
f"{source_file_prefix}jaxlib/mlir/dialects/func.py",
301303
f"{source_file_prefix}jaxlib/mlir/dialects/math.py",

tests/mosaic/gpu_test.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from jax._src.lib.mlir import ir
3333
from jax._src.lib.mlir import passmanager
3434
from jax._src.lib.mlir.dialects import arith
35+
from jax._src.lib.mlir.dialects import cf
3536
from jax._src.lib.mlir.dialects import scf
3637
from jax._src.lib.mlir.dialects import vector
3738
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
@@ -237,7 +238,6 @@ def capture_stdout(self):
237238
mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices()
238239

239240

240-
241241
class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest):
242242

243243
def setUp(self):
@@ -3320,6 +3320,34 @@ def test_parse_indices_oob(self, indices):
33203320
with self.assertRaisesRegex(IndexError, "out of bounds"):
33213321
utils.parse_indices(indices, (2, 3, 4))
33223322

3323+
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
3324+
def test_assert(self):
3325+
if cf is None:
3326+
self.skipTest("``cf`` is not available")
3327+
3328+
def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None:
3329+
del ctx, out # Unused.
3330+
# TODO(b/408271232): Use a False condition once the bug is fixed.
3331+
x = mgpu.FragmentedArray.load_strided(x_ref)
3332+
cond = x.reduce_sum(*scratch) != 42.0
3333+
cf.assert_(cond.registers.item(), "OOOPS")
3334+
3335+
f = mgpu.as_gpu_kernel(
3336+
kernel,
3337+
grid=(1, 1, 1),
3338+
block=(128, 1, 1),
3339+
in_shape=(jax.ShapeDtypeStruct((128,), jnp.float32),),
3340+
out_shape=jax.ShapeDtypeStruct((128,), jnp.float32),
3341+
smem_scratch_shape=(jax.ShapeDtypeStruct((4,), jnp.float32),),
3342+
)
3343+
3344+
with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass:
3345+
f(jnp.ones((128,), jnp.float32))
3346+
3347+
# SASS doesn't seem to include the assertion message, so we are just
3348+
# checking that __assertfail appears in the symbol table for the kernel.
3349+
self.assertIn("__assertfail", sass())
3350+
33233351

33243352
class SerializationTest(absltest.TestCase):
33253353

0 commit comments

Comments
 (0)