diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu new file mode 100644 index 000000000000..fae627be657f --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -0,0 +1,322 @@ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu +// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu +#include +#include +#include + +#include "causal_conv1d.h" +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "../static_switch.h" + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +// (APC writeback is implemented in Python wrapper for functional equivalence) +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device tensors + const at::Tensor& x, + const at::Tensor& weight, + const at::Tensor& out, + // optional pointers (can be nullptr) + void* bias_ptr, + bool silu_activation, + int64_t pad_slot_id, + void* query_start_loc_ptr, + void* cache_indices_ptr, + void* has_initial_state_ptr) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + params.pad_slot_id = pad_slot_id; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = const_cast(x.const_data_ptr()); + params.weight_ptr = const_cast(weight.const_data_ptr()); + params.bias_ptr = bias_ptr; + params.out_ptr = const_cast(out.const_data_ptr()); + // All stride are in elements, not bytes. + params.query_start_loc_ptr = query_start_loc_ptr; + params.cache_indices_ptr = cache_indices_ptr; + params.has_initial_state_ptr = has_initial_state_ptr; + const bool varlen = params.query_start_loc_ptr != nullptr; + params.x_batch_stride = x.stride(varlen ? 1 : 0); + params.x_c_stride = x.stride(varlen ? 0 : 1); + params.x_l_stride = x.stride(varlen ? 1 : -1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(varlen ? 1 : 0); + params.out_c_stride = out.stride(varlen ? 0 : 1); + params.out_l_stride = out.stride(varlen ? 1 : -1); +} + + + + + +void causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const std::optional &bias_, + bool silu_activation, + const std::optional &cache_seqlens_, + const std::optional &conv_state_indices_, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + const int conv_state_len = conv_state.size(2); + TORCH_CHECK(conv_state_len >= width - 1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = x; + + ConvParamsBase params; + void* bias_ptr = bias_.has_value() ? bias_.value().data_ptr() : nullptr; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_ptr, + silu_activation, + pad_slot_id, + /*query_start_loc_ptr*/ nullptr, + /*cache_indices_ptr*/ nullptr, + /*has_initial_state_ptr*/ nullptr); + params.conv_state_ptr = conv_state.data_ptr(); + params.conv_state_len = conv_state_len; + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + if (cache_seqlens_.has_value()) { + auto cache_seqlens = cache_seqlens_.value(); + TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); + TORCH_CHECK(cache_seqlens.is_cuda()); + TORCH_CHECK(cache_seqlens.stride(-1) == 1); + CHECK_SHAPE(cache_seqlens, batch_size); + params.cache_seqlens = cache_seqlens.data_ptr(); + } else { + params.cache_seqlens = nullptr; + } + + if (conv_state_indices_.has_value()) { + auto conv_state_indices = conv_state_indices_.value(); + TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32); + TORCH_CHECK(conv_state_indices.is_cuda()); + TORCH_CHECK(conv_state_indices.stride(0) == 1); + CHECK_SHAPE(conv_state_indices, batch_size); + + int conv_state_entries = conv_state.size(0); + CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); + + params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + } else { + CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); + params.conv_state_indices_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); +} + + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + if (channel_id >= params.dim) return; + + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + + // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor + // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. + const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr + ? batch_id + : params.conv_state_indices_ptr[batch_id]; + + // Skip padding tokens when the selected coord equals pad_slot_id. + if (conv_state_batch_coord == params.pad_slot_id) { + #pragma unroll 2 + for (int i = 0; i < params.seqlen; ++i) { + out[i * params.out_l_stride] = input_t(0.f); + } + return; + } + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + int state_len = params.conv_state_len; + int advance_len = params.seqlen; + int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; + int update_idx = cache_seqlen - (kWidth - 1); + update_idx = update_idx < 0 ? update_idx + state_len : update_idx; + + float weight_vals[kWidth] = {0}; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + float x_vals[kWidth] = {0}; + if constexpr (!kIsCircularBuffer) { + #pragma unroll 2 + for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { + conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; + } + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { + input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; + if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { + conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; + } + x_vals[i] = float(state_val); + } + } else { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { + input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; + x_vals[i] = float(state_val); + } + } + #pragma unroll 2 + for (int i = 0; i < params.seqlen; ++i) { + input_t x_val = x[i * params.x_l_stride]; + if constexpr (!kIsCircularBuffer) { + if (i < advance_len && state_len - advance_len + i >= 0) { + conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; + } + } else { + conv_state[update_idx * params.conv_state_l_stride] = x_val; + ++update_idx; + update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; + } + x_vals[kWidth - 1] = float(x_val); + float out_val = bias_val; + #pragma unroll + for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + out[i * params.out_l_stride] = input_t(out_val); + // Shift the input buffer by 1 + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } + } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = params.cache_seqlens == nullptr + ? &causal_conv1d_update_kernel + : &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h new file mode 100644 index 000000000000..17bd59258dae --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -0,0 +1,158 @@ +// clang-format off +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h +#pragma once + +#include +#include + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + int64_t pad_slot_id; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + int conv_state_len; + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + void *__restrict__ query_start_loc_ptr; + void *__restrict__ has_initial_state_ptr; + void *__restrict__ cache_indices_ptr; + int32_t *__restrict__ cache_seqlens; + + // For the continuous batching case. Makes it so that the mamba state for + // the current batch doesn't need to be a contiguous tensor. + int32_t *__restrict__ conv_state_indices_ptr; + + void *__restrict__ seq_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; + + void * conv_states_ptr; + index_t conv_states_batch_stride; + index_t conv_states_l_stride; + index_t conv_states_c_stride; +}; + + +#ifndef USE_ROCM + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor_sync(uint32_t(-1), val, offset); + } + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor(val, offset); + } + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; \ No newline at end of file diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index d534e138d26d..bb69beb31d4e 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -23,7 +23,7 @@ #endif #include "selective_scan.h" -#include "static_switch.h" +#include "../static_switch.h" template& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); +// Update path for causal conv1d (decode). Writes results in-place to x. +void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, + const at::Tensor& weight, + const std::optional& bias_, + bool silu_activation, + const std::optional& cache_seqlens_, + const std::optional& conv_state_indices_, + int64_t pad_slot_id); + torch::Tensor dynamic_4bit_int_moe_cpu( torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8f091a429fbe..0485a2230018 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -614,6 +614,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + // Mamba causal conv1d CUDA kernels + ops.def( + "causal_conv1d_update(" + " Tensor! x, Tensor conv_state, Tensor weight, Tensor? bias," + " bool silu_activation, Tensor? cache_seqlens," + " Tensor? conv_state_indices, int pad_slot_id) -> ()"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + // Hadamard transforms ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9110b0573fc9..20083711d30f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1738,6 +1738,29 @@ def selective_scan_fwd( ) +def causal_conv1d_update_cuda( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, + pad_slot_id: int, +): + torch.ops._C.causal_conv1d_update( + x, + conv_state, + weight, + bias, + silu_activation, + cache_seqlens, + conv_state_indices, + pad_slot_id, + ) + return x + + # ROCm skinny gemms def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 83c2c5f11e18..d178ecb88ec0 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -1080,6 +1080,7 @@ def causal_conv1d_update( block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, validate_data=False, + use_cuda=False, ): """ x: Input tensor which can take the following shapes: @@ -1159,6 +1160,22 @@ def causal_conv1d_update( assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this + if use_cuda: + from vllm._custom_ops import causal_conv1d_update_cuda + + causal_conv1d_update_cuda( + x, + conv_state, + weight, + bias, + activation == "silu", + cache_seqlens=None, + conv_state_indices=conv_state_indices, + pad_slot_id=pad_slot_id, + ) + if unsqueeze: + x = x.squeeze(-1) + return x.to(original_x_dtype) # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index e81ad5f68d8f..4135af000d55 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -535,6 +535,7 @@ def _forward( : attn_metadata.num_decodes ], validate_data=True, + use_cuda=True, ) else: mixed_qkv_non_spec = None