20
20
import itertools
21
21
import math
22
22
import operator
23
- import os
24
23
import re
25
24
import unittest
26
25
@@ -88,20 +87,6 @@ def mlir_sum(elems):
88
87
return total
89
88
90
89
91
- @contextlib .contextmanager
92
- def get_sass ():
93
- prev_dump = os .environ .get ("MOSAIC_GPU_DUMP_SASS" , None )
94
- os .environ ["MOSAIC_GPU_DUMP_SASS" ] = "1"
95
- try :
96
- with jtu .capture_stdout () as output :
97
- yield output
98
- finally :
99
- if prev_dump is not None :
100
- os .environ ["MOSAIC_GPU_DUMP_SASS" ] = prev_dump
101
- else :
102
- del os .environ ["MOSAIC_GPU_DUMP_SASS" ]
103
-
104
-
105
90
def copy (src : ir .Value , dst : ir .Value , swizzle : int | None = None ):
106
91
index = ir .IndexType .get ()
107
92
thread_id = gpu .thread_id (gpu .Dimension .x )
@@ -2430,6 +2415,7 @@ def kernel(ctx, dst, _):
2430
2415
num_col_tiles = [1 , 2 , 3 ],
2431
2416
row_tiling = [8 , 64 ],
2432
2417
)
2418
+ @jtu .thread_unsafe_test () # Modifies ``os.environ``.
2433
2419
def test_copy_tiled (self , dtype , swizzle , num_col_tiles , row_tiling ):
2434
2420
mlir_dtype = utils .dtype_to_ir_type (dtype )
2435
2421
bw = bytewidth (mlir_dtype )
@@ -2455,7 +2441,7 @@ def kernel(ctx, in_, out, smems):
2455
2441
.transpose (0 , 2 , 1 , 3 )
2456
2442
)
2457
2443
2458
- with get_sass () as sass :
2444
+ with jtu . set_env ( MOSAIC_GPU_DUMP_SASS = "1" ), self . capture_stdout () as sass :
2459
2445
iota = mgpu .as_gpu_kernel (
2460
2446
kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), expected , expected ,
2461
2447
[expected , expected , mgpu .TMABarrier ()],
@@ -2554,6 +2540,7 @@ def kernel(ctx, in_, out, smems):
2554
2540
(fa .WGMMA_LAYOUT_UPCAST_2X , fa .WGMMA_LAYOUT , jnp .int4 , jnp .int4 , 0.5 ),
2555
2541
(fa .WGMMA_LAYOUT_UPCAST_4X , fa .WGMMA_LAYOUT , jnp .int4 , jnp .int4 , 2 ),
2556
2542
)
2543
+ @jtu .thread_unsafe_test () # Modifies ``os.environ``.
2557
2544
def test_upcast_to_wgmma (
2558
2545
self , start_layout , end_layout , in_dtype , cast_dtype , shfl_per_reg
2559
2546
):
@@ -2597,7 +2584,7 @@ def tile(x, tiling):
2597
2584
f = mgpu .as_gpu_kernel (
2598
2585
kernel , (1 , 1 , 1 ), (128 , 1 , 1 ), xt , yt , [xt , yt , mgpu .TMABarrier ()],
2599
2586
)
2600
- with get_sass () as sass :
2587
+ with jtu . set_env ( MOSAIC_GPU_DUMP_SASS = "1" ), self . capture_stdout () as sass :
2601
2588
yt_kernel = f (xt )
2602
2589
np .testing .assert_array_equal (yt_kernel , yt )
2603
2590
self .assertEqual (sass ().count ("SHFL.BFLY" ), regs_per_thread * shfl_per_reg )
0 commit comments