Skip to content

Commit 4af37e9

Browse files
authored
Generalize test cases to support XPU (#983)
1 parent ad16d6f commit 4af37e9

File tree

5 files changed

+31
-19
lines changed

5 files changed

+31
-19
lines changed

test/test_constexpr.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from helion._testing import TestCase
1111
from helion._testing import code_and_output
1212
from helion._testing import skipIfRefEager
13-
from helion._testing import skipIfXPU
1413
import helion.language as hl
1514

1615

@@ -95,7 +94,6 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor:
9594
self.assertExpectedJournal(code)
9695

9796
@skipIfRefEager("Triton codegen does not work in ref eager mode")
98-
@skipIfXPU("Failed on XPU due to a different configuration for min dot size")
9997
def test_block_size_constexpr_assignment_in_host_code(self) -> None:
10098
@helion.kernel(
10199
config=helion.Config(

test/test_dot.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,19 @@ def test_hl_dot_codegen_acc_differs_uses_addition(self):
199199
self.assertIn("out_dtype=tl.float32", code)
200200

201201
# Test case 2: separate addition (acc_dtype = float16, common dtype = float32)
202-
input_dtype_2 = torch.float32
203-
acc_dtype_2 = torch.float16
204-
x2 = torch.randn(64, 64, device=DEVICE, dtype=input_dtype_2)
205-
y2 = torch.randn(64, 64, device=DEVICE, dtype=input_dtype_2)
206-
code2, out2 = code_and_output(dot_kernel_acc_arg, (x2, y2, acc_dtype_2))
207-
# Validate we use separate addition pattern with cast
208-
self.assertIn("tl.dot(", code2)
209-
# Check for the addition pattern: acc + result
210-
self.assertIn(" + ", code2)
211-
# Check that we cast the result to acc_dtype
212-
self.assertIn("tl.cast", code2)
202+
# TODO(Eikan): Support this case on XPU
203+
if not torch.xpu.is_available():
204+
input_dtype_2 = torch.float32
205+
acc_dtype_2 = torch.float16
206+
x2 = torch.randn(64, 64, device=DEVICE, dtype=input_dtype_2)
207+
y2 = torch.randn(64, 64, device=DEVICE, dtype=input_dtype_2)
208+
code2, out2 = code_and_output(dot_kernel_acc_arg, (x2, y2, acc_dtype_2))
209+
# Validate we use separate addition pattern with cast
210+
self.assertIn("tl.dot(", code2)
211+
# Check for the addition pattern: acc + result
212+
self.assertIn(" + ", code2)
213+
# Check that we cast the result to acc_dtype
214+
self.assertIn("tl.cast", code2)
213215

214216
# Test case 3: separate addition (acc_dtype = int32, common dtype = int8)
215217
input_dtype_3 = torch.int8
@@ -951,6 +953,17 @@ def test_matmul_reshape_n_2(self):
951953
REF_EAGER_TEST_FAILURES_FP8_E4M3FN_LOW_COMPUTE_CAP[test_name]
952954
)(_test_func)
953955

956+
# Apply skipIfXPU decorator if needed
957+
if acc_dtype is torch.float16 and input_dtype in (
958+
torch.float8_e4m3fn,
959+
torch.float8_e5m2,
960+
torch.bfloat16,
961+
torch.float32,
962+
):
963+
_test_func = skipIfXPU("skip: float6 accmulator for non-fp16 input data types")(
964+
_test_func
965+
)
966+
954967
# Additional ref eager skips for unsupported accumulator/input combos
955968
if acc_dtype is torch.float16 and input_dtype in (
956969
torch.bfloat16,

test/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def test_low_mem_dropout(self):
433433
)
434434

435435
@skipIfRocm("precision differences with bf16xint16 operations on rocm")
436+
@skipIfXPU("precision differences with bf16xint16 operations on xpu")
436437
def test_bf16xint16(self):
437438
from examples.bf16xint16_gemm import reference_bf16xint16_pytorch
438439

test/test_indexing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,23 +433,23 @@ def run_case(
433433
kernel = make_kernel(index_dtype=index_dtype)
434434
x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
435435
y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
436-
torch.cuda.synchronize()
436+
torch.accelerator.synchronize()
437437
if expect_error:
438438
with self.assertRaisesRegex(
439439
helion.exc.IndexOffsetOutOfRangeForInt32,
440440
f"index_dtype is {index_dtype}",
441441
):
442442
code_and_output(kernel, (x, y))
443-
torch.cuda.synchronize()
443+
torch.accelerator.synchronize()
444444
return
445445

446446
code, out = code_and_output(kernel, (x, y))
447-
torch.cuda.synchronize()
447+
torch.accelerator.synchronize()
448448
checker = self.assertIn if expect_int64_in_code else self.assertNotIn
449449
checker("tl.int64", code)
450-
torch.cuda.synchronize()
450+
torch.accelerator.synchronize()
451451
ref_out = torch.add(x, y)
452-
torch.cuda.synchronize()
452+
torch.accelerator.synchronize()
453453
torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2)
454454

455455
small_shape = (128, 128)

test/test_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def matmul_bf16_packed_int4(
243243
C = torch.zeros((M, N), dtype=torch.float32, device=DEVICE)
244244

245245
matmul_bf16_packed_int4(A, B_packed, C)
246-
torch.cuda.synchronize()
246+
torch.accelerator.synchronize()
247247

248248
self.assertTrue(torch.isfinite(C).all())
249249
self.assertFalse(torch.allclose(C, torch.zeros_like(C)))

0 commit comments

Comments
 (0)