@@ -4802,14 +4802,17 @@ static void ggml_compute_forward_soft_max_f32(
4802
4802
memcpy (&scale, (float *) dst->op_params + 0 , sizeof (float ));
4803
4803
memcpy (&max_bias, (float *) dst->op_params + 1 , sizeof (float ));
4804
4804
4805
- // TODO: handle transposed/permuted matrices
4806
-
4807
4805
const int ith = params->ith ;
4808
4806
const int nth = params->nth ;
4809
4807
4810
4808
GGML_TENSOR_UNARY_OP_LOCALS
4811
4809
4812
- // const int64_t ne11 = src1 ? src1->ne[1] : 1;
4810
+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
4811
+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
4812
+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
4813
+
4814
+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
4815
+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
4813
4816
4814
4817
// TODO: is this supposed to be ceil instead of floor?
4815
4818
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -4819,68 +4822,66 @@ static void ggml_compute_forward_soft_max_f32(
4819
4822
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
4820
4823
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
4821
4824
4822
- const int nc = src0->ne [0 ];
4823
- const int nr = ggml_nrows (src0);
4824
-
4825
- // rows per thread
4826
- const int dr = (nr + nth - 1 )/nth;
4827
-
4828
- // row range for this thread
4829
- const int ir0 = dr*ith;
4830
- const int ir1 = MIN (ir0 + dr, nr);
4831
-
4832
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
4825
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
4833
4826
4834
4827
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
4835
4828
4836
- for (int i1 = ir0; i1 < ir1; i1++) {
4837
- // ALiBi
4838
- const uint32_t h = (i1/ne01)%ne02; // head
4839
- const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
4840
-
4841
- float * sp = (float *)((char *) src0->data + i1*src0->nb [1 ]);
4842
- float * dp = (float *)((char *) dst->data + i1*dst->nb [1 ]);
4843
-
4844
- // broadcast the mask across rows
4845
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
4846
- float * mp_f32 = src1 ? (float *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
4847
-
4848
- ggml_vec_cpy_f32 (nc, wp, sp);
4849
- ggml_vec_scale_f32 (nc, wp, scale);
4850
- if (mp_f32) {
4851
- if (use_f16) {
4852
- for (int i = 0 ; i < nc; ++i) {
4853
- wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
4854
- }
4855
- } else {
4856
- for (int i = 0 ; i < nc; ++i) {
4857
- wp[i] += slope*mp_f32[i];
4829
+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
4830
+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
4831
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4832
+ const int64_t i11 = i01;
4833
+ const int64_t i12 = i02%ne12;
4834
+ const int64_t i13 = i03%ne13;
4835
+
4836
+ // ALiBi
4837
+ const uint32_t h = i02; // head
4838
+ const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
4839
+
4840
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4841
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4842
+
4843
+ // broadcast the mask across rows
4844
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
4845
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
4846
+
4847
+ ggml_vec_cpy_f32 (ne00, wp, sp);
4848
+ ggml_vec_scale_f32 (ne00, wp, scale);
4849
+ if (mp_f32) {
4850
+ if (use_f16) {
4851
+ for (int i = 0 ; i < ne00; ++i) {
4852
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
4853
+ }
4854
+ } else {
4855
+ for (int i = 0 ; i < ne00; ++i) {
4856
+ wp[i] += slope*mp_f32[i];
4857
+ }
4858
+ }
4858
4859
}
4859
- }
4860
- }
4861
4860
4862
4861
#ifndef NDEBUG
4863
- for (int i = 0 ; i < nc ; ++i) {
4864
- // printf("p[%d] = %f\n", i, p[i]);
4865
- assert (!isnan (wp[i]));
4866
- }
4862
+ for (int i = 0 ; i < ne00 ; ++i) {
4863
+ // printf("p[%d] = %f\n", i, p[i]);
4864
+ assert (!isnan (wp[i]));
4865
+ }
4867
4866
#endif
4868
4867
4869
- float max = -INFINITY;
4870
- ggml_vec_max_f32 (nc , &max, wp);
4868
+ float max = -INFINITY;
4869
+ ggml_vec_max_f32 (ne00 , &max, wp);
4871
4870
4872
- ggml_float sum = ggml_vec_soft_max_f32 (nc , dp, wp, max);
4873
- assert (sum > 0.0 );
4871
+ ggml_float sum = ggml_vec_soft_max_f32 (ne00 , dp, wp, max);
4872
+ assert (sum > 0.0 );
4874
4873
4875
- sum = 1.0 /sum;
4876
- ggml_vec_scale_f32 (nc , dp, sum);
4874
+ sum = 1.0 /sum;
4875
+ ggml_vec_scale_f32 (ne00 , dp, sum);
4877
4876
4878
4877
#ifndef NDEBUG
4879
- for (int i = 0 ; i < nc ; ++i) {
4880
- assert (!isnan (dp[i]));
4881
- assert (!isinf (dp[i]));
4882
- }
4878
+ for (int i = 0 ; i < ne00 ; ++i) {
4879
+ assert (!isnan (dp[i]));
4880
+ assert (!isinf (dp[i]));
4881
+ }
4883
4882
#endif
4883
+ }
4884
+ }
4884
4885
}
4885
4886
}
4886
4887
@@ -7151,7 +7152,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7151
7152
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
7152
7153
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
7153
7154
7154
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
7155
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
7155
7156
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu (k_vec_dot_type)->from_float ;
7156
7157
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
7157
7158
ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
@@ -7183,7 +7184,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7183
7184
memset (VKQ32, 0 , DV*sizeof (float ));
7184
7185
}
7185
7186
7186
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ]) : NULL ;
7187
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ] + (iq3%mask-> ne [ 2 ])*mask-> nb [ 2 ] ) : NULL ;
7187
7188
7188
7189
// k indices
7189
7190
const int ik3 = iq3 / rk3;
0 commit comments