@@ -5901,8 +5901,8 @@ struct ggml_tensor * ggml_ssm_scan(
5901
5901
GGML_ASSERT(ggml_is_contiguous(dt));
5902
5902
GGML_ASSERT(ggml_is_contiguous(A));
5903
5903
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
5904
- ggml_are_same_shape(x, dt);
5905
- GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1 ); // the ssm_state should be 2D
5904
+ GGML_ASSERT( ggml_are_same_shape(x, dt) );
5905
+ GGML_ASSERT(ggml_is_matrix(s) ); // the ssm_state should be 2D
5906
5906
5907
5907
{
5908
5908
const int64_t d_state = s->ne[0];
@@ -5919,6 +5919,7 @@ struct ggml_tensor * ggml_ssm_scan(
5919
5919
bool is_node = false;
5920
5920
5921
5921
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
5922
+ GGML_ASSERT(false); // TODO: implement
5922
5923
is_node = true;
5923
5924
}
5924
5925
@@ -14177,7 +14178,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14177
14178
14178
14179
// first batch
14179
14180
{
14180
- float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14181
+ float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14181
14182
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14182
14183
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
14183
14184
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
@@ -14191,14 +14192,14 @@ static void ggml_compute_forward_ssm_scan_f32(
14191
14192
for (int i0 = 0; i0 < nc; ++i0) {
14192
14193
int i = i0 + i1*nc;
14193
14194
// ssm_state * dA + dB * x
14194
- dest [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14195
+ pdst [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14195
14196
}
14196
14197
}
14197
14198
}
14198
14199
14199
14200
// compute state for rest of tokens, previous state comes from dest
14200
14201
for (int i2 = 1; i2 < n_t; ++i2) {
14201
- float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14202
+ float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14202
14203
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
14203
14204
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
14204
14205
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
@@ -14212,7 +14213,7 @@ static void ggml_compute_forward_ssm_scan_f32(
14212
14213
for (int i0 = 0; i0 < nc; ++i0) {
14213
14214
int i = i0 + i1*nc;
14214
14215
// ssm_state * dA + dB * x
14215
- dest [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14216
+ pdst [i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
14216
14217
}
14217
14218
}
14218
14219
}
0 commit comments