@@ -1005,8 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1005
1005
inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1006
1006
inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1007
1007
1008
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1009
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1008
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1010
1009
ggml_set_input (inp->self_kq_mask );
1011
1010
1012
1011
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1143,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1143
1142
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1144
1143
1145
1144
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1146
- inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1147
- // cb(inp_kq_mask, "KQ_mask", -1);
1145
+ inp->kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1148
1146
ggml_set_input (inp->kq_mask );
1149
1147
1150
1148
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
@@ -1209,7 +1207,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1209
1207
inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1210
1208
inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1211
1209
1212
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1210
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1213
1211
ggml_set_input (inp->self_kq_mask );
1214
1212
1215
1213
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1343,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1343
1341
1344
1342
const int32_t n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
1345
1343
1346
- inp->cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1344
+ inp->cross_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1347
1345
ggml_set_input (inp->cross_kq_mask );
1348
1346
1349
1347
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->cross_kq_mask , GGML_TYPE_F16) : inp->cross_kq_mask ;
@@ -1457,7 +1455,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1457
1455
inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1458
1456
inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1459
1457
1460
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1458
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1461
1459
ggml_set_input (inp->self_kq_mask );
1462
1460
1463
1461
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1471,7 +1469,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1471
1469
inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1472
1470
inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1473
1471
1474
- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1472
+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1475
1473
ggml_set_input (inp->self_kq_mask_swa );
1476
1474
1477
1475
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments