Skip to content

Commit 943b38e

Browse files
authored
Merge pull request #535 from gongchensu/feature/add_gemm_python_api
Add gemm operator python interface and tests.
2 parents 89b42a8 + 001141c commit 943b38e

File tree

9 files changed

+245
-29
lines changed

9 files changed

+245
-29
lines changed

include/infinicore/ops/gemm.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
8+
class Gemm {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
11+
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
16+
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
17+
18+
} // namespace infinicore::op

include/infinicore/ops/matmul.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,8 @@
44
#include "common/op.hpp"
55

66
namespace infinicore::op {
7-
class Matmul {
8-
public:
9-
using schema = void (*)(Tensor, Tensor, Tensor);
10-
static void execute(Tensor c, Tensor a, Tensor b);
11-
static common::OpDispatcher<schema> &dispatcher();
12-
};
137

148
Tensor matmul(Tensor a, Tensor b);
159
void matmul_(Tensor c, Tensor a, Tensor b);
10+
1611
} // namespace infinicore::op

python/infinicore/ops/gemm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def gemm(input, other, alpha=1.0, beta=0.0, *, out=None):
6+
if out is None:
7+
return Tensor(_infinicore.gemm(input._underlying, other._underlying, alpha, beta))
8+
9+
_infinicore.gemm_(out._underlying, input._underlying, other._underlying, alpha, beta)

src/infinicore/ops/gemm/gemm.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "infinicore/ops/gemm.hpp"
2+
3+
namespace infinicore::op {
4+
5+
common::OpDispatcher<Gemm::schema> &Gemm::dispatcher() {
6+
static common::OpDispatcher<Gemm::schema> dispatcher_;
7+
return dispatcher_;
8+
};
9+
10+
void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
11+
dispatcher().lookup(context::getDevice().getType())(c, a, b, alpha, beta);
12+
}
13+
14+
Tensor gemm(Tensor a, Tensor b, float alpha, float beta) {
15+
Shape shape = a->shape();
16+
Size size = a->ndim();
17+
shape[size - 1] = b->size(size - 1);
18+
auto c = Tensor::empty(shape, a->dtype(), a->device());
19+
gemm_(c, a, b, alpha, beta);
20+
return c;
21+
}
22+
23+
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
24+
Gemm::execute(c, a, b, alpha, beta);
25+
}
26+
27+
} // namespace infinicore::op

src/infinicore/ops/matmul/matmul_infiniop.cc renamed to src/infinicore/ops/gemm/gemm_infiniop.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#include "../../utils.hpp"
22
#include "infinicore/common/hash.hpp"
33
#include "infinicore/ops/common/cache.hpp"
4-
#include "infinicore/ops/matmul.hpp"
4+
#include "infinicore/ops/gemm.hpp"
55
#include <infiniop.h>
66

7-
namespace infinicore::op::matmul_impl::infiniop {
7+
namespace infinicore::op::gemm_impl::infiniop {
88

99
thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
1010
100, // capacity
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
1515
}
1616
});
1717

18-
void calculate(Tensor c, Tensor a, Tensor b) {
19-
size_t seed = hash_combine(c, b, a);
18+
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
19+
size_t seed = hash_combine(c, b, a, alpha, beta);
2020

2121
auto device_type = context::getDevice().getType();
2222
auto device_index = context::getDevice().getIndex();
@@ -41,12 +41,12 @@ void calculate(Tensor c, Tensor a, Tensor b) {
4141

4242
INFINICORE_CHECK_ERROR(infiniopGemm(
4343
desc, workspace->data(), workspace_size,
44-
c->data(), a->data(), b->data(), 1.f, 0.f, context::getStream()));
44+
c->data(), a->data(), b->data(), alpha, beta, context::getStream()));
4545
}
4646

4747
static bool registered = []() {
48-
Matmul::dispatcher().registerAll(&calculate, false);
48+
Gemm::dispatcher().registerAll(&calculate, false);
4949
return true;
5050
}();
5151

52-
} // namespace infinicore::op::matmul_impl::infiniop
52+
} // namespace infinicore::op::gemm_impl::infiniop
Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,13 @@
11
#include "infinicore/ops/matmul.hpp"
2+
#include "infinicore/ops/gemm.hpp"
23

34
namespace infinicore::op {
45

5-
common::OpDispatcher<Matmul::schema> &Matmul::dispatcher() {
6-
static common::OpDispatcher<Matmul::schema> dispatcher_;
7-
return dispatcher_;
8-
};
9-
10-
void Matmul::execute(Tensor c, Tensor a, Tensor b) {
11-
dispatcher().lookup(context::getDevice().getType())(c, a, b);
12-
}
13-
146
Tensor matmul(Tensor a, Tensor b) {
15-
Shape shape = a->shape();
16-
Size size = a->ndim();
17-
shape[size - 1] = b->size(size - 1);
18-
auto c = Tensor::empty(shape, a->dtype(), a->device());
19-
matmul_(c, a, b);
20-
return c;
7+
return gemm(a, b, 1.0f, 0.0f);
218
}
229

2310
void matmul_(Tensor c, Tensor a, Tensor b) {
24-
Matmul::execute(c, a, b);
11+
Gemm::execute(c, a, b, 1.0f, 0.0f);
2512
}
2613
} // namespace infinicore::op

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "ops/add.hpp"
66
#include "ops/attention.hpp"
77
#include "ops/causal_softmax.hpp"
8+
#include "ops/gemm.hpp"
89
#include "ops/matmul.hpp"
910
#include "ops/rearrange.hpp"
1011
#include "ops/rms_norm.hpp"
@@ -19,6 +20,7 @@ inline void bind(py::module &m) {
1920
bind_add(m);
2021
bind_attention(m);
2122
bind_causal_softmax(m);
23+
bind_gemm(m);
2224
bind_matmul(m);
2325
bind_rearrange(m);
2426
bind_rms_norm(m);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/gemm.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_gemm(py::module &m) {
12+
m.def("gemm",
13+
&op::gemm,
14+
py::arg("a"),
15+
py::arg("b"),
16+
py::arg("alpha") = 1.0f,
17+
py::arg("beta") = 0.0f,
18+
R"doc(General matrix multiplication: C = alpha * A @ B + beta * C.)doc");
19+
20+
m.def("gemm_",
21+
&op::gemm_,
22+
py::arg("c"),
23+
py::arg("a"),
24+
py::arg("b"),
25+
py::arg("alpha"),
26+
py::arg("beta"),
27+
R"doc(In-place general matrix multiplication.)doc");
28+
}
29+
30+
} // namespace infinicore::ops

test/infinicore/ops/gemm.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from infinicore.ops.gemm import gemm as ic_gemm
9+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
10+
from framework.tensor import TensorInitializer
11+
from framework.runner import GenericTestRunner
12+
13+
# ==============================================================================
14+
# Operator-specific configuration
15+
# ==============================================================================
16+
17+
# Test cases format: (operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
18+
# If nbatch is None: a_shape=(m, k), b_shape=(k, n), c_shape=(m, n)
19+
# If nbatch is provided: a_shape=(nbatch, m, k), b_shape=(nbatch, k, n), c_shape=(nbatch, m, n)
20+
# Aligned with test/infiniop/gemm.py shapes/strides and per-case alpha/beta
21+
# Each item: (alpha, beta, operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
22+
_TEST_CASES_DATA = [
23+
# (1) alpha=1.0, beta=0.0, a=(1,2048), b=(2048,2048), c=(1,2048)
24+
(1.0, 0.0, TestCase.BOTH, None, 1, 2048, 2048, None, None, None),
25+
# (2) alpha=1.0, beta=0.0, a=(2,4,2048), b=(2,2048,2048), c=(2,4,2048)
26+
(1.0, 0.0, TestCase.BOTH, 2, 4, 2048, 2048, None, None, None),
27+
# (3) alpha=1.0, beta=0.0, strided (4096,1)
28+
(1.0, 0.0, TestCase.BOTH, None, 1, 2048, 2048, (4096, 1), (4096, 1), (4096, 1)),
29+
# (4) alpha=1.0, beta=1.0, only meaningful for IN_PLACE (needs existing C)
30+
(1.0, 1.0, TestCase.IN_PLACE, None, 6, 2560, 2048, (2048, 1), (1, 2048), (2560, 1)),
31+
# (5) alpha=1.0/8.0, beta=0.0, a=(4,48,64), b=(4,64,6), c=(4,48,6)
32+
(1.0 / 8.0, 0.0, TestCase.BOTH, 4, 48, 6, 64, None, None, None),
33+
]
34+
35+
36+
def parse_test_cases(data):
37+
"""
38+
Parse gemm test case data according to format:
39+
(operation_mode, nbatch, m, n, k, a_strides, b_strides, c_strides)
40+
"""
41+
alpha = data[0]
42+
beta = data[1]
43+
operation_mode = data[2]
44+
nbatch = data[3]
45+
m, n, k = data[4], data[5], data[6]
46+
a_strides = data[7] if len(data) > 7 else None
47+
b_strides = data[8] if len(data) > 8 else None
48+
c_strides = data[9] if len(data) > 9 else None
49+
50+
# Determine shapes based on batch dimension
51+
if nbatch is None:
52+
a_shape = (m, k)
53+
b_shape = (k, n)
54+
c_shape = (m, n)
55+
else:
56+
a_shape = (nbatch, m, k)
57+
b_shape = (nbatch, k, n)
58+
c_shape = (nbatch, m, n)
59+
60+
# Create input specifications
61+
inputs = []
62+
63+
# Tensor a
64+
if a_strides is not None:
65+
inputs.append(TensorSpec.from_strided_tensor(a_shape, a_strides))
66+
else:
67+
inputs.append(TensorSpec.from_tensor(a_shape))
68+
69+
# Tensor b
70+
if b_strides is not None:
71+
inputs.append(TensorSpec.from_strided_tensor(b_shape, b_strides))
72+
else:
73+
inputs.append(TensorSpec.from_tensor(b_shape))
74+
75+
# Output tensor
76+
if c_strides is not None:
77+
output = TensorSpec.from_strided_tensor(
78+
c_shape,
79+
c_strides,
80+
init_mode=TensorInitializer.ONES if beta != 0.0 else TensorInitializer.RANDOM,
81+
)
82+
else:
83+
output = TensorSpec.from_tensor(
84+
c_shape,
85+
init_mode=TensorInitializer.ONES if beta != 0.0 else TensorInitializer.RANDOM,
86+
)
87+
88+
return TestCase(operation_mode, inputs, output, alpha=alpha, beta=beta)
89+
90+
91+
# Parse test cases
92+
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
93+
94+
# Data types
95+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
96+
97+
# Tolerance
98+
_TOLERANCE_MAP = {
99+
infinicore.float16: {"atol": 0, "rtol": 1e-2},
100+
infinicore.float32: {"atol": 0, "rtol": 1e-3},
101+
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
102+
}
103+
104+
105+
class OpTest(BaseOperatorTest):
106+
"""GEMM test with simplified test case parsing
107+
108+
Note: We test default alpha=1.0 and beta=0.0 so it should match torch.matmul.
109+
"""
110+
111+
def __init__(self):
112+
super().__init__("Gemm")
113+
114+
def get_test_cases(self):
115+
return _TEST_CASES
116+
117+
def get_tensor_dtypes(self):
118+
return _TENSOR_DTYPES
119+
120+
def get_tolerance_map(self):
121+
return _TOLERANCE_MAP
122+
123+
def torch_operator(self, a, b, out=None, **kwargs):
124+
alpha = kwargs.get("alpha", 1.0)
125+
beta = kwargs.get("beta", 0.0)
126+
mm = torch.matmul(a, b)
127+
if out is None:
128+
return mm.mul(alpha)
129+
out.mul_(beta)
130+
out.add_(mm, alpha=alpha)
131+
return out
132+
133+
def infinicore_operator(self, a, b, out=None, **kwargs):
134+
alpha = kwargs.get("alpha", 1.0)
135+
beta = kwargs.get("beta", 0.0)
136+
if out is None:
137+
return ic_gemm(a, b, alpha=alpha, beta=beta)
138+
return ic_gemm(a, b, alpha=alpha, beta=beta, out=out)
139+
140+
141+
def main():
142+
"""Main entry point"""
143+
runner = GenericTestRunner(OpTest)
144+
runner.run_and_exit()
145+
146+
147+
if __name__ == "__main__":
148+
main()

0 commit comments

Comments
 (0)