diff --git a/3rdparty/cutlass b/3rdparty/cutlass index 19b4c5e06..80243e0b8 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit 19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc +Subproject commit 80243e0b8c644f281e2beb0c20fe78cf7b267061 diff --git a/cpp/include/tensorrt_llm/common/stringUtils.h b/cpp/include/tensorrt_llm/common/stringUtils.h index 9c5ecde98..f368dcb62 100644 --- a/cpp/include/tensorrt_llm/common/stringUtils.h +++ b/cpp/include/tensorrt_llm/common/stringUtils.h @@ -23,6 +23,7 @@ #include // std::make_unique #include // std::stringstream +#include #include #include #include diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h index 6f26d7901..88a4c7f60 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -87,29 +87,6 @@ namespace epilogue namespace threadblock { -//////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ - -/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp -{ - using WarpTileIterator - = cutlass::epilogue::warp::TileIteratorTensorOpMixed; - - using SharedLoadIterator - = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; - - static int const kFragmentsPerIteration = 2; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tile iterator used to load output tile from shared memory in epilogue. diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 8ac984faf..e1d8c012e 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -130,6 +130,24 @@ template using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; +template + struct LayoutDetailsB < TypeA, + uint2b_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + + template struct LayoutDetailsB= 90>::type> { @@ -148,6 +166,15 @@ struct LayoutDetailsB +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h index 17c634655..59d372e4f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -230,8 +230,9 @@ struct DqMma::value, "Mma multistage must dequantize after ldsm"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element B must be uint8, uint4 or uint2"); static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h index 345cd2eec..1d2f7b81f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -126,8 +126,10 @@ struct DqMma::value || platform::is_same::value, "Element A must be fp16 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element B must be uint8, uint4 or uint2"); using OperatorInfo = arch::DetagOperator; using Operator = typename OperatorInfo::Operator; @@ -213,8 +215,10 @@ struct DqMma::value || platform::is_same::value, "Element A must be fp16 or bf16"); - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element B must be uint8, uint4 or uint2"); using OperatorInfo = arch::DetagOperator; using Operator = typename OperatorInfo::Operator; diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h index ad6c7496e..d7f0736b8 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h @@ -124,6 +124,54 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage /// (stage>=3) @@ -232,6 +280,59 @@ struct DefaultMma=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; #ifdef ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////// /// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage @@ -287,6 +388,59 @@ struct DefaultMma=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; #endif // fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 77af81005..876d23258 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -244,6 +244,54 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + template < /// Layout type for A matrix operand typename LayoutA, @@ -348,6 +396,59 @@ struct DefaultMma +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + } // namespace threadblock } // namespace gemm } // namespace cutlass diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h index a4f80dc6f..2a7cc02f3 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -562,6 +562,7 @@ class DqMmaMultistage (-Base::kStages + 1);) { @@ -569,6 +570,8 @@ class DqMmaMultistage; + FragmentOperandB converted_frag_B_operand; // Computes a warp-level GEMM on data held in shared memory // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate CUTLASS_PRAGMA_UNROLL @@ -588,23 +591,26 @@ class DqMmaMultistagewarp_tile_iterator_B_.set_kgroup_index( (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(idx + 1) % 2]); ++this->warp_tile_iterator_B_; } + if (warp_tileB_k_compute_offset == 0) { + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[idx % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + using Converter + = cutlass::NumericArrayConverter; + converted_frag_B_operand = Converter::convert(converted_frag_B); + } - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + idx++; + } - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, warp_tileB_k_compute_offset); diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680..d18f883f7 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @@ -440,6 +440,245 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i2s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM0_MASK = 0x00030003; + static constexpr uint32_t BOTTOM1_MASK = 0x000c000c; + static constexpr uint32_t TOP0_MASK = 0x00300030; + static constexpr uint32_t TOP1_MASK = 0x00c000c0; + static constexpr uint32_t I2s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i2s = i2s >> 8; + // Extract elt_01 - (i2s & 0x00020002) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i2s), "n"(BOTTOM0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i2s), "n"(BOTTOM1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(i2s), "n"(TOP0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(i2s), "n"(TOP1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // Extract elt_89 - (i2s & 0x00020002) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[4]) + : "r"(top_i2s), "n"(BOTTOM0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[5]) + : "r"(top_i2s), "n"(BOTTOM1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[6]) + : "r"(top_i2s), "n"(TOP0_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[7]) + : "r"(top_i2s), "n"(TOP1_MASK), "n"(I2s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1026, 1026} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64026402; + // This is the half2 {1 / 4, 1 / 4} represented as an integer. + static constexpr uint32_t ONE_FOUR = 0x34003400; + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + static constexpr uint32_t ONE_SIXTY_FOUR = 0x24002400; + // This is the half2 {-72, -72} represented as an integer. + //static constexpr uint32_t NEG_72 = 0xd480d480; + static constexpr uint32_t NEG_258 = 0xdc08dc08; + static constexpr uint32_t NEG_66 = 0xd420d420; + static constexpr uint32_t NEG_18 = 0xcc80cc80; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_FOUR), "r"(NEG_258)); + // Convert elt_45 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(ONE_SIXTEENTH), "r"(NEG_66)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTY_FOUR), "r"(NEG_18)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(ONE_FOUR), "r"(NEG_258)); + // Convert elt_45 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(ONE_SIXTEENTH), "r"(NEG_66)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(ONE_SIXTY_FOUR), "r"(NEG_18)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 16; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t source_i2s = reinterpret_cast(source); + + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x00030003; + static constexpr uint32_t I2s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i2s = source_i2s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i2s), "n"(MASK), "n"(I2s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) + { + i2s >>= sizeof_bits::value; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i2s), "n"(MASK), "n"(I2s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-130, -130} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC302C302; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) + { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + //*/ +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 16; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + //printf("convert uint2 to bfloat16\n"); + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp index 6130ad5da..dfa9f9b70 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp @@ -114,6 +114,9 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) case QuantType::W4_AFP8: details = getLayoutDetailsForArchAndQuantType(); break; + case QuantType::W2_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; default: TLLM_THROW("Unsupported quantization type"); } return details; @@ -173,6 +176,12 @@ std::vector get_permutation_map(QuantType quant_type) return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; } + else if (quant_type == QuantType::W2_A16) + { + return {0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57, 2, 3, 10, 11, 18, 19, 26, 27, 34, 35, + 42, 43, 50, 51, 58, 59, 4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61, 6, 7, 14, 15, 22, + 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63}; + } else { TLLM_THROW("Invalid quantization type for LDSM permutation"); @@ -351,6 +360,32 @@ void subbyte_transpose_impl( } } } + else if constexpr (bits_per_elt == 2) + { + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) + { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; + + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0x3 & (cache_buf[ii][jj_byte] >> (2 * jj_bit_offset)); + uint8_t tgt_elt = 0x3 & (cache_buf[jj][ii_byte] >> (2 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (~(0x3 << (2 * jj_bit_offset))); + cache_buf[jj][ii_byte] &= (~(0x3 << (2 * ii_bit_offset))); + + cache_buf[ii][jj_byte] |= (tgt_elt << (2 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (2 * ii_bit_offset)); + } + } + } else { TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type."); @@ -402,6 +437,10 @@ void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quanti { subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); } + else if (quant_type == QuantType::W2_A16) + { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else { TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); @@ -487,6 +526,70 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz } } +void add_bias_and_interleave_int2s_inplace(int8_t* packed_int2_tensor, const size_t num_elts) +{ + const int num_bytes = num_elts / 4; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) + { + int8_t transformed_packed_int2s = 0; + int8_t transformed_elt0 + = (int8_t(packed_int2_tensor[ii] << 6) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt1 + = (int8_t(packed_int2_tensor[ii] << 4) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt2 + = (int8_t(packed_int2_tensor[ii] << 2) >> 6) + 2; // The double shift here is to ensure sign extension + int8_t transformed_elt3 = (packed_int2_tensor[ii] >> 6) + 2; + + TLLM_CHECK_WITH_INFO( + transformed_elt0 >= 0 && transformed_elt0 <= 3, "Illegal result for int2 transform (elt0)"); + TLLM_CHECK_WITH_INFO( + transformed_elt1 >= 0 && transformed_elt1 <= 3, "Illegal result for int2 transform (elt1)"); + TLLM_CHECK_WITH_INFO( + transformed_elt2 >= 0 && transformed_elt2 <= 3, "Illegal result for int2 transform (elt2)"); + TLLM_CHECK_WITH_INFO( + transformed_elt3 >= 0 && transformed_elt3 <= 3, "Illegal result for int2 transform (elt3)"); + + // We don't need to mask in these ops since everything should be in the range 0-3 + transformed_packed_int2s |= transformed_elt0; + transformed_packed_int2s |= (transformed_elt1 << 2); + transformed_packed_int2s |= (transformed_elt2 << 4); + transformed_packed_int2s |= (transformed_elt3 << 6); + packed_int2_tensor[ii] = transformed_packed_int2s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt15 ... elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_15 ... elt_5 elt_3 elt_1 elt_14 ... elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + TLLM_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int2 tensor must be a multiple of 16 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int2_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) + { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 16; ++dest_idx) + { + const int src_idx = dest_idx < 8 ? 2 * dest_idx : 2 * (dest_idx - 8) + 1; + const int src_shift = 2 * src_idx; + const int dest_shift = 2 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0x3; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -502,6 +605,10 @@ void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size // for conversion to FP16. add_bias_and_interleave_int4s_inplace(tensor, num_elts); } + else if (quant_type == QuantType::W2_A16) + { + add_bias_and_interleave_int2s_inplace(tensor, num_elts); + } else { TLLM_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h index b12fd7372..5b3e62332 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h @@ -33,7 +33,8 @@ enum class QuantType { W8_A16, W4_A16, - W4_AFP8 + W4_AFP8, + W2_A16 }; constexpr int get_weight_quant_bits(QuantType quant_type) @@ -43,6 +44,7 @@ constexpr int get_weight_quant_bits(QuantType quant_type) case QuantType::W8_A16: return 8; case QuantType::W4_A16: return 4; case QuantType::W4_AFP8: return 4; + case QuantType::W2_A16: return 2; default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1; } } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h index 0ec8ab2e3..501ff6526 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h @@ -18,7 +18,9 @@ #include #include +#if defined(ENABLE_FP8) #include +#endif #include "cutlass/bfloat16.h" #include "cutlass/float8.h" diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu new file mode 100644 index 000000000..a5b523ea4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scalebias.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu new file mode 100644 index 000000000..60d2e0903 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_fg_scaleonly.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu new file mode 100644 index 000000000..1d00561be --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/bf16_int2_gemm_per_col.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint2b_t, + cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>; +#endif +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu new file mode 100644 index 000000000..0966ed761 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scalebias.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu new file mode 100644 index 000000000..20f17fadb --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_fg_scaleonly.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu new file mode 100644 index 000000000..a000397f8 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fp16_int2_gemm_per_col.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index f1b7bb0f7..7288b6c8d 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -78,7 +78,8 @@ void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, ""); // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. @@ -256,6 +257,7 @@ void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, Sca + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); } +#ifdef ENABLE_FP8 else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { @@ -264,6 +266,7 @@ void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, Sca + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; throw std::runtime_error("[TensorRT-LLm Error][filter_and_run_mixed_gemm] " + err_msg); } +#endif else { generic_mixed_gemm_kernelLauncher constexpr bool is_fp8() { +#ifdef ENABLE_FP8 return std::is_same_v || std::is_same_v; +#else + return false; +#endif } template #include +#ifdef __CUDACC__ +#include +#endif + namespace tensorrt_llm { namespace kernels diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h index 9ee3085dd..ab0f0755b 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/details.h @@ -53,7 +53,7 @@ struct ColumnMajor using DetailsA = TypeDetailsA; using DetailsW = TypeDetailsW; using AccessTypeA = float4; - using AccessTypeW = int; + using AccessTypeW = typename std::conditional::type; static constexpr int kAccessSize = 128; static constexpr int kStepK = kAccessSize / TypeDetailsA::kElemBits; static constexpr int kTileSize = TileSizeK; @@ -114,6 +114,22 @@ struct KernelDetails static constexpr bool kUseInterleavedConverter = UseInterleavedConverter; }; +template +struct ShMemOptimizer +{ + using T = T_; + using TVec = TVec_; + static constexpr bool GtoShStrided = GtoShStrided_; + static constexpr int ShStep = ShStep_; + static constexpr int ShStride = ShStride_; + static constexpr int Elements = Elements_; + static constexpr int Continuous = Continuous_; + static constexpr int VecSize = sizeof(TVec) / sizeof(T); + static constexpr int c_sh = Elements_ * sizeof(T) / (Elements_ < VecSize ? 1 : sizeof(TVec)); + static constexpr int c_load = Continuous_ / VecSize; + +}; + } // namespace weight_only } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index bf91a7bd2..e8e07ba1b 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -58,25 +58,59 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; - int const real_offset_n = interleaved_offset_n * Details::kInterleave - + ((tid * StepK / Details::LayoutDetails::kTileSize) % Details::kInterleave); + int const blk_offset_n = interleaved_offset_n * Details::kInterleave; + int const thr_offset_n = (tid / Details::kThreadsPerInterleavedTile) % Details::kInterleave; int const real_offset_k = (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize + ((tid * StepK) % Details::LayoutDetails::kTileSize); + int const offset_k_group = (GroupSize != 0 ? real_offset_k / GroupSize : 0); + + // number of threads with neighboring scale or zero values + static constexpr int near_sz_group = (GroupSize != 0 ? GroupSize / Details::kInterleave / Details::kThreadsPerInterleavedTile : 32); + // number of scale or zero values needed in a single block per k iterations + static constexpr int sz_per_iter = CtaN * Details::kInterleave; + // number of k-iterations which would be loaded in a single shared memory block load + static constexpr int sh_sz_group = (GroupSize != 0 ? near_sz_group / (sz_per_iter > (sizeof(AccessTypeA) / sizeof(TypeA)) ? + sz_per_iter / (sizeof(AccessTypeA) / sizeof(TypeA)) : 1) : 1); + static constexpr int CtaKInterleave = CtaK / Details::kInterleave; + + __shared__ TypeA shmem_sz[sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group : 1) * sh_sz_group * (EnableZero ? 2 : 1)]; + __shared__ uint8_t shmem_w[CtaK / Details::kElemsPerByteW * CtaN]; + __shared__ TypeA shmem_a[CtaKInterleave * (EnableActScale ? CtaM + 1 : CtaM)]; - GMemIterator act_iterator( - act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); - GMemIterator act_scale_iterator( - act_scale, real_offset_k, CtaK / Details::kInterleave, 0); - GMemIterator weight_iterator(weight, - (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, - interleaved_k / Details::kElemsPerByteW); - GMemIterator scales_iterator(scales, - (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); - GMemIterator zeros_iterator(zeros, - (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, - (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + TypeA* sh_scale = shmem_sz + sz_per_iter * offset_k_group; + TypeA* sh_zero = nullptr; + TypeA* sh_actscale = nullptr; + + if constexpr (EnableZero) + { + sh_zero = shmem_sz + sz_per_iter * (GroupSize != 0 ? Threads / near_sz_group * sh_sz_group : 1); + sh_zero += sz_per_iter * offset_k_group; + } + + if constexpr (EnableActScale) + { + sh_actscale = shmem_a + CtaKInterleave * CtaM; + } + + SHMemIterator> act_iterator( + act, 0, shmem_a, real_offset_k, CtaKInterleave, origin_k, interleaved_k / CtaK); + SHMemIterator> act_scale_iterator( + act_scale, 0, sh_actscale, real_offset_k, CtaKInterleave, 0, interleaved_k / CtaK); + SHMemIterator> weight_iterator( + weight, interleaved_offset_n * interleaved_k / Details::kElemsPerByteW, shmem_w, tid * StepK / Details::kElemsPerByteW, + CtaK / Details::kElemsPerByteW, interleaved_k / Details::kElemsPerByteW, interleaved_k / CtaK); + SHMemIterator> scales_iterator( + scales, offset_k_group * n + blk_offset_n, sh_scale, thr_offset_n, (GroupSize != 0 ? CtaKInterleave / GroupSize * n : 0), + 0, (GroupSize != 0 ? interleaved_k / CtaK : 1)); + SHMemIterator> zeros_iterator( + zeros, offset_k_group * n + blk_offset_n, sh_zero, thr_offset_n, (GroupSize != 0 ? CtaKInterleave / GroupSize * n : 0), + 0, (GroupSize != 0 ? interleaved_k / CtaK : 1)); out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; if constexpr (EnableBias) @@ -84,6 +118,20 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca bias += tile_id_n * CtaN * Details::kInterleave; } +// Prefetch stage 0 +#pragma unroll + for (int i = 0; i < CtaN; ++i) + { + weight_iterator.copy_to_shmem(0, i); + } +#pragma unroll + for (int i = 0; i < CtaM; ++i) + { + act_iterator.copy_to_shmem(0, i); + } + act_scale_iterator.copy_to_shmem(0); + __pipeline_commit(); + TypeA tile_acc[CtaM * CtaN]; fill(tile_acc, static_cast(0.f)); @@ -91,29 +139,60 @@ __global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* sca { TypeA vec_act_scale[StepK]; TypeA vec_scale[CtaN], vec_zero[CtaN]; - TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; - uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; + TypeA tile_a[StepK * CtaM], tile_w[StepK], tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW * CtaN]; + + if (iter % sh_sz_group == 0) + { + scales_iterator.copy_to_shmem(iter); + zeros_iterator.copy_to_shmem(iter); + __pipeline_commit(); + } + __pipeline_wait_prior(0); + #pragma unroll for (int i = 0; i < CtaN; ++i) { - scales_iterator.load(vec_scale + i, iter, i); - zeros_iterator.load(vec_zero + i, iter, i); + scales_iterator.load(vec_scale + i, i, iter % sh_sz_group); + zeros_iterator.load(vec_zero + i, i, iter % sh_sz_group); + weight_iterator.load(tile_w_quantized + i * StepK / Details::kElemsPerByteW, i); } - act_scale_iterator.load(vec_act_scale, iter); +#pragma unroll + for (int i = 0; i < CtaM; ++i) + { + act_iterator.load(tile_a + i * StepK, i); + } + act_scale_iterator.load(vec_act_scale); + + // Prefetch next stage + if (idx_k + CtaK < interleaved_k) + { +#pragma unroll + for (int i = 0; i < CtaN; ++i) + { + weight_iterator.copy_to_shmem(iter + 1, i); + } +#pragma unroll + for (int i = 0; i < CtaM; ++i) + { + act_iterator.copy_to_shmem(iter + 1, i); + } + act_scale_iterator.copy_to_shmem(iter + 1); + __pipeline_commit(); + } + #pragma unroll for (int i = 0; i < CtaN; ++i) { - weight_iterator.load(tile_w_quantized, iter, i); dequantize( - tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); + tile_w, tile_w_quantized + i * StepK / Details::kElemsPerByteW, vec_scale + i, vec_zero + i, alpha); pack_to_vec2(tile_w_pack2, tile_w, i); } #pragma unroll for (int i = 0; i < CtaM; ++i) { - act_iterator.load(tile_a, iter, i); - apply_scale(tile_a, vec_act_scale); - mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); + apply_scale(tile_a + i * StepK, vec_act_scale); + mma(tile_acc + i * CtaN, tile_w_pack2, tile_a + i * StepK); } } epilogue(out, n, tile_acc, bias, alpha); diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index 470aab759..6d1d198cd 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -301,11 +301,150 @@ class GMemIterator } } + __device__ __forceinline__ void copy_to_shmem(int iter) + { + // Do nothing + } + private: T* addr_; int step_; int stride_; }; + +template +class SHMemIterator +{ + using T = typename ShTraits::T; + using TVec = typename ShTraits::TVec; + +public: + __device__ __forceinline__ SHMemIterator(T* g_addr, int g_offset, T* sh_addr, int sh_offset, + int g_step, int g_stride, float k_max_iter) + : g_addr_(Enable ? (g_addr + g_offset) : nullptr) + , sh_addr_(Enable ? sh_addr : nullptr) + , sh_offset_(sh_offset) + , g_step_(g_step) + , g_stride_(g_stride) + , k_max_iter_(k_max_iter) + { + } + + __device__ __forceinline__ void copy_to_shmem(int iter, int ii = 0) + // __pipeline_memcpy_async will use synced version in sm < 80 + { + if constexpr (Enable) + { + if constexpr (Grouped == 0) + { // Grouped == 0 is for weight / act case + int const i = threadIdx.x; + if constexpr (ShTraits::GtoShStrided) + { // W, A + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + ii * sh_stride_) + i, + reinterpret_cast(g_addr_ + iter * g_step_ + ii * g_stride_) + i, + sizeof(TVec) + ); + } + else + { // As + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_) + i, + reinterpret_cast(g_addr_ + iter * g_step_) + i, + sizeof(TVec) + ); + } + } + else + { // Grouped != 0 is for scale / zero + if constexpr (Elements < VecSize) + { // Uncommon slow case + static_assert(c_sh % 4 == 0); + int const s = threadIdx.x % Grouped; + if (s + iter <= k_max_iter_) + { + static_assert(!ShTraits::GtoShStrided); + __pipeline_memcpy_async( + sh_addr_ + s * sh_step_, + g_addr_ + (iter + s) * g_step_, + c_sh + ); + } + } + else + { + // s should be float to compare with k_max_iter_ + if (threadIdx.x % Grouped / c_sh + iter <= k_max_iter_) + { + int const s = threadIdx.x % Grouped / c_sh; + int const i = threadIdx.x % c_sh; + static_assert(!ShTraits::GtoShStrided); + __pipeline_memcpy_async( + reinterpret_cast(sh_addr_ + s * sh_step_) + i, + reinterpret_cast(g_addr_ + (iter + s) * g_step_) + i, + sizeof(TVec) + ); + } + } + } + } + } + + __device__ __forceinline__ void load(void* dst, int ii = 0, int iter = 0) + { + if constexpr (Enable) { + if constexpr (Continuous < VecSize) + { +#pragma unroll + for (int jj = 0; jj < Continuous; ++jj) + { + reinterpret_cast(dst)[jj] = sh_addr_[iter * sh_step_ + ii * sh_stride_ + sh_offset_ + jj]; + } + } + else + { + static_assert(Continuous % VecSize == 0); +#pragma unroll + for (int jj = 0; jj < c_load; ++jj) + { + if constexpr (sh_step_ == 0) + { + if constexpr (sh_stride_ == 0) + { + reinterpret_cast(dst)[jj] = reinterpret_cast(sh_addr_ + sh_offset_)[jj]; + } + else + { + reinterpret_cast(dst)[jj] = reinterpret_cast(sh_addr_ + ii * sh_stride_ + sh_offset_)[jj]; + } + } + else + { + reinterpret_cast(dst)[jj] = reinterpret_cast(sh_addr_ + iter * sh_step_ + ii * sh_stride_ + sh_offset_)[jj]; + } + } + } + } + } + +private: + T* g_addr_; + T* sh_addr_; + int sh_offset_; + int g_step_; + int g_stride_; + // Decimal value represents that the last k iteration will only use a few warps + float k_max_iter_; + + static constexpr int VecSize = ShTraits::VecSize; + static constexpr int c_sh = ShTraits::c_sh; + static constexpr int c_load = ShTraits::c_load; + static constexpr int Continuous = ShTraits::Continuous; + static constexpr int Elements = ShTraits::Elements; + static constexpr int sh_step_ = ShTraits::ShStep; + static constexpr int sh_stride_ = ShTraits::ShStride; +}; + } // namespace weight_only } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu index 1fe21bd91..339c432b1 100644 --- a/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu +++ b/cpp/tensorrt_llm/plugins/gemmSwigluPlugin/gemmSwigluPlugin.cu @@ -36,6 +36,6 @@ void GemmSwigluPluginProfiler::initTmpData(int m, int n, int k, char* workspace, if (mType == nvinfer1::DataType::kFP8) { cutlass::reference::device::BlockFillRandomUniform(reinterpret_cast(workspace), - m * k + n * k + 1 * n, 42, cutlass::float_e4m3_t{128}, -cutlass::float_e4m3_t{128}, -1, stream); + m * k + n * k + 1 * n, 42, cutlass::float_e4m3_t{128}, -cutlass::float_e4m3_t{128}, -1, 0, stream); } } diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt index 65f54c0c3..bda47b137 100644 --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -23,6 +23,7 @@ else() endif() find_package(pybind11 REQUIRED) +find_package(Python COMPONENTS Interpreter Development REQUIRED) set(SRCS bindings.cpp diff --git a/examples/summarize.py b/examples/summarize.py index 1ede00e7d..ead0fd316 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -107,6 +107,7 @@ def main(args): dataset_name = f"{args.dataset_dir}/{dataset_name}" dataset = load_dataset(dataset_name, dataset_revision, + trust_remote_code=True, cache_dir=args.dataset_cache_dir, split=dataset_split) diff --git a/tensorrt_llm/parameter.py b/tensorrt_llm/parameter.py index f405387a0..bbbd6143c 100644 --- a/tensorrt_llm/parameter.py +++ b/tensorrt_llm/parameter.py @@ -225,8 +225,8 @@ def set_name(self, name: str, network): def _get_weights(self, network) -> trt.Weights | Tensor | None: tensor = network.get_parameter_tensor(self) if self.is_managed(network): - return tensor - elif tensor is not None: +# return tensor + # elif tensor is not None: tensor.producer.__class__ = trt.IConstantLayer return tensor.producer.weights else: