From 638b533eca3ab58d9ebc08be849473eaa758a19d Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 11 Apr 2024 06:24:04 +0000 Subject: [PATCH 01/28] Change config files to suit in the A100 server --- docker/Dockerfile.multi | 2 +- docker/Makefile | 4 ++-- examples/summarize.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index 7f256b79b..e1ad78ce7 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -60,7 +60,7 @@ COPY tensorrt_llm tensorrt_llm COPY 3rdparty 3rdparty COPY setup.py requirements.txt requirements-dev.txt ./ -ARG BUILD_WHEEL_ARGS="--clean --trt_root /usr/local/tensorrt --python_bindings --benchmarks" +ARG BUILD_WHEEL_ARGS="--clean --trt_root /usr/local/tensorrt --nvtx -a "80-real" --python_bindings --benchmarks" RUN python3 scripts/build_wheel.py ${BUILD_WHEEL_ARGS} FROM ${DEVEL_IMAGE} as release diff --git a/docker/Makefile b/docker/Makefile index e219a9df2..100911816 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -96,8 +96,8 @@ endef @echo "Pulling docker image: $(IMAGE_WITH_TAG)" docker pull $(IMAGE_WITH_TAG) -DOCKER_RUN_OPTS ?= --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -DOCKER_RUN_ARGS ?= +DOCKER_RUN_OPTS ?= -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 +DOCKER_RUN_ARGS ?= --volume /data/storage1/model:/mnt/model --privileged=true GPU_OPTS ?= --gpus=all SOURCE_DIR ?= $(shell readlink -f ..) CODE_DIR ?= /code/tensorrt_llm diff --git a/examples/summarize.py b/examples/summarize.py index 13e5de43b..2f5bf5ad3 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -363,6 +363,7 @@ def eval_hf(datapoint, max_output_len=output_len, max_beam_width=num_beams, max_attention_window_size=max_attention_window_size, + free_gpu_memory_fraction=0.75, sink_token_length=sink_token_length) runner = runner_cls.from_dir(**runner_kwargs) assert not (args.eval_ppl and not (runner.gather_context_logits and runner.gather_generation_logits)), \ From 3d643cc0b1071c95d5ae2fbe1fc8d2c831a1e534 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 11 Apr 2024 06:25:06 +0000 Subject: [PATCH 02/28] Update submodules --- 3rdparty/cutlass | 2 +- 3rdparty/json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a8f2c80db..bbe579a9e 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a8f2c80db0564c74f4efccac71993b971dfc448b +Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49 diff --git a/3rdparty/json b/3rdparty/json index bc889afb4..9cca280a4 160000 --- a/3rdparty/json +++ b/3rdparty/json @@ -1 +1 @@ -Subproject commit bc889afb4c5bf1c0d8ee29ef35eaaf4c8bef8a5d +Subproject commit 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 From c03b40778ad01dd44d7618330c22b326c5a6b093 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 18 Apr 2024 08:31:39 +0000 Subject: [PATCH 03/28] Change summarization task default setting --- .gitignore | 3 +++ examples/summarize.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index cb9aee85b..2d7f47d2e 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,9 @@ tensorrt_llm/libs tensorrt_llm/bindings.pyi tensorrt_llm/bindings/*.pyi +# Debugging Purpose +.env + # Testing .coverage.* results_trt/ diff --git a/examples/summarize.py b/examples/summarize.py index af926fc24..4e44dec80 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -88,6 +88,7 @@ def main(args): dataset = load_dataset(dataset_name, dataset_revision, cache_dir=args.dataset_path, + trust_remote_code=True, split=dataset_split) max_batch_size = args.batch_size @@ -363,7 +364,7 @@ def eval_hf(datapoint, max_output_len=output_len, max_beam_width=num_beams, max_attention_window_size=max_attention_window_size, - free_gpu_memory_fraction=0.75, + free_gpu_memory_fraction=0.6, sink_token_length=sink_token_length) runner = runner_cls.from_dir(**runner_kwargs) assert not (args.eval_ppl and not (runner.gather_context_logits and runner.gather_generation_logits)), \ From 298fcc176b745c948b4af96f8092f36a471af024 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Tue, 23 Apr 2024 11:08:00 +0000 Subject: [PATCH 04/28] Change CMakeLists for a debugging purpose --- cpp/CMakeLists.txt | 14 ++++++++++++++ cpp/tensorrt_llm/plugins/CMakeLists.txt | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4cffce5a3..a0acde279 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -273,6 +273,16 @@ if(FAST_MATH) message("CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") endif() +# Add the option with CMAKE_CUDA_FLAGS causes error. +if(${CMAKE_BUILD_TYPE} MATCHES "Debug") + add_compile_options("$<$:-G>" + "$<$:--host-linker-script>") +elseif(${CMAKE_BUILD_TYPE} MATCHES "RelWithDebInfo") + add_compile_options("$<$:-lineinfo>" + "$<$:--host-linker-script>") +endif() + + set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDAToolkit_INCLUDE_DIR}) message(STATUS "COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}") @@ -325,6 +335,10 @@ if(BUILD_PYT) link_directories("${Python3_LIBRARY_DIRS}") list(APPEND COMMON_HEADER_DIRS ${Python3_INCLUDE_DIRS}) + # Let torch find the cudnn and cusparselt libraries + set(CAFFE2_USE_CUDNN ON) + set(CAFFE2_USE_CUSPARSELT ON) + execute_process( COMMAND ${Python3_EXECUTABLE} "-c" diff --git a/cpp/tensorrt_llm/plugins/CMakeLists.txt b/cpp/tensorrt_llm/plugins/CMakeLists.txt index 3afb88422..e8b090d53 100755 --- a/cpp/tensorrt_llm/plugins/CMakeLists.txt +++ b/cpp/tensorrt_llm/plugins/CMakeLists.txt @@ -101,7 +101,7 @@ else() ${PLUGIN_SHARED_TARGET} PROPERTIES LINK_FLAGS - "-Wl,--exclude-libs,ALL -Wl,--version-script=${PLUGIN_EXPORT_MAP} -Wl,-rpath,'$ORIGIN' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" + "-Wl,--no-relax -Wl,--exclude-libs,ALL -Wl,--version-script=${PLUGIN_EXPORT_MAP} -Wl,-rpath,'$ORIGIN' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" ) endif() From dfc29ad61b9460a594fed7fd7a19d1f7721f4443 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Tue, 23 Apr 2024 11:08:51 +0000 Subject: [PATCH 05/28] Apply shared mem to scale factor of quantization. --- .../kernels/weightOnlyBatchedGemv/common.h | 4 ++ .../kernels/weightOnlyBatchedGemv/kernel.h | 25 +++++++--- .../kernels/weightOnlyBatchedGemv/utility.h | 49 +++++++++++++++++++ 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h index db0762351..702cc2dc2 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h @@ -24,6 +24,10 @@ #include #include +#ifdef __CUDACC__ +#include +#endif + namespace tensorrt_llm { namespace kernels diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 29dfcbdc1..0592cd0e9 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -71,9 +71,11 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca GMemIterator weight_iterator(weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); - GMemIterator scales_iterator(scales, - (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + SHMemIterator scales_iterator(scales, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + interleaved_offset_n * Details::kInterleave, + ((tid * StepK / Details::LayoutDeatils::kTileSize) % Details::kInterleave), + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, + (GroupSize != 0 ? real_offset_k / GroupSize : 0)); GMemIterator zeros_iterator(zeros, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); @@ -87,25 +89,30 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca TypeA tile_acc[CtaM * CtaN]; fill(tile_acc, static_cast(0.f)); + extern __shared__ TypeA shmem_sz []; + // dimension of each is [kInterleave * CtaN, blockDim.x / (GroupSize * kInterleave / StepK)] + TypeA* vec_scale = (TypeA*)shmem_sz; for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { TypeA vec_act_scale[StepK]; - TypeA vec_scale[CtaN], vec_zero[CtaN]; + TypeA vec_zero[CtaN]; TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; + #pragma unroll for (int i = 0; i < CtaN; ++i) { - scales_iterator.load(vec_scale + i, iter, i); zeros_iterator.load(vec_zero + i, iter, i); } + scales_iterator.load(vec_scale, iter); act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) { weight_iterator.load(tile_w_quantized, iter, i); dequantize( - tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); + tile_w, tile_w_quantized, scales_iterator.stride_iter(i), + vec_zero + i, alpha); pack_to_vec2(tile_w_pack2, tile_w, i); } #pragma unroll @@ -131,7 +138,11 @@ void exec_kernel(Params& params, cudaStream_t s) dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); dim3 block(Threads); // clang-format off - kernel<<>>( + kernel<<< + grid, block, + CtaN * (GroupSize != 0 ? Details::kStepK * Threads / GroupSize : Details::kInterleave) * sizeof(T), + s + >>>( reinterpret_cast(params.act), reinterpret_cast(params.act_scale), reinterpret_cast(params.weight), diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index ddaf34cd1..d243b4c5b 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -306,6 +306,55 @@ class GMemIterator int step_; int stride_; }; + +template +class SHMemIterator +{ +public: + __device__ SHMemIterator(T* g_addr, int g_offset, int sh_offset, + int step, int stride, int sz_group) + : g_addr_(Enable ? (g_addr + g_offset) : nullptr) + , sh_addr_(nullptr) + , sh_offset_(sh_offset) + , step_(step) + , stride_(stride) + , sz_group_(sz_group) + { + + } + + __device__ void load(T* dst, int iter) + { + if constexpr (Enable) + { + sh_addr_ = dst; + // TODO: Can we make async copy here? + // TODO: Duplicated work for some threads + #pragma unroll + for (int i = 0; i < Continuous; ++i) + { + sh_addr_[Continuous * sz_group_ + i] = g_addr_[iter * step_ + i]; + } + } + } + + __device__ T* stride_iter(int ii = 0) + { + if constexpr (Enable) { + return &sh_addr_[Continuous * sz_group_ + stride_ * ii + sh_offset_]; + } + else return nullptr; + } + +private: + T* g_addr_; + int step_; + T* sh_addr_; + int sh_offset_; + int stride_; + int sz_group_; +}; + } // namespace weight_only } // namespace kernels } // namespace tensorrt_llm From 40b1cfb97b1a6d037fc24d4d4876c2f122b82c72 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 25 Apr 2024 02:22:41 +0000 Subject: [PATCH 06/28] Remove redundancy of loading scale factors --- .../kernels/weightOnlyBatchedGemv/kernel.h | 9 ++++---- .../kernels/weightOnlyBatchedGemv/utility.h | 22 ++++++++++++------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 0592cd0e9..a19e6ab05 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -73,9 +73,11 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca interleaved_k / Details::kElemsPerByteW); SHMemIterator scales_iterator(scales, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + interleaved_offset_n * Details::kInterleave, - ((tid * StepK / Details::LayoutDeatils::kTileSize) % Details::kInterleave), + ((tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave), (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, - (GroupSize != 0 ? real_offset_k / GroupSize : 0)); + (GroupSize != 0 ? real_offset_k / GroupSize : 0), + (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) + ); GMemIterator zeros_iterator(zeros, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); @@ -98,13 +100,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca TypeA vec_zero[CtaN]; TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; - #pragma unroll for (int i = 0; i < CtaN; ++i) { zeros_iterator.load(vec_zero + i, iter, i); } - scales_iterator.load(vec_scale, iter); + scales_iterator.load(vec_scale, iter, tid); act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index d243b4c5b..565f65b21 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -311,34 +311,39 @@ template class SHMemIterator { public: - __device__ SHMemIterator(T* g_addr, int g_offset, int sh_offset, - int step, int stride, int sz_group) + __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, int sh_offset, + int step, int stride, int sz_group, int sz_size) : g_addr_(Enable ? (g_addr + g_offset) : nullptr) , sh_addr_(nullptr) , sh_offset_(sh_offset) , step_(step) , stride_(stride) , sz_group_(sz_group) + , sz_size_(sz_size) { } - __device__ void load(T* dst, int iter) + __device__ __forceinline__ void load(T* dst, int iter, int tid) { if constexpr (Enable) { sh_addr_ = dst; // TODO: Can we make async copy here? - // TODO: Duplicated work for some threads #pragma unroll - for (int i = 0; i < Continuous; ++i) - { - sh_addr_[Continuous * sz_group_ + i] = g_addr_[iter * step_ + i]; + // for (int i = 0; i < Continuous; ++i) + // { + // sh_addr_[Continuous * sz_group_ + i] = g_addr_[iter * step_ + i]; + // } + + for (int i = 0; i < Continuous; i+=sz_size_) { + int ii = i + (tid % sz_size_); + sh_addr_[Continuous * sz_group_ + ii] = g_addr_[iter * step_ + ii]; } } } - __device__ T* stride_iter(int ii = 0) + __device__ __forceinline__ T* stride_iter(int ii = 0) { if constexpr (Enable) { return &sh_addr_[Continuous * sz_group_ + stride_ * ii + sh_offset_]; @@ -353,6 +358,7 @@ class SHMemIterator int sh_offset_; int stride_; int sz_group_; + int sz_size_; }; } // namespace weight_only From cc1d2c1cc441f6beb84dd8dcef23c2f9a1aa5155 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 25 Apr 2024 07:51:42 +0000 Subject: [PATCH 07/28] Apply asyncs to scale factors --- .../kernels/weightOnlyBatchedGemv/kernel.h | 15 ++++--- .../kernels/weightOnlyBatchedGemv/utility.h | 39 +++++++++---------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index a19e6ab05..4c42f1814 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -64,6 +64,9 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca = (tid * StepK / (Details::kInterleave * Details::LayoutDeatils::kTileSize)) * Details::LayoutDeatils::kTileSize + ((tid * StepK) % Details::LayoutDeatils::kTileSize); + extern __shared__ TypeA shmem_sz[]; + // dimension of each is [kInterleave * CtaN, blockDim.x / (GroupSize * kInterleave / StepK)] + TypeA* vec_scale = shmem_sz + CtaN * Details::kInterleave * (GroupSize != 0 ? real_offset_k / GroupSize : 0); GMemIterator act_iterator( act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); GMemIterator act_scale_iterator( @@ -71,11 +74,10 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca GMemIterator weight_iterator(weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); - SHMemIterator scales_iterator(scales, + SHMemIterator scales_iterator(scales, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + interleaved_offset_n * Details::kInterleave, - ((tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave), + vec_scale, ((tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave), (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, - (GroupSize != 0 ? real_offset_k / GroupSize : 0), (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) ); GMemIterator zeros_iterator(zeros, @@ -91,9 +93,6 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca TypeA tile_acc[CtaM * CtaN]; fill(tile_acc, static_cast(0.f)); - extern __shared__ TypeA shmem_sz []; - // dimension of each is [kInterleave * CtaN, blockDim.x / (GroupSize * kInterleave / StepK)] - TypeA* vec_scale = (TypeA*)shmem_sz; for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { TypeA vec_act_scale[StepK]; @@ -105,14 +104,14 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca { zeros_iterator.load(vec_zero + i, iter, i); } - scales_iterator.load(vec_scale, iter, tid); + scales_iterator.load(iter); act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) { weight_iterator.load(tile_w_quantized, iter, i); dequantize( - tile_w, tile_w_quantized, scales_iterator.stride_iter(i), + tile_w, tile_w_quantized, scales_iterator.iter(i), vec_zero + i, alpha); pack_to_vec2(tile_w_pack2, tile_w, i); } diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 565f65b21..c18d09e31 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -307,48 +307,46 @@ class GMemIterator int stride_; }; -template +template class SHMemIterator { public: - __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, int sh_offset, - int step, int stride, int sz_group, int sz_size) + __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, T* sh_addr, int sh_offset, + int step, int stride, int sz_size) : g_addr_(Enable ? (g_addr + g_offset) : nullptr) - , sh_addr_(nullptr) + , sh_addr_(Enable ? sh_addr : nullptr) , sh_offset_(sh_offset) , step_(step) , stride_(stride) - , sz_group_(sz_group) , sz_size_(sz_size) { } - __device__ __forceinline__ void load(T* dst, int iter, int tid) + __device__ __forceinline__ void load(int iter) { if constexpr (Enable) { - sh_addr_ = dst; - // TODO: Can we make async copy here? - #pragma unroll - // for (int i = 0; i < Continuous; ++i) - // { - // sh_addr_[Continuous * sz_group_ + i] = g_addr_[iter * step_ + i]; - // } - - for (int i = 0; i < Continuous; i+=sz_size_) { - int ii = i + (tid % sz_size_); - sh_addr_[Continuous * sz_group_ + ii] = g_addr_[iter * step_ + ii]; +#pragma unroll + for (int i = threadIdx.x % sz_size_; i < Continuous; i += sz_size_) { + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_) + i, + reinterpret_cast(g_addr_ + iter * step_) + i, + sizeof(TVec), + 0 + ); + __pipeline_commit(); } + __pipeline_wait_prior(0); } } - __device__ __forceinline__ T* stride_iter(int ii = 0) + __device__ __forceinline__ T* iter(int ii = 0) { if constexpr (Enable) { - return &sh_addr_[Continuous * sz_group_ + stride_ * ii + sh_offset_]; + return &sh_addr_[ii * stride_ + sh_offset_]; } - else return nullptr; + return nullptr; } private: @@ -357,7 +355,6 @@ class SHMemIterator T* sh_addr_; int sh_offset_; int stride_; - int sz_group_; int sz_size_; }; From 3fa16994047ebb936ea3eb271fe40caa6fcd73bb Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Fri, 26 Apr 2024 06:09:16 +0000 Subject: [PATCH 08/28] Apply shared mem asyncs to zeropoints --- .../kernels/weightOnlyBatchedGemv/kernel.h | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 4c42f1814..2273101c2 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -59,14 +59,20 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; int const real_offset_n = interleaved_offset_n * Details::kInterleave - + ((tid * StepK / Details::LayoutDeatils::kTileSize) % Details::kInterleave); + + (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave; int const real_offset_k = (tid * StepK / (Details::kInterleave * Details::LayoutDeatils::kTileSize)) * Details::LayoutDeatils::kTileSize + ((tid * StepK) % Details::LayoutDeatils::kTileSize); extern __shared__ TypeA shmem_sz[]; - // dimension of each is [kInterleave * CtaN, blockDim.x / (GroupSize * kInterleave / StepK)] + // dimension of each is [kInterleave * CtaN, {Threads * kInterleave * kThreadsPerInterleavedTile / GroupSize or 1}] TypeA* vec_scale = shmem_sz + CtaN * Details::kInterleave * (GroupSize != 0 ? real_offset_k / GroupSize : 0); + TypeA* vec_zero = nullptr; + if constexpr (EnableZero) + { + vec_zero = shmem_sz + CtaN * Details::kInterleave * (GroupSize != 0 ? Threads * Details::kInterleave * Details::kThreadsPerInterleavedTile / GroupSize : 1); + vec_zero += CtaN * Details::kInterleave * (GroupSize != 0 ? real_offset_k / GroupSize : 0); + } GMemIterator act_iterator( act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); GMemIterator act_scale_iterator( @@ -76,13 +82,16 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca interleaved_k / Details::kElemsPerByteW); SHMemIterator scales_iterator(scales, (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + interleaved_offset_n * Details::kInterleave, - vec_scale, ((tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave), + vec_scale, (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, + (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) + ); + SHMemIterator zeros_iterator(zeros, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + interleaved_offset_n * Details::kInterleave, + vec_zero, (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) ); - GMemIterator zeros_iterator(zeros, - (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; if constexpr (EnableBias) @@ -96,15 +105,11 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { TypeA vec_act_scale[StepK]; - TypeA vec_zero[CtaN]; TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; -#pragma unroll - for (int i = 0; i < CtaN; ++i) - { - zeros_iterator.load(vec_zero + i, iter, i); - } + scales_iterator.load(iter); + zeros_iterator.load(iter); act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) @@ -112,7 +117,7 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca weight_iterator.load(tile_w_quantized, iter, i); dequantize( tile_w, tile_w_quantized, scales_iterator.iter(i), - vec_zero + i, alpha); + zeros_iterator.iter(i), alpha); pack_to_vec2(tile_w_pack2, tile_w, i); } #pragma unroll @@ -137,12 +142,12 @@ void exec_kernel(Params& params, cudaStream_t s) } dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); dim3 block(Threads); - // clang-format off + const int shmem_size = + (EnableZero ? 2 : 1) * CtaN * Details::kInterleave * sizeof(T) + * (GroupSize != 0 ? Threads * Details::kInterleave * Details::kThreadsPerInterleavedTile / GroupSize : 1); + // clang-format off kernel<<< - grid, block, - CtaN * (GroupSize != 0 ? Details::kStepK * Threads / GroupSize : Details::kInterleave) * sizeof(T), - s - >>>( + grid, block, shmem_size, s>>>( reinterpret_cast(params.act), reinterpret_cast(params.act_scale), reinterpret_cast(params.weight), From cde2e2ec1b3533a48a5f210d3bcfba51e71c6b04 Mon Sep 17 00:00:00 2001 From: "chenwei.gavin" Date: Tue, 30 Apr 2024 12:00:09 +0800 Subject: [PATCH 09/28] [feat]: Support weight only gemm with 2bit --- .../gemm/kernel/default_fpA_intB_traits.h | 2 + .../gemm/kernel/mixed_gemm_B_layout.h | 27 ++ .../threadblock/default_dq_mma_multistage.h | 5 +- .../gemm/threadblock/default_mma.h | 154 +++++++++++ .../gemm/threadblock/default_mma_bf16.h | 101 ++++++++ .../dq_mma_multistage_finegrained.h | 32 ++- .../interleaved_numeric_conversion.h | 239 ++++++++++++++++++ .../cutlass_kernels/cutlass_preprocessors.cpp | 107 ++++++++ .../cutlass_kernels/cutlass_preprocessors.h | 4 +- .../cutlass_kernels/cutlass_type_conversion.h | 2 + .../bf16_int2_gemm_fg_scalebias.cu | 31 +++ .../bf16_int2_gemm_fg_scaleonly.cu | 31 +++ .../fpA_intB_gemm/bf16_int2_gemm_per_col.cu | 31 +++ .../fp16_int2_gemm_fg_scalebias.cu | 29 +++ .../fp16_int2_gemm_fg_scaleonly.cu | 28 ++ .../fpA_intB_gemm/fp16_int2_gemm_per_col.cu | 28 ++ .../fpA_intB_gemm/fpA_intB_gemm_template.h | 9 +- .../python/generate_kernels.py | 3 + 18 files changed, 846 insertions(+), 17 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index ee084116a..ee1189c51 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -141,6 +141,8 @@ struct MixedGemmArchTraits::value #ifdef ENABLE_FP8 || cutlass::platform::is_same::value>::type +#else + >::type #endif > { diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index a1712431e..8ac7f0723 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -113,6 +113,24 @@ template using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; +template + struct LayoutDetailsB < TypeA, + uint2b_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + + template struct LayoutDetailsB= 90>::type> { @@ -131,6 +149,15 @@ struct LayoutDetailsB +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h index 17c634655..59d372e4f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -230,8 +230,9 @@ struct DqMma::value, "Mma multistage must dequantize after ldsm"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element B must be uint8, uint4 or uint2"); static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h index ad6c7496e..d7f0736b8 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h @@ -124,6 +124,54 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage /// (stage>=3) @@ -232,6 +280,59 @@ struct DefaultMma=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; #ifdef ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage @@ -287,6 +388,59 @@ struct DefaultMma=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; #endif // fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 77af81005..876d23258 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -244,6 +244,54 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + template < /// Layout type for A matrix operand typename LayoutA, @@ -348,6 +396,59 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + } // namespace threadblock } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h index 2d34d43cb..abd690804 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -562,6 +562,7 @@ class DqMmaMultistage (-Base::kStages + 1);) { @@ -569,6 +570,8 @@ class DqMmaMultistage; + FragmentOperandB converted_frag_B_operand; // Computes a warp-level GEMM on data held in shared memory // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate CUTLASS_PRAGMA_UNROLL @@ -588,23 +591,26 @@ class DqMmaMultistagewarp_tile_iterator_B_.set_kgroup_index( (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(idx + 1) % 2]); ++this->warp_tile_iterator_B_; } + if (warp_tileB_k_compute_offset == 0) { + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[idx % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + using Converter + = cutlass::NumericArrayConverter; + converted_frag_B_operand = Converter::convert(converted_frag_B); + } - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + idx++; + } - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, warp_tileB_k_compute_offset); diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680..d18f883f7 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @@ -440,6 +440,245 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i2s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM0_MASK = 0x00030003; + static constexpr uint32_t BOTTOM1_MASK = 0x000c000c; + static constexpr uint32_t TOP0_MASK = 0x00300030; + static constexpr uint32_t TOP1_MASK = 0x00c000c0; + static constexpr uint32_t I2s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i2s = i2s >> 8; + // Extract elt_01 - (i2s & 0x00020002) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i2s), "n"(BOTTOM0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i2s), "n"(BOTTOM1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(i2s), "n"(TOP0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(i2s), "n"(TOP1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // Extract elt_89 - (i2s & 0x00020002) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[4]) + : "r"(top_i2s), "n"(BOTTOM0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[5]) + : "r"(top_i2s), "n"(BOTTOM1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[6]) + : "r"(top_i2s), "n"(TOP0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[7]) + : "r"(top_i2s), "n"(TOP1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1026, 1026} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64026402; + // This is the half2 {1 / 4, 1 / 4} represented as an integer. + static constexpr uint32_t ONE_FOUR = 0x34003400; + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + static constexpr uint32_t ONE_SIXTY_FOUR = 0x24002400; + // This is the half2 {-72, -72} represented as an integer. + //static constexpr uint32_t NEG_72 = 0xd480d480; + static constexpr uint32_t NEG_258 = 0xdc08dc08; + static constexpr uint32_t NEG_66 = 0xd420d420; + static constexpr uint32_t NEG_18 = 0xcc80cc80; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_FOUR), "r"(NEG_258)); + // Convert elt_45 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(ONE_SIXTEENTH), "r"(NEG_66)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTY_FOUR), "r"(NEG_18)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(ONE_FOUR), "r"(NEG_258)); + // Convert elt_45 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(ONE_SIXTEENTH), "r"(NEG_66)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(ONE_SIXTY_FOUR), "r"(NEG_18)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 16; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t source_i2s = reinterpret_cast(source); + + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x00030003; + static constexpr uint32_t I2s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i2s = source_i2s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i2s), "n"(MASK), "n"(I2s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) + { + i2s >>= sizeof_bits::value; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i2s), "n"(MASK), "n"(I2s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-130, -130} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC302C302; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) + { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + //*/ +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 16; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + //printf("convert uint2 to bfloat16\n"); + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp index 84cb50917..f1410fbd3 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp @@ -114,6 +114,9 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) case QuantType::W4_AFP8: details = getLayoutDetailsForArchAndQuantType(); break; + case QuantType::W2_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; default: TLLM_THROW("Unsupported quantization type"); } return details; @@ -173,6 +176,12 @@ std::vector get_permutation_map(QuantType quant_type) return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; } + else if (quant_type == QuantType::W2_A16) + { + return {0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57, 2, 3, 10, 11, 18, 19, 26, 27, 34, 35, + 42, 43, 50, 51, 58, 59, 4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61, 6, 7, 14, 15, 22, + 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63}; + } else { TLLM_THROW("Invalid quantization type for LDSM permutation"); @@ -350,6 +359,32 @@ void subbyte_transpose_impl( } } } + else if constexpr (bits_per_elt == 2) + { + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) + { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; + + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0x3 & (cache_buf[ii][jj_byte] >> (2 * jj_bit_offset)); + uint8_t tgt_elt = 0x3 & (cache_buf[jj][ii_byte] >> (2 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (~(0x3 << (2 * jj_bit_offset))); + cache_buf[jj][ii_byte] &= (~(0x3 << (2 * ii_bit_offset))); + + cache_buf[ii][jj_byte] |= (tgt_elt << (2 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (2 * ii_bit_offset)); + } + } + } else { TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type."); @@ -400,6 +435,10 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quanti { subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); } + else if (quant_type == QuantType::W2_A16) + { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else { TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); @@ -485,6 +524,70 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz } } +void add_bias_and_interleave_int2s_inplace(int8_t* packed_int2_tensor, const size_t num_elts) +{ + const int num_bytes = num_elts / 4; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) + { + int8_t transformed_packed_int2s = 0; + int8_t transformed_elt0 + = (int8_t(packed_int2_tensor[ii] << 6) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt1 + = (int8_t(packed_int2_tensor[ii] << 4) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt2 + = (int8_t(packed_int2_tensor[ii] << 2) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt3 = (packed_int2_tensor[ii] >> 6) + 2; + + TLLM_CHECK_WITH_INFO( + transformed_elt0 >= 0 && transformed_elt0 <= 3, "Illegal result for int2 transform (elt0)"); + TLLM_CHECK_WITH_INFO( + transformed_elt1 >= 0 && transformed_elt1 <= 3, "Illegal result for int2 transform (elt1)"); + TLLM_CHECK_WITH_INFO( + transformed_elt2 >= 0 && transformed_elt2 <= 3, "Illegal result for int2 transform (elt2)"); + TLLM_CHECK_WITH_INFO( + transformed_elt3 >= 0 && transformed_elt3 <= 3, "Illegal result for int2 transform (elt3)"); + + // We don't need to mask in these ops since everything should be in the range 0-3 + transformed_packed_int2s |= transformed_elt0; + transformed_packed_int2s |= (transformed_elt1 << 2); + transformed_packed_int2s |= (transformed_elt2 << 4); + transformed_packed_int2s |= (transformed_elt3 << 6); + packed_int2_tensor[ii] = transformed_packed_int2s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt15 ... elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_15 ... elt_5 elt_3 elt_1 elt_14 ... elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + TLLM_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int2 tensor must be a multiple of 16 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int2_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) + { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 16; ++dest_idx) + { + const int src_idx = dest_idx < 8 ? 2 * dest_idx : 2 * (dest_idx - 8) + 1; + const int src_shift = 2 * src_idx; + const int dest_shift = 2 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0x3; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) { if (quant_type == QuantType::W8_A16) @@ -499,6 +602,10 @@ void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size // for conversion to FP16. add_bias_and_interleave_int4s_inplace(tensor, num_elts); } + else if (quant_type == QuantType::W2_A16) + { + add_bias_and_interleave_int2s_inplace(tensor, num_elts); + } else { TLLM_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h index b12fd7372..5b3e62332 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h @@ -33,7 +33,8 @@ enum class QuantType { W8_A16, W4_A16, - W4_AFP8 + W4_AFP8, + W2_A16 }; constexpr int get_weight_quant_bits(QuantType quant_type) @@ -43,6 +44,7 @@ constexpr int get_weight_quant_bits(QuantType quant_type) case QuantType::W8_A16: return 8; case QuantType::W4_A16: return 4; case QuantType::W4_AFP8: return 4; + case QuantType::W2_A16: return 2; default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1; } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h index 0ec8ab2e3..501ff6526 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h @@ -18,7 +18,9 @@ #include #include +#if defined(ENABLE_FP8) #include +#endif #include "cutlass/bfloat16.h" #include "cutlass/float8.h" diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu new file mode 100644 index 000000000..a5b523ea4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu new file mode 100644 index 000000000..60d2e0903 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu new file mode 100644 index 000000000..1d00561be --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu new file mode 100644 index 000000000..0966ed761 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu new file mode 100644 index 000000000..20f17fadb --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu new file mode 100644 index 000000000..a000397f8 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 0d32045eb..aba7a7468 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -78,7 +78,8 @@ void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, ""); // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. @@ -263,6 +264,7 @@ void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, Sca + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); } +#ifdef ENABLE_FP8 else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { @@ -271,6 +273,7 @@ void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, Sca + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); } +#endif else { generic_mixed_gemm_kernelLauncher constexpr bool is_fp8() { +#ifdef ENABLE_FP8 return std::is_same_v || std::is_same_v; +#else + return false; +#endif } template Date: Mon, 13 May 2024 14:13:38 +0900 Subject: [PATCH 10/28] refactoring offset --- .../kernels/weightOnlyBatchedGemv/kernel.h | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 2273101c2..70486b3ee 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -58,20 +58,21 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; - int const real_offset_n = interleaved_offset_n * Details::kInterleave - + (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave; + int const blk_offset_n = interleaved_offset_n * Details::kInterleave; + int const thr_offset_n = (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave; int const real_offset_k = (tid * StepK / (Details::kInterleave * Details::LayoutDeatils::kTileSize)) * Details::LayoutDeatils::kTileSize + ((tid * StepK) % Details::LayoutDeatils::kTileSize); + int const offset_k_group = (GroupSize != 0 ? real_offset_k / GroupSize : 0); extern __shared__ TypeA shmem_sz[]; // dimension of each is [kInterleave * CtaN, {Threads * kInterleave * kThreadsPerInterleavedTile / GroupSize or 1}] - TypeA* vec_scale = shmem_sz + CtaN * Details::kInterleave * (GroupSize != 0 ? real_offset_k / GroupSize : 0); + TypeA* vec_scale = shmem_sz + CtaN * Details::kInterleave * offset_k_group; TypeA* vec_zero = nullptr; if constexpr (EnableZero) { vec_zero = shmem_sz + CtaN * Details::kInterleave * (GroupSize != 0 ? Threads * Details::kInterleave * Details::kThreadsPerInterleavedTile / GroupSize : 1); - vec_zero += CtaN * Details::kInterleave * (GroupSize != 0 ? real_offset_k / GroupSize : 0); + vec_zero += CtaN * Details::kInterleave * offset_k_group; } GMemIterator act_iterator( act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); @@ -81,14 +82,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); SHMemIterator scales_iterator(scales, - (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + interleaved_offset_n * Details::kInterleave, - vec_scale, (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave, + offset_k_group * n + blk_offset_n, vec_scale, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) ); SHMemIterator zeros_iterator(zeros, - (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + interleaved_offset_n * Details::kInterleave, - vec_zero, (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave, + offset_k_group * n + blk_offset_n, vec_zero, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) ); @@ -142,10 +141,9 @@ void exec_kernel(Params& params, cudaStream_t s) } dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); dim3 block(Threads); - const int shmem_size = - (EnableZero ? 2 : 1) * CtaN * Details::kInterleave * sizeof(T) + int const shmem_size = (EnableZero ? 2 : 1) * CtaN * Details::kInterleave * sizeof(T) * (GroupSize != 0 ? Threads * Details::kInterleave * Details::kThreadsPerInterleavedTile / GroupSize : 1); - // clang-format off + // clang-format off kernel<<< grid, block, shmem_size, s>>>( reinterpret_cast(params.act), From cb76c98be2ea1f8831b68f7eff5ec10cb51152d1 Mon Sep 17 00:00:00 2001 From: dasistwo Date: Wed, 31 Jul 2024 08:12:24 +0000 Subject: [PATCH 11/28] Fix GCC 13 compile error --- cpp/CMakeLists.txt | 9 +++++++++ cpp/include/tensorrt_llm/common/stringUtils.h | 1 + requirements.txt | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 84d96061f..88f0a254c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -363,6 +363,15 @@ endif() # "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --generate-line-info") # set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") +# Add the option with CMAKE_CUDA_FLAGS causes error. Add G and Lineinfo. +if(${CMAKE_BUILD_TYPE} MATCHES "Debug") + add_compile_options("$<$:-G>" + "$<$:--host-linker-script>") +elseif(${CMAKE_BUILD_TYPE} MATCHES "RelWithDebInfo") + add_compile_options("$<$:-lineinfo>" + "$<$:--host-linker-script>") +endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss -DENABLE_MULTI_DEVICE=${ENABLE_MULTI_DEVICE}" ) diff --git a/cpp/include/tensorrt_llm/common/stringUtils.h b/cpp/include/tensorrt_llm/common/stringUtils.h index a9b213c74..fbeb73eba 100644 --- a/cpp/include/tensorrt_llm/common/stringUtils.h +++ b/cpp/include/tensorrt_llm/common/stringUtils.h @@ -23,6 +23,7 @@ #include // std::make_unique #include // std::stringstream +#include #include #include #include diff --git a/requirements.txt b/requirements.txt index 550c90aff..5f39bd6f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ sentencepiece>=0.1.99 tensorrt~=10.2.0 # https://github.com/pytorch/pytorch/blob/v2.3.1/version.txt uses 2.3.0a0. # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-05.html#rel-24-05 uses 2.4.0a0. -torch>=2.3.0a0,<=2.4.0a0 +torch>=2.3.0a0,<=2.4.0 nvidia-modelopt~=0.13,<0.14 transformers>=4.38.2 pillow==10.3.0 From c8c643262341f1b8ee081091f4054d5b24e56f5a Mon Sep 17 00:00:00 2001 From: Jaeyoung Choi Date: Thu, 1 Aug 2024 16:40:45 +0900 Subject: [PATCH 12/28] Fix TensorRT layermap error --- tensorrt_llm/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 5aed71389..0df10e921 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -389,7 +389,7 @@ def build_engine(self, network: Network, param._get_weights(), name): raise RuntimeError(f'Failed to set weight: {name}') # This mark_weights_refittable has no side effect when refit_individual is not enabled. - network.trt_network.mark_weights_refittable(name) + # network.trt_network.mark_weights_refittable(name) network._fill_weights() # Build engine From a0f8499e0e491d415d9ec68ca08a2a0a7538d0e8 Mon Sep 17 00:00:00 2001 From: dasistwo Date: Tue, 6 Aug 2024 12:51:40 +0900 Subject: [PATCH 13/28] Fix bug: loading ModelSpec in test --- cpp/tests/CMakeLists.txt | 1 - cpp/tests/runtime/gptSessionTest.cpp | 207 +++++++++++++++++++-------- tensorrt_llm/builder.py | 2 +- 3 files changed, 147 insertions(+), 63 deletions(-) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 8ada97551..e5950f156 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -80,7 +80,6 @@ add_gtest(transposeKVKernelTest runtime/transposeKVKernelTest.cpp) add_gtest(gptDecoderTest runtime/gptDecoderTest.cpp) add_gtest(gptDecoderBatchTest runtime/gptDecoderBatchTest.cpp) add_gtest(gptSessionTest runtime/gptSessionTest.cpp) -target_link_libraries(gptSessionTest PRIVATE modelSpecStatic) add_gtest(memoryUtilsTest common/memoryUtilsTest.cu) if(ENABLE_MULTI_DEVICE EQUAL 1) add_gtest(mpiUtilsTest common/mpiUtilsTest.cpp) diff --git a/cpp/tests/runtime/gptSessionTest.cpp b/cpp/tests/runtime/gptSessionTest.cpp index 6ef3830c8..5ba2051db 100644 --- a/cpp/tests/runtime/gptSessionTest.cpp +++ b/cpp/tests/runtime/gptSessionTest.cpp @@ -19,7 +19,6 @@ #include -#include "modelSpec.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/mpiUtils.h" #include "tensorrt_llm/common/stlUtils.h" @@ -36,10 +35,6 @@ using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace fs = std::filesystem; -using tensorrt_llm::testing::ModelSpec; -using tensorrt_llm::testing::KVCacheType; -using tensorrt_llm::testing::QuantMethod; -using tensorrt_llm::testing::OutputContentType; namespace { @@ -54,10 +49,6 @@ auto const CHATGLM_MODEL_DIR = "chatglm-6b"; auto const CHATGLM2_MODEL_DIR = "chatglm2-6b"; auto const CHATGLM3_MODEL_DIR = "chatglm3-6b"; auto const MAMBA_MODEL_DIR = "mamba-2.8b-hf"; -auto const INPUT_FILE = "input_tokens.npy"; -auto const CHATGLM_INPUT_FILE = "input_tokens_chatglm-6b.npy"; -auto const CHATGLM2_INPUT_FILE = "input_tokens_chatglm2-6b.npy"; -auto const CHATGLM3_INPUT_FILE = "input_tokens_chatglm3-6b.npy"; // Engines need to be generated using cpp/tests/resources/scripts/build_*_engines.py. auto const FP32_GPT_DIR = "fp32-default"; @@ -90,6 +81,77 @@ struct ModelParams ModelIds ids; }; +class ModelSpec +{ +public: + ModelSpec(fs::path modelPath, fs::path resultsFile, nvinfer1::DataType dtype) + : mModelPath{std::move(modelPath)} + , mResultsFile{std::move(resultsFile)} + , mDataType{dtype} + , mUseGptAttentionPlugin{false} + , mUsePackedInput{false} + , mUsePagedKvCache{false} + , mDecoderPerRequest{false} + , mPPSize(1) + , mTPSize(1) + , mRandomEndId(false) + { + } + + ModelSpec& useGptAttentionPlugin() + { + mUseGptAttentionPlugin = true; + return *this; + } + + ModelSpec& usePackedInput() + { + mUsePackedInput = true; + return *this; + } + + ModelSpec& usePagedKvCache() + { + mUsePagedKvCache = true; + return *this; + } + + ModelSpec& useDecoderPerRequest() + { + mDecoderPerRequest = true; + return *this; + } + + ModelSpec& usePipelineParallelism(int ppSize) + { + mPPSize = ppSize; + return *this; + } + + ModelSpec& useTensorParallelism(int tpSize) + { + mTPSize = tpSize; + return *this; + } + + ModelSpec& useRandomEndId() + { + mRandomEndId = true; + return *this; + } + + fs::path mModelPath; + fs::path mResultsFile; + nvinfer1::DataType mDataType; + bool mUseGptAttentionPlugin; + bool mUsePackedInput; + bool mUsePagedKvCache; + bool mDecoderPerRequest; + int mPPSize; + int mTPSize; + bool mRandomEndId; +}; + struct MicroBatchSizes { std::optional ctxMicroBatchSize{std::nullopt}; @@ -124,7 +186,7 @@ void verifyModelConfig(ModelConfig const& modelConfig, ModelSpec const& modelSpe { ASSERT_EQ(modelSpec.mUseGptAttentionPlugin, modelConfig.useGptAttentionPlugin()); ASSERT_EQ(modelSpec.mUsePackedInput, modelConfig.usePackedInput()); - ASSERT_EQ(modelSpec.mKVCacheType == KVCacheType::kPAGED, modelConfig.usePagedKvCache()); + ASSERT_EQ(modelSpec.mUsePagedKvCache, modelConfig.usePagedKvCache()); ASSERT_EQ(modelSpec.mDataType, modelConfig.getDataType()); } @@ -137,9 +199,7 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model ASSERT_TRUE(fs::exists(DATA_PATH)); std::string modelName{isChatGlmTest ? resultsFile.parent_path().parent_path().filename().string() : ""}; - - fs::path inputPath = DATA_PATH / modelSpec.mInputFile; - + fs::path inputPath = DATA_PATH / (isChatGlmTest ? "input_tokens_" + modelName + ".npy" : "input_tokens.npy"); auto const& givenInput = utils::loadNpy(manager, inputPath.string(), MemoryType::kCPU); auto const& inputShape = givenInput->getShape(); ASSERT_EQ(inputShape.nbDims, 2); @@ -403,7 +463,7 @@ std::string generateTestName(testing::TestParamInfo const& info) name.append("AttentionPlugin"); if (modelSpec.mUsePackedInput) name.append("Packed"); - if (modelSpec.mKVCacheType == KVCacheType::kPAGED) + if (modelSpec.mUsePagedKvCache) name.append("PagedKvCache"); if (modelSpec.mDecoderPerRequest) name.append("DecoderBatch"); @@ -451,10 +511,10 @@ TEST_P(ParamTest, Test) std::ostringstream gpuSizePath; gpuSizePath << "tp" << modelSpec.mTPSize << "-pp" << modelSpec.mPPSize << "-gpu"; - auto const modelPath{ENGINE_PATH / modelDir / modelSpec.getModelPath() / gpuSizePath.str()}; + auto const modelPath{ENGINE_PATH / modelDir / modelSpec.mModelPath / gpuSizePath.str()}; auto const resultsPath = DATA_PATH / modelDir / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth)); - fs::path const resultsFile{resultsPath / modelSpec.getResultsFile()}; + fs::path const resultsFile{resultsPath / modelSpec.mResultsFile}; // Warning: This should be the last check before running the test. // It will initialize MPI which can take significant time. @@ -472,10 +532,11 @@ INSTANTIATE_TEST_SUITE_P(GptSessionOtbTest, ParamTest, testing::Combine(testing::Values(ModelParams{GPT_MODEL_DIR, {50256, 50256}}), testing::Values( // single decoder - ModelSpec{INPUT_FILE, nvinfer1::DataType::kFLOAT}, ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}, + ModelSpec{FP32_GPT_DIR, FP32_RESULT_FILE, nvinfer1::DataType::kFLOAT}, + ModelSpec{FP16_GPT_DIR, FP16_RESULT_FILE, nvinfer1::DataType::kHALF}, // decoderBatch - ModelSpec{INPUT_FILE, nvinfer1::DataType::kFLOAT}.useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useDecoderPerRequest() + ModelSpec{FP32_GPT_DIR, FP32_RESULT_FILE, nvinfer1::DataType::kFLOAT}.useDecoderPerRequest(), + ModelSpec{FP16_GPT_DIR, FP16_RESULT_FILE, nvinfer1::DataType::kHALF}.useDecoderPerRequest() ), testing::Values(1), // beamWidth @@ -492,30 +553,40 @@ INSTANTIATE_TEST_SUITE_P(GptSessionTest, ParamTest, // Disabled because of flakey beam search test // ModelSpec{FP32_GPT_ATTENTION_DIR, FP32_PLUGIN_RESULT_FILE, nvinfer1::DataType::kFLOAT} // .useGptAttentionPlugin(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput().setKVCacheType( - KVCacheType::kPAGED), + ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin(), + ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() + .usePackedInput(), + ModelSpec{ + FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() + .usePackedInput() + .usePagedKvCache(), // decoderBatch // Disabled because of flakey beam search test // ModelSpec{FP32_GPT_ATTENTION_DIR, FP32_PLUGIN_RESULT_FILE, nvinfer1::DataType::kFLOAT} // .useGptAttentionPlugin() // .useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() + .useDecoderPerRequest(), + ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() .useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{ + FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .setKVCacheType(KVCacheType::kPAGED) + .usePagedKvCache() .useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{ + FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .setKVCacheType(KVCacheType::kPAGED) + .usePagedKvCache() .useDecoderPerRequest() .useRandomEndId() @@ -531,20 +602,29 @@ INSTANTIATE_TEST_SUITE_P(GptjSessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{GPTJ_MODEL_DIR, {50256, 50256}}), testing::Values( // single decoder - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput().setKVCacheType( - KVCacheType::kPAGED), + ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin(), + ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() + .usePackedInput(), + ModelSpec{ + FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() + .usePackedInput() + .usePagedKvCache(), // decoderBatch - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() + .useDecoderPerRequest(), + ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() .useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{ + FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .setKVCacheType(KVCacheType::kPAGED) + .usePagedKvCache() .useDecoderPerRequest() ), @@ -559,7 +639,7 @@ INSTANTIATE_TEST_SUITE_P(MambaSessionOOTBTest, ParamTest, testing::Combine(testing::Values(ModelParams{MAMBA_MODEL_DIR, {0, 1}}), testing::Values( // single decoder - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}), + ModelSpec{FP16_GPT_DIR, FP16_RESULT_FILE, nvinfer1::DataType::kHALF}), testing::Values(1), // beamWidth testing::Values(false), // cudaGraphMode testing::Values(MicroBatchSizes()), @@ -568,7 +648,7 @@ INSTANTIATE_TEST_SUITE_P(MambaSessionOOTBTest, ParamTest, generateTestName); INSTANTIATE_TEST_SUITE_P(MambaSessionPluginTest, ParamTest, testing::Combine(testing::Values(ModelParams{MAMBA_MODEL_DIR, {0, 1}}), - testing::Values(ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useMambaPlugin()), + testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF}), testing::Values(1), // beamWidth testing::Values(false), // cudaGraphMode testing::Values(MicroBatchSizes()), @@ -580,30 +660,37 @@ INSTANTIATE_TEST_SUITE_P(LlamaSessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{LLAMA_MODEL_DIR, {2, 2}}), testing::Values( // single decoder - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput().setKVCacheType( - KVCacheType::kPAGED), + ModelSpec{ + FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() + .usePackedInput() + .usePagedKvCache(), // decoderBatch - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{ + FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .setKVCacheType(KVCacheType::kPAGED) + .usePagedKvCache() .useDecoderPerRequest(), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP4_FILE, + nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .setKVCacheType(KVCacheType::kPAGED) + .usePagedKvCache() .useDecoderPerRequest() .usePipelineParallelism(4), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_TP4_PP1_FILE, + nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .setKVCacheType(KVCacheType::kPAGED) + .usePagedKvCache() .useDecoderPerRequest() .useTensorParallelism(4), - ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP2_FILE, + nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .setKVCacheType(KVCacheType::kPAGED) + .usePagedKvCache() .useDecoderPerRequest() .usePipelineParallelism(2) .useTensorParallelism(2) @@ -618,7 +705,8 @@ INSTANTIATE_TEST_SUITE_P(LlamaSessionTest, ParamTest, INSTANTIATE_TEST_SUITE_P(ChatGlmSessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{CHATGLM_MODEL_DIR, {130005, 3}}), // end_id, pad_id - testing::Values(ModelSpec{CHATGLM_INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin() + testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() ), testing::Values(1, 2), // beamWidth @@ -630,7 +718,8 @@ INSTANTIATE_TEST_SUITE_P(ChatGlmSessionTest, ParamTest, INSTANTIATE_TEST_SUITE_P(ChatGlm2SessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{CHATGLM2_MODEL_DIR, {2, 0}}), // end_id, pad_id - testing::Values(ModelSpec{CHATGLM2_INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin() + testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() ), testing::Values(1, 2), // beamWidth @@ -642,7 +731,8 @@ INSTANTIATE_TEST_SUITE_P(ChatGlm2SessionTest, ParamTest, INSTANTIATE_TEST_SUITE_P(ChatGlm3SessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{CHATGLM3_MODEL_DIR, {2, 0}}), // end_id, pad_id - testing::Values(ModelSpec{CHATGLM3_INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin() + testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} + .useGptAttentionPlugin() ), testing::Values(1, 2), // beamWidth @@ -663,12 +753,11 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16WithAttentionPlugin) auto const engineDir = "llama_7bf_outputs_tp1"; auto const modelPath{ENGINE_PATH / modelDir / engineDir}; SizeType32 constexpr beamWidth{1}; + fs::path resultsFile{DATA_PATH / modelDir / FP16_RESULT_FILE}; auto const batchSizes = {8}; auto constexpr dtype = nvinfer1::DataType::kHALF; - auto otherModelSpecPtr = std::make_shared(INPUT_FILE, dtype); - auto const modelSpec = ModelSpec{INPUT_FILE, dtype}.useGptAttentionPlugin(); - fs::path resultsFile{DATA_PATH / modelDir / modelSpec.getResultsFile()}; + auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin(); auto const modeIds = ModelIds{2, 2}; testGptSession( @@ -681,15 +770,11 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16AttentionPluginDecoderBatch) auto const modelDir = "llamav2"; auto const modelPath{ENGINE_PATH / modelDir}; SizeType32 constexpr beamWidth{1}; + fs::path resultsFile{DATA_PATH / modelDir / FP16_RESULT_FILE}; auto const batchSizes = {8}; auto constexpr dtype = nvinfer1::DataType::kHALF; - auto otherModelSpecPtr = std::make_shared(INPUT_FILE, dtype); - auto const modelSpec = ModelSpec{INPUT_FILE, dtype, otherModelSpecPtr} - .useGptAttentionPlugin() - .usePackedInput() - .useDecoderPerRequest(); - fs::path resultsFile{DATA_PATH / modelDir / modelSpec.getResultsFile()}; + auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin().usePackedInput().useDecoderPerRequest(); auto const modeIds = ModelIds{2, 2}; testGptSession( diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 0df10e921..5aed71389 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -389,7 +389,7 @@ def build_engine(self, network: Network, param._get_weights(), name): raise RuntimeError(f'Failed to set weight: {name}') # This mark_weights_refittable has no side effect when refit_individual is not enabled. - # network.trt_network.mark_weights_refittable(name) + network.trt_network.mark_weights_refittable(name) network._fill_weights() # Build engine From a6fe44dac6988b99705391d6460c71da5d0baa52 Mon Sep 17 00:00:00 2001 From: dasistwo Date: Tue, 6 Aug 2024 16:47:26 +0900 Subject: [PATCH 14/28] Fix L1 shared bank conflict --- .../kernels/weightOnlyBatchedGemv/kernel.h | 4 +-- .../kernels/weightOnlyBatchedGemv/utility.h | 28 +++++++++++-------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 1d6d4f107..ff501bb0f 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -81,12 +81,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca GMemIterator weight_iterator(weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); - SHMemIterator scales_iterator(scales, + SHMemIterator scales_iterator(scales, offset_k_group * n + blk_offset_n, vec_scale, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) ); - SHMemIterator zeros_iterator(zeros, + SHMemIterator zeros_iterator(zeros, offset_k_group * n + blk_offset_n, vec_zero, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 54e3b2853..cc511c678 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -133,6 +133,7 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, reinterpret_cast(w) + n * K); Type2 vec_scale, vec_zero; + __pipeline_wait_prior(0); if constexpr (ApplyAlphaInAdvance) { vec_scale = MathWrapper::to_vec2( @@ -307,7 +308,7 @@ class GMemIterator int stride_; }; -template +template class SHMemIterator { public: @@ -327,17 +328,20 @@ class SHMemIterator { if constexpr (Enable) { -#pragma unroll - for (int i = threadIdx.x % sz_size_; i < Continuous; i += sz_size_) { - __pipeline_memcpy_async( - reinterpret_cast(sh_addr_) + i, - reinterpret_cast(g_addr_ + iter * step_) + i, - sizeof(TVec), - 0 - ); - __pipeline_commit(); - } - __pipeline_wait_prior(0); + const int loadunit = 128 / sizeof(T); + int chunks = Continuous / sz_size_; + int i = (threadIdx.x % sz_size_) * chunks; + while (chunks >= loadunit) { + __pipeline_memcpy_async(sh_addr_ + i, g_addr_ + iter * step_ + i, + sizeof(T) * loadunit, 0); + chunks -= loadunit; + i += loadunit; + } + if (chunks > 0) { + __pipeline_memcpy_async(sh_addr_ + i, g_addr_ + iter * step_ + i, + sizeof(T) * chunks, 0); + } + __pipeline_commit(); } } From bfa1b747bb8cf853d0a31189bf8997c390f3c66a Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Fri, 16 Aug 2024 11:31:38 +0900 Subject: [PATCH 15/28] Revoke private changes --- cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h | 1 + cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h | 2 +- cpp/tensorrt_llm/plugins/CMakeLists.txt | 2 +- docker/Dockerfile.multi | 2 +- docker/Makefile | 4 ++-- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index ff501bb0f..a88dad6b2 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -109,6 +109,7 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca scales_iterator.load(iter); zeros_iterator.load(iter); + __pipeline_wait_prior(0); act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index cc511c678..13bf4f366 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -133,7 +133,6 @@ __device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* sca Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, reinterpret_cast(w) + n * K); Type2 vec_scale, vec_zero; - __pipeline_wait_prior(0); if constexpr (ApplyAlphaInAdvance) { vec_scale = MathWrapper::to_vec2( @@ -331,6 +330,7 @@ class SHMemIterator const int loadunit = 128 / sizeof(T); int chunks = Continuous / sz_size_; int i = (threadIdx.x % sz_size_) * chunks; + // __pipeline_memcpy_async is implemented as a synced version in sm < 80 while (chunks >= loadunit) { __pipeline_memcpy_async(sh_addr_ + i, g_addr_ + iter * step_ + i, sizeof(T) * loadunit, 0); diff --git a/cpp/tensorrt_llm/plugins/CMakeLists.txt b/cpp/tensorrt_llm/plugins/CMakeLists.txt index 9f51ab918..5af9b78ee 100755 --- a/cpp/tensorrt_llm/plugins/CMakeLists.txt +++ b/cpp/tensorrt_llm/plugins/CMakeLists.txt @@ -113,7 +113,7 @@ else() ${PLUGIN_SHARED_TARGET} PROPERTIES LINK_FLAGS - "-Wl,--no-relax -Wl,--exclude-libs,ALL -Wl,--version-script=${PLUGIN_EXPORT_MAP} -Wl,-rpath,'$ORIGIN' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" + "-Wl,--exclude-libs,ALL -Wl,--version-script=${PLUGIN_EXPORT_MAP} -Wl,-rpath,'$ORIGIN' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}" ) endif() diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index e683fec17..db3509bd2 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -68,7 +68,7 @@ COPY setup.py requirements.txt requirements-dev.txt ./ RUN mkdir -p /root/.cache/pip /root/.cache/ccache ENV CCACHE_DIR=/root/.cache/ccache # Build the TRT-LLM wheel -ARG BUILD_WHEEL_ARGS="--clean --trt_root /usr/local/tensorrt --nvtx -a "80-real" --python_bindings --benchmarks" +ARG BUILD_WHEEL_ARGS="--clean --trt_root /usr/local/tensorrt --python_bindings --benchmarks" RUN --mount=type=cache,target=/root/.cache/pip --mount=type=cache,target=/root/.cache/ccache \ python3 scripts/build_wheel.py ${BUILD_WHEEL_ARGS} diff --git a/docker/Makefile b/docker/Makefile index 0ec836b84..1c9066274 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -99,8 +99,8 @@ endef @echo "Pulling docker image: $(IMAGE_WITH_TAG)" docker pull $(IMAGE_WITH_TAG) -DOCKER_RUN_OPTS ?= -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -DOCKER_RUN_ARGS ?= --volume /data/storage1/model:/mnt/model --privileged=true +DOCKER_RUN_OPTS ?= --rm -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 +DOCKER_RUN_ARGS ?= GPU_OPTS ?= --gpus=all SOURCE_DIR ?= $(shell readlink -f ..) CODE_DIR ?= /code/tensorrt_llm From 5971baf0e74bbb972adf152149300a985e04de88 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Tue, 27 Aug 2024 17:05:34 +0900 Subject: [PATCH 16/28] Refactor & Revoke commit 'Fix L1 shared bank conflict' Found some error cases with unit test cases with small load chunks. --- .../kernels/weightOnlyBatchedGemv/kernel.h | 61 ++++++++++++------- .../kernels/weightOnlyBatchedGemv/utility.h | 51 ++++++++-------- 2 files changed, 64 insertions(+), 48 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index a88dad6b2..7ed8548ce 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -55,6 +55,7 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca } int const origin_k = k, interleaved_k = k * Details::kInterleave; + int const sh_k_iter = GroupSize != 0 ? interleaved_k / CtaK : 1; int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; @@ -64,15 +65,19 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca = (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize + ((tid * StepK) % Details::LayoutDetails::kTileSize); int const offset_k_group = (GroupSize != 0 ? real_offset_k / GroupSize : 0); + // number of threads with neighboring scale or zero values + int const near_sz_group = (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : 32); + // number of scale or zero values needed in a single block per k iterations + int const sz_per_iter = CtaN * Details::kInterleave; extern __shared__ TypeA shmem_sz[]; - // dimension of each is [kInterleave * CtaN, {Threads * kInterleave * kThreadsPerInterleavedTile / GroupSize or 1}] - TypeA* vec_scale = shmem_sz + CtaN * Details::kInterleave * offset_k_group; + // dimension of each is [sz_per_iter, {Threads / near_sz_group or 1}, {sh_k_iter or 1}] + TypeA* vec_scale = shmem_sz + sz_per_iter * offset_k_group; TypeA* vec_zero = nullptr; if constexpr (EnableZero) { - vec_zero = shmem_sz + CtaN * Details::kInterleave * (GroupSize != 0 ? Threads * Details::kInterleave * Details::kThreadsPerInterleavedTile / GroupSize : 1); - vec_zero += CtaN * Details::kInterleave * offset_k_group; + vec_zero = shmem_sz + sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group * sh_k_iter: 1); + vec_zero += sz_per_iter * offset_k_group; } GMemIterator act_iterator( act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); @@ -81,16 +86,16 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca GMemIterator weight_iterator(weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); - SHMemIterator scales_iterator(scales, - offset_k_group * n + blk_offset_n, vec_scale, thr_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, - (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) - ); - SHMemIterator zeros_iterator(zeros, - offset_k_group * n + blk_offset_n, vec_zero, thr_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave, - (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : CtaN * Details::kInterleave) - ); + SHMemIterator scales_iterator( + scales, offset_k_group * n + blk_offset_n, vec_scale, thr_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), + (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), + Details::kInterleave); + SHMemIterator zeros_iterator( + zeros, offset_k_group * n + blk_offset_n, vec_zero, thr_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), + (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), + Details::kInterleave); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; if constexpr (EnableBias) @@ -101,23 +106,33 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca TypeA tile_acc[CtaM * CtaN]; fill(tile_acc, static_cast(0.f)); +#pragma unroll + for (int iter = 0; iter < sh_k_iter; ++iter) + { + scales_iterator.copy_to_shmem(iter); + zeros_iterator.copy_to_shmem(iter); + } + __pipeline_wait_prior(0); + for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { TypeA vec_act_scale[StepK]; + TypeA vec_scale[CtaN], vec_zero[CtaN]; TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; - - scales_iterator.load(iter); - zeros_iterator.load(iter); - __pipeline_wait_prior(0); +#pragma unroll + for (int i = 0; i < CtaN; ++i) + { + scales_iterator.load(vec_scale + i, iter, i); + zeros_iterator.load(vec_zero + i, iter, i); + } act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) { weight_iterator.load(tile_w_quantized, iter, i); dequantize( - tile_w, tile_w_quantized, scales_iterator.iter(i), - zeros_iterator.iter(i), alpha); + tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); pack_to_vec2(tile_w_pack2, tile_w, i); } #pragma unroll @@ -142,11 +157,11 @@ void exec_kernel(Params& params, cudaStream_t s) } dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); dim3 block(Threads); - int const shmem_size = (EnableZero ? 2 : 1) * CtaN * Details::kInterleave * sizeof(T) + int const shmem_size = (EnableZero ? 2 : 1) * CtaN * Details::kInterleave * sizeof(T) + * (GroupSize != 0 ? params.k * Details::kInterleave / Details::kStepK / Threads : 1) * (GroupSize != 0 ? Threads * Details::kInterleave * Details::kThreadsPerInterleavedTile / GroupSize : 1); // clang-format off - kernel<<< - grid, block, shmem_size, s>>>( + kernel<<>>( reinterpret_cast(params.act), reinterpret_cast(params.act_scale), reinterpret_cast(params.weight), diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 13bf4f366..09c36ef4e 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -301,65 +301,66 @@ class GMemIterator } } + __device__ __forceinline__ void copy_to_shmem(int iter) + { + if constexpr (Enable) + { + // Do nothing + } + } + private: T* addr_; int step_; int stride_; }; -template +template class SHMemIterator { public: __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, T* sh_addr, int sh_offset, - int step, int stride, int sz_size) + int g_step, int sh_step, int stride) : g_addr_(Enable ? (g_addr + g_offset) : nullptr) , sh_addr_(Enable ? sh_addr : nullptr) , sh_offset_(sh_offset) - , step_(step) + , g_step_(g_step) + , sh_step_(sh_step) , stride_(stride) - , sz_size_(sz_size) { } - __device__ __forceinline__ void load(int iter) + __device__ __forceinline__ void copy_to_shmem(int iter) { if constexpr (Enable) { - const int loadunit = 128 / sizeof(T); - int chunks = Continuous / sz_size_; - int i = (threadIdx.x % sz_size_) * chunks; - // __pipeline_memcpy_async is implemented as a synced version in sm < 80 - while (chunks >= loadunit) { - __pipeline_memcpy_async(sh_addr_ + i, g_addr_ + iter * step_ + i, - sizeof(T) * loadunit, 0); - chunks -= loadunit; - i += loadunit; - } - if (chunks > 0) { - __pipeline_memcpy_async(sh_addr_ + i, g_addr_ + iter * step_ + i, - sizeof(T) * chunks, 0); - } - __pipeline_commit(); +#pragma unroll + for (int i = threadIdx.x % Strided; i < Continuous; i += Strided) { + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + iter * sh_step_) + i, + reinterpret_cast(g_addr_ + iter * g_step_) + i, + sizeof(TVec) + ); + } + __pipeline_commit(); } } - __device__ __forceinline__ T* iter(int ii = 0) + __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) { if constexpr (Enable) { - return &sh_addr_[ii * stride_ + sh_offset_]; + reinterpret_cast(dst)[0] = sh_addr_[iter * sh_step_ + ii * stride_ + sh_offset_]; } - return nullptr; } private: T* g_addr_; - int step_; + int g_step_; + int sh_step_; T* sh_addr_; int sh_offset_; int stride_; - int sz_size_; }; } // namespace weight_only From d5ecf9283cb76c35e9e43a92f08e3032ad3b6961 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Wed, 28 Aug 2024 13:39:26 +0900 Subject: [PATCH 17/28] Copy to shared memory within K iteration --- .../kernels/weightOnlyBatchedGemv/kernel.h | 17 +++---- .../kernels/weightOnlyBatchedGemv/utility.h | 51 +++++++++++++++---- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 7ed8548ce..716d21bef 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -86,16 +86,16 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca GMemIterator weight_iterator(weight, (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW); - SHMemIterator scales_iterator( + SHMemIterator scales_iterator( scales, offset_k_group * n + blk_offset_n, vec_scale, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave); - SHMemIterator zeros_iterator( + Details::kInterleave, sh_k_iter); + SHMemIterator zeros_iterator( zeros, offset_k_group * n + blk_offset_n, vec_zero, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave); + Details::kInterleave, sh_k_iter); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; if constexpr (EnableBias) @@ -106,13 +106,8 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca TypeA tile_acc[CtaM * CtaN]; fill(tile_acc, static_cast(0.f)); -#pragma unroll - for (int iter = 0; iter < sh_k_iter; ++iter) - { - scales_iterator.copy_to_shmem(iter); - zeros_iterator.copy_to_shmem(iter); - } - __pipeline_wait_prior(0); + scales_iterator.copy_to_shmem(); + zeros_iterator.copy_to_shmem(); for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 09c36ef4e..2ae1f7b5a 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -320,30 +320,62 @@ class SHMemIterator { public: __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, T* sh_addr, int sh_offset, - int g_step, int sh_step, int stride) + int g_step, int sh_step, int stride, int sh_iter) : g_addr_(Enable ? (g_addr + g_offset) : nullptr) , sh_addr_(Enable ? sh_addr : nullptr) , sh_offset_(sh_offset) , g_step_(g_step) , sh_step_(sh_step) , stride_(stride) + , sh_iter_(sh_iter) { } - __device__ __forceinline__ void copy_to_shmem(int iter) + __device__ __forceinline__ void copy_to_shmem() + // __pipeline_memcpy_async will use synced version in sm < 80 { if constexpr (Enable) { + if constexpr (Continuous < sizeof(TVec) / sizeof(T)) + { // Uncommon slow case + int const c = Continuous * sizeof(T); + static_assert(c % 4 == 0); + int const s = threadIdx.x % Strided; +#pragma unroll + for (int iter = 0; iter < sh_iter_; iter += Strided) + { + if (s < sh_iter_) + { + __pipeline_memcpy_async( + sh_addr_ + (iter + s) * sh_step_, + g_addr_ + (iter + s) * g_step_, + c + ); + __pipeline_commit(); + } + } + } + else + { + int const c = Continuous * sizeof(T) / sizeof(TVec); + int const s = (threadIdx.x % Strided) / c; + int const i = threadIdx.x % c; #pragma unroll - for (int i = threadIdx.x % Strided; i < Continuous; i += Strided) { - __pipeline_memcpy_async( - reinterpret_cast(sh_addr_ + iter * sh_step_) + i, - reinterpret_cast(g_addr_ + iter * g_step_) + i, - sizeof(TVec) - ); + for (int iter = 0; iter < sh_iter_; iter += Strided / c) + { + if (s < sh_iter_) + { + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + (iter + s) * sh_step_) + i, + reinterpret_cast(g_addr_ + (iter + s) * g_step_ ) + i, + sizeof(TVec) + ); + __pipeline_commit(); + } + } } - __pipeline_commit(); + __pipeline_wait_prior(0); } } @@ -361,6 +393,7 @@ class SHMemIterator T* sh_addr_; int sh_offset_; int stride_; + int sh_iter_; }; } // namespace weight_only From c2ccb90a4e3683ffb6223f7285c23a7e62dbdca8 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:42:15 +0900 Subject: [PATCH 18/28] Refactoring & Apply double buffering for weight --- .../kernels/weightOnlyBatchedGemv/kernel.h | 80 ++++++++++++------ .../kernels/weightOnlyBatchedGemv/utility.h | 83 +++++++++++-------- 2 files changed, 101 insertions(+), 62 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 716d21bef..4f7f05db6 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/converter.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h" +#define W_STAGE 2 namespace tensorrt_llm { @@ -55,7 +56,6 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca } int const origin_k = k, interleaved_k = k * Details::kInterleave; - int const sh_k_iter = GroupSize != 0 ? interleaved_k / CtaK : 1; int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; @@ -69,33 +69,39 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca int const near_sz_group = (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : 32); // number of scale or zero values needed in a single block per k iterations int const sz_per_iter = CtaN * Details::kInterleave; + // number of k-iterations which would be loaded in a single shared memory block load + int const sh_sz_group = (GroupSize != 0 ? near_sz_group / (sz_per_iter > (sizeof(AccessTypeA) / sizeof(TypeA)) ? + sz_per_iter / (sizeof(AccessTypeA) / sizeof(TypeA)) : 1) : 1); - extern __shared__ TypeA shmem_sz[]; - // dimension of each is [sz_per_iter, {Threads / near_sz_group or 1}, {sh_k_iter or 1}] - TypeA* vec_scale = shmem_sz + sz_per_iter * offset_k_group; - TypeA* vec_zero = nullptr; + __shared__ TypeA shmem_sz[sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group : 1) * sh_sz_group * (EnableZero ? 2 : 1)]; + __shared__ uint8_t shmem_w[CtaK / Details::kElemsPerByteW * CtaN * W_STAGE]; + + TypeA* sh_scale = shmem_sz + sz_per_iter * offset_k_group; + TypeA* sh_zero = nullptr; if constexpr (EnableZero) { - vec_zero = shmem_sz + sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group * sh_k_iter: 1); - vec_zero += sz_per_iter * offset_k_group; + sh_zero = shmem_sz + sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group * sh_sz_group : 1); + sh_zero += sz_per_iter * offset_k_group; } + GMemIterator act_iterator( act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); GMemIterator act_scale_iterator( act_scale, real_offset_k, CtaK / Details::kInterleave, 0); - GMemIterator weight_iterator(weight, - (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, - interleaved_k / Details::kElemsPerByteW); - SHMemIterator scales_iterator( - scales, offset_k_group * n + blk_offset_n, vec_scale, thr_offset_n, + SHMemIterator weight_iterator( + weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, + CtaK / Details::kElemsPerByteW, CtaK * CtaN / Details::kElemsPerByteW, + interleaved_k / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / CtaK); + SHMemIterator scales_iterator( + scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave, sh_k_iter); - SHMemIterator zeros_iterator( - zeros, offset_k_group * n + blk_offset_n, vec_zero, thr_offset_n, + Details::kInterleave, Details::kInterleave, (GroupSize != 0 ? interleaved_k / CtaK : 1)); + SHMemIterator zeros_iterator( + zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave, sh_k_iter); + Details::kInterleave, Details::kInterleave, (GroupSize != 0 ? interleaved_k / CtaK : 1)); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; if constexpr (EnableBias) @@ -103,29 +109,54 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca bias += tile_id_n * CtaN * Details::kInterleave; } +// Prefetch stage 0 +#pragma unroll + for (int i = 0; i < CtaN; ++i) + { + weight_iterator.copy_to_shmem(0, 0, i); + } + __pipeline_commit(); + TypeA tile_acc[CtaM * CtaN]; fill(tile_acc, static_cast(0.f)); - scales_iterator.copy_to_shmem(); - zeros_iterator.copy_to_shmem(); - for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { TypeA vec_act_scale[StepK]; TypeA vec_scale[CtaN], vec_zero[CtaN]; TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; + + if (iter % sh_sz_group == 0) + { + scales_iterator.copy_to_shmem(iter); + zeros_iterator.copy_to_shmem(iter); + __pipeline_commit(); + } + + // Prefetch next stage + if (idx_k + CtaK < interleaved_k) + { +#pragma unroll + for (int i = 0; i < CtaN; ++i) + { + weight_iterator.copy_to_shmem(iter + 1, (iter + 1) % W_STAGE, i); + } + __pipeline_commit(); + } + __pipeline_wait_prior(1); + #pragma unroll for (int i = 0; i < CtaN; ++i) { - scales_iterator.load(vec_scale + i, iter, i); - zeros_iterator.load(vec_zero + i, iter, i); + scales_iterator.load(vec_scale + i, iter % sh_sz_group, i); + zeros_iterator.load(vec_zero + i, iter % sh_sz_group, i); } act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) { - weight_iterator.load(tile_w_quantized, iter, i); + weight_iterator.load(tile_w_quantized, iter % W_STAGE, i); dequantize( tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); pack_to_vec2(tile_w_pack2, tile_w, i); @@ -152,11 +183,8 @@ void exec_kernel(Params& params, cudaStream_t s) } dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); dim3 block(Threads); - int const shmem_size = (EnableZero ? 2 : 1) * CtaN * Details::kInterleave * sizeof(T) - * (GroupSize != 0 ? params.k * Details::kInterleave / Details::kStepK / Threads : 1) - * (GroupSize != 0 ? Threads * Details::kInterleave * Details::kThreadsPerInterleavedTile / GroupSize : 1); // clang-format off - kernel<<>>( + kernel<<>>( reinterpret_cast(params.act), reinterpret_cast(params.act_scale), reinterpret_cast(params.weight), diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 2ae1f7b5a..02235783b 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -315,85 +315,96 @@ class GMemIterator int stride_; }; -template +template class SHMemIterator { public: __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, T* sh_addr, int sh_offset, - int g_step, int sh_step, int stride, int sh_iter) + int g_step, int sh_step, int g_stride, int sh_stride, float k_max_iter) : g_addr_(Enable ? (g_addr + g_offset) : nullptr) , sh_addr_(Enable ? sh_addr : nullptr) , sh_offset_(sh_offset) , g_step_(g_step) , sh_step_(sh_step) - , stride_(stride) - , sh_iter_(sh_iter) + , g_stride_(g_stride) + , sh_stride_(sh_stride) + , k_max_iter_(k_max_iter) { } - __device__ __forceinline__ void copy_to_shmem() + __device__ __forceinline__ void copy_to_shmem(int g_iter, int sh_iter = 0, int ii = 0) // __pipeline_memcpy_async will use synced version in sm < 80 { if constexpr (Enable) { - if constexpr (Continuous < sizeof(TVec) / sizeof(T)) + if constexpr (Elements < VecSize) { // Uncommon slow case - int const c = Continuous * sizeof(T); + int const c = Elements * sizeof(T); static_assert(c % 4 == 0); - int const s = threadIdx.x % Strided; -#pragma unroll - for (int iter = 0; iter < sh_iter_; iter += Strided) + int const s = threadIdx.x % Grouped; + if (threadIdx.x % Grouped + g_iter <= k_max_iter_) { - if (s < sh_iter_) - { - __pipeline_memcpy_async( - sh_addr_ + (iter + s) * sh_step_, - g_addr_ + (iter + s) * g_step_, - c - ); - __pipeline_commit(); - } + __pipeline_memcpy_async( + sh_addr_ + (sh_iter + s) * sh_step_ + ii * sh_stride_, + g_addr_ + (g_iter + s) * g_step_ + ii * g_stride_ , + c + ); } } else { - int const c = Continuous * sizeof(T) / sizeof(TVec); - int const s = (threadIdx.x % Strided) / c; + int const c = Elements / VecSize; + int const s = threadIdx.x % Grouped / c; int const i = threadIdx.x % c; -#pragma unroll - for (int iter = 0; iter < sh_iter_; iter += Strided / c) + if (threadIdx.x % Grouped / c + g_iter <= k_max_iter_) { - if (s < sh_iter_) - { - __pipeline_memcpy_async( - reinterpret_cast(sh_addr_ + (iter + s) * sh_step_) + i, - reinterpret_cast(g_addr_ + (iter + s) * g_step_ ) + i, - sizeof(TVec) - ); - __pipeline_commit(); - } + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + (sh_iter + s) * sh_step_ + ii * sh_stride_) + i, + reinterpret_cast(g_addr_ + (g_iter + s) * g_step_ + ii * g_stride_ ) + i, + sizeof(TVec) + ); } } - __pipeline_wait_prior(0); } } __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) { if constexpr (Enable) { - reinterpret_cast(dst)[0] = sh_addr_[iter * sh_step_ + ii * stride_ + sh_offset_]; + if constexpr (Continuous < VecSize) + { +#pragma unroll + for (int jj = 0; jj < Continuous; ++jj) + { + reinterpret_cast(dst)[jj] = sh_addr_[iter * sh_step_ + ii * sh_stride_ + sh_offset_ + jj]; + } + } + else + { + static_assert(Continuous % VecSize == 0); + int const c = Continuous / VecSize; +#pragma unroll + for (int jj = 0; jj < c; ++jj) + { + reinterpret_cast(dst)[jj] = reinterpret_cast(sh_addr_ + iter * sh_step_ + ii * sh_stride_ + sh_offset_)[jj]; + } + } } } private: + static constexpr int VecSize = sizeof(TVec) / sizeof(T); + T* g_addr_; int g_step_; int sh_step_; T* sh_addr_; int sh_offset_; - int stride_; - int sh_iter_; + int g_stride_; + int sh_stride_; + // Decimal value represents that the last k iteration will only use a few warps + float k_max_iter_; }; } // namespace weight_only From 1488f3f852cb8c01cf967ecee17b0d7579350438 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 5 Sep 2024 02:01:29 +0000 Subject: [PATCH 19/28] Debug ColumnMajor Test Case --- .../kernels/weightOnlyBatchedGemv/kernel.h | 2 +- .../kernels/weightOnlyBatchedGemv/utility.h | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 4f7f05db6..e73016cec 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -88,7 +88,7 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); GMemIterator act_scale_iterator( act_scale, real_offset_k, CtaK / Details::kInterleave, 0); - SHMemIterator weight_iterator( + SHMemIterator weight_iterator( weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, CtaK * CtaN / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / CtaK); diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 02235783b..d959813d5 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -343,7 +343,7 @@ class SHMemIterator int const c = Elements * sizeof(T); static_assert(c % 4 == 0); int const s = threadIdx.x % Grouped; - if (threadIdx.x % Grouped + g_iter <= k_max_iter_) + if (s + g_iter <= k_max_iter_) { __pipeline_memcpy_async( sh_addr_ + (sh_iter + s) * sh_step_ + ii * sh_stride_, @@ -356,14 +356,18 @@ class SHMemIterator { int const c = Elements / VecSize; int const s = threadIdx.x % Grouped / c; - int const i = threadIdx.x % c; - if (threadIdx.x % Grouped / c + g_iter <= k_max_iter_) + // this for loop is mostly single iteration + for (int i = threadIdx.x % c; i < c; i += blockDim.x * blockDim.y * blockDim.z) { - __pipeline_memcpy_async( - reinterpret_cast(sh_addr_ + (sh_iter + s) * sh_step_ + ii * sh_stride_) + i, - reinterpret_cast(g_addr_ + (g_iter + s) * g_step_ + ii * g_stride_ ) + i, - sizeof(TVec) - ); + // s should be float to compare with k_max_iter_ + if (threadIdx.x % Grouped / c + g_iter <= k_max_iter_) + { + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + (sh_iter + s) * sh_step_ + ii * sh_stride_) + i, + reinterpret_cast(g_addr_ + (g_iter + s) * g_step_ + ii * g_stride_ ) + i, + sizeof(TVec) + ); + } } } } From e3e6d93dc5a1b219718d9d6b06b87f4d4f115522 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 5 Sep 2024 12:39:21 +0900 Subject: [PATCH 20/28] Apply double buffering for Act --- .../kernels/weightOnlyBatchedGemv/kernel.h | 46 ++++++++++++++----- .../kernels/weightOnlyBatchedGemv/utility.h | 6 +-- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index e73016cec..0fbc461d3 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -72,35 +72,45 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca // number of k-iterations which would be loaded in a single shared memory block load int const sh_sz_group = (GroupSize != 0 ? near_sz_group / (sz_per_iter > (sizeof(AccessTypeA) / sizeof(TypeA)) ? sz_per_iter / (sizeof(AccessTypeA) / sizeof(TypeA)) : 1) : 1); - + __shared__ TypeA shmem_sz[sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group : 1) * sh_sz_group * (EnableZero ? 2 : 1)]; __shared__ uint8_t shmem_w[CtaK / Details::kElemsPerByteW * CtaN * W_STAGE]; - + __shared__ TypeA shmem_a[CtaK / Details::kInterleave * W_STAGE * (EnableActScale ? CtaM + 1 : CtaM)]; + TypeA* sh_scale = shmem_sz + sz_per_iter * offset_k_group; TypeA* sh_zero = nullptr; + TypeA* sh_actscale = nullptr; + if constexpr (EnableZero) { sh_zero = shmem_sz + sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group * sh_sz_group : 1); sh_zero += sz_per_iter * offset_k_group; } + + if constexpr (EnableActScale) + { + sh_actscale = shmem_a + CtaK / Details::kInterleave * CtaM * W_STAGE; + } - GMemIterator act_iterator( - act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); - GMemIterator act_scale_iterator( - act_scale, real_offset_k, CtaK / Details::kInterleave, 0); + SHMemIterator act_iterator( + act, 0, shmem_a, offset_m * origin_k + real_offset_k, + CtaK / Details::kInterleave, CtaK * CtaM / Details::kInterleave, + origin_k, CtaK / Details::kInterleave, interleaved_k / CtaK); + SHMemIterator act_scale_iterator( + act_scale, 0, sh_actscale, real_offset_k, + CtaK / Details::kInterleave, CtaK / Details::kInterleave, + 0, 0, interleaved_k / CtaK); SHMemIterator weight_iterator( weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, CtaK * CtaN / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / CtaK); SHMemIterator scales_iterator( scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), - (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), Details::kInterleave, Details::kInterleave, (GroupSize != 0 ? interleaved_k / CtaK : 1)); SHMemIterator zeros_iterator( zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), - (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), Details::kInterleave, Details::kInterleave, (GroupSize != 0 ? interleaved_k / CtaK : 1)); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; @@ -115,6 +125,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca { weight_iterator.copy_to_shmem(0, 0, i); } +#pragma unroll + for (int i = 0; i < CtaM; ++i) + { + act_iterator.copy_to_shmem(0, 0, i); + } + act_scale_iterator.copy_to_shmem(0); __pipeline_commit(); TypeA tile_acc[CtaM * CtaN]; @@ -142,6 +158,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca { weight_iterator.copy_to_shmem(iter + 1, (iter + 1) % W_STAGE, i); } +#pragma unroll + for (int i = 0; i < CtaM; ++i) + { + act_iterator.copy_to_shmem(iter + 1, (iter + 1) % W_STAGE, i); + } + act_scale_iterator.copy_to_shmem(iter + 1, (iter + 1) % W_STAGE); __pipeline_commit(); } __pipeline_wait_prior(1); @@ -152,7 +174,6 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca scales_iterator.load(vec_scale + i, iter % sh_sz_group, i); zeros_iterator.load(vec_zero + i, iter % sh_sz_group, i); } - act_scale_iterator.load(vec_act_scale, iter); #pragma unroll for (int i = 0; i < CtaN; ++i) { @@ -161,10 +182,11 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); pack_to_vec2(tile_w_pack2, tile_w, i); } + act_scale_iterator.load(vec_act_scale, iter % W_STAGE); #pragma unroll for (int i = 0; i < CtaM; ++i) { - act_iterator.load(tile_a, iter, i); + act_iterator.load(tile_a, iter % W_STAGE, i); apply_scale(tile_a, vec_act_scale); mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); } diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index d959813d5..b4b4abb14 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -303,10 +303,7 @@ class GMemIterator __device__ __forceinline__ void copy_to_shmem(int iter) { - if constexpr (Enable) - { - // Do nothing - } + // Do nothing } private: @@ -330,7 +327,6 @@ class SHMemIterator , sh_stride_(sh_stride) , k_max_iter_(k_max_iter) { - } __device__ __forceinline__ void copy_to_shmem(int g_iter, int sh_iter = 0, int ii = 0) From fb8ab20c54f9ff36fed9ba531395074668f651fa Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Thu, 5 Sep 2024 13:46:15 +0900 Subject: [PATCH 21/28] Reduce shared memory size & increase grid size --- .../weightOnlyBatchedGemv/kernelDispatcher.h | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h index 1300732e1..b1a07fd6c 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h @@ -41,24 +41,12 @@ void dispatcher(Params& params, cudaStream_t s) return; \ } \ } while (0); - if constexpr (EnableZero) - { - // clang-format off - DISPATCHER_FOR_M(1, 1, 4, 128); - DISPATCHER_FOR_M(2, 2, 4, 128); - DISPATCHER_FOR_M(3, 3, 4, 128); - DISPATCHER_FOR_M(4, 4, 4, 128); - // clang-format on - } - else - { - // clang-format off - DISPATCHER_FOR_M(1, 1, 8, 128); - DISPATCHER_FOR_M(2, 2, 8, 128); - DISPATCHER_FOR_M(3, 3, 8, 128); - DISPATCHER_FOR_M(4, 4, 8, 128); - // clang-format on - } + // clang-format off + DISPATCHER_FOR_M(1, 1, 4, 128); + DISPATCHER_FOR_M(2, 2, 4, 128); + DISPATCHER_FOR_M(3, 3, 4, 128); + DISPATCHER_FOR_M(4, 4, 4, 128); + // clang-format on throw std::runtime_error("unsupported m"); #undef DISPATCHER_FOR_M } From 599f5e8c906c8ea7b0e5d839c52ec3476772853c Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Fri, 6 Sep 2024 11:42:23 +0900 Subject: [PATCH 22/28] Revert "Increase grid size" & reduce shared memory buffer TODO: Increased instructions hide the ShMem advantage --- .../kernels/weightOnlyBatchedGemv/kernel.h | 72 ++++++++++--------- .../weightOnlyBatchedGemv/kernelDispatcher.h | 24 +++++-- .../kernels/weightOnlyBatchedGemv/utility.h | 14 ++-- 3 files changed, 64 insertions(+), 46 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 0fbc461d3..7859aa61f 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -19,7 +19,6 @@ #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/converter.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h" #include "tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h" -#define W_STAGE 2 namespace tensorrt_llm { @@ -74,8 +73,8 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca sz_per_iter / (sizeof(AccessTypeA) / sizeof(TypeA)) : 1) : 1); __shared__ TypeA shmem_sz[sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group : 1) * sh_sz_group * (EnableZero ? 2 : 1)]; - __shared__ uint8_t shmem_w[CtaK / Details::kElemsPerByteW * CtaN * W_STAGE]; - __shared__ TypeA shmem_a[CtaK / Details::kInterleave * W_STAGE * (EnableActScale ? CtaM + 1 : CtaM)]; + __shared__ uint8_t shmem_w[CtaK / Details::kElemsPerByteW * CtaN]; + __shared__ TypeA shmem_a[CtaK / Details::kInterleave * (EnableActScale ? CtaM + 1 : CtaM)]; TypeA* sh_scale = shmem_sz + sz_per_iter * offset_k_group; TypeA* sh_zero = nullptr; @@ -89,29 +88,31 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca if constexpr (EnableActScale) { - sh_actscale = shmem_a + CtaK / Details::kInterleave * CtaM * W_STAGE; + sh_actscale = shmem_a + CtaK / Details::kInterleave * CtaM; } SHMemIterator act_iterator( - act, 0, shmem_a, offset_m * origin_k + real_offset_k, - CtaK / Details::kInterleave, CtaK * CtaM / Details::kInterleave, - origin_k, CtaK / Details::kInterleave, interleaved_k / CtaK); + act, 0, shmem_a, real_offset_k, + CtaK / Details::kInterleave, 0, origin_k, CtaK / Details::kInterleave, + interleaved_k / CtaK); SHMemIterator act_scale_iterator( act_scale, 0, sh_actscale, real_offset_k, - CtaK / Details::kInterleave, CtaK / Details::kInterleave, - 0, 0, interleaved_k / CtaK); + CtaK / Details::kInterleave, 0, 0, 0, + interleaved_k / CtaK); SHMemIterator weight_iterator( weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, - CtaK / Details::kElemsPerByteW, CtaK * CtaN / Details::kElemsPerByteW, - interleaved_k / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / CtaK); + CtaK / Details::kElemsPerByteW, 0, interleaved_k / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, + interleaved_k / CtaK); SHMemIterator scales_iterator( scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave, Details::kInterleave, (GroupSize != 0 ? interleaved_k / CtaK : 1)); + Details::kInterleave, Details::kInterleave, + (GroupSize != 0 ? interleaved_k / CtaK : 1)); SHMemIterator zeros_iterator( zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave, Details::kInterleave, (GroupSize != 0 ? interleaved_k / CtaK : 1)); + Details::kInterleave, Details::kInterleave, + (GroupSize != 0 ? interleaved_k / CtaK : 1)); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; if constexpr (EnableBias) @@ -123,12 +124,12 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca #pragma unroll for (int i = 0; i < CtaN; ++i) { - weight_iterator.copy_to_shmem(0, 0, i); + weight_iterator.copy_to_shmem(0, i); } #pragma unroll for (int i = 0; i < CtaM; ++i) { - act_iterator.copy_to_shmem(0, 0, i); + act_iterator.copy_to_shmem(0, i); } act_scale_iterator.copy_to_shmem(0); __pipeline_commit(); @@ -140,8 +141,8 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca { TypeA vec_act_scale[StepK]; TypeA vec_scale[CtaN], vec_zero[CtaN]; - TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; - uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; + TypeA tile_a[StepK * CtaM], tile_w[StepK], tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW * CtaN]; if (iter % sh_sz_group == 0) { @@ -149,6 +150,21 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca zeros_iterator.copy_to_shmem(iter); __pipeline_commit(); } + __pipeline_wait_prior(0); + +#pragma unroll + for (int i = 0; i < CtaN; ++i) + { + scales_iterator.load(vec_scale + i, iter % sh_sz_group, i); + zeros_iterator.load(vec_zero + i, iter % sh_sz_group, i); + weight_iterator.load(tile_w_quantized + i * StepK / Details::kElemsPerByteW, 0, i); + } +#pragma unroll + for (int i = 0; i < CtaM; ++i) + { + act_iterator.load(tile_a + i * StepK, 0, i); + } + act_scale_iterator.load(vec_act_scale, 0); // Prefetch next stage if (idx_k + CtaK < interleaved_k) @@ -156,39 +172,29 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca #pragma unroll for (int i = 0; i < CtaN; ++i) { - weight_iterator.copy_to_shmem(iter + 1, (iter + 1) % W_STAGE, i); + weight_iterator.copy_to_shmem(iter + 1, i); } #pragma unroll for (int i = 0; i < CtaM; ++i) { - act_iterator.copy_to_shmem(iter + 1, (iter + 1) % W_STAGE, i); + act_iterator.copy_to_shmem(iter + 1, i); } - act_scale_iterator.copy_to_shmem(iter + 1, (iter + 1) % W_STAGE); + act_scale_iterator.copy_to_shmem(iter + 1); __pipeline_commit(); } - __pipeline_wait_prior(1); #pragma unroll for (int i = 0; i < CtaN; ++i) { - scales_iterator.load(vec_scale + i, iter % sh_sz_group, i); - zeros_iterator.load(vec_zero + i, iter % sh_sz_group, i); - } -#pragma unroll - for (int i = 0; i < CtaN; ++i) - { - weight_iterator.load(tile_w_quantized, iter % W_STAGE, i); dequantize( - tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); + tile_w, tile_w_quantized + i * StepK / Details::kElemsPerByteW, vec_scale + i, vec_zero + i, alpha); pack_to_vec2(tile_w_pack2, tile_w, i); } - act_scale_iterator.load(vec_act_scale, iter % W_STAGE); #pragma unroll for (int i = 0; i < CtaM; ++i) { - act_iterator.load(tile_a, iter % W_STAGE, i); - apply_scale(tile_a, vec_act_scale); - mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); + apply_scale(tile_a + i * StepK, vec_act_scale); + mma(tile_acc + i * CtaN, tile_w_pack2, tile_a + i * StepK); } } epilogue(out, n, tile_acc, bias, alpha); diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h index b1a07fd6c..1300732e1 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h @@ -41,12 +41,24 @@ void dispatcher(Params& params, cudaStream_t s) return; \ } \ } while (0); - // clang-format off - DISPATCHER_FOR_M(1, 1, 4, 128); - DISPATCHER_FOR_M(2, 2, 4, 128); - DISPATCHER_FOR_M(3, 3, 4, 128); - DISPATCHER_FOR_M(4, 4, 4, 128); - // clang-format on + if constexpr (EnableZero) + { + // clang-format off + DISPATCHER_FOR_M(1, 1, 4, 128); + DISPATCHER_FOR_M(2, 2, 4, 128); + DISPATCHER_FOR_M(3, 3, 4, 128); + DISPATCHER_FOR_M(4, 4, 4, 128); + // clang-format on + } + else + { + // clang-format off + DISPATCHER_FOR_M(1, 1, 8, 128); + DISPATCHER_FOR_M(2, 2, 8, 128); + DISPATCHER_FOR_M(3, 3, 8, 128); + DISPATCHER_FOR_M(4, 4, 8, 128); + // clang-format on + } throw std::runtime_error("unsupported m"); #undef DISPATCHER_FOR_M } diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index b4b4abb14..dc41aefcb 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -329,7 +329,7 @@ class SHMemIterator { } - __device__ __forceinline__ void copy_to_shmem(int g_iter, int sh_iter = 0, int ii = 0) + __device__ __forceinline__ void copy_to_shmem(int iter, int ii = 0) // __pipeline_memcpy_async will use synced version in sm < 80 { if constexpr (Enable) @@ -339,11 +339,11 @@ class SHMemIterator int const c = Elements * sizeof(T); static_assert(c % 4 == 0); int const s = threadIdx.x % Grouped; - if (s + g_iter <= k_max_iter_) + if (s + iter <= k_max_iter_) { __pipeline_memcpy_async( - sh_addr_ + (sh_iter + s) * sh_step_ + ii * sh_stride_, - g_addr_ + (g_iter + s) * g_step_ + ii * g_stride_ , + sh_addr_ + s * sh_step_ + ii * sh_stride_, + g_addr_ + (iter + s) * g_step_ + ii * g_stride_ , c ); } @@ -356,11 +356,11 @@ class SHMemIterator for (int i = threadIdx.x % c; i < c; i += blockDim.x * blockDim.y * blockDim.z) { // s should be float to compare with k_max_iter_ - if (threadIdx.x % Grouped / c + g_iter <= k_max_iter_) + if (threadIdx.x % Grouped / c + iter <= k_max_iter_) { __pipeline_memcpy_async( - reinterpret_cast(sh_addr_ + (sh_iter + s) * sh_step_ + ii * sh_stride_) + i, - reinterpret_cast(g_addr_ + (g_iter + s) * g_step_ + ii * g_stride_ ) + i, + reinterpret_cast(sh_addr_ + s * sh_step_ + ii * sh_stride_) + i, + reinterpret_cast(g_addr_ + (iter + s) * g_step_ + ii * g_stride_) + i, sizeof(TVec) ); } From 51b0df6c33998e7f17621b9296f1f4c4ad4559aa Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Fri, 6 Sep 2024 16:02:38 +0900 Subject: [PATCH 23/28] Compute memory address at compile time --- .../kernels/weightOnlyBatchedGemv/details.h | 8 ++ .../kernels/weightOnlyBatchedGemv/kernel.h | 57 +++++---- .../kernels/weightOnlyBatchedGemv/utility.h | 109 ++++++++++++------ 3 files changed, 111 insertions(+), 63 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h index 9ee3085dd..94d3cfaff 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h @@ -114,6 +114,14 @@ struct KernelDetails static constexpr bool kUseInterleavedConverter = UseInterleavedConverter; }; +template +struct ShMemOptimizer +{ + static constexpr bool GtoShStrided = GtoShStrided_; + static constexpr int ShStep = ShStep_; + static constexpr int ShStride = ShStride_; +}; + } // namespace weight_only } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 7859aa61f..d721530d5 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -64,12 +64,13 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca = (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize + ((tid * StepK) % Details::LayoutDetails::kTileSize); int const offset_k_group = (GroupSize != 0 ? real_offset_k / GroupSize : 0); + // number of threads with neighboring scale or zero values - int const near_sz_group = (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : 32); + static constexpr int near_sz_group = (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : 32); // number of scale or zero values needed in a single block per k iterations - int const sz_per_iter = CtaN * Details::kInterleave; + static constexpr int sz_per_iter = CtaN * Details::kInterleave; // number of k-iterations which would be loaded in a single shared memory block load - int const sh_sz_group = (GroupSize != 0 ? near_sz_group / (sz_per_iter > (sizeof(AccessTypeA) / sizeof(TypeA)) ? + static constexpr int sh_sz_group = (GroupSize != 0 ? near_sz_group / (sz_per_iter > (sizeof(AccessTypeA) / sizeof(TypeA)) ? sz_per_iter / (sizeof(AccessTypeA) / sizeof(TypeA)) : 1) : 1); __shared__ TypeA shmem_sz[sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group : 1) * sh_sz_group * (EnableZero ? 2 : 1)]; @@ -91,28 +92,24 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca sh_actscale = shmem_a + CtaK / Details::kInterleave * CtaM; } - SHMemIterator act_iterator( - act, 0, shmem_a, real_offset_k, - CtaK / Details::kInterleave, 0, origin_k, CtaK / Details::kInterleave, - interleaved_k / CtaK); - SHMemIterator act_scale_iterator( - act_scale, 0, sh_actscale, real_offset_k, - CtaK / Details::kInterleave, 0, 0, 0, - interleaved_k / CtaK); - SHMemIterator weight_iterator( - weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, - CtaK / Details::kElemsPerByteW, 0, interleaved_k / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, - interleaved_k / CtaK); - SHMemIterator scales_iterator( - scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave, Details::kInterleave, - (GroupSize != 0 ? interleaved_k / CtaK : 1)); - SHMemIterator zeros_iterator( - zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), (GroupSize != 0 ? sz_per_iter * Threads / near_sz_group : 0), - Details::kInterleave, Details::kInterleave, - (GroupSize != 0 ? interleaved_k / CtaK : 1)); + SHMemIterator, TypeA> act_iterator( + act, 0, shmem_a, real_offset_k, CtaK / Details::kInterleave, origin_k, interleaved_k / CtaK); + SHMemIterator, TypeA> act_scale_iterator( + act_scale, 0, sh_actscale, real_offset_k, CtaK / Details::kInterleave, 0, interleaved_k / CtaK); + SHMemIterator, uint8_t> weight_iterator( + weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, + CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW, interleaved_k / CtaK); + SHMemIterator, TypeA> scales_iterator( + scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), + 0, (GroupSize != 0 ? interleaved_k / CtaK : 1)); + SHMemIterator, TypeA> zeros_iterator( + zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), + 0, (GroupSize != 0 ? interleaved_k / CtaK : 1)); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; if constexpr (EnableBias) @@ -155,16 +152,16 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca #pragma unroll for (int i = 0; i < CtaN; ++i) { - scales_iterator.load(vec_scale + i, iter % sh_sz_group, i); - zeros_iterator.load(vec_zero + i, iter % sh_sz_group, i); - weight_iterator.load(tile_w_quantized + i * StepK / Details::kElemsPerByteW, 0, i); + scales_iterator.load(vec_scale + i, i, iter % sh_sz_group); + zeros_iterator.load(vec_zero + i, i, iter % sh_sz_group); + weight_iterator.load(tile_w_quantized + i * StepK / Details::kElemsPerByteW, i); } #pragma unroll for (int i = 0; i < CtaM; ++i) { - act_iterator.load(tile_a + i * StepK, 0, i); + act_iterator.load(tile_a + i * StepK, i); } - act_scale_iterator.load(vec_act_scale, 0); + act_scale_iterator.load(vec_act_scale); // Prefetch next stage if (idx_k + CtaK < interleaved_k) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index dc41aefcb..723cc3a59 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -312,19 +312,17 @@ class GMemIterator int stride_; }; -template +template class SHMemIterator { public: __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, T* sh_addr, int sh_offset, - int g_step, int sh_step, int g_stride, int sh_stride, float k_max_iter) + int g_step, int g_stride, float k_max_iter) : g_addr_(Enable ? (g_addr + g_offset) : nullptr) , sh_addr_(Enable ? sh_addr : nullptr) , sh_offset_(sh_offset) , g_step_(g_step) - , sh_step_(sh_step) , g_stride_(g_stride) - , sh_stride_(sh_stride) , k_max_iter_(k_max_iter) { } @@ -334,42 +332,70 @@ class SHMemIterator { if constexpr (Enable) { - if constexpr (Elements < VecSize) - { // Uncommon slow case - int const c = Elements * sizeof(T); - static_assert(c % 4 == 0); - int const s = threadIdx.x % Grouped; - if (s + iter <= k_max_iter_) - { - __pipeline_memcpy_async( - sh_addr_ + s * sh_step_ + ii * sh_stride_, - g_addr_ + (iter + s) * g_step_ + ii * g_stride_ , - c - ); + if constexpr (Grouped == 0) + { // Grouped == 0 is for weight / act case + if (iter <= k_max_iter_) + { // this for loop is mostly single iteration + for (int i = threadIdx.x; i < c_sh; i += CtaSize) + { + if constexpr (ShTraits::GtoShStrided) + { // W, A + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + ii * sh_stride_) + i, + reinterpret_cast(g_addr_ + iter * g_step_ + ii * g_stride_) + i, + sizeof(TVec) + ); + } + else + { // As + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_) + i, + reinterpret_cast(g_addr_ + iter * g_step_) + i, + sizeof(TVec) + ); + } + } } } else - { - int const c = Elements / VecSize; - int const s = threadIdx.x % Grouped / c; - // this for loop is mostly single iteration - for (int i = threadIdx.x % c; i < c; i += blockDim.x * blockDim.y * blockDim.z) - { - // s should be float to compare with k_max_iter_ - if (threadIdx.x % Grouped / c + iter <= k_max_iter_) + { // Grouped != 0 is for scale / zero + if constexpr (Elements < VecSize) + { // Uncommon slow case + static_assert(c_sh % 4 == 0); + int const s = threadIdx.x % Grouped; + if (s + iter <= k_max_iter_) { + static_assert(!ShTraits::GtoShStrided); __pipeline_memcpy_async( - reinterpret_cast(sh_addr_ + s * sh_step_ + ii * sh_stride_) + i, - reinterpret_cast(g_addr_ + (iter + s) * g_step_ + ii * g_stride_) + i, - sizeof(TVec) + sh_addr_ + s * sh_step_, + g_addr_ + (iter + s) * g_step_, + c_sh ); } } + else + { + // s should be float to compare with k_max_iter_ + if (threadIdx.x % Grouped / c_sh + iter <= k_max_iter_) + { + int const s = threadIdx.x % Grouped / c_sh; + // this for loop is mostly single iteration + for (int i = threadIdx.x % c_sh; i < c_sh; i += CtaSize) + { + static_assert(!ShTraits::GtoShStrided); + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + s * sh_step_) + i, + reinterpret_cast(g_addr_ + (iter + s) * g_step_) + i, + sizeof(TVec) + ); + } + } + } } } } - __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) + __device__ __forceinline__ void load(void* dst, int ii = 0, int iter = 0) { if constexpr (Enable) { if constexpr (Continuous < VecSize) @@ -383,11 +409,24 @@ class SHMemIterator else { static_assert(Continuous % VecSize == 0); - int const c = Continuous / VecSize; #pragma unroll - for (int jj = 0; jj < c; ++jj) + for (int jj = 0; jj < c_load; ++jj) { + if constexpr (sh_step_ == 0) + { + if constexpr (sh_stride_ == 0) + { + reinterpret_cast(dst)[jj] = reinterpret_cast(sh_addr_ + sh_offset_)[jj]; + } + else + { + reinterpret_cast(dst)[jj] = reinterpret_cast(sh_addr_ + ii * sh_stride_ + sh_offset_)[jj]; + } + } + else + { reinterpret_cast(dst)[jj] = reinterpret_cast(sh_addr_ + iter * sh_step_ + ii * sh_stride_ + sh_offset_)[jj]; + } } } } @@ -395,16 +434,20 @@ class SHMemIterator private: static constexpr int VecSize = sizeof(TVec) / sizeof(T); + static constexpr int c_sh = Elements * sizeof(T) / (Elements < VecSize ? 1 : sizeof(TVec)); + static constexpr int c_load = Continuous / VecSize; + const int CtaSize = blockDim.x * blockDim.y * blockDim.z; T* g_addr_; - int g_step_; - int sh_step_; T* sh_addr_; int sh_offset_; + int g_step_; int g_stride_; - int sh_stride_; // Decimal value represents that the last k iteration will only use a few warps float k_max_iter_; + + static constexpr int sh_step_ = ShTraits::ShStep; + static constexpr int sh_stride_ = ShTraits::ShStride; }; } // namespace weight_only From 44c6699f2925c08c5083d926d556008edd034023 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:12:40 +0900 Subject: [PATCH 24/28] Apply compile-time calculation for less instruction --- .../kernels/weightOnlyBatchedGemv/details.h | 10 ++- .../kernels/weightOnlyBatchedGemv/kernel.h | 33 ++++----- .../kernels/weightOnlyBatchedGemv/utility.h | 69 +++++++++---------- 3 files changed, 58 insertions(+), 54 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h index 94d3cfaff..75792e294 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h @@ -114,12 +114,20 @@ struct KernelDetails static constexpr bool kUseInterleavedConverter = UseInterleavedConverter; }; -template +template struct ShMemOptimizer { + using T = T_; + using TVec = TVec_; static constexpr bool GtoShStrided = GtoShStrided_; static constexpr int ShStep = ShStep_; static constexpr int ShStride = ShStride_; + static constexpr int Elements = Elements_; + static constexpr int Continuous = Continuous_; + static constexpr int VecSize = sizeof(TVec) / sizeof(T); + static constexpr int c_sh = Elements_ * sizeof(T) / (Elements_ < VecSize ? 1 : sizeof(TVec)); + static constexpr int c_load = Continuous_ / VecSize; + }; } // namespace weight_only diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index d721530d5..e8e07ba1b 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -72,10 +72,11 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca // number of k-iterations which would be loaded in a single shared memory block load static constexpr int sh_sz_group = (GroupSize != 0 ? near_sz_group / (sz_per_iter > (sizeof(AccessTypeA) / sizeof(TypeA)) ? sz_per_iter / (sizeof(AccessTypeA) / sizeof(TypeA)) : 1) : 1); + static constexpr int CtaKInterleave = CtaK / Details::kInterleave; __shared__ TypeA shmem_sz[sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group : 1) * sh_sz_group * (EnableZero ? 2 : 1)]; __shared__ uint8_t shmem_w[CtaK / Details::kElemsPerByteW * CtaN]; - __shared__ TypeA shmem_a[CtaK / Details::kInterleave * (EnableActScale ? CtaM + 1 : CtaM)]; + __shared__ TypeA shmem_a[CtaKInterleave * (EnableActScale ? CtaM + 1 : CtaM)]; TypeA* sh_scale = shmem_sz + sz_per_iter * offset_k_group; TypeA* sh_zero = nullptr; @@ -89,26 +90,26 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca if constexpr (EnableActScale) { - sh_actscale = shmem_a + CtaK / Details::kInterleave * CtaM; + sh_actscale = shmem_a + CtaKInterleave * CtaM; } - SHMemIterator, TypeA> act_iterator( - act, 0, shmem_a, real_offset_k, CtaK / Details::kInterleave, origin_k, interleaved_k / CtaK); - SHMemIterator, TypeA> act_scale_iterator( - act_scale, 0, sh_actscale, real_offset_k, CtaK / Details::kInterleave, 0, interleaved_k / CtaK); - SHMemIterator, uint8_t> weight_iterator( + SHMemIterator> act_iterator( + act, 0, shmem_a, real_offset_k, CtaKInterleave, origin_k, interleaved_k / CtaK); + SHMemIterator> act_scale_iterator( + act_scale, 0, sh_actscale, real_offset_k, CtaKInterleave, 0, interleaved_k / CtaK); + SHMemIterator> weight_iterator( weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW, interleaved_k / CtaK); - SHMemIterator, TypeA> scales_iterator( - scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), + SHMemIterator> scales_iterator( + scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, (GroupSize != 0 ? CtaKInterleave / GroupSize * n : 0), 0, (GroupSize != 0 ? interleaved_k / CtaK : 1)); - SHMemIterator, TypeA> zeros_iterator( - zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), + SHMemIterator> zeros_iterator( + zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, (GroupSize != 0 ? CtaKInterleave / GroupSize * n : 0), 0, (GroupSize != 0 ? interleaved_k / CtaK : 1)); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 723cc3a59..6d1d198cd 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -312,9 +312,12 @@ class GMemIterator int stride_; }; -template +template class SHMemIterator { + using T = typename ShTraits::T; + using TVec = typename ShTraits::TVec; + public: __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, T* sh_addr, int sh_offset, int g_step, int g_stride, float k_max_iter) @@ -334,27 +337,22 @@ class SHMemIterator { if constexpr (Grouped == 0) { // Grouped == 0 is for weight / act case - if (iter <= k_max_iter_) - { // this for loop is mostly single iteration - for (int i = threadIdx.x; i < c_sh; i += CtaSize) - { - if constexpr (ShTraits::GtoShStrided) - { // W, A - __pipeline_memcpy_async( - reinterpret_cast(sh_addr_ + ii * sh_stride_) + i, - reinterpret_cast(g_addr_ + iter * g_step_ + ii * g_stride_) + i, - sizeof(TVec) - ); - } - else - { // As - __pipeline_memcpy_async( - reinterpret_cast(sh_addr_) + i, - reinterpret_cast(g_addr_ + iter * g_step_) + i, - sizeof(TVec) - ); - } - } + int const i = threadIdx.x; + if constexpr (ShTraits::GtoShStrided) + { // W, A + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + ii * sh_stride_) + i, + reinterpret_cast(g_addr_ + iter * g_step_ + ii * g_stride_) + i, + sizeof(TVec) + ); + } + else + { // As + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_) + i, + reinterpret_cast(g_addr_ + iter * g_step_) + i, + sizeof(TVec) + ); } } else @@ -379,16 +377,13 @@ class SHMemIterator if (threadIdx.x % Grouped / c_sh + iter <= k_max_iter_) { int const s = threadIdx.x % Grouped / c_sh; - // this for loop is mostly single iteration - for (int i = threadIdx.x % c_sh; i < c_sh; i += CtaSize) - { - static_assert(!ShTraits::GtoShStrided); - __pipeline_memcpy_async( - reinterpret_cast(sh_addr_ + s * sh_step_) + i, - reinterpret_cast(g_addr_ + (iter + s) * g_step_) + i, - sizeof(TVec) - ); - } + int const i = threadIdx.x % c_sh; + static_assert(!ShTraits::GtoShStrided); + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + s * sh_step_) + i, + reinterpret_cast(g_addr_ + (iter + s) * g_step_) + i, + sizeof(TVec) + ); } } } @@ -433,11 +428,6 @@ class SHMemIterator } private: - static constexpr int VecSize = sizeof(TVec) / sizeof(T); - static constexpr int c_sh = Elements * sizeof(T) / (Elements < VecSize ? 1 : sizeof(TVec)); - static constexpr int c_load = Continuous / VecSize; - const int CtaSize = blockDim.x * blockDim.y * blockDim.z; - T* g_addr_; T* sh_addr_; int sh_offset_; @@ -446,6 +436,11 @@ class SHMemIterator // Decimal value represents that the last k iteration will only use a few warps float k_max_iter_; + static constexpr int VecSize = ShTraits::VecSize; + static constexpr int c_sh = ShTraits::c_sh; + static constexpr int c_load = ShTraits::c_load; + static constexpr int Continuous = ShTraits::Continuous; + static constexpr int Elements = ShTraits::Elements; static constexpr int sh_step_ = ShTraits::ShStep; static constexpr int sh_stride_ = ShTraits::ShStride; }; From 90c798c7ba4ac31c33623b5f93d571aec5d466a5 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Mon, 9 Sep 2024 09:53:23 +0000 Subject: [PATCH 25/28] Debug for ColumnMajor Case --- cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h index 75792e294..ab0f0755b 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h @@ -53,7 +53,7 @@ struct ColumnMajor using DetailsA = TypeDetailsA; using DetailsW = TypeDetailsW; using AccessTypeA = float4; - using AccessTypeW = int; + using AccessTypeW = typename std::conditional::type; static constexpr int kAccessSize = 128; static constexpr int kStepK = kAccessSize / TypeDetailsA::kElemBits; static constexpr int kTileSize = TileSizeK; From a50ccee9e04d658886312be02a1801a8b04fde65 Mon Sep 17 00:00:00 2001 From: dasistwo <99160400+dasistwo@users.noreply.github.com> Date: Tue, 10 Sep 2024 12:30:52 +0900 Subject: [PATCH 26/28] Revoke irrelevant commits --- .gitignore | 3 - cpp/CMakeLists.txt | 23 --- cpp/tests/CMakeLists.txt | 1 + cpp/tests/runtime/gptSessionTest.cpp | 205 ++++++++------------------- 4 files changed, 61 insertions(+), 171 deletions(-) diff --git a/.gitignore b/.gitignore index 3f10399da..d9463eeb2 100644 --- a/.gitignore +++ b/.gitignore @@ -41,9 +41,6 @@ docs/source/llm-api docs/source/llm-api-examples/llm_*.rst *.swp -# Debugging Purpose -.env - # Testing .coverage.* results_trt/ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 205df273e..f4c74b0ef 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -380,15 +380,6 @@ endif() # "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --generate-line-info") # set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") -# Add the option with CMAKE_CUDA_FLAGS causes error. Add G and Lineinfo. -if(${CMAKE_BUILD_TYPE} MATCHES "Debug") - add_compile_options("$<$:-G>" - "$<$:--host-linker-script>") -elseif(${CMAKE_BUILD_TYPE} MATCHES "RelWithDebInfo") - add_compile_options("$<$:-lineinfo>" - "$<$:--host-linker-script>") -endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss -DENABLE_MULTI_DEVICE=${ENABLE_MULTI_DEVICE}" ) @@ -450,16 +441,6 @@ if(FAST_MATH) message("CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") endif() -# Add the option with CMAKE_CUDA_FLAGS causes error. -if(${CMAKE_BUILD_TYPE} MATCHES "Debug") - add_compile_options("$<$:-G>" - "$<$:--host-linker-script>") -elseif(${CMAKE_BUILD_TYPE} MATCHES "RelWithDebInfo") - add_compile_options("$<$:-lineinfo>" - "$<$:--host-linker-script>") -endif() - - set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDAToolkit_INCLUDE_DIR}) message(STATUS "COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}") @@ -512,10 +493,6 @@ if(BUILD_PYT) link_directories("${Python3_LIBRARY_DIRS}") list(APPEND COMMON_HEADER_DIRS ${Python3_INCLUDE_DIRS}) - # Let torch find the cudnn and cusparselt libraries - set(CAFFE2_USE_CUDNN ON) - set(CAFFE2_USE_CUSPARSELT ON) - execute_process( COMMAND ${Python3_EXECUTABLE} "-c" diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 2e9ac67a1..34034fa7d 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -81,6 +81,7 @@ add_gtest(transposeKVKernelTest runtime/transposeKVKernelTest.cpp) add_gtest(gptDecoderTest runtime/gptDecoderTest.cpp) add_gtest(gptDecoderBatchedTest runtime/gptDecoderBatchedTest.cpp) add_gtest(gptSessionTest runtime/gptSessionTest.cpp) +target_link_libraries(gptSessionTest PRIVATE modelSpecStatic) add_gtest(memoryUtilsTest common/memoryUtilsTest.cu) if(ENABLE_MULTI_DEVICE) add_gtest(mpiUtilsTest common/mpiUtilsTest.cpp) diff --git a/cpp/tests/runtime/gptSessionTest.cpp b/cpp/tests/runtime/gptSessionTest.cpp index 43ea481d4..79e21f97b 100644 --- a/cpp/tests/runtime/gptSessionTest.cpp +++ b/cpp/tests/runtime/gptSessionTest.cpp @@ -19,6 +19,7 @@ #include +#include "modelSpec.h" #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/mpiUtils.h" #include "tensorrt_llm/common/stlUtils.h" @@ -35,6 +36,10 @@ using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace fs = std::filesystem; +using tensorrt_llm::testing::ModelSpec; +using tensorrt_llm::testing::KVCacheType; +using tensorrt_llm::testing::QuantMethod; +using tensorrt_llm::testing::OutputContentType; namespace { @@ -49,6 +54,10 @@ auto const CHATGLM_MODEL_DIR = "chatglm-6b"; auto const CHATGLM2_MODEL_DIR = "chatglm2-6b"; auto const CHATGLM3_MODEL_DIR = "chatglm3-6b"; auto const MAMBA_MODEL_DIR = "mamba-2.8b-hf"; +auto const INPUT_FILE = "input_tokens.npy"; +auto const CHATGLM_INPUT_FILE = "input_tokens_chatglm-6b.npy"; +auto const CHATGLM2_INPUT_FILE = "input_tokens_chatglm2-6b.npy"; +auto const CHATGLM3_INPUT_FILE = "input_tokens_chatglm3-6b.npy"; // Engines need to be generated using cpp/tests/resources/scripts/build_*_engines.py. auto const FP32_GPT_DIR = "fp32-default"; @@ -81,77 +90,6 @@ struct ModelParams ModelIds ids; }; -class ModelSpec -{ -public: - ModelSpec(fs::path modelPath, fs::path resultsFile, nvinfer1::DataType dtype) - : mModelPath{std::move(modelPath)} - , mResultsFile{std::move(resultsFile)} - , mDataType{dtype} - , mUseGptAttentionPlugin{false} - , mUsePackedInput{false} - , mUsePagedKvCache{false} - , mDecoderPerRequest{false} - , mPPSize(1) - , mTPSize(1) - , mRandomEndId(false) - { - } - - ModelSpec& useGptAttentionPlugin() - { - mUseGptAttentionPlugin = true; - return *this; - } - - ModelSpec& usePackedInput() - { - mUsePackedInput = true; - return *this; - } - - ModelSpec& usePagedKvCache() - { - mUsePagedKvCache = true; - return *this; - } - - ModelSpec& useDecoderPerRequest() - { - mDecoderPerRequest = true; - return *this; - } - - ModelSpec& usePipelineParallelism(int ppSize) - { - mPPSize = ppSize; - return *this; - } - - ModelSpec& useTensorParallelism(int tpSize) - { - mTPSize = tpSize; - return *this; - } - - ModelSpec& useRandomEndId() - { - mRandomEndId = true; - return *this; - } - - fs::path mModelPath; - fs::path mResultsFile; - nvinfer1::DataType mDataType; - bool mUseGptAttentionPlugin; - bool mUsePackedInput; - bool mUsePagedKvCache; - bool mDecoderPerRequest; - int mPPSize; - int mTPSize; - bool mRandomEndId; -}; - struct MicroBatchSizes { std::optional ctxMicroBatchSize{std::nullopt}; @@ -199,7 +137,9 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model ASSERT_TRUE(fs::exists(DATA_PATH)); std::string modelName{isChatGlmTest ? resultsFile.parent_path().parent_path().filename().string() : ""}; - fs::path inputPath = DATA_PATH / (isChatGlmTest ? "input_tokens_" + modelName + ".npy" : "input_tokens.npy"); + + fs::path inputPath = DATA_PATH / modelSpec.mInputFile; + auto const& givenInput = utils::loadNpy(manager, inputPath.string(), MemoryType::kCPU); auto const& inputShape = givenInput->getShape(); ASSERT_EQ(inputShape.nbDims, 2); @@ -463,7 +403,7 @@ std::string generateTestName(testing::TestParamInfo const& info) name.append("AttentionPlugin"); if (modelSpec.mUsePackedInput) name.append("Packed"); - if (modelSpec.mUsePagedKvCache) + if (modelSpec.mKVCacheType == KVCacheType::kPAGED) name.append("PagedKvCache"); if (modelSpec.mDecoderPerRequest) name.append("DecoderBatch"); @@ -511,10 +451,10 @@ TEST_P(ParamTest, Test) std::ostringstream gpuSizePath; gpuSizePath << "tp" << modelSpec.mTPSize << "-pp" << modelSpec.mPPSize << "-gpu"; - auto const modelPath{ENGINE_PATH / modelDir / modelSpec.mModelPath / gpuSizePath.str()}; + auto const modelPath{ENGINE_PATH / modelDir / modelSpec.getModelPath() / gpuSizePath.str()}; auto const resultsPath = DATA_PATH / modelDir / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth)); - fs::path const resultsFile{resultsPath / modelSpec.mResultsFile}; + fs::path const resultsFile{resultsPath / modelSpec.getResultsFile()}; // Warning: This should be the last check before running the test. // It will initialize MPI which can take significant time. @@ -532,11 +472,10 @@ INSTANTIATE_TEST_SUITE_P(GptSessionOtbTest, ParamTest, testing::Combine(testing::Values(ModelParams{GPT_MODEL_DIR, {50256, 50256}}), testing::Values( // single decoder - ModelSpec{FP32_GPT_DIR, FP32_RESULT_FILE, nvinfer1::DataType::kFLOAT}, - ModelSpec{FP16_GPT_DIR, FP16_RESULT_FILE, nvinfer1::DataType::kHALF}, + ModelSpec{INPUT_FILE, nvinfer1::DataType::kFLOAT}, ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}, // decoderBatch - ModelSpec{FP32_GPT_DIR, FP32_RESULT_FILE, nvinfer1::DataType::kFLOAT}.useDecoderPerRequest(), - ModelSpec{FP16_GPT_DIR, FP16_RESULT_FILE, nvinfer1::DataType::kHALF}.useDecoderPerRequest() + ModelSpec{INPUT_FILE, nvinfer1::DataType::kFLOAT}.useDecoderPerRequest(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useDecoderPerRequest() ), testing::Values(1), // beamWidth @@ -553,40 +492,30 @@ INSTANTIATE_TEST_SUITE_P(GptSessionTest, ParamTest, // Disabled because of flakey beam search test // ModelSpec{FP32_GPT_ATTENTION_DIR, FP32_PLUGIN_RESULT_FILE, nvinfer1::DataType::kFLOAT} // .useGptAttentionPlugin(), - ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin(), - ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() - .usePackedInput(), - ModelSpec{ - FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() - .usePackedInput() - .usePagedKvCache(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput().setKVCacheType( + KVCacheType::kPAGED), // decoderBatch // Disabled because of flakey beam search test // ModelSpec{FP32_GPT_ATTENTION_DIR, FP32_PLUGIN_RESULT_FILE, nvinfer1::DataType::kFLOAT} // .useGptAttentionPlugin() // .useDecoderPerRequest(), - ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() - .useDecoderPerRequest(), - ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().useDecoderPerRequest(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() .useDecoderPerRequest(), - ModelSpec{ - FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .usePagedKvCache() + .setKVCacheType(KVCacheType::kPAGED) .useDecoderPerRequest(), - ModelSpec{ - FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .usePagedKvCache() + .setKVCacheType(KVCacheType::kPAGED) .useDecoderPerRequest() .useRandomEndId() @@ -602,29 +531,20 @@ INSTANTIATE_TEST_SUITE_P(GptjSessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{GPTJ_MODEL_DIR, {50256, 50256}}), testing::Values( // single decoder - ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin(), - ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() - .usePackedInput(), - ModelSpec{ - FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() - .usePackedInput() - .usePagedKvCache(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput().setKVCacheType( + KVCacheType::kPAGED), // decoderBatch - ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() - .useDecoderPerRequest(), - ModelSpec{FP16_GPT_ATTENTION_PACKED_DIR, FP16_PLUGIN_PACKED_RESULT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().useDecoderPerRequest(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() .useDecoderPerRequest(), - ModelSpec{ - FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .usePagedKvCache() + .setKVCacheType(KVCacheType::kPAGED) .useDecoderPerRequest() ), @@ -639,7 +559,7 @@ INSTANTIATE_TEST_SUITE_P(MambaSessionOOTBTest, ParamTest, testing::Combine(testing::Values(ModelParams{MAMBA_MODEL_DIR, {0, 1}}), testing::Values( // single decoder - ModelSpec{FP16_GPT_DIR, FP16_RESULT_FILE, nvinfer1::DataType::kHALF}), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}), testing::Values(1), // beamWidth testing::Values(false), // cudaGraphMode testing::Values(MicroBatchSizes()), @@ -648,7 +568,7 @@ INSTANTIATE_TEST_SUITE_P(MambaSessionOOTBTest, ParamTest, generateTestName); INSTANTIATE_TEST_SUITE_P(MambaSessionPluginTest, ParamTest, testing::Combine(testing::Values(ModelParams{MAMBA_MODEL_DIR, {0, 1}}), - testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF}), + testing::Values(ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useMambaPlugin()), testing::Values(1), // beamWidth testing::Values(false), // cudaGraphMode testing::Values(MicroBatchSizes()), @@ -660,37 +580,30 @@ INSTANTIATE_TEST_SUITE_P(LlamaSessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{LLAMA_MODEL_DIR, {2, 2}}), testing::Values( // single decoder - ModelSpec{ - FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() - .usePackedInput() - .usePagedKvCache(), + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin().usePackedInput().setKVCacheType( + KVCacheType::kPAGED), // decoderBatch - ModelSpec{ - FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_FILE, nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .usePagedKvCache() + .setKVCacheType(KVCacheType::kPAGED) .useDecoderPerRequest(), - ModelSpec{FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP4_FILE, - nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .usePagedKvCache() + .setKVCacheType(KVCacheType::kPAGED) .useDecoderPerRequest() .usePipelineParallelism(4), - ModelSpec{FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_TP4_PP1_FILE, - nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .usePagedKvCache() + .setKVCacheType(KVCacheType::kPAGED) .useDecoderPerRequest() .useTensorParallelism(4), - ModelSpec{FP16_GPT_ATTENTION_PACKED_PAGED_DIR, FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP2_FILE, - nvinfer1::DataType::kHALF} + ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF} .useGptAttentionPlugin() .usePackedInput() - .usePagedKvCache() + .setKVCacheType(KVCacheType::kPAGED) .useDecoderPerRequest() .usePipelineParallelism(2) .useTensorParallelism(2) @@ -705,8 +618,7 @@ INSTANTIATE_TEST_SUITE_P(LlamaSessionTest, ParamTest, INSTANTIATE_TEST_SUITE_P(ChatGlmSessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{CHATGLM_MODEL_DIR, {130005, 3}}), // end_id, pad_id - testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() + testing::Values(ModelSpec{CHATGLM_INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin() ), testing::Values(1, 2), // beamWidth @@ -718,8 +630,7 @@ INSTANTIATE_TEST_SUITE_P(ChatGlmSessionTest, ParamTest, INSTANTIATE_TEST_SUITE_P(ChatGlm2SessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{CHATGLM2_MODEL_DIR, {2, 0}}), // end_id, pad_id - testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() + testing::Values(ModelSpec{CHATGLM2_INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin() ), testing::Values(1, 2), // beamWidth @@ -731,8 +642,7 @@ INSTANTIATE_TEST_SUITE_P(ChatGlm2SessionTest, ParamTest, INSTANTIATE_TEST_SUITE_P(ChatGlm3SessionTest, ParamTest, testing::Combine(testing::Values(ModelParams{CHATGLM3_MODEL_DIR, {2, 0}}), // end_id, pad_id - testing::Values(ModelSpec{FP16_GPT_ATTENTION_DIR, FP16_PLUGIN_RESULT_FILE, nvinfer1::DataType::kHALF} - .useGptAttentionPlugin() + testing::Values(ModelSpec{CHATGLM3_INPUT_FILE, nvinfer1::DataType::kHALF}.useGptAttentionPlugin() ), testing::Values(1, 2), // beamWidth @@ -753,11 +663,12 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16WithAttentionPlugin) auto const engineDir = "llama_7bf_outputs_tp1"; auto const modelPath{ENGINE_PATH / modelDir / engineDir}; SizeType32 constexpr beamWidth{1}; - fs::path resultsFile{DATA_PATH / modelDir / FP16_RESULT_FILE}; auto const batchSizes = {8}; auto constexpr dtype = nvinfer1::DataType::kHALF; - auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin(); + auto otherModelSpecPtr = std::make_shared(INPUT_FILE, dtype); + auto const modelSpec = ModelSpec{INPUT_FILE, dtype}.useGptAttentionPlugin(); + fs::path resultsFile{DATA_PATH / modelDir / modelSpec.getResultsFile()}; auto const modeIds = ModelIds{2, 2}; testGptSession( @@ -770,11 +681,15 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16AttentionPluginDecoderBatch) auto const modelDir = "llamav2"; auto const modelPath{ENGINE_PATH / modelDir}; SizeType32 constexpr beamWidth{1}; - fs::path resultsFile{DATA_PATH / modelDir / FP16_RESULT_FILE}; auto const batchSizes = {8}; auto constexpr dtype = nvinfer1::DataType::kHALF; - auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin().usePackedInput().useDecoderPerRequest(); + auto otherModelSpecPtr = std::make_shared(INPUT_FILE, dtype); + auto const modelSpec = ModelSpec{INPUT_FILE, dtype, otherModelSpecPtr} + .useGptAttentionPlugin() + .usePackedInput() + .useDecoderPerRequest(); + fs::path resultsFile{DATA_PATH / modelDir / modelSpec.getResultsFile()}; auto const modeIds = ModelIds{2, 2}; testGptSession( From b19f748a1a66c6d6f383d0ec22a0de39eb80efd5 Mon Sep 17 00:00:00 2001 From: Jaeyoung Choi Date: Fri, 6 Dec 2024 11:39:46 +0900 Subject: [PATCH 27/28] Update submodule cutlass --- 3rdparty/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cutlass b/3rdparty/cutlass index 19b4c5e06..80243e0b8 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit 19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc +Subproject commit 80243e0b8c644f281e2beb0c20fe78cf7b267061 From 2a8779de05f2b36840114919ac99586a45d19d92 Mon Sep 17 00:00:00 2001 From: Jaeyoung Choi Date: Fri, 6 Dec 2024 17:13:01 +0900 Subject: [PATCH 28/28] Debug errors with updated cutlass --- .../threadblock/epilogue_tensor_op_int32.h | 23 ------------------- .../gemm/kernel/default_fpA_intB_traits.h | 7 +----- .../threadblock/default_dq_mma_pipelined.h | 12 ++++++---- .../gemmSwigluPlugin/gemmSwigluPlugin.cu | 2 +- 4 files changed, 10 insertions(+), 34 deletions(-) diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h index 6f26d7901..88a4c7f60 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -87,29 +87,6 @@ namespace epilogue namespace threadblock { -//////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ - -/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp -{ - using WarpTileIterator - = cutlass::epilogue::warp::TileIteratorTensorOpMixed; - - using SharedLoadIterator - = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; - - static int const kFragmentsPerIteration = 2; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tile iterator used to load output tile from shared memory in epilogue. diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index 9a6b32eaa..3f8347027 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -139,12 +139,7 @@ struct MixedGemmArchTraits struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value -#ifdef ENABLE_FP8 - || cutlass::platform::is_same::value>::type> -#else - >::type> -#endif + || cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h index 345cd2eec..1d2f7b81f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -126,8 +126,10 @@ struct DqMma::value || platform::is_same::value, "Element A must be fp16 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element B must be uint8, uint4 or uint2"); using OperatorInfo = arch::DetagOperator; using Operator = typename OperatorInfo::Operator; @@ -213,8 +215,10 @@ struct DqMma::value || platform::is_same::value, "Element A must be fp16 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element B must be uint8, uint4 or uint2"); using OperatorInfo = arch::DetagOperator; using Operator = typename OperatorInfo::Operator; diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu index 1fe21bd91..339c432b1 100644 --- a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu @@ -36,6 +36,6 @@ void GemmSwigluPluginProfiler::initTmpData(int m, int n, int k, char* workspace, if (mType == nvinfer1::DataType::kFP8) { cutlass::reference::device::BlockFillRandomUniform(reinterpret_cast(workspace), - m * k + n * k + 1 * n, 42, cutlass::float_e4m3_t{128}, -cutlass::float_e4m3_t{128}, -1, stream); + m * k + n * k + 1 * n, 42, cutlass::float_e4m3_t{128}, -cutlass::float_e4m3_t{128}, -1, 0, stream); } }