@@ -28,28 +28,28 @@ void reshape_paged_cache(torch::Tensor& key,
2828 key, value, k_cache, v_cache, slot_mapping, direction);
2929}
3030
31- void flash_attention (const torch::Tensor& query,
32- const torch::Tensor& key,
33- const torch::Tensor& value,
34- torch::Tensor& output,
35- torch::Tensor& output_lse,
36- const std::optional<torch::Tensor>& query_start_loc,
37- const std::optional<torch::Tensor>& seq_start_loc,
38- const std::optional<torch::Tensor>& alibi_slope,
39- const std::optional<torch::Tensor>& attn_bias,
40- const std::optional<torch::Tensor>& q_quant_scale,
41- const std::optional<torch::Tensor>& k_quant_scale,
42- const std::optional<torch::Tensor>& v_quant_scale,
43- const std::optional<torch::Tensor>& out_quant_scale,
44- const std::optional<torch::Tensor>& block_table,
45- int max_query_len,
46- int max_seq_len,
47- float scale,
48- bool is_causal,
49- int window_size_left,
50- int window_size_right,
51- const std::string& compute_dtype,
52- bool return_lse) {
31+ void batch_prefill (const torch::Tensor& query,
32+ const torch::Tensor& key,
33+ const torch::Tensor& value,
34+ torch::Tensor& output,
35+ torch::Tensor& output_lse,
36+ const std::optional<torch::Tensor>& query_start_loc,
37+ const std::optional<torch::Tensor>& seq_start_loc,
38+ const std::optional<torch::Tensor>& alibi_slope,
39+ const std::optional<torch::Tensor>& attn_bias,
40+ const std::optional<torch::Tensor>& q_quant_scale,
41+ const std::optional<torch::Tensor>& k_quant_scale,
42+ const std::optional<torch::Tensor>& v_quant_scale,
43+ const std::optional<torch::Tensor>& out_quant_scale,
44+ const std::optional<torch::Tensor>& block_table,
45+ int max_query_len,
46+ int max_seq_len,
47+ float scale,
48+ bool is_causal,
49+ int window_size_left,
50+ int window_size_right,
51+ const std::string& compute_dtype,
52+ bool return_lse) {
5353 tmo::torch_api::flash_attention (query,
5454 key,
5555 value,
@@ -74,27 +74,26 @@ void flash_attention(const torch::Tensor& query,
7474 return_lse);
7575}
7676
77- void single_query_cached_kv_attn (
78- const torch::Tensor& query,
79- const torch::Tensor& k_cache,
80- torch::Tensor& output,
81- const torch::Tensor& block_table,
82- const torch::Tensor& seq_lens,
83- const torch::Tensor& v_cache,
84- torch::Tensor& output_lse,
85- const std::optional<torch::Tensor>& q_quant_scale,
86- const std::optional<torch::Tensor>& k_cache_quant_scale,
87- const std::optional<torch::Tensor>& v_cache_quant_scale,
88- const std::optional<torch::Tensor>& out_quant_scale,
89- const std::optional<torch::Tensor>& alibi_slope,
90- const std::optional<torch::Tensor>& mask,
91- const std::string& compute_dtype,
92- int max_seq_len,
93- int window_size_left,
94- int window_size_right,
95- float scale,
96- bool return_lse,
97- int kv_cache_quant_bit_size) {
77+ void batch_decode (const torch::Tensor& query,
78+ const torch::Tensor& k_cache,
79+ torch::Tensor& output,
80+ const torch::Tensor& block_table,
81+ const torch::Tensor& seq_lens,
82+ const torch::Tensor& v_cache,
83+ torch::Tensor& output_lse,
84+ const std::optional<torch::Tensor>& q_quant_scale,
85+ const std::optional<torch::Tensor>& k_cache_quant_scale,
86+ const std::optional<torch::Tensor>& v_cache_quant_scale,
87+ const std::optional<torch::Tensor>& out_quant_scale,
88+ const std::optional<torch::Tensor>& alibi_slope,
89+ const std::optional<torch::Tensor>& mask,
90+ const std::string& compute_dtype,
91+ int max_seq_len,
92+ int window_size_left,
93+ int window_size_right,
94+ float scale,
95+ bool return_lse,
96+ int kv_cache_quant_bit_size) {
9897 tmo::torch_api::single_query_cached_kv_attn (query,
9998 k_cache,
10099 output,
0 commit comments