Skip to content

Commit 9fa7bbf

Browse files
authored
Generalize the cuda-bias test cases by replacing hardcoded "cuda" literal with the DEVICE variable (#775)
1 parent 3322ca9 commit 9fa7bbf

File tree

7 files changed

+74
-69
lines changed

7 files changed

+74
-69
lines changed

test/test_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,7 @@ def test_kl_div(self):
11321132
),
11331133
)
11341134
torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=False).to(
1135-
"cuda"
1135+
device=DEVICE
11361136
)
11371137
self.assertExpectedJournal(
11381138
check_example(

test/test_misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_tile_begin(x: torch.Tensor) -> torch.Tensor:
259259
out[tile_m.begin, tile_n.begin] = 1
260260
return out
261261

262-
x = torch.randn(64, 64, device="cuda")
262+
x = torch.randn(64, 64, device=DEVICE)
263263
config = helion.Config(block_sizes=[16, 16])
264264
test_tile_begin.bind((x,)).to_triton_code(config)
265265
result = test_tile_begin.bind((x,)).compile_config(config)(x)
@@ -272,7 +272,7 @@ def test_tile_end(x: torch.Tensor) -> torch.Tensor:
272272
out[tile_m.end, tile_n.end] = 1
273273
return out
274274

275-
x = torch.randn(64, 64, device="cuda")
275+
x = torch.randn(64, 64, device=DEVICE)
276276
config = helion.Config(block_sizes=[16, 16])
277277
test_tile_end.bind((x,)).to_triton_code(config)
278278
result = test_tile_end.bind((x,)).compile_config(config)(x)
@@ -285,7 +285,7 @@ def test_tile_id(x: torch.Tensor) -> torch.Tensor:
285285
out[tile_m.id, tile_n.id] = 1
286286
return out
287287

288-
x = torch.randn(64, 64, device="cuda")
288+
x = torch.randn(64, 64, device=DEVICE)
289289
config = helion.Config(block_sizes=[16, 16])
290290
test_tile_id.bind((x,)).to_triton_code(config)
291291
result = test_tile_id.bind((x,)).compile_config(config)(x)

test/test_print_ref_eager_mode.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import helion
1111
from helion import exc
12+
from helion._testing import DEVICE
1213
from helion._testing import TestCase
1314
import helion.language as hl
1415

@@ -35,8 +36,8 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3536
out[tile] = x[tile] + y[tile]
3637
return out
3738

38-
x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
39-
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
39+
x = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
40+
y = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
4041
torch.testing.assert_close(add(x, y), torch.add(x, y))
4142

4243
def test_normal_mode_code_print(self):
@@ -61,8 +62,8 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6162
out[tile] = x[tile] + y[tile]
6263
return out
6364

64-
x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
65-
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
65+
x = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
66+
y = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
6667
torch.testing.assert_close(add(x, y), torch.add(x, y))
6768

6869
self.assertNotEqual(

test/test_ref_eager.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import helion
1111
from helion import exc
12+
from helion._testing import DEVICE
1213
from helion._testing import TestCase
1314
from helion._testing import assert_ref_eager_mode
1415
import helion.language as hl
@@ -32,8 +33,8 @@ def print_intermediate_tensor_kernel(
3233
out[tile_m, tile_n] = sum_val
3334
return out
3435

35-
x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 10.0
36-
y = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 5.0
36+
x = torch.ones([2, 2], device=DEVICE, dtype=torch.float32) * 10.0
37+
y = torch.ones([2, 2], device=DEVICE, dtype=torch.float32) * 5.0
3738
expected = x + y
3839

3940
# Capture stdout to check print output
@@ -67,7 +68,7 @@ def incorrect_kernel(x: torch.Tensor) -> torch.Tensor:
6768
pass # noqa: PIE790
6869
return x
6970

70-
x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * math.pi
71+
x = torch.ones([2, 2], device=DEVICE, dtype=torch.float32) * math.pi
7172

7273
# Capture stdout to check print output
7374
captured_output = io.StringIO()
@@ -89,7 +90,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
8990
return out
9091

9192
with assert_ref_eager_mode():
92-
x = torch.randn(128, 128, device="cuda")
93+
x = torch.randn(128, 128, device=DEVICE)
9394
result = kernel(x)
9495
expected = x * 2.0
9596
torch.testing.assert_close(result, expected)
@@ -107,7 +108,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
107108
# Run the kernel to capture the warning message
108109
captured_stderr = io.StringIO()
109110
with contextlib.redirect_stderr(captured_stderr):
110-
x = torch.randn(128, 128, device="cuda")
111+
x = torch.randn(128, 128, device=DEVICE)
111112
kernel(x)
112113

113114
stderr_output = captured_stderr.getvalue()

test/test_tensor_descriptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
259259
block_sizes=[16, 16, 16],
260260
indexing="tensor_descriptor",
261261
)
262-
torch.cuda.synchronize()
262+
torch.accelerator.synchronize()
263263
torch.testing.assert_close(result_large, expected, atol=1e-2, rtol=1e-2)
264264
self.assertIn(get_tensor_descriptor_fn_name(), code_large)
265265

0 commit comments

Comments
 (0)