99
1010import helion
1111from helion import exc
12+ from helion ._testing import DEVICE
1213from helion ._testing import TestCase
1314from helion ._testing import assert_ref_eager_mode
1415import helion .language as hl
@@ -32,8 +33,8 @@ def print_intermediate_tensor_kernel(
3233 out [tile_m , tile_n ] = sum_val
3334 return out
3435
35- x = torch .ones ([2 , 2 ], device = "cuda" , dtype = torch .float32 ) * 10.0
36- y = torch .ones ([2 , 2 ], device = "cuda" , dtype = torch .float32 ) * 5.0
36+ x = torch .ones ([2 , 2 ], device = DEVICE , dtype = torch .float32 ) * 10.0
37+ y = torch .ones ([2 , 2 ], device = DEVICE , dtype = torch .float32 ) * 5.0
3738 expected = x + y
3839
3940 # Capture stdout to check print output
@@ -67,7 +68,7 @@ def incorrect_kernel(x: torch.Tensor) -> torch.Tensor:
6768 pass # noqa: PIE790
6869 return x
6970
70- x = torch .ones ([2 , 2 ], device = "cuda" , dtype = torch .float32 ) * math .pi
71+ x = torch .ones ([2 , 2 ], device = DEVICE , dtype = torch .float32 ) * math .pi
7172
7273 # Capture stdout to check print output
7374 captured_output = io .StringIO ()
@@ -89,7 +90,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
8990 return out
9091
9192 with assert_ref_eager_mode ():
92- x = torch .randn (128 , 128 , device = "cuda" )
93+ x = torch .randn (128 , 128 , device = DEVICE )
9394 result = kernel (x )
9495 expected = x * 2.0
9596 torch .testing .assert_close (result , expected )
@@ -107,7 +108,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
107108 # Run the kernel to capture the warning message
108109 captured_stderr = io .StringIO ()
109110 with contextlib .redirect_stderr (captured_stderr ):
110- x = torch .randn (128 , 128 , device = "cuda" )
111+ x = torch .randn (128 , 128 , device = DEVICE )
111112 kernel (x )
112113
113114 stderr_output = captured_stderr .getvalue ()
0 commit comments