Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 250 additions & 89 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

30 changes: 29 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
#if USE_SUBGROUP_ADD

#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#endif
Expand All @@ -12,10 +13,19 @@

#include "types.comp"

#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif

layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif

layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
Expand Down Expand Up @@ -92,6 +102,23 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
layout (constant_id = 2) const uint NUM_COLS = 1;

#ifdef USE_SUBGROUP_ADD_NO_SHMEM
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = subgroupAdd(temp[j][n]);
}
}

if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
}
}
#else
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];

void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
Expand Down Expand Up @@ -152,3 +179,4 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
}
#endif
}
#endif
140 changes: 140 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_integer_dot_product : require

#define MMQ
#define B_TYPE block_q8_1_x4

#include "mul_mat_vec_base.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

#define K_PER_ITER 8

#include "mul_mmq_funcs.comp"

uint a_offset, b_offset, d_offset;

int32_t cache_b_qs[2];
vec2 cache_b_ds;

void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;

// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);

#if QUANT_R == 2
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#endif

uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
ibi += p.ncols;

int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif

#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
}
}
}

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;

get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
b_offset /= QUANT_K_Q8_1;

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = FLOAT_TYPE(0.0f);
}
}

uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);

uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}

unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);

#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif

while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}

reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
17 changes: 12 additions & 5 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};

#ifdef MUL_MAT_ID
Expand Down Expand Up @@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif

#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_B 4
#define LOAD_VEC_B 16

#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
Expand Down Expand Up @@ -270,15 +270,22 @@ void main() {
const uint iqs = idx & 0x7;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;

const uint iqs = loadr_b;
#endif

const uint buf_ib = loadc_b + l;

if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}

barrier();
Expand Down Expand Up @@ -349,7 +356,7 @@ void main() {
cache_b_qs[cc * (BK / 4) + idx_k]);
}

sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
(vui >> 4) & 0x0F0F0F0F);
}

ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif

Expand All @@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
(vui >> 4) & 0x0F0F0F0F);
}

ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif

Expand All @@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1);
}

ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif

Expand All @@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1);
}

ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif

Expand All @@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
data_a[ib].qs[iqs * 2 + 1]));
}

ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#endif
Expand Down
Loading
Loading