Skip to content

Commit 0bf841e

Browse files
committed
add noop to comp amax
Signed-off-by: zhongboz <[email protected]>
1 parent 07db17b commit 0bf841e

File tree

3 files changed

+76
-11
lines changed

3 files changed

+76
-11
lines changed

transformer_engine/common/include/transformer_engine/recipe.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,21 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
8484
*/
8585
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
8686

87+
/*! \brief Compute an FP8 tensor's amax with quantization config.
88+
*
89+
* The amax (maximum absolute value) of the input tensor is computed
90+
* and written to the amax buffer of the output tensor, using the provided
91+
* quantization configuration.
92+
* One useful config is the noop tensor, which is needed by cuda graph.
93+
*
94+
* \param[in] input Input tensor. Must be unquantized.
95+
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
96+
* \param[in] config Quantization configuration.
97+
* \param[in] stream CUDA stream used for the operation.
98+
*/
99+
void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
100+
const NVTEQuantizationConfig config, cudaStream_t stream);
101+
87102
/*! \brief Update an FP8 tensor's scale based on its amax.
88103
*
89104
* This is only supported for FP8 tensors with per-tensor scaling.

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512;
2323
template <int nvec, bool aligned, typename InputType>
2424
__launch_bounds__(amax_kernel_threads) __global__
2525
void amax_kernel(const InputType *input, float *amax, const size_t N,
26-
const size_t num_aligned_elements) {
26+
const size_t num_aligned_elements, const float *noop_ptr) {
27+
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
28+
return;
29+
}
30+
2731
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
2832
InputType max = 0.f;
2933
const int warp_id = threadIdx.x / THREADS_PER_WARP;
@@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__
5862
}
5963

6064
template <int nvec, typename InputType>
61-
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
65+
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
66+
cudaStream_t stream) {
6267
// Zero out amax so we can update with atomic max
6368
cudaMemsetAsync(amax, 0, sizeof(float), stream);
6469

@@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
8186
switch (align) {
8287
case Alignment::SAME_ALIGNED:
8388
amax_kernel<nvec, true, InputType>
84-
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
89+
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
8590
break;
8691
case Alignment::SAME_UNALIGNED:
8792
amax_kernel<nvec, false, InputType>
88-
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
93+
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
8994
break;
9095
case Alignment::DIFFERENT: {
9196
// This case is a logic error, since there is only one pointer (input)
9297
// in the alignment check. Still safe to process without vectorization.
93-
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
98+
amax_kernel<1, true, InputType>
99+
<<<num_blocks, threads, 0, stream>>>(input, amax, N, N, noop_ptr);
94100
break;
95101
}
96102
}
@@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
102108
} // namespace
103109
} // namespace transformer_engine
104110

105-
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
106-
NVTE_API_CALL(nvte_compute_amax);
111+
namespace {
112+
113+
void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream,
114+
const NVTEQuantizationConfig config_) {
107115
using namespace transformer_engine;
108116

109117
// Check input tensor
@@ -138,20 +146,49 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
138146
to_string(output.amax.dtype), ")");
139147
CheckOutputTensor(output, "output_compute_amax", true);
140148

149+
// Optionally use config_ for future extension (e.g., for cuda graph/noop tensor)
150+
// For now, config_ is unused, but the API is ready for future use.
151+
float *noop_ptr = nullptr;
152+
if (config_ != nullptr) {
153+
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
154+
155+
// extract noop tensor from quant_config_cpp if it's not null
156+
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
157+
noop_ptr = reinterpret_cast<float *>(
158+
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
159+
}
160+
141161
// Compute amax
142162
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
143163
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
144164
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
145165
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
146-
stream);); // NOLINT(*)
166+
noop_ptr, stream);); // NOLINT(*)
167+
}
168+
169+
} // anonymous namespace
170+
171+
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
172+
NVTE_API_CALL(nvte_compute_amax);
173+
compute_amax_impl(input_, output_, stream, nullptr);
174+
}
175+
176+
void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_,
177+
const NVTEQuantizationConfig config_, cudaStream_t stream) {
178+
NVTE_API_CALL(nvte_compute_amax_with_config);
179+
compute_amax_impl(input_, output_, stream, config_);
147180
}
148181

149182
namespace transformer_engine {
150183
namespace {
151184

152185
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
153186
const float max_fp8, const bool force_pow_2_scales,
154-
const float epsilon) {
187+
const float epsilon, const float *noop_ptr) {
188+
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
189+
return;
190+
}
191+
155192
*scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
156193
std::numeric_limits<float>::max());
157194
}
@@ -197,10 +234,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
197234
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType,
198235
max_fp8 = Quantized_Limits<DType>::max_norm;);
199236

237+
// noop tensor for cuda graph
238+
float *noop_ptr = nullptr;
239+
if (config_ != nullptr) {
240+
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
241+
242+
// extract noop tensor from quant_config_cpp if it's not null
243+
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
244+
noop_ptr = reinterpret_cast<float *>(
245+
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
246+
}
247+
200248
// Update scale
201249
compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
202250
reinterpret_cast<const float *>(output.amax.dptr),
203251
reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
204-
config.amax_epsilon);
252+
config.amax_epsilon, noop_ptr);
205253
NVTE_CHECK_CUDA(cudaGetLastError());
206254
}

transformer_engine/pytorch/csrc/quantizer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,9 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te
518518

519519
// Compute amax
520520
if (compute_amax) {
521-
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
521+
// NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
522+
NVTE_SCOPED_GIL_RELEASE(
523+
{ nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); });
522524
}
523525

524526
// Perform amax reduction if needed

0 commit comments

Comments
 (0)