diff --git a/.gitignore b/.gitignore index dbde1b11..a343313a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ build/ **.so *.hip -*_hip.* \ No newline at end of file +*_hip.* +.idea/ +dist/ \ No newline at end of file diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index c720ba28..18539d69 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ + *****************************************************************************/ #pragma once @@ -9,6 +9,10 @@ #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include // For atomicAdd on complex +#ifndef M_LOG2E +#define M_LOG2E 1.4426950408889634074f +#endif + #ifndef USE_ROCM #include #include @@ -28,6 +32,20 @@ template __device__ __forceinline__ scalar_t conj(scalar_t x) template<> __device__ __forceinline__ float conj(float x) { return x; } template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } +// Helper: set the kernel's max dynamic shared memory size. +// This helper is defined at global scope so that preprocessor directives are not inside lambdas. +template +__host__ inline void setDynamicSharedMemoryAttr(KernelT kernel, int smemSize) { + if (smemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); + #else + C10_CUDA_CHECK(cudaFuncSetAttribute((void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); + std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } +} + template struct Selective_Scan_bwd_kernel_traits { @@ -94,10 +112,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { // Shared memory. extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); @@ -158,7 +172,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { u -= kChunkSize; __syncthreads(); load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - // Will reload delta at the same location if kDeltaSoftplus if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } __syncthreads(); load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); @@ -198,13 +211,10 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { } __syncthreads(); store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); - if (params.out_z_ptr != nullptr) { // Recompute and store out_z + if (params.out_z_ptr != nullptr) { float out_z_vals[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); - // } input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + dim_id * params.out_z_d_stride + chunk * kChunkSize; __syncthreads(); @@ -245,7 +255,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); } - // const weight_t A_val = smem_a[state_idx]; scan_t thread_data[kNItems], thread_reverse_data[kNItems]; if constexpr (!kIsComplex) { #pragma unroll @@ -266,7 +275,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; - // Initialize running total scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -289,9 +297,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; dA_val += dx * delta_vals[i] * a; if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val + if constexpr (!kIsVariableB) { dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); - } else { // dBC_val is dC_val + } else { dBC_val += dout_vals[i] * thread_data[i].y; } } @@ -300,7 +308,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); } } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower if constexpr (kIsVariableB || kIsVariableC) { if constexpr (kIsVariableB) { typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); @@ -336,7 +343,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { } else { #pragma unroll for (int i = 0; i < kNItems; ++i) { - // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); @@ -359,7 +365,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; - // Initialize running total scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -379,9 +384,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val + if constexpr (!kIsVariableB) { dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); - } else { // dBC_val is dC_val + } else { dBC_val += (2 * dout_vals[i]) * conj(x); } } @@ -394,7 +399,6 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); } } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower if constexpr (kIsVariableB || kIsVariableC) { float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; if constexpr (kIsVariableB) { @@ -431,7 +435,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; } } else { dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); @@ -502,27 +506,10 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { using Ktraits = Selective_Scan_bwd_kernel_traits; - // using Ktraits = Selective_Scan_bwd_kernel_traits; - // TODO: check this constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); - dim3 grid(params.batch, params.dim); - auto kernel = &selective_scan_bwd_kernel; - - if (kSmemSize >= 48 * 1024) { - - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - - } - + setDynamicSharedMemoryAttr(kernel, kSmemSize); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -534,7 +521,6 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { - #ifndef USE_ROCM if (params.seqlen <= 128) { selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); @@ -547,7 +533,7 @@ void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { } else { selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); } - #else + #else if (params.seqlen <= 256) { selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream); } else if (params.seqlen <= 512) { @@ -558,4 +544,4 @@ void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); } #endif -} \ No newline at end of file +} diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 80e9e37e..6b340634 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ + *****************************************************************************/ #pragma once @@ -8,6 +8,11 @@ #include #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +// Define math constant if not already defined +#ifndef M_LOG2E +#define M_LOG2E 1.4426950408889634074f +#endif + #ifndef USE_ROCM #include #include @@ -21,6 +26,20 @@ #include "selective_scan_common.h" #include "static_switch.h" +// Helper: set the kernel's max dynamic shared memory size. +// This is defined at global scope so that preprocessor directives are not inside a lambda. +template +__host__ inline void setDynamicSharedMemoryAttr(KernelT kernel, int smemSize) { + if (smemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); + #else + C10_CUDA_CHECK(cudaFuncSetAttribute((void *)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); + std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } +} + template @@ -86,17 +105,11 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Shared memory. extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); - // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); const int batch_id = blockIdx.x; @@ -128,11 +141,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - constexpr int kChunkSize = kNThreads * kNItems; for (int chunk = 0; chunk < params.n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; @@ -148,7 +156,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } u += kChunkSize; delta += kChunkSize; - + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -178,9 +186,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { A_val[r].real_ *= kLog2e; } } - // This variable holds B * C if both B and C are constant across seqlen. If only B varies - // across seqlen, this holds C. If only C varies across seqlen, this holds B. - // If both B and C vary, this is unused. weight_t BC_val[kNRows]; weight_t B_vals[kNItems], C_vals[kNItems]; if constexpr (kIsVariableB) { @@ -213,46 +218,39 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { #pragma unroll for (int r = 0; r < kNRows; ++r) { - if (r > 0) { __syncthreads(); } // Scan could be using the same smem + if (r > 0) { __syncthreads(); } scan_t thread_data[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { if constexpr (!kIsComplex) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if constexpr (!Ktraits::kIsEvenLen) { if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); } } } else { - // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if constexpr (!Ktraits::kIsEvenLen) { if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); } } } } - // Initialize running total scan_t running_prefix; if constexpr (!kIsComplex) { - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); } else { running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); } SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( thread_data, thread_data, SSMScanOp(), prefix_op ); - // There's a syncthreads in the scan op, so we don't need to sync here. - // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; @@ -270,7 +268,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } } - + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; __syncthreads(); @@ -309,35 +307,21 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { template void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { - // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block - // processing 1 row. - constexpr int kNRows = 1; + const static int kNRows = 1; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { using Ktraits = Selective_Scan_fwd_kernel_traits; - - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - dim3 grid(params.batch, params.dim / kNRows); - // Had to change this substantially since potentially the hip - // interface for setting kernel launch attributes is slightly different from - // cuda's. In particualar, it seems to expect a plain const void * pointer. + const static int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + + dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; - - if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - } + // Set the kernel launch attribute using the helper function. + setDynamicSharedMemoryAttr(kernel, kSmemSize); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h index 7920ac04..87493ef0 100644 --- a/csrc/selective_scan/static_switch.h +++ b/csrc/selective_scan/static_switch.h @@ -16,10 +16,10 @@ #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ - constexpr bool CONST_NAME = true; \ + const static bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ - constexpr bool CONST_NAME = false; \ + const static bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c51ec40d..95ca87be 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -107,7 +107,7 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, diff --git a/pyproject.toml b/pyproject.toml index 5831fe66..4de24a4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,47 +1,52 @@ +[build-system] +requires = [ + "setuptools>=61.0", + "wheel", + "torch>=2.4.0", + "packaging", + "ninja" +] +build-backend = "setuptools.build_meta" + [project] name = "mamba_ssm" -description = "Mamba state-space model" +dynamic = ["version"] +description = "Efficient implementation of selective state space models (Mamba)" readme = "README.md" +requires-python = ">=3.8" +license = {text = "MIT"} authors = [ - { name = "Tri Dao", email = "tri@tridao.me" }, - { name = "Albert Gu", email = "agu@cs.cmu.edu" } + {name = "Albert Gu", email = "albertgu@stanford.edu"}, + {name = "Tri Dao", email = "trid@cs.stanford.edu"} ] -requires-python = ">= 3.9" -dynamic = ["version"] -license = { file = "LICENSE" } # Include a LICENSE file in your repo -keywords = ["cuda", "pytorch", "state-space model"] classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - "Operating System :: Unix" + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "torch", - "triton", - "ninja", + "torch>=2.4.0", "einops", - "transformers", - "packaging", - "setuptools>=61.0.0", + "transformers>=4.51.3", + "triton-windows; platform_system=='Windows'", + "triton; platform_system!='Windows'" ] + [project.urls] -Repository = "https://github.com/state-spaces/mamba" +"Homepage" = "https://github.com/state-spaces/mamba" +"Bug Tracker" = "https://github.com/state-spaces/mamba/issues" -[project.optional-dependencies] -causal-conv1d = [ - "causal-conv1d>=1.2.0" -] -dev = [ - "pytest" -] +[tool.setuptools] +include-package-data = true +[tool.setuptools.dynamic] +version = {attr = "mamba_ssm.__version__"} -[build-system] -requires = [ - "setuptools>=61.0.0", - "wheel", - "torch", - "packaging", - "ninja", -] -build-backend = "setuptools.build_meta" +[tool.setuptools.packages.find] +include = ["mamba_ssm*"] diff --git a/setup.py b/setup.py index 7c6196d7..b2ca4e40 100755 --- a/setup.py +++ b/setup.py @@ -207,6 +207,7 @@ def append_nvcc_threads(nvcc_extra_args): "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-fgpu-flush-denormals-to-zero", + "-DWIN32_LEAN_AND_MEAN", ] + cc_flag, } @@ -367,12 +368,13 @@ def run(self): }, python_requires=">=3.9", install_requires=[ - "torch", + "torch>=2.4.0", "packaging", "ninja", "einops", - "triton", - "transformers", - # "causal_conv1d>=1.4.0", + "transformers>=4.51.3" ], + extras_require={'win32': ['triton-windows'], + 'linux': ['triton'] + } )