|
| 1 | +import sys |
| 2 | +import os |
| 3 | + |
| 4 | +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| 5 | + |
| 6 | +import torch |
| 7 | +import infinicore |
| 8 | +from framework.base import BaseOperatorTest, TensorSpec, TestCase |
| 9 | +from framework.runner import GenericTestRunner |
| 10 | +from framework.utils import is_broadcast |
| 11 | + |
| 12 | +# ============================================================================== |
| 13 | +# Operator-specific configuration |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +# Test cases format: (shape, a_strides, b_strides, c_strides) |
| 17 | +_TEST_CASES_DATA = [ |
| 18 | + ((13, 4), None, None, None), |
| 19 | + ((13, 4), (10, 1), (10, 1), (10, 1)), |
| 20 | + ((13, 4), (0, 1), None, None), |
| 21 | + ((13, 4, 4), None, None, None), |
| 22 | + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), |
| 23 | + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), |
| 24 | + ((16, 5632), None, None, None), |
| 25 | + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), |
| 26 | +] |
| 27 | + |
| 28 | +# Data types |
| 29 | +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] |
| 30 | + |
| 31 | +# Tolerance |
| 32 | +_TOLERANCE_MAP = { |
| 33 | + infinicore.float16: {"atol": 0, "rtol": 1e-2}, |
| 34 | + infinicore.float32: {"atol": 0, "rtol": 1e-3}, |
| 35 | + infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +def build_test_cases(): |
| 40 | + test_cases = [] |
| 41 | + |
| 42 | + for data in _TEST_CASES_DATA: |
| 43 | + shape = data[0] |
| 44 | + a_strides = data[1] if len(data) > 1 else None |
| 45 | + b_strides = data[2] if len(data) > 2 else None |
| 46 | + c_strides = data[3] if len(data) > 3 else None |
| 47 | + |
| 48 | + a_supports_inplace = not is_broadcast(a_strides) |
| 49 | + b_supports_inplace = not is_broadcast(b_strides) |
| 50 | + c_supports_inplace = not is_broadcast(c_strides) |
| 51 | + |
| 52 | + for dtype in _TENSOR_DTYPES: |
| 53 | + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) |
| 54 | + |
| 55 | + a_spec = TensorSpec.from_tensor(shape, a_strides, dtype) |
| 56 | + b_spec = TensorSpec.from_tensor(shape, b_strides, dtype) |
| 57 | + c_spec = TensorSpec.from_tensor(shape, c_strides, dtype) |
| 58 | + |
| 59 | + # Out-of-place (return value) |
| 60 | + test_cases.append( |
| 61 | + TestCase( |
| 62 | + inputs=[a_spec, b_spec], |
| 63 | + kwargs={}, |
| 64 | + output_spec=None, |
| 65 | + comparison_target=None, |
| 66 | + tolerance=tolerance, |
| 67 | + description=f"Mul - OUT_OF_PLACE (dtype={dtype})", |
| 68 | + ) |
| 69 | + ) |
| 70 | + |
| 71 | + # With explicit output tensor (mul(a, b, out=c)) |
| 72 | + if c_supports_inplace: |
| 73 | + test_cases.append( |
| 74 | + TestCase( |
| 75 | + inputs=[a_spec, b_spec], |
| 76 | + kwargs={}, |
| 77 | + output_spec=c_spec, |
| 78 | + comparison_target="out", |
| 79 | + tolerance=tolerance, |
| 80 | + description=f"Mul - INPLACE(out) (dtype={dtype})", |
| 81 | + ) |
| 82 | + ) |
| 83 | + |
| 84 | + # In-place on first input (mul(a, b, out=a)) |
| 85 | + if a_supports_inplace: |
| 86 | + test_cases.append( |
| 87 | + TestCase( |
| 88 | + inputs=[a_spec, b_spec], |
| 89 | + kwargs={"out": 0}, |
| 90 | + output_spec=None, |
| 91 | + comparison_target=0, |
| 92 | + tolerance=tolerance, |
| 93 | + description=f"Mul - INPLACE(a) (dtype={dtype})", |
| 94 | + ) |
| 95 | + ) |
| 96 | + |
| 97 | + # In-place on second input (mul(a, b, out=b)) |
| 98 | + if b_supports_inplace: |
| 99 | + test_cases.append( |
| 100 | + TestCase( |
| 101 | + inputs=[a_spec, b_spec], |
| 102 | + kwargs={"out": 1}, |
| 103 | + output_spec=None, |
| 104 | + comparison_target=1, |
| 105 | + tolerance=tolerance, |
| 106 | + description=f"Mul - INPLACE(b) (dtype={dtype})", |
| 107 | + ) |
| 108 | + ) |
| 109 | + |
| 110 | + return test_cases |
| 111 | + |
| 112 | + |
| 113 | +_TEST_CASES = build_test_cases() |
| 114 | + |
| 115 | + |
| 116 | +class OpTest(BaseOperatorTest): |
| 117 | + """Mul test with simplified test case parsing""" |
| 118 | + |
| 119 | + def __init__(self): |
| 120 | + super().__init__("Mul") |
| 121 | + |
| 122 | + def get_test_cases(self): |
| 123 | + return _TEST_CASES |
| 124 | + |
| 125 | + def torch_operator(self, a, b, out=None, **kwargs): |
| 126 | + return torch.mul(a, b, out=out) |
| 127 | + |
| 128 | + def infinicore_operator(self, a, b, out=None, **kwargs): |
| 129 | + try: |
| 130 | + return infinicore.mul(a, b, out=out) |
| 131 | + except AttributeError as exc: |
| 132 | + raise NotImplementedError("InfiniCore mul operator not available") from exc |
| 133 | + |
| 134 | + |
| 135 | +def main(): |
| 136 | + """Main entry point""" |
| 137 | + runner = GenericTestRunner(OpTest) |
| 138 | + runner.run_and_exit() |
| 139 | + |
| 140 | + |
| 141 | +if __name__ == "__main__": |
| 142 | + main() |
0 commit comments