Skip to content

Commit 7c599b3

Browse files
committed
check: add test cases to the scripts/python_test.py one-click test script
1 parent ba1d8c0 commit 7c599b3

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

scripts/python_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def run_tests(args):
2424
"rope.py",
2525
"sub.py",
2626
"swiglu.py",
27+
"silu.py",
28+
"div.py",
29+
"logical_and.py",
30+
"logical_or.py",
31+
"equal.py",
32+
"all_equal.py",
33+
"relu_backward.py",
34+
"gelu.py",
35+
"gelu_backward.py",
36+
"cross_entropy_loss_backward.py"
2737
]:
2838
result = subprocess.run(
2939
f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True

test/infiniop/gelu.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
infiniopOperatorDescriptor_t,
2020
profile_operation,
2121
test_operator,
22+
to_torch_dtype,
23+
torch_device_map,
2224
)
2325

2426
# ==============================================================================
@@ -101,8 +103,14 @@ def test(
101103
f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
102104
)
103105

104-
new_tensor = torch.nn.functional.gelu(input.torch_tensor(), approximate="tanh")
105-
output.update_torch_tensor(new_tensor)
106+
# ans的shape对齐至input,而input可能存在广播维度
107+
ans = torch.nn.functional.gelu(input.torch_tensor(), approximate="tanh")
108+
# 利用add(+)计算的自动广播机制,确保ouput的torch_tensor与actual_tensor shape一致,以通过debug模式的shape检查
109+
zero = torch.zeros(
110+
*shape, dtype=to_torch_dtype(dtype), device=torch_device_map[device]
111+
)
112+
new_output = ans + zero
113+
output.update_torch_tensor(new_output)
106114

107115
if sync is not None:
108116
sync()

test/infiniop/libinfiniop/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
)
111111
else:
112112
raise ValueError("Unsupported mode")
113-
113+
114114
if is_bool:
115115
self._torch_tensor = self._torch_tensor > 0.5
116116

@@ -367,6 +367,11 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
367367
actual = actual.to(torch.float32)
368368
desired = desired.to(torch.float32)
369369

370+
# 如果是BOOL,全部转成FP32再比对
371+
if actual.dtype == torch.bool or desired.dtype == torch.bool:
372+
actual = actual.to(torch.float32)
373+
desired = desired.to(torch.float32)
374+
370375
print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose)
371376
np.testing.assert_allclose(
372377
actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True

test/infiniop/silu.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
infiniopOperatorDescriptor_t,
2020
profile_operation,
2121
test_operator,
22+
to_torch_dtype,
23+
torch_device_map,
2224
)
2325

2426
# ==============================================================================
@@ -101,7 +103,13 @@ def test(
101103
f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
102104
)
103105

104-
new_output = torch.nn.functional.silu(input.torch_tensor())
106+
# ans的shape对齐至input,而input可能存在广播维度
107+
ans = torch.nn.functional.silu(input.torch_tensor())
108+
# 利用add(+)计算的自动广播机制,确保ouput的torch_tensor与actual_tensor shape一致,以通过debug模式的shape检查
109+
zero = torch.zeros(
110+
*shape, dtype=to_torch_dtype(dtype), device=torch_device_map[device]
111+
)
112+
new_output = ans + zero
105113
output.update_torch_tensor(new_output)
106114

107115
if sync is not None:

0 commit comments

Comments
 (0)