Skip to content

Commit 5c89a40

Browse files
zhuyuegongchensu
authored andcommitted
Add mul python interface and tests.
1 parent a5e20fc commit 5c89a40

File tree

8 files changed

+237
-0
lines changed

8 files changed

+237
-0
lines changed

include/infinicore/ops/mul.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class Mul {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor c, Tensor a, Tensor b);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor mul(Tensor a, Tensor b);
15+
void mul_(Tensor c, Tensor a, Tensor b);
16+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from infinicore.ops.attention import attention
3030
from infinicore.ops.causal_softmax import causal_softmax
3131
from infinicore.ops.matmul import matmul
32+
from infinicore.ops.mul import mul
3233
from infinicore.ops.rearrange import rearrange
3334
from infinicore.ops.rms_norm import rms_norm
3435
from infinicore.ops.silu import silu
@@ -76,6 +77,7 @@
7677
"attention",
7778
"causal_softmax",
7879
"matmul",
80+
"mul",
7981
"rearrange",
8082
"rms_norm",
8183
"silu",

python/infinicore/ops/mul.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def mul(input, other, *, out=None):
6+
if out is None:
7+
return Tensor(_infinicore.mul(input._underlying, other._underlying))
8+
9+
_infinicore.mul_(out._underlying, input._underlying, other._underlying)

src/infinicore/ops/mul/mul.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "infinicore/ops/mul.hpp"
2+
3+
namespace infinicore::op {
4+
5+
common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
6+
static common::OpDispatcher<Mul::schema> dispatcher_;
7+
return dispatcher_;
8+
};
9+
10+
void Mul::execute(Tensor c, Tensor a, Tensor b) {
11+
dispatcher().lookup(context::getDevice().getType())(c, a, b);
12+
}
13+
14+
Tensor mul(Tensor a, Tensor b) {
15+
auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
16+
mul_(c, a, b);
17+
return c;
18+
}
19+
20+
void mul_(Tensor c, Tensor a, Tensor b) {
21+
Mul::execute(c, a, b);
22+
}
23+
24+
} // namespace infinicore::op
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/common/cache.hpp"
4+
#include "infinicore/ops/mul.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::mul_impl::infiniop {
8+
9+
thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopMulDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
void calculate(Tensor c, Tensor a, Tensor b) {
19+
size_t seed = hash_combine(c, b, a);
20+
21+
auto device_type = context::getDevice().getType();
22+
auto device_index = context::getDevice().getIndex();
23+
24+
auto &cache = caches.getCache(device_type, device_index);
25+
26+
auto desc_opt = cache.get(seed);
27+
infiniopMulDescriptor_t desc = nullptr;
28+
29+
if (!desc_opt) {
30+
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
31+
context::getInfiniopHandle(), &desc,
32+
c->desc(), a->desc(), b->desc()));
33+
cache.put(seed, desc);
34+
} else {
35+
desc = *desc_opt;
36+
}
37+
38+
size_t workspace_size = 0;
39+
INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size));
40+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
41+
42+
INFINICORE_CHECK_ERROR(infiniopMul(
43+
desc, workspace->data(), workspace_size,
44+
c->data(), a->data(), b->data(), context::getStream()));
45+
}
46+
47+
static bool registered = []() {
48+
Mul::dispatcher().registerAll(&calculate, false);
49+
return true;
50+
}();
51+
52+
} // namespace infinicore::op::mul_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ops/attention.hpp"
77
#include "ops/causal_softmax.hpp"
88
#include "ops/matmul.hpp"
9+
#include "ops/mul.hpp"
910
#include "ops/rearrange.hpp"
1011
#include "ops/rms_norm.hpp"
1112
#include "ops/silu.hpp"
@@ -20,6 +21,7 @@ inline void bind(py::module &m) {
2021
bind_attention(m);
2122
bind_causal_softmax(m);
2223
bind_matmul(m);
24+
bind_mul(m);
2325
bind_rearrange(m);
2426
bind_rms_norm(m);
2527
bind_silu(m);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/mul.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_mul(py::module &m) {
12+
m.def("mul",
13+
&op::mul,
14+
py::arg("a"),
15+
py::arg("b"),
16+
R"doc(Element-wise multiplication of two tensors.)doc");
17+
18+
m.def("mul_",
19+
&op::mul_,
20+
py::arg("c"),
21+
py::arg("a"),
22+
py::arg("b"),
23+
R"doc(In-place element-wise tensor multiplication.)doc");
24+
}
25+
26+
} // namespace infinicore::ops

test/infinicore/ops/mul.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
11+
# ==============================================================================
12+
# Operator-specific configuration
13+
# ==============================================================================
14+
15+
# Test cases format: (operation_mode, shape, a_strides, b_strides, c_strides)
16+
_TEST_CASES_DATA = [
17+
(TestCase.BOTH, (13, 4), None, None, None),
18+
(TestCase.BOTH, (13, 4), (10, 1), (10, 1), (10, 1)),
19+
(TestCase.BOTH, (13, 4), (0, 1), None, None),
20+
(TestCase.BOTH, (13, 4, 4), None, None, None),
21+
(TestCase.BOTH, (13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
22+
(TestCase.BOTH, (13, 4, 4), (4, 0, 1), (0, 4, 1), None),
23+
(TestCase.BOTH, (16, 5632), None, None, None),
24+
(TestCase.BOTH, (16, 5632), (13312, 1), (13312, 1), (13312, 1)),
25+
]
26+
27+
28+
def parse_test_cases(data):
29+
"""
30+
Parse mul test case data according to format:
31+
(operation_mode, shape, a_strides, b_strides, c_strides)
32+
"""
33+
operation_mode = data[0]
34+
shape = data[1]
35+
a_strides = data[2] if len(data) > 2 else None
36+
b_strides = data[3] if len(data) > 3 else None
37+
c_strides = data[4] if len(data) > 4 else None
38+
39+
# Create input specifications
40+
inputs = []
41+
42+
# Input tensor a
43+
if a_strides is not None:
44+
inputs.append(TensorSpec.from_strided_tensor(shape, a_strides))
45+
else:
46+
inputs.append(TensorSpec.from_tensor(shape))
47+
48+
# Input tensor b (same shape as a)
49+
if b_strides is not None:
50+
inputs.append(TensorSpec.from_strided_tensor(shape, b_strides))
51+
else:
52+
inputs.append(TensorSpec.from_tensor(shape))
53+
54+
# Output tensor
55+
if c_strides is not None:
56+
output = TensorSpec.from_strided_tensor(shape, c_strides)
57+
else:
58+
output = TensorSpec.from_tensor(shape)
59+
60+
return TestCase(operation_mode, inputs, output)
61+
62+
63+
# Parse test cases
64+
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
65+
66+
# Data types
67+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
68+
69+
# Tolerance
70+
_TOLERANCE_MAP = {
71+
infinicore.float16: {"atol": 0, "rtol": 1e-2},
72+
infinicore.float32: {"atol": 0, "rtol": 1e-3},
73+
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
74+
}
75+
76+
77+
class OpTest(BaseOperatorTest):
78+
"""Mul test with simplified test case parsing"""
79+
80+
def __init__(self):
81+
super().__init__("Mul")
82+
83+
def get_test_cases(self):
84+
return _TEST_CASES
85+
86+
def get_tensor_dtypes(self):
87+
return _TENSOR_DTYPES
88+
89+
def get_tolerance_map(self):
90+
return _TOLERANCE_MAP
91+
92+
def torch_operator(self, a, b, out=None, **kwargs):
93+
return torch.mul(a, b, out=out)
94+
95+
def infinicore_operator(self, a, b, out=None, **kwargs):
96+
return infinicore.mul(a, b, out=out)
97+
98+
99+
def main():
100+
"""Main entry point"""
101+
runner = GenericTestRunner(OpTest)
102+
runner.run_and_exit()
103+
104+
105+
if __name__ == "__main__":
106+
main()

0 commit comments

Comments
 (0)