Skip to content

Commit e6faa45

Browse files
committed
ggml : support broadcast for ggml_soft_max_ext and ggml_flash_attn_ext
ggml-ci
1 parent 4367806 commit e6faa45

File tree

7 files changed

+204
-148
lines changed

7 files changed

+204
-148
lines changed

ggml/include/ggml.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,8 +1433,14 @@ extern "C" {
14331433
struct ggml_context * ctx,
14341434
struct ggml_tensor * a);
14351435

1436+
// a [ne0, ne01, ne02, ne03]
1437+
// mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
1438+
//
1439+
// broadcast:
1440+
// ne02 % ne12 == 0
1441+
// ne03 % ne13 == 0
1442+
//
14361443
// fused soft_max(a*scale + mask*(ALiBi slope))
1437-
// mask is optional
14381444
// max_bias = 0.0f for no ALiBi
14391445
GGML_API struct ggml_tensor * ggml_soft_max_ext(
14401446
struct ggml_context * ctx,
@@ -1868,11 +1874,11 @@ extern "C" {
18681874

18691875
#define GGML_KQ_MASK_PAD 64
18701876

1871-
// q: [n_embd_k, n_batch, n_head, 1]
1872-
// k: [n_embd_k, n_kv, n_head_kv, 1]
1873-
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1874-
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1875-
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
1877+
// q: [n_embd_k, n_batch, n_head, ne3]
1878+
// k: [n_embd_k, n_kv, n_head_kv, ne3]
1879+
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
1880+
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1881+
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
18761882
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
18771883
struct ggml_context * ctx,
18781884
struct ggml_tensor * q,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4802,14 +4802,17 @@ static void ggml_compute_forward_soft_max_f32(
48024802
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
48034803
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
48044804

4805-
// TODO: handle transposed/permuted matrices
4806-
48074805
const int ith = params->ith;
48084806
const int nth = params->nth;
48094807

48104808
GGML_TENSOR_UNARY_OP_LOCALS
48114809

4812-
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
4810+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
4811+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
4812+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
4813+
4814+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
4815+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
48134816

48144817
// TODO: is this supposed to be ceil instead of floor?
48154818
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -4819,68 +4822,66 @@ static void ggml_compute_forward_soft_max_f32(
48194822
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
48204823
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
48214824

4822-
const int nc = src0->ne[0];
4823-
const int nr = ggml_nrows(src0);
4824-
4825-
// rows per thread
4826-
const int dr = (nr + nth - 1)/nth;
4827-
4828-
// row range for this thread
4829-
const int ir0 = dr*ith;
4830-
const int ir1 = MIN(ir0 + dr, nr);
4831-
4832-
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
4825+
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
48334826

48344827
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
48354828

4836-
for (int i1 = ir0; i1 < ir1; i1++) {
4837-
// ALiBi
4838-
const uint32_t h = (i1/ne01)%ne02; // head
4839-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
4840-
4841-
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
4842-
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
4843-
4844-
// broadcast the mask across rows
4845-
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4846-
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
4847-
4848-
ggml_vec_cpy_f32 (nc, wp, sp);
4849-
ggml_vec_scale_f32(nc, wp, scale);
4850-
if (mp_f32) {
4851-
if (use_f16) {
4852-
for (int i = 0; i < nc; ++i) {
4853-
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
4854-
}
4855-
} else {
4856-
for (int i = 0; i < nc; ++i) {
4857-
wp[i] += slope*mp_f32[i];
4829+
for (int64_t i03 = 0; i03 < ne03; i03++) {
4830+
for (int64_t i02 = 0; i02 < ne02; i02++) {
4831+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4832+
const int64_t i11 = i01;
4833+
const int64_t i12 = i02%ne12;
4834+
const int64_t i13 = i03%ne13;
4835+
4836+
// ALiBi
4837+
const uint32_t h = i02; // head
4838+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
4839+
4840+
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4841+
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4842+
4843+
// broadcast the mask across rows
4844+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
4845+
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
4846+
4847+
ggml_vec_cpy_f32 (ne00, wp, sp);
4848+
ggml_vec_scale_f32(ne00, wp, scale);
4849+
if (mp_f32) {
4850+
if (use_f16) {
4851+
for (int i = 0; i < ne00; ++i) {
4852+
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
4853+
}
4854+
} else {
4855+
for (int i = 0; i < ne00; ++i) {
4856+
wp[i] += slope*mp_f32[i];
4857+
}
4858+
}
48584859
}
4859-
}
4860-
}
48614860

48624861
#ifndef NDEBUG
4863-
for (int i = 0; i < nc; ++i) {
4864-
//printf("p[%d] = %f\n", i, p[i]);
4865-
assert(!isnan(wp[i]));
4866-
}
4862+
for (int i = 0; i < ne00; ++i) {
4863+
//printf("p[%d] = %f\n", i, p[i]);
4864+
assert(!isnan(wp[i]));
4865+
}
48674866
#endif
48684867

4869-
float max = -INFINITY;
4870-
ggml_vec_max_f32(nc, &max, wp);
4868+
float max = -INFINITY;
4869+
ggml_vec_max_f32(ne00, &max, wp);
48714870

4872-
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
4873-
assert(sum > 0.0);
4871+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
4872+
assert(sum > 0.0);
48744873

4875-
sum = 1.0/sum;
4876-
ggml_vec_scale_f32(nc, dp, sum);
4874+
sum = 1.0/sum;
4875+
ggml_vec_scale_f32(ne00, dp, sum);
48774876

48784877
#ifndef NDEBUG
4879-
for (int i = 0; i < nc; ++i) {
4880-
assert(!isnan(dp[i]));
4881-
assert(!isinf(dp[i]));
4882-
}
4878+
for (int i = 0; i < ne00; ++i) {
4879+
assert(!isnan(dp[i]));
4880+
assert(!isinf(dp[i]));
4881+
}
48834882
#endif
4883+
}
4884+
}
48844885
}
48854886
}
48864887

@@ -7151,7 +7152,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
71517152
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
71527153
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
71537154

7154-
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
7155+
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
71557156
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
71567157
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
71577158
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
@@ -7183,7 +7184,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
71837184
memset(VKQ32, 0, DV*sizeof(float));
71847185
}
71857186

7186-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
7187+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
71877188

71887189
// k indices
71897190
const int ik3 = iq3 / rk3;

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ typedef struct {
229229
uint64_t nb21;
230230
uint64_t nb22;
231231
uint64_t nb23;
232+
int32_t ne32;
232233
uint64_t nb31;
234+
uint64_t nb32;
233235
int32_t ne1;
234236
int32_t ne2;
235237
float scale;
@@ -450,9 +452,21 @@ typedef struct {
450452
} ggml_metal_kargs_sum_rows;
451453

452454
typedef struct {
453-
int64_t ne00;
454-
int64_t ne01;
455-
int64_t ne02;
455+
int32_t ne00;
456+
int32_t ne01;
457+
int32_t ne02;
458+
uint64_t nb01;
459+
uint64_t nb02;
460+
uint64_t nb03;
461+
int32_t ne11;
462+
int32_t ne12;
463+
int32_t ne13;
464+
uint64_t nb11;
465+
uint64_t nb12;
466+
uint64_t nb13;
467+
uint64_t nb1;
468+
uint64_t nb2;
469+
uint64_t nb3;
456470
float scale;
457471
float max_bias;
458472
float m0;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2573,10 +2573,7 @@ static bool ggml_metal_encode_node(
25732573
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
25742574
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
25752575

2576-
const int64_t nrows_x = ggml_nrows(src0);
2577-
const int64_t nrows_y = src0->ne[1];
2578-
2579-
const uint32_t n_head = nrows_x/nrows_y;
2576+
const uint32_t n_head = src0->ne[2];
25802577
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
25812578

25822579
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2636,6 +2633,18 @@ static bool ggml_metal_encode_node(
26362633
/*.ne00 =*/ ne00,
26372634
/*.ne01 =*/ ne01,
26382635
/*.ne02 =*/ ne02,
2636+
/*.nb01 =*/ nb01,
2637+
/*.nb02 =*/ nb02,
2638+
/*.nb03 =*/ nb03,
2639+
/*.ne11 =*/ ne11,
2640+
/*.ne12 =*/ ne12,
2641+
/*.ne13 =*/ ne13,
2642+
/*.nb11 =*/ nb11,
2643+
/*.nb12 =*/ nb12,
2644+
/*.nb13 =*/ nb13,
2645+
/*.nb1 =*/ nb1,
2646+
/*.nb2 =*/ nb2,
2647+
/*.nb3 =*/ nb3,
26392648
/*.scale =*/ scale,
26402649
/*.max_bias =*/ max_bias,
26412650
/*.m0 =*/ m0,
@@ -2655,7 +2664,7 @@ static bool ggml_metal_encode_node(
26552664

26562665
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
26572666

2658-
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2667+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
26592668
} break;
26602669
case GGML_OP_DIAG_MASK_INF:
26612670
{
@@ -4908,7 +4917,9 @@ static bool ggml_metal_encode_node(
49084917
/*.nb21 =*/ nb21,
49094918
/*.nb22 =*/ nb22,
49104919
/*.nb23 =*/ nb23,
4920+
/*.ne32 =*/ ne32,
49114921
/*.nb31 =*/ nb31,
4922+
/*.nb32 =*/ nb32,
49124923
/*.ne1 =*/ ne1,
49134924
/*.ne2 =*/ ne2,
49144925
/*.scale =*/ scale,

0 commit comments

Comments
 (0)