@@ -4853,7 +4853,8 @@ static void ggml_compute_forward_soft_max_f32(
4853
4853
4854
4854
GGML_TENSOR_UNARY_OP_LOCALS
4855
4855
4856
- // const int64_t ne11 = src1 ? src1->ne[1] : 1;
4856
+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
4857
+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
4857
4858
4858
4859
// TODO: is this supposed to be ceil instead of floor?
4859
4860
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -4878,6 +4879,10 @@ static void ggml_compute_forward_soft_max_f32(
4878
4879
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
4879
4880
4880
4881
for (int i1 = ir0; i1 < ir1; i1++) {
4882
+ const int64_t i11 = (i1%ne01);
4883
+ // const int64_t i12 = (i1/ne01)%ne02;
4884
+ const int64_t i13 = (i1/ne01)/ne02;
4885
+
4881
4886
// ALiBi
4882
4887
const uint32_t h = (i1/ne01)%ne02; // head
4883
4888
const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
@@ -4886,8 +4891,8 @@ static void ggml_compute_forward_soft_max_f32(
4886
4891
float * dp = (float *)((char *) dst->data + i1*dst->nb [1 ]);
4887
4892
4888
4893
// broadcast the mask across rows
4889
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
4890
- float * mp_f32 = src1 ? (float *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
4894
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL ;
4895
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i13*nb12) : NULL ;
4891
4896
4892
4897
ggml_vec_cpy_f32 (nc, wp, sp);
4893
4898
ggml_vec_scale_f32 (nc, wp, scale);
@@ -7227,7 +7232,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7227
7232
memset (VKQ32, 0 , DV*sizeof (float ));
7228
7233
}
7229
7234
7230
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ]) : NULL ;
7235
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ] + iq3*mask-> nb [ 2 ] ) : NULL ;
7231
7236
7232
7237
// k indices
7233
7238
const int ik3 = iq3 / rk3;
0 commit comments