-
Notifications
You must be signed in to change notification settings - Fork 540
[Common] Split cast/gated kernels by scaling mode #2248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
36533fe
d7d8e43
30484be
1bac1f8
40cc697
bcdbee0
67e32c3
a72d53a
3e5edd2
bcbdf8a
3dea6bd
4a05df0
2433529
3526e54
6c0ed53
aa5ec2e
bcd55f6
7158b9d
0c4314e
41416b3
91ea154
e14ae55
f7225e9
37747dd
9afdba1
b764dea
703556c
ec58510
b84301b
ebcfafe
80df2fd
deb012b
6142ff7
25e9b48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include <cuda.h> | ||
| #include <cudaTypedefs.h> | ||
| #include <cuda_runtime.h> | ||
| #include <transformer_engine/cast.h> | ||
| #include <transformer_engine/multi_stream.h> | ||
|
|
||
| #include "../common.h" | ||
| #include "../transpose/cast_transpose.h" | ||
| #include "../util/multi_stream.h" | ||
| #include "../utils.cuh" | ||
| #include "dispatch/dequantize.cuh" | ||
| #include "dispatch/quantize.cuh" | ||
| #include "transformer_engine/transpose.h" | ||
|
|
||
| void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_quantize); | ||
| using namespace transformer_engine; | ||
|
|
||
| constexpr bool IS_ACT = false; | ||
| dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream); | ||
| } | ||
|
|
||
| void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_quantize_noop); | ||
| using namespace transformer_engine; | ||
|
|
||
| // Create config with noop tensor | ||
| QuantizationConfig quant_config; | ||
| quant_config.noop_tensor = noop; | ||
|
|
||
| nvte_quantize_v2(input, output, reinterpret_cast<NVTEQuantizationConfig>(&quant_config), stream); | ||
| } | ||
|
|
||
| void nvte_quantize_v2(const NVTETensor input, NVTETensor output, | ||
| const NVTEQuantizationConfig quant_config, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_quantize_v2); | ||
| using namespace transformer_engine; | ||
|
|
||
| constexpr bool IS_ACT = false; | ||
| dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream); | ||
| } | ||
|
|
||
| void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, | ||
| NVTETensor workspace, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_quantize_dbias); | ||
| using namespace transformer_engine; | ||
|
|
||
| constexpr bool IS_DBIAS = true; | ||
| constexpr bool IS_DACT = false; | ||
| constexpr const NVTETensor activation_input = nullptr; | ||
|
|
||
| dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>( | ||
| input, activation_input, output, dbias, workspace, nullptr, stream); | ||
| } | ||
|
|
||
| void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_dequantize); | ||
| using namespace transformer_engine; | ||
| dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), | ||
| stream); | ||
| } | ||
|
|
||
| void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, | ||
| const NVTEQuantizationConfig quant_configs, | ||
| const size_t num_tensors, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_multi_tensor_quantize); | ||
| using namespace transformer_engine; | ||
|
|
||
| constexpr bool IS_ACT = false; | ||
|
|
||
| const size_t num_streams = nvte_get_num_compute_streams(); | ||
|
|
||
| int num_stream_used = std::min(num_streams, num_tensors); | ||
| // wait for current stream to finish | ||
| NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); | ||
| for (int s = 0; s < num_stream_used; s++) { | ||
| NVTE_CHECK_CUDA( | ||
| cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); | ||
| } | ||
|
|
||
| for (int i = 0; i < num_tensors; i++) { | ||
| dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>( | ||
| inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); | ||
| } | ||
|
|
||
| // record events on compute streams | ||
| for (int s = 0; s < num_stream_used; s++) { | ||
| NVTE_CHECK_CUDA( | ||
| cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); | ||
| } | ||
| // wait for all compute streams to finish | ||
| for (int s = 0; s < num_stream_used; s++) { | ||
| NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| /*! \file common.cuh | ||
| * \brief Common functions in quantize. | ||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: header guard mismatch – file is |
||
| #define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
|
|
||
| #include <cuda.h> | ||
| #include <cudaTypedefs.h> | ||
| #include <cuda_runtime.h> | ||
| #include <transformer_engine/transformer_engine.h> | ||
|
|
||
| #include "../../common.h" | ||
| #include "../../utils.cuh" | ||
|
|
||
| namespace transformer_engine { | ||
| namespace dispatch { | ||
| namespace common { | ||
| inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { | ||
| const size_t N = product(t->data.shape); | ||
| const bool isFullTile = (N % elems_per_block == 0); | ||
| return isFullTile; | ||
| } | ||
|
|
||
| inline bool dimensions_supported_by_TMA(const Tensor *const t) { | ||
| const size_t cols = t->flat_last_dim(); | ||
| constexpr size_t TMA_bytes = 16; | ||
| const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); | ||
| return cols % alignment_requirement == 0; | ||
| } | ||
|
|
||
| namespace kernel { | ||
|
|
||
| constexpr size_t THREADS_PER_BLOCK = 256; | ||
| template <int nvec, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) | ||
| reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, | ||
| const size_t rows, const size_t cols) { | ||
| using ComputeVec = Vec<float, nvec>; | ||
| using OutputVec = Vec<OType, nvec>; | ||
|
|
||
| const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
|
||
| if (thread_id * nvec >= cols) { | ||
| return; | ||
| } | ||
|
|
||
| const float *const thread_in_base = dbias_partial + thread_id * nvec; | ||
| OType *const thread_out_base = dbias_output + thread_id * nvec; | ||
|
|
||
| ComputeVec ldg_vec; | ||
| ComputeVec acc_vec; | ||
| acc_vec.clear(); | ||
| for (int i = 0; i < rows; ++i) { | ||
| ldg_vec.load_from(thread_in_base + i * cols); | ||
| #pragma unroll | ||
| for (int e = 0; e < nvec; ++e) { | ||
| acc_vec.data.elt[e] += ldg_vec.data.elt[e]; | ||
| } | ||
| } | ||
|
|
||
| OutputVec stg_vec; | ||
| #pragma unroll | ||
| for (int e = 0; e < nvec; ++e) { | ||
| stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]); | ||
| } | ||
| stg_vec.store_to(thread_out_base); | ||
| } | ||
| } // namespace kernel | ||
|
|
||
| template <typename IType> | ||
| void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, | ||
|
Comment on lines
+77
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: template parameter |
||
| cudaStream_t stream) { | ||
| using namespace kernel; | ||
| constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 | ||
| constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); | ||
|
|
||
| NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); | ||
| const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec); | ||
|
|
||
| reduce_dbias_kernel<reduce_dbias_nvec, IType> | ||
| <<<reduce_dbias_num_blocks, THREADS_PER_BLOCK, 0, stream>>>( | ||
| reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols); | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| } | ||
|
|
||
| } // namespace common | ||
| } // namespace dispatch | ||
| } // namespace transformer_engine | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: file comment says "in quantize" but this lives in
cast/core/– update to "Common functions in cast." or similar