Skip to content

Commit cf38145

Browse files
committed
vulkan: add q8_1_x4 type with 128-bit alignment, use in mul_mat_vecq shader
1 parent 39ac703 commit cf38145

File tree

5 files changed

+108
-39
lines changed

5 files changed

+108
-39
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ struct vk_device_struct {
434434

435435
vk_pipeline pipeline_matmul_split_k_reduce;
436436
vk_pipeline pipeline_quantize_q8_1;
437+
vk_pipeline pipeline_quantize_q8_1_x4;
437438

438439
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
439440
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -2934,8 +2935,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
29342935

29352936
if (device->subgroup_clustered && device->subgroup_require_full_support) {
29362937
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
2938+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
29372939
} else {
29382940
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2941+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
29392942
}
29402943

29412944
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -5440,20 +5443,20 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
54405443
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
54415444
}
54425445

5443-
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
5446+
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) {
54445447
switch(type) {
54455448
case GGML_TYPE_Q8_1:
5446-
return ctx->device->pipeline_quantize_q8_1;
5449+
return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1;
54475450
default:
54485451
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
54495452
GGML_ABORT("fatal error");
54505453
}
54515454
}
54525455

5453-
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
5456+
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) {
54545457
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
54555458

5456-
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5459+
vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
54575460

54585461
ggml_vk_sync_buffers(subctx);
54595462
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
@@ -5573,7 +5576,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
55735576
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
55745577

55755578
if (quantize_y) {
5576-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5579+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
55775580
}
55785581

55795582
if (dryrun) {
@@ -5741,16 +5744,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57415744
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
57425745

57435746
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
5744-
5745-
const uint64_t x_ne = ne01 * ne00;
5746-
const uint64_t y_ne = ne11 * ne10;
5747-
const uint64_t d_ne = ne11 * ne01;
5748-
5749-
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5750-
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5751-
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5752-
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
5753-
const uint64_t d_sz = sizeof(float) * d_ne;
5747+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
57545748

57555749
vk_pipeline to_fp16_vk_0 = nullptr;
57565750
vk_pipeline to_fp16_vk_1 = nullptr;
@@ -5763,8 +5757,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57635757
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
57645758
}
57655759

5766-
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
5767-
57685760
// Check for mmq first
57695761
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11) : nullptr;
57705762
vk_pipeline to_q8_1 = nullptr;
@@ -5776,7 +5768,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57765768
}
57775769

57785770
if (quantize_y) {
5779-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5771+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
57805772
}
57815773

57825774
const bool qx_needs_dequant = x_non_contig;
@@ -5789,6 +5781,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
57895781
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
57905782
GGML_ASSERT(dmmv != nullptr);
57915783

5784+
const uint64_t x_ne = ne01 * ne00;
5785+
const uint64_t y_ne = ne11 * ne10;
5786+
const uint64_t d_ne = ne11 * ne01;
5787+
5788+
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
5789+
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
5790+
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
5791+
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
5792+
const uint64_t d_sz = sizeof(float) * d_ne;
5793+
57925794
if (dryrun) {
57935795
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
57945796
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -5801,7 +5803,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58015803
ctx->prealloc_size_x = x_sz_upd;
58025804
}
58035805
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5804-
ctx->prealloc_size_y = y_sz_upd;
5806+
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
58055807
}
58065808

58075809
// Request descriptor sets
@@ -5846,7 +5848,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58465848
d_Y = ctx->prealloc_y;
58475849
} else if (quantize_y) {
58485850
d_Y = ctx->prealloc_y;
5849-
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
5851+
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 128) * 128);
58505852
} else {
58515853
d_Y = d_Qy;
58525854
y_buf_offset = qy_buf_offset;
@@ -5862,7 +5864,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
58625864
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
58635865
}
58645866
if (quantize_y) {
5865-
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5867+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
58665868
}
58675869

58685870
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#endif
1414

1515
#define MMQ
16-
#define B_TYPE block_q8_1_packed32
16+
#define B_TYPE block_q8_1_x4_packed128
1717

1818
#include "mul_mat_vec_base.comp"
1919

@@ -80,7 +80,7 @@ void reduce_result_grouped(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const i
8080
}
8181
#endif
8282

83-
int32_t cache_b_qs[8];
83+
ivec4 cache_b_qs[2];
8484
vec2 cache_b_ds;
8585

8686
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid_in_group, const uint i) {
@@ -89,10 +89,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
8989

9090
// Preload data_b block
9191
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
92-
cache_b_ds = vec2(data_b[b_block_idx].ds);
93-
[[unroll]] for (uint k = 0; k < 8; k++) {
94-
cache_b_qs[k] = data_b[b_block_idx].qs[k];
95-
}
92+
const uint b_block_idx_outer = b_block_idx / 4;
93+
const uint b_block_idx_inner = b_block_idx % 4;
94+
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
95+
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2];
96+
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2 + 1];
9697

9798
uint ibi = first_row*p.ncols;
9899
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -101,19 +102,51 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
101102

102103
int32_t q_sum = 0;
103104
#if QUANT_R == 2
104-
[[unroll]] for (uint k = 0; k < 4; k++) {
105-
const i32vec2 data_a_qs = repack(a_block_idx, k);
106-
q_sum += dotPacked4x8EXT(data_a_qs.x,
107-
cache_b_qs[k]);
108-
q_sum += dotPacked4x8EXT(data_a_qs.y,
109-
cache_b_qs[k + 4]);
110-
}
105+
i32vec2 data_a_qs = repack(a_block_idx, 0);
106+
q_sum += dotPacked4x8EXT(data_a_qs.x,
107+
cache_b_qs[0].x);
108+
q_sum += dotPacked4x8EXT(data_a_qs.y,
109+
cache_b_qs[1].x);
110+
data_a_qs = repack(a_block_idx, 1);
111+
q_sum += dotPacked4x8EXT(data_a_qs.x,
112+
cache_b_qs[0].y);
113+
q_sum += dotPacked4x8EXT(data_a_qs.y,
114+
cache_b_qs[1].y);
115+
data_a_qs = repack(a_block_idx, 2);
116+
q_sum += dotPacked4x8EXT(data_a_qs.x,
117+
cache_b_qs[0].z);
118+
q_sum += dotPacked4x8EXT(data_a_qs.y,
119+
cache_b_qs[1].z);
120+
data_a_qs = repack(a_block_idx, 3);
121+
q_sum += dotPacked4x8EXT(data_a_qs.x,
122+
cache_b_qs[0].w);
123+
q_sum += dotPacked4x8EXT(data_a_qs.y,
124+
cache_b_qs[1].w);
111125
#else
112-
[[unroll]] for (uint k = 0; k < 8; k++) {
113-
const int32_t data_a_qs = repack(a_block_idx, k);
114-
q_sum += dotPacked4x8EXT(data_a_qs,
115-
cache_b_qs[k]);
116-
}
126+
int32_t data_a_qs = repack(a_block_idx, 0);
127+
q_sum += dotPacked4x8EXT(data_a_qs,
128+
cache_b_qs[0].x);
129+
data_a_qs = repack(a_block_idx, 1);
130+
q_sum += dotPacked4x8EXT(data_a_qs,
131+
cache_b_qs[0].y);
132+
data_a_qs = repack(a_block_idx, 2);
133+
q_sum += dotPacked4x8EXT(data_a_qs,
134+
cache_b_qs[0].z);
135+
data_a_qs = repack(a_block_idx, 3);
136+
q_sum += dotPacked4x8EXT(data_a_qs,
137+
cache_b_qs[0].w);
138+
data_a_qs = repack(a_block_idx, 4);
139+
q_sum += dotPacked4x8EXT(data_a_qs,
140+
cache_b_qs[1].x);
141+
data_a_qs = repack(a_block_idx, 5);
142+
q_sum += dotPacked4x8EXT(data_a_qs,
143+
cache_b_qs[1].y);
144+
data_a_qs = repack(a_block_idx, 6);
145+
q_sum += dotPacked4x8EXT(data_a_qs,
146+
cache_b_qs[1].z);
147+
data_a_qs = repack(a_block_idx, 7);
148+
q_sum += dotPacked4x8EXT(data_a_qs,
149+
cache_b_qs[1].w);
117150
#endif
118151

119152
#if QUANT_AUXF == 1

ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32;
2323
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2424

2525
layout (binding = 0) readonly buffer A {vec4 data_a[];};
26+
#ifndef QBLOCK_X4
2627
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
28+
#else
29+
layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
30+
#endif
2731

2832
#ifndef USE_SUBGROUPS
2933
shared float shmem[GROUP_SIZE];
@@ -45,6 +49,11 @@ void quantize() {
4549
return;
4650
}
4751

52+
#ifdef QBLOCK_X4
53+
const uint ibx4_outer = ib / 4;
54+
const uint ibx4_inner = ib % 4;
55+
#endif
56+
4857
const uint a_idx = ib * 8 + iqs;
4958

5059
vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
@@ -70,7 +79,13 @@ void quantize() {
7079
const float d = amax / 127.0;
7180
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
7281
vals = round(vals * d_inv);
82+
83+
#ifndef QBLOCK_X4
7384
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
85+
#else
86+
data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
87+
#endif
88+
7489
barrier();
7590

7691
// Calculate the sum for each block
@@ -92,7 +107,11 @@ void quantize() {
92107
const float sum = shmem[tid];
93108
#endif
94109

110+
#ifndef QBLOCK_X4
95111
data_b[ib].ds = f16vec2(vec2(d, sum * d));
112+
#else
113+
data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
114+
#endif
96115
}
97116
}
98117

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,18 @@ struct block_q8_1_packed32
207207
int32_t qs[8];
208208
};
209209

210+
// 4 blocks in one to allow 16-byte/128-bit alignment and loads
211+
struct block_q8_1_x4
212+
{
213+
f16vec2 ds[4];
214+
int32_t qs[32];
215+
};
216+
struct block_q8_1_x4_packed128
217+
{
218+
f16vec2 ds[4];
219+
ivec4 qs[8];
220+
};
221+
210222
// K-quants
211223
#define QUANT_K_Q2_K 256
212224

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,9 @@ void process_shaders() {
581581
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
582582
string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
583583

584+
string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
585+
string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
586+
584587
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
585588

586589
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

0 commit comments

Comments
 (0)