Skip to content

Commit 1b74b9d

Browse files
committed
ggml : extend support for n_seq for soft_max and fattn
ggml-ci
1 parent 8c68219 commit 1b74b9d

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4853,7 +4853,8 @@ static void ggml_compute_forward_soft_max_f32(
48534853

48544854
GGML_TENSOR_UNARY_OP_LOCALS
48554855

4856-
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
4856+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
4857+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
48574858

48584859
// TODO: is this supposed to be ceil instead of floor?
48594860
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -4878,6 +4879,10 @@ static void ggml_compute_forward_soft_max_f32(
48784879
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
48794880

48804881
for (int i1 = ir0; i1 < ir1; i1++) {
4882+
const int64_t i11 = (i1%ne01);
4883+
//const int64_t i12 = (i1/ne01)%ne02;
4884+
const int64_t i13 = (i1/ne01)/ne02;
4885+
48814886
// ALiBi
48824887
const uint32_t h = (i1/ne01)%ne02; // head
48834888
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
@@ -4886,8 +4891,8 @@ static void ggml_compute_forward_soft_max_f32(
48864891
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
48874892

48884893
// broadcast the mask across rows
4889-
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4890-
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4894+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;
4895+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL;
48914896

48924897
ggml_vec_cpy_f32 (nc, wp, sp);
48934898
ggml_vec_scale_f32(nc, wp, scale);
@@ -7227,7 +7232,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
72277232
memset(VKQ32, 0, DV*sizeof(float));
72287233
}
72297234

7230-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
7235+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + iq3*mask->nb[2]) : NULL;
72317236

72327237
// k indices
72337238
const int ik3 = iq3 / rk3;

0 commit comments

Comments
 (0)