Skip to content

Commit 00fbe4f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent eef3c93 commit 00fbe4f

File tree

7 files changed

+2605
-2824
lines changed

7 files changed

+2605
-2824
lines changed

transformer_engine/common/common.cu

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,11 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
197197
}
198198

199199
// TMA descriptor with swizzle
200-
void create_2D_tensor_map(
201-
CUtensorMap &tensorMap,
202-
const SimpleTensor &tensor,
203-
CUtensorMapSwizzle swizzle,
204-
const uint64_t globalY,
205-
const uint64_t globalX,
206-
const uint32_t shmemY,
207-
const uint32_t shmemX,
208-
const uint32_t stride_elems,
209-
const uint32_t offset_elems,
210-
const size_t type_num_bits) {
200+
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
201+
CUtensorMapSwizzle swizzle, const uint64_t globalY,
202+
const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
203+
const uint32_t stride_elems, const uint32_t offset_elems,
204+
const size_t type_num_bits) {
211205
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
212206
void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled");
213207
return reinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(driver_ptr);
@@ -227,46 +221,37 @@ void create_2D_tensor_map(
227221

228222
const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
229223
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
230-
(offset_elems * type_num_bits) / 8);
231-
224+
(offset_elems * type_num_bits) / 8);
225+
232226
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
233227
"Tensor data pointer must be 16B aligned");
234228
const int32_t TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
235229
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
236230
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
237231

238232
int32_t swizzle_size = [&]() {
239-
switch (swizzle) {
240-
case CU_TENSOR_MAP_SWIZZLE_32B:
241-
return 32;
242-
case CU_TENSOR_MAP_SWIZZLE_64B:
243-
return 64;
244-
case CU_TENSOR_MAP_SWIZZLE_128B:
245-
return 128;
246-
case CU_TENSOR_MAP_SWIZZLE_NONE:
247-
default:
248-
return 0;
249-
}
250-
}();
233+
switch (swizzle) {
234+
case CU_TENSOR_MAP_SWIZZLE_32B:
235+
return 32;
236+
case CU_TENSOR_MAP_SWIZZLE_64B:
237+
return 64;
238+
case CU_TENSOR_MAP_SWIZZLE_128B:
239+
return 128;
240+
case CU_TENSOR_MAP_SWIZZLE_NONE:
241+
default:
242+
return 0;
243+
}
244+
}();
251245
if (swizzle != CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE) {
252-
NVTE_CHECK(boxSize[0] * (type_num_bits / 8) <= swizzle_size,
253-
"boxSize[0]:", boxSize[0], " must be less than swizzle size:", swizzle_size);
246+
NVTE_CHECK(boxSize[0] * (type_num_bits / 8) <= swizzle_size, "boxSize[0]:", boxSize[0],
247+
" must be less than swizzle size:", swizzle_size);
254248
}
255249

256-
NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
257-
&tensorMap,
258-
tensorDataType,
259-
rank,
260-
dataPtr,
261-
size,
262-
stride,
263-
boxSize,
264-
elemStride,
265-
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
266-
swizzle,
267-
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
268-
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
269-
));
250+
NVTE_CHECK_CUDA_DRIVER(
251+
cuDriverTensorMapEncodeTiled(&tensorMap, tensorDataType, rank, dataPtr, size, stride, boxSize,
252+
elemStride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
253+
swizzle, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
254+
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
270255
}
271256

272257
bool is_supported_by_CC_100() {

transformer_engine/common/common.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -710,17 +710,11 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
710710
const uint32_t shmemX, const uint32_t stride_elems,
711711
const uint32_t offset_elems, const size_t type_num_bits);
712712

713-
void create_2D_tensor_map(
714-
CUtensorMap &tensorMap,
715-
const SimpleTensor &tensor,
716-
CUtensorMapSwizzle swizzle,
717-
const uint64_t globalY,
718-
const uint64_t globalX,
719-
const uint32_t shmemY,
720-
const uint32_t shmemX,
721-
const uint32_t stride_elems,
722-
const uint32_t offset_elems,
723-
const size_t type_num_bits);
713+
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
714+
CUtensorMapSwizzle swizzle, const uint64_t globalY,
715+
const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
716+
const uint32_t stride_elems, const uint32_t offset_elems,
717+
const size_t type_num_bits);
724718

725719
bool is_supported_by_CC_100();
726720

transformer_engine/common/util/cast_kernels.cuh

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
#include "../transpose/cast_transpose.h"
2323
#include "../util/vectorized_pointwise.h"
2424
#include "../utils.cuh"
25+
#include "cast_kernels_spec.cuh"
2526
#include "math.h"
2627
#include "ptx.cuh"
2728
#include "transformer_engine/transformer_engine.h"
28-
#include "cast_kernels_spec.cuh"
2929

3030
namespace transformer_engine {
3131

@@ -1083,31 +1083,21 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
10831083
output->dtype(), OType,
10841084

10851085
if (spec::hasSpec<IS_DBIAS, IS_DACT, IS_ACT, IType, OType>()) {
1086-
10871086
switch (scaling_type) {
10881087
case ScalingType::ROWWISE: {
10891088
using traits = spec::CastTraits<IType, OType, true, false>;
10901089
auto kernel = spec::cast_mxfp8_kernel<traits>;
10911090

1092-
cudaFuncSetAttribute(
1093-
kernel,
1094-
cudaFuncAttributeMaxDynamicSharedMemorySize,
1095-
traits::smem);
1091+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
1092+
traits::smem);
10961093

1097-
dim3 block(traits::threadLayout::num,
1098-
traits::warpLayout::N,
1099-
traits::warpLayout::M);
1094+
dim3 block(traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M);
11001095
dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN,
11011096
(rows + traits::blockDimM - 1) / traits::blockDimM);
11021097
kernel<<<grid, block, traits::smem, stream>>>(
1103-
reinterpret_cast<typename traits::IType *>(input.data.dptr),
1104-
reinterpret_cast<typename traits::OType *>(output->data.dptr),
1105-
scales_rowwise_ptr,
1106-
rows,
1107-
cols,
1108-
scale_stride_rowwise,
1109-
scale_stride_colwise
1110-
);
1098+
reinterpret_cast<typename traits::IType *>(input.data.dptr),
1099+
reinterpret_cast<typename traits::OType *>(output->data.dptr),
1100+
scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
11111101

11121102
break;
11131103
}
@@ -1119,55 +1109,35 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
11191109
using traits = spec::CastTraits<IType, OType, true, true>;
11201110
auto kernel = spec::cast_mxfp8_kernel<traits>;
11211111

1122-
cudaFuncSetAttribute(
1123-
kernel,
1124-
cudaFuncAttributeMaxDynamicSharedMemorySize,
1125-
traits::smem
1126-
);
1112+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
1113+
traits::smem);
11271114
// TMA for loading, so that we don't need STS for transposing
11281115
alignas(64) CUtensorMap tensor_map_input{};
11291116
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
1130-
create_2D_tensor_map(tensor_map_input,
1131-
input.data,
1132-
traits::input_swizzle_pattern,
1133-
rows, cols,
1134-
traits::blockIterDim::M, traits::blockIterDim::N,
1135-
/*stride_elems=*/cols,
1136-
/*offset_elems=*/0,
1137-
input_type_bit_size);
1117+
create_2D_tensor_map(tensor_map_input, input.data, traits::input_swizzle_pattern,
1118+
rows, cols, traits::blockIterDim::M, traits::blockIterDim::N,
1119+
/*stride_elems=*/cols,
1120+
/*offset_elems=*/0, input_type_bit_size);
11381121
alignas(64) CUtensorMap tensor_map_rowwise_output{};
11391122
alignas(64) CUtensorMap tensor_map_colwise_output{};
11401123
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
1141-
create_2D_tensor_map(tensor_map_rowwise_output,
1142-
output->data,
1143-
traits::output_swizzle_pattern,
1144-
rows, cols,
1124+
create_2D_tensor_map(tensor_map_rowwise_output, output->data,
1125+
traits::output_swizzle_pattern, rows, cols,
11451126
traits::blockIterDim::M, traits::blockIterDim::N,
1146-
/*stride_elems=*/cols,
1147-
/*offset_elems=*/0,
1127+
/*stride_elems=*/cols,
1128+
/*offset_elems=*/0, output_type_bit_size);
1129+
create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data,
1130+
traits::output_swizzle_pattern, rows, cols,
1131+
traits::blockIterDim::M, traits::blockIterDim::N, cols, 0,
11481132
output_type_bit_size);
1149-
create_2D_tensor_map(tensor_map_colwise_output,
1150-
output->columnwise_data,
1151-
traits::output_swizzle_pattern,
1152-
rows, cols,
1153-
traits::blockIterDim::M, traits::blockIterDim::N,
1154-
cols, 0, output_type_bit_size);
11551133

1156-
dim3 block(traits::rowThreadLayout::num,
1157-
traits::numWarps);
1134+
dim3 block(traits::rowThreadLayout::num, traits::numWarps);
11581135
dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N,
11591136
(rows + traits::blockDIM::M - 1) / traits::blockDIM::M);
11601137
kernel<<<grid, block, traits::smem, stream>>>(
1161-
tensor_map_input,
1162-
tensor_map_rowwise_output,
1163-
tensor_map_colwise_output,
1164-
scales_rowwise_ptr,
1165-
scales_colwise_ptr,
1166-
rows,
1167-
cols,
1168-
scale_stride_rowwise,
1169-
scale_stride_colwise
1170-
);
1138+
tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output,
1139+
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
1140+
scale_stride_colwise);
11711141

11721142
break;
11731143
}

0 commit comments

Comments
 (0)