@@ -10466,20 +10466,14 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
10466
10466
10467
10467
if (il == n_layer - 1 && inp_out_ids) {
10468
10468
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10469
- inpL = ggml_get_rows(ctx0, inpL , inp_out_ids);
10469
+ residual = ggml_get_rows(ctx0, residual , inp_out_ids);
10470
10470
}
10471
10471
10472
10472
// residual connection
10473
10473
cur = ggml_add(ctx0, cur, residual);
10474
10474
cb(cur, "ffn_residual", il);
10475
10475
10476
10476
inpL = cur;
10477
-
10478
- if (il == 1) {
10479
- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10480
- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10481
- break;
10482
- }
10483
10477
}
10484
10478
10485
10479
cur = inpL;
@@ -10627,7 +10621,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
10627
10621
// in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
10628
10622
ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
10629
10623
cb(zx, "mamba_in_proj", il);
10630
-
10624
+ // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
10631
10625
zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
10632
10626
zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
10633
10627
cb(zx, "mamba_in_proj_out", il);
@@ -10636,14 +10630,11 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
10636
10630
// => {head_dim * n_heads, n_seq_tokens, n_seqs}
10637
10631
ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx));
10638
10632
x = ggml_cont(ctx0, x);
10639
- x = ggml_reshape_4d (ctx0, x, head_dim * n_heads, 1 , n_seq_tokens, n_seqs);
10640
- x = ggml_permute(ctx0, x, 0, 2, 1, 3);
10633
+ x = ggml_reshape_3d (ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
10634
+ // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
10641
10635
cb(x, "mamba_x_split", il);
10642
10636
10643
10637
ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0);
10644
- z = ggml_cont(ctx0, z);
10645
- z = ggml_reshape_4d(ctx0, z, head_dim * n_heads, 1, n_seq_tokens, n_seqs);
10646
- z = ggml_permute(ctx0, z, 0, 2, 1, 3);
10647
10638
cb(z, "mamba_z_split", il);
10648
10639
10649
10640
// conv1d
@@ -10699,11 +10690,10 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
10699
10690
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
10700
10691
cb(dt, "mamba_dt_proj", il);
10701
10692
10702
- ggml_tensor * A = ggml_new_tensor_2d(ctx0, model.layers[il].ssm_a->type, d_state, n_heads);
10703
- A = ggml_repeat(ctx0, model.layers[il].ssm_a, A);
10693
+ ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads);
10704
10694
cb(A, "mamba_A", il);
10705
10695
10706
- x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * x->nb[0], x->nb[1], x->nb[2] , 0);
10696
+ x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x) , 0);
10707
10697
B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0);
10708
10698
C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0);
10709
10699
@@ -10725,22 +10715,22 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
10725
10715
// store last states
10726
10716
ggml_build_forward_expand(gf,
10727
10717
ggml_cpy(ctx0,
10728
- ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
10718
+ ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3] ),
10729
10719
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
10730
10720
kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
10731
10721
10732
- ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * x->nb[0] , head_dim * n_heads * x->nb[1] , head_dim * n_heads * n_seq_tokens * x->nb[2] , 0);
10722
+ ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x) , head_dim * n_heads * ggml_element_size(x) , head_dim * n_heads * n_seq_tokens * ggml_element_size(x) , 0);
10733
10723
cb(y, "mamba_y_view", il);
10734
10724
10735
10725
// Add D parameter and apply gating with z
10736
10726
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
10737
- y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
10738
- cb(y, "mamba_y_with_D", il);
10727
+ ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads);
10728
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
10729
+ cb(y, "mamba_y_add_d", il);
10739
10730
10740
10731
ggml_tensor * z_silu = ggml_silu(ctx0, ggml_cont(ctx0, z));
10741
10732
cb(z_silu, "mamba_z_silu", il);
10742
10733
10743
- y = ggml_reshape_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs);
10744
10734
y = ggml_mul(ctx0, y, z_silu);
10745
10735
cb(y, "mamba_y_gated", il);
10746
10736
0 commit comments