From 96657a79286e5ce9c860bedbfeddf8d20e9570dc Mon Sep 17 00:00:00 2001 From: PityQAQ <623679890@qq.com> Date: Mon, 17 Jun 2024 00:53:43 +0800 Subject: [PATCH 1/4] Support position-index at the SSM operator level, enabling SSM operations on packed data without interference between token sequences. --- csrc/selective_scan/selective_scan.cpp | 38 ++++++++--- csrc/selective_scan/selective_scan.h | 1 + .../selective_scan_bwd_kernel.cuh | 64 ++++++++++++++----- csrc/selective_scan/selective_scan_common.h | 17 +++++ .../selective_scan_fwd_kernel.cuh | 63 ++++++++++++++---- mamba_ssm/modules/mamba_simple.py | 4 +- mamba_ssm/ops/selective_scan_interface.py | 43 +++++++------ 7 files changed, 172 insertions(+), 58 deletions(-) diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index cde867cd..3f70e825 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -78,8 +78,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, void* D_ptr, void* delta_bias_ptr, void* x_ptr, - bool has_z, - bool delta_softplus) { + bool has_z, + bool delta_softplus, + void* index_ptr) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -109,6 +110,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.x_ptr = x_ptr; params.z_ptr = has_z ? z.data_ptr() : nullptr; params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.index_ptr = index_ptr; + // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); params.A_dstate_stride = A.stride(1); @@ -173,7 +177,8 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms, void* ddelta_bias_ptr, bool has_z, bool delta_softplus, - bool recompute_out_z) { + bool recompute_out_z, + void* index_ptr) { // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, has_z ? out : dout, @@ -181,7 +186,7 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms, // If not recompute_out_z, pass dout instead of out_z. // This won't be used by the bwd kernel recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, index_ptr); if (!recompute_out_z) { params.out_z_ptr = nullptr; } // Set the pointers and strides. @@ -229,7 +234,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus) { + bool delta_softplus, + const c10::optional &index_) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -292,6 +298,12 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); CHECK_SHAPE(delta_bias, dim); } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } at::Tensor z, out_z; const bool has_z = z_.has_value(); @@ -319,7 +331,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, x.data_ptr(), has_z, - delta_softplus); + delta_softplus, + index_.has_value() ? index_.value().data_ptr() : nullptr); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing @@ -346,7 +359,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &out_, c10::optional &dz_, bool delta_softplus, - bool recompute_out_z) { + bool recompute_out_z, + const c10::optional &index_) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -414,8 +428,15 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, CHECK_SHAPE(delta_bias, dim); } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } at::Tensor z, out, dz, out_z; const bool has_z = z_.has_value(); + if (has_z) { z = z_.value(); TORCH_CHECK(z.scalar_type() == input_type); @@ -474,7 +495,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, dout, du, ddelta, dA, dB, dC, dz, D_.has_value() ? dD.data_ptr() : nullptr, delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, - has_z, delta_softplus, recompute_out_z); + has_z, delta_softplus, recompute_out_z, + index_.has_value() ? index_.value().data_ptr() : nullptr); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h index e2c7bcdb..725529bb 100644 --- a/csrc/selective_scan/selective_scan.h +++ b/csrc/selective_scan/selective_scan.h @@ -66,6 +66,7 @@ struct SSMParamsBase { void *__restrict__ x_ptr; void *__restrict__ z_ptr; void *__restrict__ out_z_ptr; + void *__restrict__ index_ptr; }; struct SSMParamsBwd: public SSMParamsBase { diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index 2ed10114..bb7326f6 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -24,7 +24,7 @@ template<> __device__ __forceinline__ float conj(float x) { return x; } template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } template + bool kDeltaSoftplus_, bool kHasZ_, bool kUseIndex_,typename input_t_, typename weight_t_> struct Selective_Scan_bwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; @@ -42,6 +42,7 @@ struct Selective_Scan_bwd_kernel_traits { static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. // For complex this would lead to massive register spilling, so we keep it at 2. static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; @@ -49,6 +50,8 @@ struct Selective_Scan_bwd_kernel_traits { using scan_t = std::conditional_t; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; using BlockStoreT = cub::BlockStore; @@ -80,6 +83,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; using input_t = typename Ktraits::input_t; @@ -94,6 +98,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); @@ -136,21 +141,30 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; - + int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; constexpr int kChunkSize = kNThreads * kNItems; u += (params.n_chunks - 1) * kChunkSize; + index += (params.n_chunks - 1) * kChunkSize; delta += (params.n_chunks - 1) * kChunkSize; dout += (params.n_chunks - 1) * kChunkSize; Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; input_t dout_vals_load[kNItems]; + int index_vals_load[kNItems]; __syncthreads(); load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); u -= kChunkSize; __syncthreads(); + if constexpr (kUseIndex) { + load_index(index, index_vals_load, smem_load_index, params.seqlen - chunk * kChunkSize); + index -= kChunkSize; + } + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); // Will reload delta at the same location if kDeltaSoftplus if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } @@ -244,8 +258,16 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { if constexpr (!kIsComplex) { #pragma unroll for (int i = 0; i < kNItems; ++i) { - const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[i] == 0) { + delta_a_exp = 0.f; + } + } thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; } else { @@ -332,6 +354,14 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { for (int i = 0; i < kNItems; ++i) { // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + + // Reset A bar for cumulative sequences (Complex) + if constexpr (kUseIndex) { + if (index_vals_load[i] == 0) { + delta_a_exp.real_ = 0.f; + delta_a_exp.imag_ = 0.f; + } + } weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); if (i == 0) { @@ -495,19 +525,21 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_bwd_kernel_traits; - // using Ktraits = Selective_Scan_bwd_kernel_traits; - // TODO: check this - constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim); - auto kernel = &selective_scan_bwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + // using Ktraits = Selective_Scan_bwd_kernel_traits; + // TODO: check this + constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/selective_scan/selective_scan_common.h b/csrc/selective_scan/selective_scan_common.h index 9140dcdf..17508166 100644 --- a/csrc/selective_scan/selective_scan_common.h +++ b/csrc/selective_scan/selective_scan_common.h @@ -162,6 +162,23 @@ inline __device__ void load_input(typename Ktraits::input_t *u, } } +template +inline __device__ void load_index(int *u, + int (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_index_vec = reinterpret_cast(smem_load_index); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); + } +} + template inline __device__ void load_weight(typename Ktraits::input_t *Bvar, typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 440a2091..c0a77af5 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -18,7 +18,7 @@ template + bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_> struct Selective_Scan_fwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; @@ -38,6 +38,7 @@ struct Selective_Scan_fwd_kernel_traits { static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; @@ -46,6 +47,9 @@ struct Selective_Scan_fwd_kernel_traits { using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; @@ -57,6 +61,8 @@ struct Selective_Scan_fwd_kernel_traits { using BlockScanT = cub::BlockScan; static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockLoadVecT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), sizeof(typename BlockStoreT::TempStorage), @@ -71,6 +77,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr bool kIsVariableB = Ktraits::kIsVariableB; constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; constexpr int kNRows = Ktraits::kNRows; @@ -87,6 +94,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); @@ -107,7 +115,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - + int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; + float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { #pragma unroll @@ -123,6 +132,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; @@ -131,6 +141,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr int kChunkSize = kNThreads * kNItems; for (int chunk = 0; chunk < params.n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + int index_vals_load[kNRows][kNItems]; + __syncthreads(); #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -140,6 +152,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); if constexpr (!kDirectIO) { __syncthreads(); } load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kUseIndex) { + load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); + } + } + if constexpr (kUseIndex) { + index += kChunkSize; } u += kChunkSize; delta += kChunkSize; @@ -211,10 +229,19 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (r > 0) { __syncthreads(); } // Scan could be using the same smem scan_t thread_data[kNItems]; #pragma unroll + // printf("blockIdx.x:%d,blockIdx.y:%d,threadIdx.x:%d \t index[0][0]: %d\tindex[0][1]: %d\tindex[0][2]: %d\tindex[0][3]: %d\n",blockIdx.x, blockIdx.y,threadIdx.x, index_vals_load[0][0],index_vals_load[0][1],index_vals_load[0][2],index_vals_load[0][3]); for (int i = 0; i < kNItems; ++i) { if constexpr (!kIsComplex) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; + } + } + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); @@ -225,6 +252,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; + thread_data[i].y = 0.f; + } + } if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); @@ -311,18 +344,20 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 4c8a3882..d6e3323c 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -116,9 +116,10 @@ def __init__( self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states, position_indices=None, inference_params=None): """ hidden_states: (B, L, D) + position_indices: (B, L) a tensor that stores the positional indexes of elements within sequences. Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape @@ -157,6 +158,7 @@ def forward(self, hidden_states, inference_params=None): self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, + position_indices = position_indices, ) else: x, z = xz.chunk(2, dim=1) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c3596bfe..b4a5ba6e 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -20,7 +20,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, position_indices = None): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -39,26 +39,26 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, position_indices) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, position_indices) return out if not return_last_state else (out, last_state) else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, position_indices) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + u, delta, A, B, C, D, delta_bias, x, position_indices= ctx.saved_tensors z = None out = None else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + u, delta, A, B, C, D, z, delta_bias, x, out, position_indices= ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the @@ -66,7 +66,8 @@ def backward(ctx, dout, *args): # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here + False, # option to recompute out_z, not used here + position_indices ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB @@ -76,20 +77,21 @@ def backward(ctx, dout, *args): dz, ddelta_bias if delta_bias is not None else None, None, + None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, position_indices = None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, position_indices) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, position_indices = None): """ u: r(B D L) delta: r(B D L) @@ -136,7 +138,10 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if position_indices is not None and position_indices[0,i] == 0: + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: @@ -164,7 +169,7 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + C_proj_bias=None, delta_softplus=True, position_indices=None, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ @@ -223,7 +228,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh if D is not None: D = D.contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, position_indices ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None @@ -232,7 +237,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out) + A, B, C, D, delta_bias, scan_intermediates, out, position_indices) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -241,7 +246,7 @@ def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out, position_indices) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) @@ -263,7 +268,7 @@ def backward(ctx, dout): dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, - True # option to recompute out_z + True, position_indices # option to recompute out_z ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None @@ -305,18 +310,18 @@ def backward(ctx, dout): dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None) + dB_proj_bias, dC_proj_bias, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, position_indices=None ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, position_indices) def mamba_inner_ref( From ac62320887829e81f10a45c720f9fc2477a10340 Mon Sep 17 00:00:00 2001 From: PityQAQ <623679890@qq.com> Date: Fri, 21 Jun 2024 21:15:34 +0800 Subject: [PATCH 2/4] Updated a test case, test_mamba_cu_seqlens_equivalence.py, for end-to-end pack experiments with the mamba block. Added support for position_indices in conv1d within mamba_inner_fn. The conv1d code can be found at https://github.com/ptxu78/causal-conv1d-pack/tree/feat/pack_with_position_indices. --- .../selective_scan_bwd_kernel.cuh | 3 +- csrc/selective_scan/selective_scan_common.h | 5 +- .../selective_scan_fwd_kernel.cuh | 13 +- mamba_ssm/ops/selective_scan_interface.py | 33 +++- .../ops/test_mamba_cu_seqlens_equivalence.py | 187 ++++++++++++++++++ 5 files changed, 227 insertions(+), 14 deletions(-) create mode 100644 tests/ops/test_mamba_cu_seqlens_equivalence.py diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index bb7326f6..15afa080 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -43,6 +43,7 @@ struct Selective_Scan_bwd_kernel_traits { static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; static constexpr bool kHasZ = kHasZ_; static constexpr bool kUseIndex = kUseIndex_; + static constexpr int kNLoadsIndex = kNItems / 4; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. // For complex this would lead to massive register spilling, so we keep it at 2. static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; @@ -51,7 +52,7 @@ struct Selective_Scan_bwd_kernel_traits { using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; using BlockLoadIndexT = cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; using BlockStoreT = cub::BlockStore; diff --git a/csrc/selective_scan/selective_scan_common.h b/csrc/selective_scan/selective_scan_common.h index 17508166..2a04eb97 100644 --- a/csrc/selective_scan/selective_scan_common.h +++ b/csrc/selective_scan/selective_scan_common.h @@ -169,10 +169,9 @@ inline __device__ void load_index(int *u, int seqlen) { if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_index_vec = reinterpret_cast(smem_load_index); - using vec_t = typename Ktraits::vec_t; Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( - reinterpret_cast(u), - reinterpret_cast(u_vals) + reinterpret_cast(u), + reinterpret_cast(u_vals) ); } else { Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index c0a77af5..e5cb47db 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -41,15 +41,15 @@ struct Selective_Scan_fwd_kernel_traits { static constexpr bool kUseIndex = kUseIndex_; static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - + static constexpr int kNLoadsIndex = kNItems / 4; using vec_t = typename BytesToType::Type; using scan_t = std::conditional_t; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; using BlockLoadIndexT = cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; @@ -116,7 +116,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; - + float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { #pragma unroll @@ -142,7 +142,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { for (int chunk = 0; chunk < params.n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; int index_vals_load[kNRows][kNItems]; - + __syncthreads(); #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -196,7 +196,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // If both B and C vary, this is unused. weight_t BC_val[kNRows]; weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { + if constexpr (kIsVariableB) { load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); if constexpr (!kIsVariableC) { @@ -229,7 +229,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (r > 0) { __syncthreads(); } // Scan could be using the same smem scan_t thread_data[kNItems]; #pragma unroll - // printf("blockIdx.x:%d,blockIdx.y:%d,threadIdx.x:%d \t index[0][0]: %d\tindex[0][1]: %d\tindex[0][2]: %d\tindex[0][3]: %d\n",blockIdx.x, blockIdx.y,threadIdx.x, index_vals_load[0][0],index_vals_load[0][1],index_vals_load[0][2],index_vals_load[0][3]); for (int i = 0; i < kNItems; ++i) { if constexpr (!kIsComplex) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index b4a5ba6e..c1c3d8af 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -14,7 +14,22 @@ causal_conv1d_cuda = None import selective_scan_cuda +class AlignTimer: + def __init__(self, message='kernel_no_name'): + self.message = message + def __enter__(self): + torch.cuda.synchronize() + self.starter = torch.cuda.Event(enable_timing=True) + self.starter.record() + return self + + def __exit__(self, type, value, traceback): + self.ender = torch.cuda.Event(enable_timing=True) + self.ender.record() + torch.cuda.synchronize() + self.time = self.starter.elapsed_time(self.ender) + print('{} uses time {:.4f} ms'.format(self.message, self.time)) class SelectiveScanFn(torch.autograd.Function): @@ -189,9 +204,11 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( - x, conv1d_weight, conv1d_bias, None, None, None, True + x, conv1d_weight, conv1d_bias, None, position_indices, None, None, True ) + # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. @@ -255,7 +272,7 @@ def backward(ctx, dout): dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( - x, conv1d_weight, conv1d_bias, None, None, None, True + x, conv1d_weight, conv1d_bias, None, position_indices, None, None, True ) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) @@ -302,10 +319,20 @@ def backward(ctx, dout): # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True + x, conv1d_weight, conv1d_bias, dconv1d_out, None, position_indices, None, None, dx, False, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + + # grad_results = {"dx":dx,"dconv1d_weight":dconv1d_weight,"dconv1d_bias":dconv1d_bias,"dx_proj_weight":dx_proj_weight,"ddelta_proj_weight":ddelta_proj_weight, + # "dout_proj_weight":dout_proj_weight,"dout_proj_bias":dout_proj_bias, + # "dA":dA, "dB":dB, "dC":dC, "dD":dD, "dB_proj_bias":dB_proj_bias,"dC_proj_bias":dC_proj_bias + # } + # if position_indices is None: + # torch.save(grad_results, 'no_position_grad_results.pt') + # else: + # torch.save(grad_results, 'use_position_grad_results.pt') + return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py new file mode 100644 index 00000000..94c955a2 --- /dev/null +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -0,0 +1,187 @@ +import random +import torch + +from mamba_ssm.modules.mamba_simple import Mamba + +class AlignTimer: + def __init__(self, message='kernel_no_name'): + self.message = message + + def __enter__(self): + torch.cuda.synchronize() + self.starter = torch.cuda.Event(enable_timing=True) + self.starter.record() + return self + + def __exit__(self, type, value, traceback): + self.ender = torch.cuda.Event(enable_timing=True) + self.ender.record() + torch.cuda.synchronize() + self.time = self.starter.elapsed_time(self.ender) + print('{} uses time {:.4f} ms'.format(self.message, self.time)) +''' +unpack function: convert packed_hidden_states (batch_size=1) to hidden_states +''' +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros(package_num * batch_size, seq_len, hidden_dim, dtype=packed_hidden_states.dtype, device=packed_hidden_states.device) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[j, cu_seqlens[i] : cu_seqlens[i + 1], :] + return hidden_states + + +''' +pack function: convert hidden_states to packed_hidden_states (batch_size=1) +''' +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device) + .unsqueeze(0) + .unsqueeze(2) + .repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size,-1, hidden_dim) + return packed_hidden_states + + +''' +Generate random cu_seqlens for testing +''' +def generate_random_cu_seqlens(seq_len, packages_num = 2): + + if packages_num > 1: + ret = sorted(random.sample(range(1, seq_len), packages_num - 1)) + else: + ret = [] + cu_seqlens = [0] + ret + [seq_len] + assert packages_num == len(cu_seqlens) - 1 + index = [] + for i in range(1, len(cu_seqlens)): + token_len = cu_seqlens[i] - cu_seqlens[i-1] + index.extend(list(range(token_len))) + return cu_seqlens, index + + +def main(): + # config tested with A100 + hidden_dim = 4 + seq_len = 1024 + batch_size = 2 + device='cuda' + + itype = torch.half + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + packages_num = 8 + # Generate random cu_seqlens for testing + cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num = packages_num) + cu_seqlens = torch.tensor(cu_seqlens).cuda() + index = torch.tensor(index, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1).contiguous().cuda() + print("cu_seqlens:", cu_seqlens, "index:",index) + # Generate packed_hidden_states with random values for testing + # packed_hidden_states (batch_size=1) should be forwarded with cu_seqlens + hidden_states_list = [torch.randn(l, hidden_dim, device=device) for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()] + packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) + packed_hidden_states = packed_hidden_states.expand(batch_size, -1,-1).contiguous() + # hidden_states should be forwarded without cu_seqlens + hidden_states = unpack(packed_hidden_states, cu_seqlens) + + + # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states + assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] + # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states + assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] + + + grads = {} + + + # creat one simple mamba block + mamba = Mamba( + # This module uses roughly 3 * expand * d_model^2 parameters + d_model=hidden_dim, # Model dimension d_model + d_state=16, # SSM state expansion factor + d_conv=4, # Local convolution width + expand=2, # Block expansion factor + ).to(device) + + mamba_ref = Mamba( + # This module uses roughly 3 * expand * d_model^2 parameters + d_model=hidden_dim, # Model dimension d_model + d_state=16, # SSM state expansion factor + d_conv=4, # Local convolution width + expand=2, # Block expansion factor + ).to(device) + mamba_ref.load_state_dict(mamba.state_dict()) + + # reference output for forwardding hidden_states + with AlignTimer("pack_fwd"): + out = mamba(packed_hidden_states, index) + + with AlignTimer("unpack_fwd"): + out_ref = mamba_ref(hidden_states) + out_ref_pack = pack(out_ref, cu_seqlens, batch_size) + + # with AlignTimer("unpack"): + # out_ref = mamba_ref(hidden_states) + # out_ref_pack = pack(out_ref, cu_seqlens, batch_size) + # output for forwardding packed_hidden_states with cu_seqlens + + + # Testing the max/mean diff + import numpy as np + np.testing.assert_allclose(out.detach().cpu().numpy(), out_ref_pack.detach().cpu().numpy(), rtol = rtol, atol=atol) + print(f'Output max diff: {(out - out_ref_pack).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref_pack).abs().mean().item()}') + assert torch.allclose(out, out_ref_pack, rtol=rtol, atol=atol) + + g = torch.randn(out.shape).to(device) + with AlignTimer("pack_bwd"): + out.backward(g) + gradients = {name: param.grad.clone() for name, param in mamba.named_parameters()} + + g_ref = unpack(g, cu_seqlens) + with AlignTimer("unpack_bwd"): + out_ref.backward(g_ref) + gradients_ref = {name: param.grad.clone() for name, param in mamba_ref.named_parameters()} + + + # 比较两组梯度 + for name in gradients_ref: + if name in gradients: + is_equal = torch.allclose(gradients_ref[name], gradients[name], rtol=rtol, atol=atol) + print(f"Gradients for {name} are {'equal' if is_equal else 'not equal'}") + if not is_equal: + print(f"Gradient difference for {name}: {torch.abs(gradients_ref[name] - gradients[name]).max()}") + else: + print(f"Parameter {name} not found in the second set of gradients") + + # grad_results = torch.load('use_position_grad_results.pt') + # grad_results_ref = torch.load('no_position_grad_results.pt') + # print(grad_results) + # for name in grad_results_ref: + # if name in grad_results and grad_results[name] is not None: + # is_equal = torch.allclose(grad_results_ref[name], grad_results[name], rtol=rtol, atol=atol) + # print(f"Gradients for {name} are {'equal' if is_equal else 'not equal'}") + # if not is_equal: + # print(f"Gradient difference for {name}: {torch.abs(grad_results_ref[name] - grad_results[name]).max()}") + # else: + # print(f"Parameter {name} not found in the second set of gradients") + + +if __name__ == "__main__": + main() \ No newline at end of file From 7f86e19a9ba3456b21d4e39b0a6671b939e02b2f Mon Sep 17 00:00:00 2001 From: PityQAQ <623679890@qq.com> Date: Wed, 26 Jun 2024 00:58:27 +0800 Subject: [PATCH 3/4] update test --- tests/ops/test_mamba_cu_seqlens_equivalence.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/ops/test_mamba_cu_seqlens_equivalence.py b/tests/ops/test_mamba_cu_seqlens_equivalence.py index 94c955a2..b57213b8 100644 --- a/tests/ops/test_mamba_cu_seqlens_equivalence.py +++ b/tests/ops/test_mamba_cu_seqlens_equivalence.py @@ -71,14 +71,11 @@ def generate_random_cu_seqlens(seq_len, packages_num = 2): return cu_seqlens, index -def main(): +def test_mamba_block(hidden_dim = 2048, seq_len = 4096, batch_size = 1, packages_num = 8): # config tested with A100 - hidden_dim = 4 - seq_len = 1024 - batch_size = 2 device='cuda' - itype = torch.half + itype = torch.bfloat16 rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 @@ -86,7 +83,6 @@ def main(): # If we have z, the errors on the weights seem higher rtolw = max(rtolw, rtol) atolw = max(atolw, atol) - packages_num = 8 # Generate random cu_seqlens for testing cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num = packages_num) cu_seqlens = torch.tensor(cu_seqlens).cuda() @@ -160,7 +156,6 @@ def main(): gradients_ref = {name: param.grad.clone() for name, param in mamba_ref.named_parameters()} - # 比较两组梯度 for name in gradients_ref: if name in gradients: is_equal = torch.allclose(gradients_ref[name], gradients[name], rtol=rtol, atol=atol) @@ -184,4 +179,9 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + # warm up + test_mamba_block(hidden_dim = 2048, seq_len = 4096, batch_size = 1, packages_num = 1) + # compare the duration of the pack process. + test_mamba_block(hidden_dim = 2048, seq_len = 4096, batch_size = 1, packages_num = 1) + # compare the acceleration ratio of the pack under common parameters. + test_mamba_block(hidden_dim = 2048, seq_len = 4096, batch_size = 1, packages_num = 8) \ No newline at end of file From 6b5f07cb81351ba724e938c1646613186360638e Mon Sep 17 00:00:00 2001 From: PityQAQ <623679890@qq.com> Date: Mon, 1 Jul 2024 16:43:57 +0800 Subject: [PATCH 4/4] update setup.py --- setup_onlyCUDA.py | 168 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 setup_onlyCUDA.py diff --git a/setup_onlyCUDA.py b/setup_onlyCUDA.py new file mode 100644 index 00000000..20b89f4e --- /dev/null +++ b/setup_onlyCUDA.py @@ -0,0 +1,168 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform +import shutil + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) + + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = "mamba_ssm" + +BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" +# SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return "linux_x86_64" + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + + +cmdclass = {} +ext_modules = [] + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +check_if_cuda_home_none(PACKAGE_NAME) +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.6"): + raise RuntimeError( + f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as +# torch._C._GLIBCXX_USE_CXX11_ABI +# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 +if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + +ext_modules.append( + CUDAExtension( + name="selective_scan_cuda", + sources=[ + "csrc/selective_scan/selective_scan.cpp", + "csrc/selective_scan/selective_scan_fwd_fp32.cu", + "csrc/selective_scan/selective_scan_fwd_fp16.cu", + "csrc/selective_scan/selective_scan_fwd_bf16.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + "-lineinfo", + ] + + cc_flag + ), + }, + include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], + ) +) + + + +setup( + name="selective_scan_cuda", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, +)