@@ -46,7 +46,9 @@ static inline void cpu_cast_impl_incremental(
4646 const std::vector<ptrdiff_t > &in_stride = info.in_stride ;
4747 const std::vector<ptrdiff_t > &out_stride = info.out_stride ;
4848
49- if (n == 0 ) return ;
49+ if (n == 0 ) {
50+ return ;
51+ }
5052
5153 std::vector<size_t > idx (ndim, 0 );
5254 ptrdiff_t in_off = 0 ;
@@ -59,15 +61,23 @@ static inline void cpu_cast_impl_incremental(
5961
6062 for (int d = static_cast <int >(ndim) - 1 ; d >= 0 ; --d) {
6163 idx[d] += 1 ;
62- if (in_stride[d] != 0 ) in_off += in_stride[d];
63- if (out_stride[d] != 0 ) out_off += out_stride[d];
64+ if (in_stride[d] != 0 ) {
65+ in_off += in_stride[d];
66+ }
67+ if (out_stride[d] != 0 ) {
68+ out_off += out_stride[d];
69+ }
6470
6571 if (idx[d] < shape[d]) {
6672 break ;
6773 } else {
6874 idx[d] = 0 ;
69- if (in_stride[d] != 0 ) in_off -= static_cast <ptrdiff_t >(shape[d]) * in_stride[d];
70- if (out_stride[d] != 0 ) out_off -= static_cast <ptrdiff_t >(shape[d]) * out_stride[d];
75+ if (in_stride[d] != 0 ) {
76+ in_off -= static_cast <ptrdiff_t >(shape[d]) * in_stride[d];
77+ }
78+ if (out_stride[d] != 0 ) {
79+ out_off -= static_cast <ptrdiff_t >(shape[d]) * out_stride[d];
80+ }
7181 }
7282 }
7383 }
@@ -80,39 +90,39 @@ infiniStatus_t Descriptor::calculate(
8090 const void *input,
8191 void *stream) const {
8292
83- if (output == const_cast <void *>(input)) {
93+ if (output == const_cast <void *>(input)) {
8494 return INFINI_STATUS_BAD_PARAM; // or INFINI_STATUS_INPLACE_NOT_SUPPORTED
8595 }
8696
87- #define CASE_OUT (DT_OUT, TOUT ) \
88- case DT_OUT: { \
89- switch (_info.dt_in ) { \
90- case INFINI_DTYPE_I32: \
91- cpu_cast_impl_incremental<TOUT, int32_t >(output, input, _info); \
92- break ; \
93- case INFINI_DTYPE_I64: \
94- cpu_cast_impl_incremental<TOUT, int64_t >(output, input, _info); \
95- break ; \
96- case INFINI_DTYPE_U32: \
97- cpu_cast_impl_incremental<TOUT, uint32_t >(output, input, _info); \
98- break ; \
99- case INFINI_DTYPE_U64: \
100- cpu_cast_impl_incremental<TOUT, uint64_t >(output, input, _info); \
101- break ; \
102- case INFINI_DTYPE_F16: \
103- cpu_cast_impl_incremental<TOUT, fp16_t >(output, input, _info); \
104- break ; \
105- case INFINI_DTYPE_F32: \
106- cpu_cast_impl_incremental<TOUT, float >(output, input, _info); \
107- break ; \
108- case INFINI_DTYPE_F64: \
109- cpu_cast_impl_incremental<TOUT, double >(output, input, _info); \
110- break ; \
111- default : \
112- return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
113- } \
114- break ; \
115- }
97+ #define CASE_OUT (DT_OUT, TOUT ) \
98+ case DT_OUT: { \
99+ switch (_info.dt_in ) { \
100+ case INFINI_DTYPE_I32: \
101+ cpu_cast_impl_incremental<TOUT, int32_t >(output, input, _info); \
102+ break ; \
103+ case INFINI_DTYPE_I64: \
104+ cpu_cast_impl_incremental<TOUT, int64_t >(output, input, _info); \
105+ break ; \
106+ case INFINI_DTYPE_U32: \
107+ cpu_cast_impl_incremental<TOUT, uint32_t >(output, input, _info); \
108+ break ; \
109+ case INFINI_DTYPE_U64: \
110+ cpu_cast_impl_incremental<TOUT, uint64_t >(output, input, _info); \
111+ break ; \
112+ case INFINI_DTYPE_F16: \
113+ cpu_cast_impl_incremental<TOUT, fp16_t >(output, input, _info); \
114+ break ; \
115+ case INFINI_DTYPE_F32: \
116+ cpu_cast_impl_incremental<TOUT, float >(output, input, _info); \
117+ break ; \
118+ case INFINI_DTYPE_F64: \
119+ cpu_cast_impl_incremental<TOUT, double >(output, input, _info); \
120+ break ; \
121+ default : \
122+ return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
123+ } \
124+ break ; \
125+ }
116126
117127 switch (_info.dt_out ) {
118128 CASE_OUT (INFINI_DTYPE_I32, int32_t );
@@ -126,10 +136,9 @@ infiniStatus_t Descriptor::calculate(
126136 return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
127137 }
128138
129- #undef CASE_OUT
139+ #undef CASE_OUT
130140
131141 return INFINI_STATUS_SUCCESS;
132142}
133143
134-
135144} // namespace op::cast::cpu
0 commit comments