Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ struct Multihead_attention_params_base {
int batch_size = 0;
// The beam width
int beam_width = 0;
// The sequence length.
// The cache length.
int memory_max_len = 0;
// The whole sequence length, which includes context and output.
int session_len = 0;
// The number of heads (H).
int num_heads = 0;
// The hidden dimension per head (Dh).
Expand All @@ -91,6 +93,10 @@ struct Multihead_attention_params_base {
bool neox_rotary_style = false;
// The maximum length of input sentences.
int max_input_length = 0;
// The number of oldest cache element to pick the least important to replace when cache is full. If 0, it will fall back to circular cache.
int important_kv_cache_size = 0;
// The buffer to store the indices of the keys and values in the cache. The shape is [B, H, memory_max_len].
int* kv_indices = nullptr;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1159,8 +1159,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
Tk* out_smem = reinterpret_cast<Tk*>(smem_);

// The shared memory buffers for the block-wide reductions. One for max, one for sum.
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
// The shared memory buffers for the block-wide reductions. One for max, one for sum, one for min.
__shared__ float red_smem[WARPS_PER_BLOCK * 3];
// The shared memory buffers for the qk_min index block-wide reduction.
__shared__ int red_int_smem[WARPS_PER_BLOCK];

// A vector of Q or K elements for the current timestep.
using Qk_vec_k = typename Qk_vec_k_<T, Dh_MAX>::Type; // with kernel-used precision
Expand Down Expand Up @@ -1213,20 +1215,51 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,

// While doing the product Q*K^T for the different keys we track the max.
float qk_max = -FLT_MAX;
float qk_min = FLT_MAX;
int qk_min_idx = -1;

float qk = 0.0F;

int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;

const size_t bi_seq_len_offset = bi * params.memory_max_len;
const size_t bi_session_len_offset = bi * params.session_len;

// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
(params.length_per_sample == nullptr) ?
params.timestep :
params.length_per_sample[bi] + params.max_prefix_prompt_length;
const int first_step = max(0, tlength + 1 - params.memory_max_len);
const int tlength_circ = tlength % params.memory_max_len;
const int fifo_cache_size =
tlength >= params.memory_max_len && params.important_kv_cache_size > 0 ?
params.memory_max_len - params.important_kv_cache_size :
params.memory_max_len;

// Make sure the following params have correct values in all 5 cases:
// 1. full cache: tlength < params.memory_max_len
// 1.1 params.important_kv_cache_size == 0
// 1.2 params.important_kv_cache_size > 0
// 2. fifo cache: tlength >= params.memory_max_len, params.important_kv_cache_size == 0
// 3. important cache: tlength >= params.memory_max_len, params.important_kv_cache_size == params.memory_max_len
// 4. hybrid cache: tlength >= params.memory_max_len, 0 < params.important_kv_cache_size < params.memory_max_len

// const int tlength_circ_offset = tlength > params.memory_max_len ? params.important_kv_cache_size : 0;
// const int tlength_circ =
// fifo_cache_size > 0 ? tlength % fifo_cache_size + tlength_circ_offset : tlength % params.memory_max_len;

const int tlength_circ_with_important_cache =
fifo_cache_size > 0 && tlength >= params.memory_max_len ?
(tlength - params.important_kv_cache_size) % fifo_cache_size + params.important_kv_cache_size:
tlength % params.memory_max_len;

// tlength_circ is the index relative to the beginning of the k/v cache.
const int tlength_circ =
tlength >= params.memory_max_len ? tlength_circ_with_important_cache : tlength % params.memory_max_len;
const bool do_important_cache = handle_kv && params.important_kv_cache_size > 0 && tlength >= params.memory_max_len;
int* kv_index_buffer = params.kv_indices == nullptr ? nullptr : params.kv_indices + bhi * params.memory_max_len;

assert(!(do_important_cache && params.kv_indices == nullptr));

// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const bool is_masked = tidx >= QK_VECS_PER_WARP;
Expand All @@ -1239,6 +1272,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;

assert(!(do_ia3 && do_important_cache));
// Without the kv_indices buffer, important cache breaks the ordered index assumption,
// so it cannot be enabled with the below features.
assert(!(params.linear_bias_slopes != nullptr && do_important_cache && kv_indices == nullptr));
assert(!(params.relative_attention_bias != nullptr && do_important_cache && kv_indices == nullptr));

// Trigger the loads from the Q and K buffers.
Qk_vec_k q;
zero(q);
Expand Down Expand Up @@ -1384,6 +1423,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
__syncthreads();
}

Qk_vec_k fifo_out_k;
zero(fifo_out_k);

if (!is_masked) {
// Store the Q values to shared memory.
*reinterpret_cast<Qk_vec_k*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
Expand Down Expand Up @@ -1412,9 +1454,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
tlength_circ * QK_ELTS_IN_16B + ci;

if (handle_kv) {
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
if (!do_important_cache) {
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
}
else if (fifo_cache_size > 0) {
fifo_out_k =
(Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_cache[offset])) :
fifo_out_k;
}
}

Expand Down Expand Up @@ -1515,7 +1565,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,

for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
const int ti_circ = ti % params.memory_max_len;
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_session_len_offset + ti];

// The keys loaded from the key cache.
K_vec_k k[K_VECS_PER_THREAD];
Expand Down Expand Up @@ -1583,11 +1633,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// token: i i i i p p p o o o where i=input, p=pad, o=output.
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
const int real_index = do_important_cache ? kv_index_buffer[ti_circ] : ti;
float dist = (real_index < max_context_length ? real_index + padd_len : real_index) - tlength;

qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
}
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
if (!is_mask && do_important_cache && ti_circ < params.important_kv_cache_size) {
qk_min_idx = qk_min > qk ? ti_circ : qk_min_idx;
qk_min = qk_min > qk ? qk : qk_min;
}
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
qk_smem[ti - first_step] = qk;
}
}
Expand All @@ -1599,6 +1654,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
if (do_important_cache) {
float temp_qk_min = __shfl_xor_sync(uint32_t(-1), qk_min, mask);
int temp_qk_min_idx = __shfl_xor_sync(uint32_t(-1), qk_min_idx, mask);
qk_min_idx = qk_min > temp_qk_min ? temp_qk_min_idx : qk_min_idx;
qk_min = fminf(qk_min, temp_qk_min);
}
}

// Decompose the thread index into warp and lane.
Expand All @@ -1608,26 +1669,43 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The warp leader writes the max to shared memory.
if (lane == 0) {
red_smem[warp] = qk_max;
if (do_important_cache) {
red_int_smem[warp] = qk_min_idx;
red_smem[WARPS_PER_BLOCK * 2 + warp] = qk_min;
}
}

// Make sure the products are in shared memory.
__syncthreads();

// The warps finalize the reduction.
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
if (do_important_cache) {
qk_min_idx = lane < WARPS_PER_BLOCK ? red_int_smem[lane] : -1;
qk_min = lane < WARPS_PER_BLOCK ? red_smem[WARPS_PER_BLOCK * 2 + lane] : FLT_MAX;
}
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
if (do_important_cache) {
float temp_qk_min = __shfl_xor_sync(uint32_t(-1), qk_min, mask);
int temp_qk_min_idx = __shfl_xor_sync(uint32_t(-1), qk_min_idx, mask);
qk_min_idx = qk_min > temp_qk_min ? temp_qk_min_idx : qk_min_idx;
qk_min = fminf(qk_min, temp_qk_min);
}
}

// Broadcast to all the threads in the warp.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
if (do_important_cache) {
qk_min_idx = __shfl_sync(uint32_t(-1), qk_min_idx, 0);
}

// Compute the logits and start the sum.
float sum = 0.f;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_session_len_offset + ti];
#ifdef FP8_MHA
float logit = 0.f;
if (FP8_MHA_KERNEL) {
Expand Down Expand Up @@ -1685,6 +1763,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;

V_vec_k fifo_out_v;
zero(fifo_out_v);
if (do_important_cache && fifo_cache_size > 0) {
// Make sure the same group of threads handles the cache storage as later in the process.
if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
fifo_out_v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]));
}
}

// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
zero(v_bias);
Expand Down Expand Up @@ -1814,6 +1901,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
}
}

__syncthreads();

// One group of threads computes the product(s) for the current timestep.
// if( vo == params.timestep % V_PER_ITER ) {
if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
Expand Down Expand Up @@ -1852,9 +1941,27 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
&params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
}

// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
if (!do_important_cache || fifo_cache_size > 0) {
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
if (kv_index_buffer != nullptr && vi == 0) {
kv_index_buffer[tlength_circ] = tlength;
}
}
if (do_important_cache) {
assert(qk_min_idx < params.important_kv_cache_size);
assert(qk_min_idx >= 0);
// TO CHECK: If the whole cache has is_mask, qk_min_idx will be -1, and it will be writing to invalid cache space.
if (fifo_cache_size > 0) {
*reinterpret_cast<V_vec_m*>(&v_cache[qk_min_idx * Dh]) = vec_conversion<V_vec_m, V_vec_k>(fifo_out_v);
} else {
*reinterpret_cast<V_vec_m*>(&v_cache[qk_min_idx * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
}
if (vi == 0) {
kv_index_buffer[qk_min_idx] = tlength - fifo_cache_size;
}
}
}

// Initialize the output value with the current timestep.
Expand All @@ -1880,6 +1987,32 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS
}

if (!is_masked && do_important_cache) {
assert(qk_min_idx < params.important_kv_cache_size);
assert(qk_min_idx >= 0);

// The 16B chunk written by the thread.
int co = tidx / QK_VECS_IN_16B;
// The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;

// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int base_offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B;
int fifo_offset = base_offset + tlength_circ * QK_ELTS_IN_16B + ci;
int important_offset = base_offset + qk_min_idx * QK_ELTS_IN_16B + ci;

if (fifo_cache_size > 0) {
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[fifo_offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[important_offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(fifo_out_k);
}
} else {
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[important_offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
}
}

// Make sure we can start writing to shared memory.
__syncthreads();

Expand Down
20 changes: 20 additions & 0 deletions src/fastertransformer/kernels/gpt_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@

namespace fastertransformer {

__global__ void initiate_indices(int* kv_indices, const int loop_size, const size_t total_length)
{
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = idx; i < total_length; i += stride) {
kv_indices[i] = i % loop_size;
}
}

void invokeInitiateIndices(int* kv_indices,
const int loop_size,
const size_t total_length,
cudaStream_t stream)
{
const int block_size = 256;
const int grid_size = min(65535, static_cast<int>((total_length + block_size - 1) / block_size));

initiate_indices<<<grid_size, block_size, 0, stream>>>(kv_indices, loop_size, total_length);
}

// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T, bool OUTPUT_ID, int PROMPT_SRC>
__global__ void start_id_embedding_position_lookups_kernel(T* from_tensor,
Expand Down
5 changes: 5 additions & 0 deletions src/fastertransformer/kernels/gpt_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@

namespace fastertransformer {

void invokeInitiateIndices(int* kv_indices,
const int loop_size,
const size_t total_length,
cudaStream_t stream);

template<typename T>
struct inputIdsEmbeddingLookupPosEncodingSoftPromptParam {
T* from_tensor;
Expand Down
Loading