Skip to content

Commit 7b50f7c

Browse files
authored
graph : prepare for 4D mask (#14515)
ggml-ci
1 parent c79184d commit 7b50f7c

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

src/llama-graph.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,8 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10051005
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
10061006
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
10071007

1008-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1009-
//cb(inp->self_kq_mask, "KQ_mask", -1);
1008+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
10101009
ggml_set_input(inp->self_kq_mask);
10111010

10121011
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1143,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
11431142
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
11441143

11451144
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1146-
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1147-
//cb(inp_kq_mask, "KQ_mask", -1);
1145+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
11481146
ggml_set_input(inp->kq_mask);
11491147

11501148
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1209,7 +1207,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12091207
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
12101208
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
12111209

1212-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1210+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
12131211
ggml_set_input(inp->self_kq_mask);
12141212

12151213
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1343,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
13431341

13441342
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
13451343

1346-
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1344+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
13471345
ggml_set_input(inp->cross_kq_mask);
13481346

13491347
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1457,7 +1455,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14571455
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
14581456
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
14591457

1460-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1458+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
14611459
ggml_set_input(inp->self_kq_mask);
14621460

14631461
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1471,7 +1469,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14711469
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
14721470
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
14731471

1474-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1472+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
14751473
ggml_set_input(inp->self_kq_mask_swa);
14761474

14771475
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
228228

229229
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
230230

231-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
232-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
231+
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
232+
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
233233

234234
const llama_hparams & hparams;
235235
const llama_cparams & cparams;
@@ -257,8 +257,8 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
257257
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258258
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
259259

260-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
261-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
260+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
261+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
262262

263263
const llama_hparams & hparams;
264264
const llama_cparams & cparams;
@@ -293,10 +293,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
293293
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294294
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
295295

296-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
297-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
298-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
299-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
296+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
297+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
298+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
299+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
300300

301301
const llama_hparams & hparams;
302302
const llama_cparams & cparams;
@@ -313,8 +313,8 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
313313

314314
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
315315

316-
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
317-
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
316+
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
317+
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
318318

319319
const llama_cross * cross = nullptr;
320320
};
@@ -343,8 +343,8 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
343343
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
344344
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
345345

346-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
347-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
346+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
347+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
348348

349349
const llama_hparams & hparams;
350350
const llama_cparams & cparams;

0 commit comments

Comments
 (0)