Skip to content

Commit 8e07b04

Browse files
committed
issue/346: Implemented linear, fp8 linear, fp8 blockwise linear with cuBLASLt and fp8 group-wise quant
1 parent 80212cb commit 8e07b04

38 files changed

+2627
-236
lines changed

include/infinicore.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ typedef enum {
3434
INFINI_STATUS_BAD_TENSOR_SHAPE = 11,
3535
INFINI_STATUS_BAD_TENSOR_STRIDES = 12,
3636
INFINI_STATUS_INSUFFICIENT_WORKSPACE = 13,
37+
INFINI_STATUS_NOT_ALIGNED = 14,
3738
} infiniStatus_t;
3839

3940
typedef enum {
@@ -70,6 +71,9 @@ typedef enum {
7071
INFINI_DTYPE_C64 = 17,
7172
INFINI_DTYPE_C128 = 18,
7273
INFINI_DTYPE_BF16 = 19,
74+
INFINI_DTYPE_F8_E4M3 = 20,
75+
INFINI_DTYPE_F8_E5M2 = 21,
76+
INFINI_DTYPE_F8_UE8M0 = 22,
7377
} infiniDtype_t;
7478

7579
#endif // __INFINICORE_API_H__

include/infiniop/operator_descriptor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
// Base descriptor for all operators
88
struct InfiniopDescriptor;
99

10-
__C __export infiniStatus_t infiniopGetDescriptorDeviceType(const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type);
11-
__C __export infiniStatus_t infiniopGetDescriptorDeviceId(const struct InfiniopDescriptor *desc_ptr, int *device_id);
10+
__C __export infiniStatus_t infiniopGetDescriptorDeviceType(
11+
const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type);
12+
__C __export infiniStatus_t infiniopGetDescriptorDeviceId(
13+
const struct InfiniopDescriptor *desc_ptr, int *device_id);
1214

1315
#endif //__INFINIOP_OPERATOR_DESCRIPTOR_API_H__

include/infiniop/ops/linear.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef __INFINIOP_LINEAR_API_H__
2+
#define __INFINIOP_LINEAR_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopLinearDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateLinearDescriptor(
9+
infiniopHandle_t handle, infiniopLinearDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t d_desc, infiniopTensorDescriptor_t a_desc,
11+
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t c_desc);
12+
13+
__C __export infiniStatus_t
14+
infiniopGetLinearWorkspaceSize(infiniopLinearDescriptor_t desc, size_t *size);
15+
16+
__C __export infiniStatus_t infiniopLinear(
17+
infiniopLinearDescriptor_t desc, float alpha, const void *a,
18+
const void *a_scale, const void *b, const void *b_scale, float beta,
19+
const void *c, const void *c_scale, const void *bias, void *d,
20+
const void *d_scale, bool is_blockwise, bool is_a_1d_scaled,
21+
bool is_b_1d_scaled, void *workspace, size_t workspace_size, void *stream);
22+
23+
__C __export infiniStatus_t
24+
infiniopDestroyLinearDescriptor(infiniopLinearDescriptor_t desc);
25+
26+
#endif

include/infiniop/ops/quantize.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef __INFINIOP_QUANTIZE_API_H__
2+
#define __INFINIOP_QUANTIZE_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopQuantizeDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateQuantizeDescriptor(
9+
infiniopHandle_t handle, infiniopQuantizeDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t input_desc,
11+
infiniopTensorDescriptor_t output_q_desc,
12+
infiniopTensorDescriptor_t output_s_desc);
13+
14+
__C __export infiniStatus_t infiniopGetQuantizeWorkspaceSize(
15+
infiniopQuantizeDescriptor_t desc, size_t *size);
16+
17+
__C __export infiniStatus_t infiniopQuantize(
18+
infiniopQuantizeDescriptor_t desc, void *workspace, size_t workspace_size,
19+
void *input, void *output_q, void *output_s, int group_size, double eps,
20+
double min_8bit, double max_8bit, bool scale_ue8m0, void *stream);
21+
22+
__C __export infiniStatus_t
23+
infiniopDestroyQuantizeDescriptor(infiniopQuantizeDescriptor_t desc);
24+
#endif

scripts/install.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
88
os.chdir(PROJECT_DIR)
99

10+
1011
def run_cmd(cmd):
1112
subprocess.run(cmd, text=True, encoding="utf-8", check=True, shell=True)
1213

src/infiniop/devices/nvidia/nvidia_handle.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "../pool.h"
66
#include "nvidia_handle.h"
77
#include <cublas_v2.h>
8+
#include <cuda.h>
9+
#include <cuda_fp8.h>
810
#include <functional>
911

1012
#ifdef ENABLE_CUDNN_API
@@ -13,6 +15,11 @@
1315

1416
#ifdef ENABLE_CUBLASLT_API
1517
#include <cublasLt.h>
18+
#if CUDA_VERSION >= 12090
19+
#define SUPPORT_FP8_BLOCKWISE_SCALE 1
20+
#else
21+
#define SUPPORT_FP8_BLOCKWISE_SCALE 0
22+
#endif
1623
#endif
1724

1825
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)

src/infiniop/elementwise/cpu/elementwise_cpu.h

Lines changed: 62 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,22 @@
66
#include <utility>
77

88
/**
9-
* @brief Define the process for initializing a Descriptor of an elementwise operation
10-
* for its CPU implementation
9+
* @brief Define the process for initializing a Descriptor of an elementwise
10+
* operation for its CPU implementation
1111
*
1212
* @param HANDLE The device handle.
1313
* @param DTYPE The output dtype.
1414
* @param OUT_DESC The output tensor descriptor.
1515
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
1616
*/
17-
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
17+
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, \
18+
INPUT_DESC_VEC) \
1819
\
1920
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
2021
CHECK_RESULT(info_result); \
2122
\
22-
*desc_ptr = new Descriptor( \
23-
DTYPE, \
24-
info_result.take(), \
25-
nullptr, \
26-
0, \
27-
HANDLE->device, \
28-
HANDLE->device_id);
23+
*desc_ptr = new Descriptor(DTYPE, info_result.take(), nullptr, 0, \
24+
HANDLE->device, HANDLE->device_id);
2925

3026
namespace op::elementwise::cpu {
3127

@@ -62,18 +58,17 @@ class DeviceImpl final {
6258
* @return infiniStatus_t Status indicating success or failure.
6359
*/
6460
template <typename Op, typename Tdata, typename... Args>
65-
infiniStatus_t calculate(
66-
const op::elementwise::ElementwiseInfo &info,
67-
void *output,
68-
const std::vector<const void *> &inputs,
69-
void *stream,
70-
Args &&...args);
61+
infiniStatus_t calculate(const op::elementwise::ElementwiseInfo &info,
62+
void *output,
63+
const std::vector<const void *> &inputs,
64+
void *stream, Args &&...args);
7165

7266
/**
7367
* @brief Dispatches an elementwise operation with heterogeneous input types.
7468
*
75-
* Supports operations where each input may have a different type, as defined by Op.
76-
* The number of input types must match the operation's expected input count.
69+
* Supports operations where each input may have a different type, as defined
70+
* by Op. The number of input types must match the operation's expected input
71+
* count.
7772
*
7873
* @tparam Op The elementwise operation to perform.
7974
* @tparam Tout Output data type.
@@ -86,15 +81,12 @@ class DeviceImpl final {
8681
* @param args Additional backend-specific arguments.
8782
* @return infiniStatus_t Status indicating success or failure.
8883
*/
89-
template <typename Op, typename Tout, typename... Tin,
90-
typename... Args,
84+
template <typename Op, typename Tout, typename... Tin, typename... Args,
9185
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
92-
infiniStatus_t calculate(
93-
const op::elementwise::ElementwiseInfo &info,
94-
void *output,
95-
const std::vector<const void *> &inputs,
96-
void *stream,
97-
Args &&...args);
86+
infiniStatus_t calculate(const op::elementwise::ElementwiseInfo &info,
87+
void *output,
88+
const std::vector<const void *> &inputs,
89+
void *stream, Args &&...args);
9890
};
9991

10092
// Define the Opaque struct for CPU, which is empty
@@ -106,74 +98,86 @@ utils::Result<DeviceImpl> DeviceImpl::create(Args &&...args) {
10698
}
10799

108100
// Perform elementwise operation for different input types
109-
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
101+
template <typename Op, typename Tout, typename... Tin, size_t... Is,
102+
typename... Args,
110103
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
111-
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
112-
void *output,
104+
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
113105
const std::vector<const void *> &inputs,
114-
std::index_sequence<Is...>,
115-
Args &&...args) {
106+
std::index_sequence<Is...>, Args &&...args) {
116107

117108
Tout *out = reinterpret_cast<Tout *>(output);
118-
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
109+
std::tuple<const Tin *...> input_ptrs = {
110+
reinterpret_cast<const Tin *>(inputs[Is])...};
119111
ptrdiff_t output_size = info.getOutputSize();
120112

121113
#pragma omp parallel for
122114
for (ptrdiff_t i = 0; i < output_size; ++i) {
123115
size_t out_idx = info.isOutputContiguous()
124116
? i
125-
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
117+
: op::common_cpu::indexToOffset(
118+
i, info.getNdim(), info.getOutputShape(),
119+
info.getOutputStrides());
126120

127121
auto get_input_idx = [&](size_t input_id) {
128122
return info.getInputContiguous()[input_id]
129123
? i
130-
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
124+
: op::common_cpu::indexToOffset(
125+
i, info.getNdim(), info.getInputShape(input_id),
126+
info.getInputStrides(input_id));
131127
};
132128

133-
out[out_idx] = utils::cast<Tout>(
134-
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
129+
out[out_idx] = utils::cast<Tout>(Op{}.template operator()<Tout, Tin...>(
130+
std::get<Is>(input_ptrs)[get_input_idx(Is)]...,
131+
std::forward<Args>(args)...));
135132
}
136133
}
137134

138135
// Invoke elementwise operation for different input types
139-
template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
140-
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
141-
void *output,
142-
const std::vector<const void *> &inputs,
143-
void *stream,
144-
Args &&...args) {
136+
template <typename Op, typename Tout, typename... Tin, typename... Args,
137+
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
138+
infiniStatus_t
139+
DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
140+
void *output, const std::vector<const void *> &inputs,
141+
void *stream, Args &&...args) {
145142

146143
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
147-
calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...);
144+
calculate_impl<Op, Tout, Tin...>(info, output, inputs,
145+
std::make_index_sequence<sizeof...(Tin)>{},
146+
std::forward<Args>(args)...);
148147
return INFINI_STATUS_SUCCESS;
149148
}
150149

151150
// Perform elementwise operation when all inputs have the same type
152151
template <typename Op, typename Tdata, size_t... Is, typename... Args>
153-
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
154-
void *output,
152+
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
155153
const std::vector<const void *> &inputs,
156-
std::index_sequence<Is...>,
157-
Args &&...args) {
154+
std::index_sequence<Is...>, Args &&...args) {
158155

159156
Tdata *out = reinterpret_cast<Tdata *>(output);
160-
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
157+
std::array<const Tdata *, sizeof...(Is)> ins = {
158+
reinterpret_cast<const Tdata *>(inputs[Is])...};
161159
const ptrdiff_t output_size = info.getOutputSize();
162160

163161
#pragma omp parallel for if (output_size > 1024)
164162
for (ptrdiff_t i = 0; i < output_size; ++i) {
165163
size_t out_idx = info.isOutputContiguous()
166164
? i
167-
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
165+
: op::common_cpu::indexToOffset(
166+
i, info.getNdim(), info.getOutputShape(),
167+
info.getOutputStrides());
168168

169169
auto get_input_idx = [&](size_t input_id) {
170170
return info.getInputContiguous()[input_id]
171171
? i
172-
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
172+
: op::common_cpu::indexToOffset(
173+
i, info.getNdim(), info.getInputShape(input_id),
174+
info.getInputStrides(input_id));
173175
};
174176

175177
if constexpr (std::is_same_v<Tdata, fp16_t> || std::is_same_v<Tdata, bf16_t>) {
176-
out[out_idx] = utils::cast<Tdata>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
178+
out[out_idx] = utils::cast<Tdata>(
179+
Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])...,
180+
std::forward<Args>(args)...));
177181
} else {
178182
out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
179183
}
@@ -182,16 +186,16 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
182186

183187
// Invoke elementwise operation when all inputs have the same type
184188
template <typename Op, typename Tdata, typename... Args>
185-
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
186-
void *output,
187-
const std::vector<const void *> &inputs,
188-
void *stream,
189-
Args &&...args) {
189+
infiniStatus_t
190+
DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
191+
void *output, const std::vector<const void *> &inputs,
192+
void *stream, Args &&...args) {
190193
constexpr size_t N = Op::num_inputs;
191-
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
194+
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{},
195+
std::forward<Args>(args)...);
192196
return INFINI_STATUS_SUCCESS;
193197
}
194198

195199
} // namespace op::elementwise::cpu
196200

197-
#endif // __INFINIOP_ELEMENTWISE_CPU_H__
201+
#endif // __INFINIOP_ELEMENTWISE_CPU_H__

0 commit comments

Comments
 (0)