Skip to content

Commit 272ffdb

Browse files
committed
Fix z shape
1 parent dbce4e2 commit 272ffdb

File tree

2 files changed

+12
-22
lines changed

2 files changed

+12
-22
lines changed

examples/eval-callback/eval-callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
122122

123123
if (!ggml_is_quantized(t->type)) {
124124
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
125-
ggml_print_tensor(data, t->type, t->ne, t->nb, 256);
125+
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
126126
}
127127

128128
return true;

src/llama-model.cpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10466,20 +10466,14 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1046610466

1046710467
if (il == n_layer - 1 && inp_out_ids) {
1046810468
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10469-
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10469+
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
1047010470
}
1047110471

1047210472
// residual connection
1047310473
cur = ggml_add(ctx0, cur, residual);
1047410474
cb(cur, "ffn_residual", il);
1047510475

1047610476
inpL = cur;
10477-
10478-
if (il == 1) {
10479-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10480-
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10481-
break;
10482-
}
1048310477
}
1048410478

1048510479
cur = inpL;
@@ -10627,7 +10621,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1062710621
// in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
1062810622
ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
1062910623
cb(zx, "mamba_in_proj", il);
10630-
10624+
// {8192, 5, 1, 1} -> {8192, 1, 5, 1}
1063110625
zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
1063210626
zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
1063310627
cb(zx, "mamba_in_proj_out", il);
@@ -10636,14 +10630,11 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1063610630
// => {head_dim * n_heads, n_seq_tokens, n_seqs}
1063710631
ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx));
1063810632
x = ggml_cont(ctx0, x);
10639-
x = ggml_reshape_4d(ctx0, x, head_dim * n_heads, 1, n_seq_tokens, n_seqs);
10640-
x = ggml_permute(ctx0, x, 0, 2, 1, 3);
10633+
x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
10634+
// x = ggml_permute(ctx0, x, 0, 2, 1, 3);
1064110635
cb(x, "mamba_x_split", il);
1064210636

1064310637
ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0);
10644-
z = ggml_cont(ctx0, z);
10645-
z = ggml_reshape_4d(ctx0, z, head_dim * n_heads, 1, n_seq_tokens, n_seqs);
10646-
z = ggml_permute(ctx0, z, 0, 2, 1, 3);
1064710638
cb(z, "mamba_z_split", il);
1064810639

1064910640
// conv1d
@@ -10699,11 +10690,10 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1069910690
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
1070010691
cb(dt, "mamba_dt_proj", il);
1070110692

10702-
ggml_tensor * A = ggml_new_tensor_2d(ctx0, model.layers[il].ssm_a->type, d_state, n_heads);
10703-
A = ggml_repeat(ctx0, model.layers[il].ssm_a, A);
10693+
ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads);
1070410694
cb(A, "mamba_A", il);
1070510695

10706-
x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * x->nb[0], x->nb[1], x->nb[2], 0);
10696+
x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
1070710697
B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0);
1070810698
C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0);
1070910699

@@ -10725,22 +10715,22 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1072510715
// store last states
1072610716
ggml_build_forward_expand(gf,
1072710717
ggml_cpy(ctx0,
10728-
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
10718+
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
1072910719
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
1073010720
kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
1073110721

10732-
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * x->nb[0], head_dim * n_heads * x->nb[1], head_dim * n_heads * n_seq_tokens * x->nb[2], 0);
10722+
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
1073310723
cb(y, "mamba_y_view", il);
1073410724

1073510725
// Add D parameter and apply gating with z
1073610726
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
10737-
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
10738-
cb(y, "mamba_y_with_D", il);
10727+
ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads);
10728+
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
10729+
cb(y, "mamba_y_add_d", il);
1073910730

1074010731
ggml_tensor * z_silu = ggml_silu(ctx0, ggml_cont(ctx0, z));
1074110732
cb(z_silu, "mamba_z_silu", il);
1074210733

10743-
y = ggml_reshape_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs);
1074410734
y = ggml_mul(ctx0, y, z_silu);
1074510735
cb(y, "mamba_y_gated", il);
1074610736

0 commit comments

Comments
 (0)