66#include < utility>
77
88/* *
9- * @brief Define the process for initializing a Descriptor of an elementwise operation
10- * for its CPU implementation
9+ * @brief Define the process for initializing a Descriptor of an elementwise
10+ * operation for its CPU implementation
1111 *
1212 * @param HANDLE The device handle.
1313 * @param DTYPE The output dtype.
1414 * @param OUT_DESC The output tensor descriptor.
1515 * @param INPUT_DESC_VEC A vector containing input tensor descriptors.
1616 */
17- #define CREATE_ELEMENTWISE_CPU_DESCRIPTOR (HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC ) \
17+ #define CREATE_ELEMENTWISE_CPU_DESCRIPTOR (HANDLE, DTYPE, OUT_DESC, \
18+ INPUT_DESC_VEC) \
1819 \
1920 auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
2021 CHECK_RESULT (info_result); \
2122 \
22- *desc_ptr = new Descriptor( \
23- DTYPE, \
24- info_result.take(), \
25- nullptr , \
26- 0 , \
27- HANDLE->device, \
28- HANDLE->device_id);
23+ *desc_ptr = new Descriptor(DTYPE, info_result.take(), nullptr , 0 , \
24+ HANDLE->device, HANDLE->device_id);
2925
3026namespace op ::elementwise::cpu {
3127
@@ -62,18 +58,17 @@ class DeviceImpl final {
6258 * @return infiniStatus_t Status indicating success or failure.
6359 */
6460 template <typename Op, typename Tdata, typename ... Args>
65- infiniStatus_t calculate (
66- const op::elementwise::ElementwiseInfo &info,
67- void *output,
68- const std::vector<const void *> &inputs,
69- void *stream,
70- Args &&...args);
61+ infiniStatus_t calculate (const op::elementwise::ElementwiseInfo &info,
62+ void *output,
63+ const std::vector<const void *> &inputs,
64+ void *stream, Args &&...args);
7165
7266 /* *
7367 * @brief Dispatches an elementwise operation with heterogeneous input types.
7468 *
75- * Supports operations where each input may have a different type, as defined by Op.
76- * The number of input types must match the operation's expected input count.
69+ * Supports operations where each input may have a different type, as defined
70+ * by Op. The number of input types must match the operation's expected input
71+ * count.
7772 *
7873 * @tparam Op The elementwise operation to perform.
7974 * @tparam Tout Output data type.
@@ -86,15 +81,12 @@ class DeviceImpl final {
8681 * @param args Additional backend-specific arguments.
8782 * @return infiniStatus_t Status indicating success or failure.
8883 */
89- template <typename Op, typename Tout, typename ... Tin,
90- typename ... Args,
84+ template <typename Op, typename Tout, typename ... Tin, typename ... Args,
9185 std::enable_if_t <(sizeof ...(Tin) == Op::num_inputs), int > = 0 >
92- infiniStatus_t calculate (
93- const op::elementwise::ElementwiseInfo &info,
94- void *output,
95- const std::vector<const void *> &inputs,
96- void *stream,
97- Args &&...args);
86+ infiniStatus_t calculate (const op::elementwise::ElementwiseInfo &info,
87+ void *output,
88+ const std::vector<const void *> &inputs,
89+ void *stream, Args &&...args);
9890};
9991
10092// Define the Opaque struct for CPU, which is empty
@@ -106,74 +98,86 @@ utils::Result<DeviceImpl> DeviceImpl::create(Args &&...args) {
10698}
10799
108100// Perform elementwise operation for different input types
109- template <typename Op, typename Tout, typename ... Tin, size_t ... Is, typename ... Args,
101+ template <typename Op, typename Tout, typename ... Tin, size_t ... Is,
102+ typename ... Args,
110103 std::enable_if_t <(sizeof ...(Tin) == Op::num_inputs), int > = 0 >
111- void calculate_impl (const op::elementwise::ElementwiseInfo &info,
112- void *output,
104+ void calculate_impl (const op::elementwise::ElementwiseInfo &info, void *output,
113105 const std::vector<const void *> &inputs,
114- std::index_sequence<Is...>,
115- Args &&...args) {
106+ std::index_sequence<Is...>, Args &&...args) {
116107
117108 Tout *out = reinterpret_cast <Tout *>(output);
118- std::tuple<const Tin *...> input_ptrs = {reinterpret_cast <const Tin *>(inputs[Is])...};
109+ std::tuple<const Tin *...> input_ptrs = {
110+ reinterpret_cast <const Tin *>(inputs[Is])...};
119111 ptrdiff_t output_size = info.getOutputSize ();
120112
121113#pragma omp parallel for
122114 for (ptrdiff_t i = 0 ; i < output_size; ++i) {
123115 size_t out_idx = info.isOutputContiguous ()
124116 ? i
125- : op::common_cpu::indexToOffset (i, info.getNdim (), info.getOutputShape (), info.getOutputStrides ());
117+ : op::common_cpu::indexToOffset (
118+ i, info.getNdim (), info.getOutputShape (),
119+ info.getOutputStrides ());
126120
127121 auto get_input_idx = [&](size_t input_id) {
128122 return info.getInputContiguous ()[input_id]
129123 ? i
130- : op::common_cpu::indexToOffset (i, info.getNdim (), info.getInputShape (input_id), info.getInputStrides (input_id));
124+ : op::common_cpu::indexToOffset (
125+ i, info.getNdim (), info.getInputShape (input_id),
126+ info.getInputStrides (input_id));
131127 };
132128
133- out[out_idx] = utils::cast<Tout>(
134- Op{}.template operator ()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx (Is)]..., std::forward<Args>(args)...));
129+ out[out_idx] = utils::cast<Tout>(Op{}.template operator ()<Tout, Tin...>(
130+ std::get<Is>(input_ptrs)[get_input_idx (Is)]...,
131+ std::forward<Args>(args)...));
135132 }
136133}
137134
138135// Invoke elementwise operation for different input types
139- template <typename Op, typename Tout, typename ... Tin, typename ... Args, std:: enable_if_t <( sizeof ...(Tin) == Op::num_inputs), int >>
140- infiniStatus_t DeviceImpl::calculate ( const op::elementwise::ElementwiseInfo &info,
141- void *output,
142- const std::vector< const void *> &inputs ,
143- void *stream ,
144- Args &&...args) {
136+ template <typename Op, typename Tout, typename ... Tin, typename ... Args,
137+ std:: enable_if_t <( sizeof ...(Tin) == Op::num_inputs), int >>
138+ infiniStatus_t
139+ DeviceImpl::calculate ( const op::elementwise::ElementwiseInfo &info ,
140+ void *output, const std::vector< const void *> &inputs ,
141+ void *stream, Args &&...args) {
145142
146143 static_assert (sizeof ...(Tin) == Op::num_inputs, " Input type count mismatch" );
147- calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof ...(Tin)>{}, std::forward<Args>(args)...);
144+ calculate_impl<Op, Tout, Tin...>(info, output, inputs,
145+ std::make_index_sequence<sizeof ...(Tin)>{},
146+ std::forward<Args>(args)...);
148147 return INFINI_STATUS_SUCCESS;
149148}
150149
151150// Perform elementwise operation when all inputs have the same type
152151template <typename Op, typename Tdata, size_t ... Is, typename ... Args>
153- void calculate_impl (const op::elementwise::ElementwiseInfo &info,
154- void *output,
152+ void calculate_impl (const op::elementwise::ElementwiseInfo &info, void *output,
155153 const std::vector<const void *> &inputs,
156- std::index_sequence<Is...>,
157- Args &&...args) {
154+ std::index_sequence<Is...>, Args &&...args) {
158155
159156 Tdata *out = reinterpret_cast <Tdata *>(output);
160- std::array<const Tdata *, sizeof ...(Is)> ins = {reinterpret_cast <const Tdata *>(inputs[Is])...};
157+ std::array<const Tdata *, sizeof ...(Is)> ins = {
158+ reinterpret_cast <const Tdata *>(inputs[Is])...};
161159 const ptrdiff_t output_size = info.getOutputSize ();
162160
163161#pragma omp parallel for if (output_size > 1024)
164162 for (ptrdiff_t i = 0 ; i < output_size; ++i) {
165163 size_t out_idx = info.isOutputContiguous ()
166164 ? i
167- : op::common_cpu::indexToOffset (i, info.getNdim (), info.getOutputShape (), info.getOutputStrides ());
165+ : op::common_cpu::indexToOffset (
166+ i, info.getNdim (), info.getOutputShape (),
167+ info.getOutputStrides ());
168168
169169 auto get_input_idx = [&](size_t input_id) {
170170 return info.getInputContiguous ()[input_id]
171171 ? i
172- : op::common_cpu::indexToOffset (i, info.getNdim (), info.getInputShape (input_id), info.getInputStrides (input_id));
172+ : op::common_cpu::indexToOffset (
173+ i, info.getNdim (), info.getInputShape (input_id),
174+ info.getInputStrides (input_id));
173175 };
174176
175177 if constexpr (std::is_same_v<Tdata, fp16_t > || std::is_same_v<Tdata, bf16_t >) {
176- out[out_idx] = utils::cast<Tdata>(Op{}(utils::cast<float >(ins[Is][get_input_idx (Is)])..., std::forward<Args>(args)...));
178+ out[out_idx] = utils::cast<Tdata>(
179+ Op{}(utils::cast<float >(ins[Is][get_input_idx (Is)])...,
180+ std::forward<Args>(args)...));
177181 } else {
178182 out[out_idx] = Op{}(ins[Is][get_input_idx (Is)]..., std::forward<Args>(args)...);
179183 }
@@ -182,16 +186,16 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
182186
183187// Invoke elementwise operation when all inputs have the same type
184188template <typename Op, typename Tdata, typename ... Args>
185- infiniStatus_t DeviceImpl::calculate (const op::elementwise::ElementwiseInfo &info,
186- void *output,
187- const std::vector<const void *> &inputs,
188- void *stream,
189- Args &&...args) {
189+ infiniStatus_t
190+ DeviceImpl::calculate (const op::elementwise::ElementwiseInfo &info,
191+ void *output, const std::vector<const void *> &inputs,
192+ void *stream, Args &&...args) {
190193 constexpr size_t N = Op::num_inputs;
191- calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
194+ calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{},
195+ std::forward<Args>(args)...);
192196 return INFINI_STATUS_SUCCESS;
193197}
194198
195199} // namespace op::elementwise::cpu
196200
197- #endif // __INFINIOP_ELEMENTWISE_CPU_H__
201+ #endif // __INFINIOP_ELEMENTWISE_CPU_H__
0 commit comments