Skip to content

Commit a5fa4d8

Browse files
committed
refactor: rename attention func.
1 parent ffb63b4 commit a5fa4d8

File tree

5 files changed

+138
-140
lines changed

5 files changed

+138
-140
lines changed

xllm/core/kernels/mlu/attention.cpp

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,28 @@ void reshape_paged_cache(torch::Tensor& key,
2828
key, value, k_cache, v_cache, slot_mapping, direction);
2929
}
3030

31-
void flash_attention(const torch::Tensor& query,
32-
const torch::Tensor& key,
33-
const torch::Tensor& value,
34-
torch::Tensor& output,
35-
torch::Tensor& output_lse,
36-
const std::optional<torch::Tensor>& query_start_loc,
37-
const std::optional<torch::Tensor>& seq_start_loc,
38-
const std::optional<torch::Tensor>& alibi_slope,
39-
const std::optional<torch::Tensor>& attn_bias,
40-
const std::optional<torch::Tensor>& q_quant_scale,
41-
const std::optional<torch::Tensor>& k_quant_scale,
42-
const std::optional<torch::Tensor>& v_quant_scale,
43-
const std::optional<torch::Tensor>& out_quant_scale,
44-
const std::optional<torch::Tensor>& block_table,
45-
int max_query_len,
46-
int max_seq_len,
47-
float scale,
48-
bool is_causal,
49-
int window_size_left,
50-
int window_size_right,
51-
const std::string& compute_dtype,
52-
bool return_lse) {
31+
void batch_prefill(const torch::Tensor& query,
32+
const torch::Tensor& key,
33+
const torch::Tensor& value,
34+
torch::Tensor& output,
35+
torch::Tensor& output_lse,
36+
const std::optional<torch::Tensor>& query_start_loc,
37+
const std::optional<torch::Tensor>& seq_start_loc,
38+
const std::optional<torch::Tensor>& alibi_slope,
39+
const std::optional<torch::Tensor>& attn_bias,
40+
const std::optional<torch::Tensor>& q_quant_scale,
41+
const std::optional<torch::Tensor>& k_quant_scale,
42+
const std::optional<torch::Tensor>& v_quant_scale,
43+
const std::optional<torch::Tensor>& out_quant_scale,
44+
const std::optional<torch::Tensor>& block_table,
45+
int max_query_len,
46+
int max_seq_len,
47+
float scale,
48+
bool is_causal,
49+
int window_size_left,
50+
int window_size_right,
51+
const std::string& compute_dtype,
52+
bool return_lse) {
5353
tmo::torch_api::flash_attention(query,
5454
key,
5555
value,
@@ -74,27 +74,26 @@ void flash_attention(const torch::Tensor& query,
7474
return_lse);
7575
}
7676

77-
void single_query_cached_kv_attn(
78-
const torch::Tensor& query,
79-
const torch::Tensor& k_cache,
80-
torch::Tensor& output,
81-
const torch::Tensor& block_table,
82-
const torch::Tensor& seq_lens,
83-
const torch::Tensor& v_cache,
84-
torch::Tensor& output_lse,
85-
const std::optional<torch::Tensor>& q_quant_scale,
86-
const std::optional<torch::Tensor>& k_cache_quant_scale,
87-
const std::optional<torch::Tensor>& v_cache_quant_scale,
88-
const std::optional<torch::Tensor>& out_quant_scale,
89-
const std::optional<torch::Tensor>& alibi_slope,
90-
const std::optional<torch::Tensor>& mask,
91-
const std::string& compute_dtype,
92-
int max_seq_len,
93-
int window_size_left,
94-
int window_size_right,
95-
float scale,
96-
bool return_lse,
97-
int kv_cache_quant_bit_size) {
77+
void batch_decode(const torch::Tensor& query,
78+
const torch::Tensor& k_cache,
79+
torch::Tensor& output,
80+
const torch::Tensor& block_table,
81+
const torch::Tensor& seq_lens,
82+
const torch::Tensor& v_cache,
83+
torch::Tensor& output_lse,
84+
const std::optional<torch::Tensor>& q_quant_scale,
85+
const std::optional<torch::Tensor>& k_cache_quant_scale,
86+
const std::optional<torch::Tensor>& v_cache_quant_scale,
87+
const std::optional<torch::Tensor>& out_quant_scale,
88+
const std::optional<torch::Tensor>& alibi_slope,
89+
const std::optional<torch::Tensor>& mask,
90+
const std::string& compute_dtype,
91+
int max_seq_len,
92+
int window_size_left,
93+
int window_size_right,
94+
float scale,
95+
bool return_lse,
96+
int kv_cache_quant_bit_size) {
9897
tmo::torch_api::single_query_cached_kv_attn(query,
9998
k_cache,
10099
output,

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -58,50 +58,49 @@ void reshape_paged_cache(torch::Tensor& key,
5858
const torch::Tensor& slot_mapping,
5959
bool direction);
6060

61-
void flash_attention(const torch::Tensor& query,
62-
const torch::Tensor& key,
63-
const torch::Tensor& value,
64-
torch::Tensor& output,
65-
torch::Tensor& output_lse,
66-
const std::optional<torch::Tensor>& query_start_loc,
67-
const std::optional<torch::Tensor>& seq_start_loc,
68-
const std::optional<torch::Tensor>& alibi_slope,
69-
const std::optional<torch::Tensor>& attn_bias,
70-
const std::optional<torch::Tensor>& q_quant_scale,
71-
const std::optional<torch::Tensor>& k_quant_scale,
72-
const std::optional<torch::Tensor>& v_quant_scale,
73-
const std::optional<torch::Tensor>& out_quant_scale,
74-
const std::optional<torch::Tensor>& block_tables,
75-
int max_query_len,
76-
int max_seq_len,
77-
float scale,
78-
bool is_causal,
79-
int window_size_left,
80-
int window_size_right,
81-
const std::string& compute_dtype,
82-
bool return_lse);
83-
84-
void single_query_cached_kv_attn(
85-
const torch::Tensor& query,
86-
const torch::Tensor& k_cache,
87-
torch::Tensor& output,
88-
const torch::Tensor& block_table,
89-
const torch::Tensor& seq_lens,
90-
const torch::Tensor& v_cache,
91-
torch::Tensor& output_lse,
92-
const std::optional<torch::Tensor>& q_quant_scale,
93-
const std::optional<torch::Tensor>& k_cache_quant_scale,
94-
const std::optional<torch::Tensor>& v_cache_quant_scale,
95-
const std::optional<torch::Tensor>& out_quant_scale,
96-
const std::optional<torch::Tensor>& alibi_slope,
97-
const std::optional<torch::Tensor>& mask,
98-
const std::string& compute_dtype,
99-
int max_seq_len,
100-
int window_size_left,
101-
int window_size_right,
102-
float scale,
103-
bool return_lse,
104-
int kv_cache_quant_bit_size);
61+
void batch_prefill(const torch::Tensor& query,
62+
const torch::Tensor& key,
63+
const torch::Tensor& value,
64+
torch::Tensor& output,
65+
torch::Tensor& output_lse,
66+
const std::optional<torch::Tensor>& query_start_loc,
67+
const std::optional<torch::Tensor>& seq_start_loc,
68+
const std::optional<torch::Tensor>& alibi_slope,
69+
const std::optional<torch::Tensor>& attn_bias,
70+
const std::optional<torch::Tensor>& q_quant_scale,
71+
const std::optional<torch::Tensor>& k_quant_scale,
72+
const std::optional<torch::Tensor>& v_quant_scale,
73+
const std::optional<torch::Tensor>& out_quant_scale,
74+
const std::optional<torch::Tensor>& block_tables,
75+
int max_query_len,
76+
int max_seq_len,
77+
float scale,
78+
bool is_causal,
79+
int window_size_left,
80+
int window_size_right,
81+
const std::string& compute_dtype,
82+
bool return_lse);
83+
84+
void batch_decode(const torch::Tensor& query,
85+
const torch::Tensor& k_cache,
86+
torch::Tensor& output,
87+
const torch::Tensor& block_table,
88+
const torch::Tensor& seq_lens,
89+
const torch::Tensor& v_cache,
90+
torch::Tensor& output_lse,
91+
const std::optional<torch::Tensor>& q_quant_scale,
92+
const std::optional<torch::Tensor>& k_cache_quant_scale,
93+
const std::optional<torch::Tensor>& v_cache_quant_scale,
94+
const std::optional<torch::Tensor>& out_quant_scale,
95+
const std::optional<torch::Tensor>& alibi_slope,
96+
const std::optional<torch::Tensor>& mask,
97+
const std::string& compute_dtype,
98+
int max_seq_len,
99+
int window_size_left,
100+
int window_size_right,
101+
float scale,
102+
bool return_lse,
103+
int kv_cache_quant_bit_size);
105104

106105
void fused_layernorm(const torch::Tensor& input,
107106
torch::Tensor& output,

xllm/core/kernels/ops_api.cpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -69,67 +69,67 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) {
6969
#endif
7070
}
7171

72-
void prefill_attention(AttentionParams& params) {
72+
void batch_prefill(AttentionParams& params) {
7373
#if defined(USE_MLU)
7474
torch::Tensor lse = params.output_lse.value_or(torch::Tensor());
75-
mlu::flash_attention(params.query,
76-
params.key,
77-
params.value,
78-
params.output,
79-
lse,
80-
params.query_start_loc,
81-
params.seq_start_loc,
82-
params.alibi_slope,
83-
params.attn_bias,
84-
params.q_quant_scale,
85-
params.k_quant_scale,
86-
params.v_quant_scale,
87-
params.out_quant_scale,
88-
params.block_table,
89-
params.max_query_len,
90-
params.max_seq_len,
91-
params.scale,
92-
params.is_causal,
93-
params.window_size_left,
94-
params.window_size_right,
95-
params.compute_dtype,
96-
params.return_lse);
75+
mlu::batch_prefill(params.query,
76+
params.key,
77+
params.value,
78+
params.output,
79+
lse,
80+
params.query_start_loc,
81+
params.seq_start_loc,
82+
params.alibi_slope,
83+
params.attn_bias,
84+
params.q_quant_scale,
85+
params.k_quant_scale,
86+
params.v_quant_scale,
87+
params.out_quant_scale,
88+
params.block_table,
89+
params.max_query_len,
90+
params.max_seq_len,
91+
params.scale,
92+
params.is_causal,
93+
params.window_size_left,
94+
params.window_size_right,
95+
params.compute_dtype,
96+
params.return_lse);
9797
params.output_lse = lse;
9898
#elif defined(USE_CUDA)
99-
throw std::runtime_error("prefill_attention for cuda not implemented");
99+
throw std::runtime_error("batch_prefill for cuda not implemented");
100100
#else
101-
throw std::runtime_error("prefill_attention not implemented");
101+
throw std::runtime_error("batch_prefill not implemented");
102102
#endif
103103
}
104104

105-
void decode_attention(AttentionParams& params) {
105+
void batch_decode(AttentionParams& params) {
106106
#if defined(USE_MLU)
107107
torch::Tensor lse = params.output_lse.value_or(torch::Tensor());
108-
mlu::single_query_cached_kv_attn(params.query,
109-
params.k_cache,
110-
params.output,
111-
params.block_table,
112-
params.seq_lens,
113-
params.v_cache,
114-
lse,
115-
params.q_quant_scale,
116-
params.k_cache_quant_scale,
117-
params.v_cache_quant_scale,
118-
params.out_quant_scale,
119-
params.alibi_slope,
120-
params.mask,
121-
params.compute_dtype,
122-
params.max_seq_len,
123-
params.window_size_left,
124-
params.window_size_right,
125-
params.scale,
126-
params.return_lse,
127-
params.kv_cache_quant_bit_size);
108+
mlu::batch_decode(params.query,
109+
params.k_cache,
110+
params.output,
111+
params.block_table,
112+
params.seq_lens,
113+
params.v_cache,
114+
lse,
115+
params.q_quant_scale,
116+
params.k_cache_quant_scale,
117+
params.v_cache_quant_scale,
118+
params.out_quant_scale,
119+
params.alibi_slope,
120+
params.mask,
121+
params.compute_dtype,
122+
params.max_seq_len,
123+
params.window_size_left,
124+
params.window_size_right,
125+
params.scale,
126+
params.return_lse,
127+
params.kv_cache_quant_bit_size);
128128
params.output_lse = lse;
129129
#elif defined(USE_CUDA)
130-
throw std::runtime_error("decode_attention for cuda not implemented");
130+
throw std::runtime_error("batch_decode for cuda not implemented");
131131
#else
132-
throw std::runtime_error("decode_attention not implemented");
132+
throw std::runtime_error("batch_decode not implemented");
133133
#endif
134134
}
135135

xllm/core/kernels/ops_api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ namespace kernel {
2929
void apply_rotary(RotaryParams& params);
3030
void active(ActivationParams& params);
3131
void reshape_paged_cache(ReshapePagedCacheParams& params);
32-
void prefill_attention(AttentionParams& params);
33-
void decode_attention(AttentionParams& params);
32+
void batch_prefill(AttentionParams& params);
33+
void batch_decode(AttentionParams& params);
3434
void fused_layernorm(FusedLayerNormParams& params);
3535
torch::Tensor matmul(MatmulParams& params);
3636
torch::Tensor fused_moe(FusedMoEParams& params);

xllm/core/layers/mlu/attention.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
105105
attention_params.seq_start_loc = attn_metadata.seq_start_loc;
106106
attention_params.max_query_len = attn_metadata.max_query_len;
107107

108-
xllm::kernel::prefill_attention(attention_params);
108+
xllm::kernel::batch_prefill(attention_params);
109109
} else if (attn_metadata.is_chunked_prefill) {
110110
attention_params.key = k_cache;
111111
attention_params.value = v_cache;
@@ -114,7 +114,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
114114
attention_params.max_query_len = attn_metadata.max_query_len;
115115
attention_params.block_table = attn_metadata.block_table;
116116

117-
xllm::kernel::prefill_attention(attention_params);
117+
xllm::kernel::batch_prefill(attention_params);
118118
} else {
119119
query = query.view({-1, 1, num_heads_, head_size_});
120120
output = output.view({-1, 1, num_heads_, head_size_});
@@ -134,7 +134,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
134134
attention_params.paged_kv_last_page_len =
135135
attn_metadata.paged_kv_last_page_len;
136136

137-
xllm::kernel::decode_attention(attention_params);
137+
xllm::kernel::batch_decode(attention_params);
138138
}
139139

140140
output = output.view({-1, num_heads_ * head_size_});

0 commit comments

Comments
 (0)