@@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512;
2323template <int nvec, bool aligned, typename InputType>
2424__launch_bounds__ (amax_kernel_threads) __global__
2525 void amax_kernel (const InputType *input, float *amax, const size_t N,
26- const size_t num_aligned_elements) {
26+ const size_t num_aligned_elements, const float *noop_ptr) {
27+ if (noop_ptr != nullptr && noop_ptr[0 ] == 1 .0f ) {
28+ return ;
29+ }
30+
2731 VectorizedLoader<InputType, nvec, aligned> loader (input, N);
2832 InputType max = 0 .f ;
2933 const int warp_id = threadIdx .x / THREADS_PER_WARP;
@@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__
5862}
5963
6064template <int nvec, typename InputType>
61- void launch_amax_kernel (const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
65+ void launch_amax_kernel (const InputType *input, float *amax, const size_t N, const float *noop_ptr,
66+ cudaStream_t stream) {
6267 // Zero out amax so we can update with atomic max
6368 cudaMemsetAsync (amax, 0 , sizeof (float ), stream);
6469
@@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
8186 switch (align) {
8287 case Alignment::SAME_ALIGNED:
8388 amax_kernel<nvec, true , InputType>
84- <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements);
89+ <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements, noop_ptr );
8590 break ;
8691 case Alignment::SAME_UNALIGNED:
8792 amax_kernel<nvec, false , InputType>
88- <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements);
93+ <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements, noop_ptr );
8994 break ;
9095 case Alignment::DIFFERENT: {
9196 // This case is a logic error, since there is only one pointer (input)
9297 // in the alignment check. Still safe to process without vectorization.
93- amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, amax, N, N);
98+ amax_kernel<1 , true , InputType>
99+ <<<num_blocks, threads, 0 , stream>>> (input, amax, N, N, noop_ptr);
94100 break ;
95101 }
96102 }
@@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
102108} // namespace
103109} // namespace transformer_engine
104110
105- void nvte_compute_amax (const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
106- NVTE_API_CALL (nvte_compute_amax);
111+ namespace {
112+
113+ void compute_amax_impl (const NVTETensor input_, const NVTETensor output_, cudaStream_t stream,
114+ const NVTEQuantizationConfig config_) {
107115 using namespace transformer_engine ;
108116
109117 // Check input tensor
@@ -138,20 +146,49 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
138146 to_string (output.amax .dtype ), " )" );
139147 CheckOutputTensor (output, " output_compute_amax" , true );
140148
149+ // Optionally use config_ for future extension (e.g., for cuda graph/noop tensor)
150+ // For now, config_ is unused, but the API is ready for future use.
151+ float *noop_ptr = nullptr ;
152+ if (config_ != nullptr ) {
153+ const QuantizationConfig *config_cpp = reinterpret_cast <const QuantizationConfig *>(config_);
154+
155+ // extract noop tensor from quant_config_cpp if it's not null
156+ const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr ;
157+ noop_ptr = reinterpret_cast <float *>(
158+ (noop != nullptr ? convertNVTETensorCheck (noop)->data .dptr : nullptr ));
159+ }
160+
141161 // Compute amax
142162 TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
143163 input.data .dtype , IType, constexpr int nvec = 32 / sizeof (IType);
144164 launch_amax_kernel<nvec>(reinterpret_cast <const IType *>(input.data .dptr ),
145165 reinterpret_cast <float *>(output.amax .dptr ), input.data .numel (),
146- stream);); // NOLINT(*)
166+ noop_ptr, stream);); // NOLINT(*)
167+ }
168+
169+ } // anonymous namespace
170+
171+ void nvte_compute_amax (const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
172+ NVTE_API_CALL (nvte_compute_amax);
173+ compute_amax_impl (input_, output_, stream, nullptr );
174+ }
175+
176+ void nvte_compute_amax_with_config (const NVTETensor input_, const NVTETensor output_,
177+ const NVTEQuantizationConfig config_, cudaStream_t stream) {
178+ NVTE_API_CALL (nvte_compute_amax_with_config);
179+ compute_amax_impl (input_, output_, stream, config_);
147180}
148181
149182namespace transformer_engine {
150183namespace {
151184
152185__global__ void compute_scale_from_amax_kernel (const float *amax_ptr, float *scale_ptr,
153186 const float max_fp8, const bool force_pow_2_scales,
154- const float epsilon) {
187+ const float epsilon, const float *noop_ptr) {
188+ if (noop_ptr != nullptr && noop_ptr[0 ] == 1 .0f ) {
189+ return ;
190+ }
191+
155192 *scale_ptr = compute_scale_from_amax (*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
156193 std::numeric_limits<float >::max ());
157194}
@@ -197,10 +234,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
197234 TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY (output.data .dtype , DType,
198235 max_fp8 = Quantized_Limits<DType>::max_norm;);
199236
237+ // noop tensor for cuda graph
238+ float *noop_ptr = nullptr ;
239+ if (config_ != nullptr ) {
240+ const QuantizationConfig *config_cpp = reinterpret_cast <const QuantizationConfig *>(config_);
241+
242+ // extract noop tensor from quant_config_cpp if it's not null
243+ const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr ;
244+ noop_ptr = reinterpret_cast <float *>(
245+ (noop != nullptr ? convertNVTETensorCheck (noop)->data .dptr : nullptr ));
246+ }
247+
200248 // Update scale
201249 compute_scale_from_amax_kernel<<<1 , 1 , 0 , stream>>> (
202250 reinterpret_cast <const float *>(output.amax .dptr ),
203251 reinterpret_cast <float *>(output.scale .dptr ), max_fp8, config.force_pow_2_scales ,
204- config.amax_epsilon );
252+ config.amax_epsilon , noop_ptr );
205253 NVTE_CHECK_CUDA (cudaGetLastError ());
206254}
0 commit comments