Skip to content

Commit e67a44b

Browse files
committed
[T1-1-1]: operators clang-format
1 parent 158d5a2 commit e67a44b

File tree

35 files changed

+533
-435
lines changed

35 files changed

+533
-435
lines changed

src/infiniop/ops/cast/cpu/cast_cpu.cc

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ static inline void cpu_cast_impl_incremental(
4646
const std::vector<ptrdiff_t> &in_stride = info.in_stride;
4747
const std::vector<ptrdiff_t> &out_stride = info.out_stride;
4848

49-
if (n == 0) return;
49+
if (n == 0) {
50+
return;
51+
}
5052

5153
std::vector<size_t> idx(ndim, 0);
5254
ptrdiff_t in_off = 0;
@@ -59,15 +61,23 @@ static inline void cpu_cast_impl_incremental(
5961

6062
for (int d = static_cast<int>(ndim) - 1; d >= 0; --d) {
6163
idx[d] += 1;
62-
if (in_stride[d] != 0) in_off += in_stride[d];
63-
if (out_stride[d] != 0) out_off += out_stride[d];
64+
if (in_stride[d] != 0) {
65+
in_off += in_stride[d];
66+
}
67+
if (out_stride[d] != 0) {
68+
out_off += out_stride[d];
69+
}
6470

6571
if (idx[d] < shape[d]) {
6672
break;
6773
} else {
6874
idx[d] = 0;
69-
if (in_stride[d] != 0) in_off -= static_cast<ptrdiff_t>(shape[d]) * in_stride[d];
70-
if (out_stride[d] != 0) out_off -= static_cast<ptrdiff_t>(shape[d]) * out_stride[d];
75+
if (in_stride[d] != 0) {
76+
in_off -= static_cast<ptrdiff_t>(shape[d]) * in_stride[d];
77+
}
78+
if (out_stride[d] != 0) {
79+
out_off -= static_cast<ptrdiff_t>(shape[d]) * out_stride[d];
80+
}
7181
}
7282
}
7383
}
@@ -80,39 +90,39 @@ infiniStatus_t Descriptor::calculate(
8090
const void *input,
8191
void *stream) const {
8292

83-
if (output == const_cast<void*>(input)) {
93+
if (output == const_cast<void *>(input)) {
8494
return INFINI_STATUS_BAD_PARAM; // or INFINI_STATUS_INPLACE_NOT_SUPPORTED
8595
}
8696

87-
#define CASE_OUT(DT_OUT, TOUT) \
88-
case DT_OUT: { \
89-
switch (_info.dt_in) { \
90-
case INFINI_DTYPE_I32: \
91-
cpu_cast_impl_incremental<TOUT, int32_t>(output, input, _info); \
92-
break; \
93-
case INFINI_DTYPE_I64: \
94-
cpu_cast_impl_incremental<TOUT, int64_t>(output, input, _info); \
95-
break; \
96-
case INFINI_DTYPE_U32: \
97-
cpu_cast_impl_incremental<TOUT, uint32_t>(output, input, _info); \
98-
break; \
99-
case INFINI_DTYPE_U64: \
100-
cpu_cast_impl_incremental<TOUT, uint64_t>(output, input, _info); \
101-
break; \
102-
case INFINI_DTYPE_F16: \
103-
cpu_cast_impl_incremental<TOUT, fp16_t>(output, input, _info); \
104-
break; \
105-
case INFINI_DTYPE_F32: \
106-
cpu_cast_impl_incremental<TOUT, float>(output, input, _info); \
107-
break; \
108-
case INFINI_DTYPE_F64: \
109-
cpu_cast_impl_incremental<TOUT, double>(output, input, _info); \
110-
break; \
111-
default: \
112-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
113-
} \
114-
break; \
115-
}
97+
#define CASE_OUT(DT_OUT, TOUT) \
98+
case DT_OUT: { \
99+
switch (_info.dt_in) { \
100+
case INFINI_DTYPE_I32: \
101+
cpu_cast_impl_incremental<TOUT, int32_t>(output, input, _info); \
102+
break; \
103+
case INFINI_DTYPE_I64: \
104+
cpu_cast_impl_incremental<TOUT, int64_t>(output, input, _info); \
105+
break; \
106+
case INFINI_DTYPE_U32: \
107+
cpu_cast_impl_incremental<TOUT, uint32_t>(output, input, _info); \
108+
break; \
109+
case INFINI_DTYPE_U64: \
110+
cpu_cast_impl_incremental<TOUT, uint64_t>(output, input, _info); \
111+
break; \
112+
case INFINI_DTYPE_F16: \
113+
cpu_cast_impl_incremental<TOUT, fp16_t>(output, input, _info); \
114+
break; \
115+
case INFINI_DTYPE_F32: \
116+
cpu_cast_impl_incremental<TOUT, float>(output, input, _info); \
117+
break; \
118+
case INFINI_DTYPE_F64: \
119+
cpu_cast_impl_incremental<TOUT, double>(output, input, _info); \
120+
break; \
121+
default: \
122+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
123+
} \
124+
break; \
125+
}
116126

117127
switch (_info.dt_out) {
118128
CASE_OUT(INFINI_DTYPE_I32, int32_t);
@@ -126,10 +136,9 @@ infiniStatus_t Descriptor::calculate(
126136
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
127137
}
128138

129-
#undef CASE_OUT
139+
#undef CASE_OUT
130140

131141
return INFINI_STATUS_SUCCESS;
132142
}
133143

134-
135144
} // namespace op::cast::cpu

src/infiniop/ops/cast/cuda/kernel.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,12 @@ __global__ void cast_kernel(
6565
} else {
6666
idx_d = 0;
6767
}
68-
if (in_stride[d] != 0) in_off += static_cast<long long>(idx_d) * in_stride[d];
69-
if (out_stride[d] != 0) out_off += static_cast<long long>(idx_d) * out_stride[d];
68+
if (in_stride[d] != 0) {
69+
in_off += static_cast<long long>(idx_d) * in_stride[d];
70+
}
71+
if (out_stride[d] != 0) {
72+
out_off += static_cast<long long>(idx_d) * out_stride[d];
73+
}
7074
}
7175
out[static_cast<size_t>(out_off)] = device_cast<ToutDev, TinDev>(in[static_cast<size_t>(in_off)]);
7276
}

src/infiniop/ops/cast/info.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace op::cast {
1010
class CastInfo {
1111
CastInfo() = default;
1212

13-
public:
13+
public:
1414
infiniDtype_t dt_in;
1515
infiniDtype_t dt_out;
1616
std::vector<size_t> shape;
@@ -21,9 +21,9 @@ class CastInfo {
2121
static utils::Result<CastInfo> create(
2222
infiniopTensorDescriptor_t out_desc,
2323
infiniopTensorDescriptor_t in_desc) {
24-
24+
2525
auto dt_out = out_desc->dtype();
26-
auto dt_in = in_desc->dtype();
26+
auto dt_in = in_desc->dtype();
2727

2828
CHECK_DTYPE(dt_in,
2929
INFINI_DTYPE_I32, INFINI_DTYPE_I64,
@@ -40,14 +40,16 @@ class CastInfo {
4040
}
4141

4242
size_t n = 1;
43-
for (size_t i = 0; i < in_desc->ndim(); ++i) n *= static_cast<size_t>(in_desc->dim(i));
43+
for (size_t i = 0; i < in_desc->ndim(); ++i) {
44+
n *= static_cast<size_t>(in_desc->dim(i));
45+
}
4446

4547
return utils::Result<CastInfo>(CastInfo{
46-
dt_in,
47-
dt_out,
48-
out_desc->shape(),
49-
in_desc->strides(),
50-
out_desc->strides(),
48+
dt_in,
49+
dt_out,
50+
out_desc->shape(),
51+
in_desc->strides(),
52+
out_desc->strides(),
5153
n,
5254
});
5355
}

0 commit comments

Comments
 (0)