Skip to content

Commit 98e6328

Browse files
committed
mamba : apply suggestions from code review
* mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32
1 parent 64fbce0 commit 98e6328

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

ggml.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5901,8 +5901,8 @@ struct ggml_tensor * ggml_ssm_scan(
59015901
GGML_ASSERT(ggml_is_contiguous(dt));
59025902
GGML_ASSERT(ggml_is_contiguous(A));
59035903
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
5904-
ggml_are_same_shape(x, dt);
5905-
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
5904+
GGML_ASSERT(ggml_are_same_shape(x, dt));
5905+
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
59065906

59075907
{
59085908
const int64_t d_state = s->ne[0];
@@ -5919,6 +5919,7 @@ struct ggml_tensor * ggml_ssm_scan(
59195919
bool is_node = false;
59205920

59215921
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
5922+
GGML_ASSERT(false); // TODO: implement
59225923
is_node = true;
59235924
}
59245925

@@ -14177,7 +14178,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1417714178

1417814179
// first batch
1417914180
{
14180-
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14181+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
1418114182
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
1418214183
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
1418314184
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
@@ -14191,14 +14192,14 @@ static void ggml_compute_forward_ssm_scan_f32(
1419114192
for (int i0 = 0; i0 < nc; ++i0) {
1419214193
int i = i0 + i1*nc;
1419314194
// ssm_state * dA + dB * x
14194-
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14195+
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1419514196
}
1419614197
}
1419714198
}
1419814199

1419914200
// compute state for rest of tokens, previous state comes from dest
1420014201
for (int i2 = 1; i2 < n_t; ++i2) {
14201-
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14202+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
1420214203
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
1420314204
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
1420414205
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
@@ -14212,7 +14213,7 @@ static void ggml_compute_forward_ssm_scan_f32(
1421214213
for (int i0 = 0; i0 < nc; ++i0) {
1421314214
int i = i0 + i1*nc;
1421414215
// ssm_state * dA + dB * x
14215-
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14216+
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1421614217
}
1421714218
}
1421814219
}

llama.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7012,17 +7012,10 @@ struct llm_build_context {
70127012
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
70137013
ggml_view_tensor(ctx0, kv_self.v_l[il])));
70147014

7015-
struct ggml_tensor * y;
7016-
if (n_tok == 1) {
7017-
// row-wise dot product ("dn,n->d")
7018-
// {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
7019-
y = ggml_mul_mat(ctx0, ssm_state, C);
7020-
} else {
7021-
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
7022-
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
7023-
// => {d_inner, n_tok}
7024-
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
7025-
}
7015+
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
7016+
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
7017+
// => {d_inner, n_tok}
7018+
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
70267019
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
70277020
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
70287021
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));

0 commit comments

Comments
 (0)