Skip to content

Commit dcdc25b

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Use jtu helpers instead of get_sass
PiperOrigin-RevId: 753639002
1 parent 9107c63 commit dcdc25b

File tree

1 file changed

+4
-17
lines changed

1 file changed

+4
-17
lines changed

tests/mosaic/gpu_test.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import itertools
2121
import math
2222
import operator
23-
import os
2423
import re
2524
import unittest
2625

@@ -88,20 +87,6 @@ def mlir_sum(elems):
8887
return total
8988

9089

91-
@contextlib.contextmanager
92-
def get_sass():
93-
prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None)
94-
os.environ["MOSAIC_GPU_DUMP_SASS"] = "1"
95-
try:
96-
with jtu.capture_stdout() as output:
97-
yield output
98-
finally:
99-
if prev_dump is not None:
100-
os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump
101-
else:
102-
del os.environ["MOSAIC_GPU_DUMP_SASS"]
103-
104-
10590
def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None):
10691
index = ir.IndexType.get()
10792
thread_id = gpu.thread_id(gpu.Dimension.x)
@@ -2430,6 +2415,7 @@ def kernel(ctx, dst, _):
24302415
num_col_tiles=[1, 2, 3],
24312416
row_tiling=[8, 64],
24322417
)
2418+
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
24332419
def test_copy_tiled(self, dtype, swizzle, num_col_tiles, row_tiling):
24342420
mlir_dtype = utils.dtype_to_ir_type(dtype)
24352421
bw = bytewidth(mlir_dtype)
@@ -2455,7 +2441,7 @@ def kernel(ctx, in_, out, smems):
24552441
.transpose(0, 2, 1, 3)
24562442
)
24572443

2458-
with get_sass() as sass:
2444+
with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass:
24592445
iota = mgpu.as_gpu_kernel(
24602446
kernel, (1, 1, 1), (128, 1, 1), expected, expected,
24612447
[expected, expected, mgpu.TMABarrier()],
@@ -2554,6 +2540,7 @@ def kernel(ctx, in_, out, smems):
25542540
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5),
25552541
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2),
25562542
)
2543+
@jtu.thread_unsafe_test() # Modifies ``os.environ``.
25572544
def test_upcast_to_wgmma(
25582545
self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg
25592546
):
@@ -2597,7 +2584,7 @@ def tile(x, tiling):
25972584
f = mgpu.as_gpu_kernel(
25982585
kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()],
25992586
)
2600-
with get_sass() as sass:
2587+
with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass:
26012588
yt_kernel = f(xt)
26022589
np.testing.assert_array_equal(yt_kernel, yt)
26032590
self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg)

0 commit comments

Comments
 (0)