22
22
#include " ../transpose/cast_transpose.h"
23
23
#include " ../util/vectorized_pointwise.h"
24
24
#include " ../utils.cuh"
25
+ #include " cast_kernels_spec.cuh"
25
26
#include " math.h"
26
27
#include " ptx.cuh"
27
28
#include " transformer_engine/transformer_engine.h"
28
- #include " cast_kernels_spec.cuh"
29
29
30
30
namespace transformer_engine {
31
31
@@ -1083,31 +1083,21 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
1083
1083
output->dtype (), OType,
1084
1084
1085
1085
if (spec::hasSpec<IS_DBIAS, IS_DACT, IS_ACT, IType, OType>()) {
1086
-
1087
1086
switch (scaling_type) {
1088
1087
case ScalingType::ROWWISE: {
1089
1088
using traits = spec::CastTraits<IType, OType, true , false >;
1090
1089
auto kernel = spec::cast_mxfp8_kernel<traits>;
1091
1090
1092
- cudaFuncSetAttribute (
1093
- kernel,
1094
- cudaFuncAttributeMaxDynamicSharedMemorySize,
1095
- traits::smem);
1091
+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
1092
+ traits::smem);
1096
1093
1097
- dim3 block (traits::threadLayout::num,
1098
- traits::warpLayout::N,
1099
- traits::warpLayout::M);
1094
+ dim3 block (traits::threadLayout::num, traits::warpLayout::N, traits::warpLayout::M);
1100
1095
dim3 grid ((cols + traits::blockDimN - 1 ) / traits::blockDimN,
1101
1096
(rows + traits::blockDimM - 1 ) / traits::blockDimM);
1102
1097
kernel<<<grid, block, traits::smem, stream>>> (
1103
- reinterpret_cast <typename traits::IType *>(input.data .dptr ),
1104
- reinterpret_cast <typename traits::OType *>(output->data .dptr ),
1105
- scales_rowwise_ptr,
1106
- rows,
1107
- cols,
1108
- scale_stride_rowwise,
1109
- scale_stride_colwise
1110
- );
1098
+ reinterpret_cast <typename traits::IType *>(input.data .dptr ),
1099
+ reinterpret_cast <typename traits::OType *>(output->data .dptr ),
1100
+ scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
1111
1101
1112
1102
break ;
1113
1103
}
@@ -1119,55 +1109,35 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
1119
1109
using traits = spec::CastTraits<IType, OType, true , true >;
1120
1110
auto kernel = spec::cast_mxfp8_kernel<traits>;
1121
1111
1122
- cudaFuncSetAttribute (
1123
- kernel,
1124
- cudaFuncAttributeMaxDynamicSharedMemorySize,
1125
- traits::smem
1126
- );
1112
+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
1113
+ traits::smem);
1127
1114
// TMA for loading, so that we don't need STS for transposing
1128
1115
alignas (64 ) CUtensorMap tensor_map_input{};
1129
1116
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
1130
- create_2D_tensor_map (tensor_map_input,
1131
- input.data ,
1132
- traits::input_swizzle_pattern,
1133
- rows, cols,
1134
- traits::blockIterDim::M, traits::blockIterDim::N,
1135
- /* stride_elems=*/ cols,
1136
- /* offset_elems=*/ 0 ,
1137
- input_type_bit_size);
1117
+ create_2D_tensor_map (tensor_map_input, input.data , traits::input_swizzle_pattern,
1118
+ rows, cols, traits::blockIterDim::M, traits::blockIterDim::N,
1119
+ /* stride_elems=*/ cols,
1120
+ /* offset_elems=*/ 0 , input_type_bit_size);
1138
1121
alignas (64 ) CUtensorMap tensor_map_rowwise_output{};
1139
1122
alignas (64 ) CUtensorMap tensor_map_colwise_output{};
1140
1123
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
1141
- create_2D_tensor_map (tensor_map_rowwise_output,
1142
- output->data ,
1143
- traits::output_swizzle_pattern,
1144
- rows, cols,
1124
+ create_2D_tensor_map (tensor_map_rowwise_output, output->data ,
1125
+ traits::output_swizzle_pattern, rows, cols,
1145
1126
traits::blockIterDim::M, traits::blockIterDim::N,
1146
- /* stride_elems=*/ cols,
1147
- /* offset_elems=*/ 0 ,
1127
+ /* stride_elems=*/ cols,
1128
+ /* offset_elems=*/ 0 , output_type_bit_size);
1129
+ create_2D_tensor_map (tensor_map_colwise_output, output->columnwise_data ,
1130
+ traits::output_swizzle_pattern, rows, cols,
1131
+ traits::blockIterDim::M, traits::blockIterDim::N, cols, 0 ,
1148
1132
output_type_bit_size);
1149
- create_2D_tensor_map (tensor_map_colwise_output,
1150
- output->columnwise_data ,
1151
- traits::output_swizzle_pattern,
1152
- rows, cols,
1153
- traits::blockIterDim::M, traits::blockIterDim::N,
1154
- cols, 0 , output_type_bit_size);
1155
1133
1156
- dim3 block (traits::rowThreadLayout::num,
1157
- traits::numWarps);
1134
+ dim3 block (traits::rowThreadLayout::num, traits::numWarps);
1158
1135
dim3 grid ((cols + traits::blockDIM::N - 1 ) / traits::blockDIM::N,
1159
1136
(rows + traits::blockDIM::M - 1 ) / traits::blockDIM::M);
1160
1137
kernel<<<grid, block, traits::smem, stream>>> (
1161
- tensor_map_input,
1162
- tensor_map_rowwise_output,
1163
- tensor_map_colwise_output,
1164
- scales_rowwise_ptr,
1165
- scales_colwise_ptr,
1166
- rows,
1167
- cols,
1168
- scale_stride_rowwise,
1169
- scale_stride_colwise
1170
- );
1138
+ tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output,
1139
+ scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
1140
+ scale_stride_colwise);
1171
1141
1172
1142
break ;
1173
1143
}
0 commit comments