|
| 1 | +#version 450 |
| 2 | + |
| 3 | +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require |
| 4 | +#extension GL_EXT_integer_dot_product : require |
| 5 | + |
| 6 | +#define MMQ |
| 7 | +#define B_TYPE block_q8_1_x4 |
| 8 | + |
| 9 | +#include "mul_mat_vec_base.comp" |
| 10 | + |
| 11 | +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; |
| 12 | + |
| 13 | +#define K_PER_ITER 8 |
| 14 | + |
| 15 | +#include "mul_mmq_funcs.comp" |
| 16 | + |
| 17 | +uint a_offset, b_offset, d_offset; |
| 18 | + |
| 19 | +int32_t cache_b_qs[2]; |
| 20 | +vec2 cache_b_ds; |
| 21 | + |
| 22 | +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { |
| 23 | + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |
| 24 | + const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; |
| 25 | + |
| 26 | + // Preload data_b block |
| 27 | + const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; |
| 28 | + const uint b_qs_idx = tid % 4; |
| 29 | + const uint b_block_idx_outer = b_block_idx / 4; |
| 30 | + const uint b_block_idx_inner = b_block_idx % 4; |
| 31 | + cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); |
| 32 | + |
| 33 | +#if QUANT_R == 2 |
| 34 | + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; |
| 35 | + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; |
| 36 | +#else |
| 37 | + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; |
| 38 | + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; |
| 39 | +#endif |
| 40 | + |
| 41 | + uint ibi = first_row*p.ncols; |
| 42 | + [[unroll]] for (uint n = 0; n < num_rows; ++n) { |
| 43 | + const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; |
| 44 | + ibi += p.ncols; |
| 45 | + |
| 46 | + int32_t q_sum = 0; |
| 47 | +#if QUANT_R == 2 |
| 48 | + const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); |
| 49 | + q_sum += dotPacked4x8EXT(data_a_qs.x, |
| 50 | + cache_b_qs[0]); |
| 51 | + q_sum += dotPacked4x8EXT(data_a_qs.y, |
| 52 | + cache_b_qs[1]); |
| 53 | +#else |
| 54 | + int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); |
| 55 | + q_sum += dotPacked4x8EXT(data_a_qs, |
| 56 | + cache_b_qs[0]); |
| 57 | + data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); |
| 58 | + q_sum += dotPacked4x8EXT(data_a_qs, |
| 59 | + cache_b_qs[1]); |
| 60 | +#endif |
| 61 | + |
| 62 | +#if QUANT_AUXF == 1 |
| 63 | + temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); |
| 64 | +#else |
| 65 | + temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); |
| 66 | +#endif |
| 67 | + } |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { |
| 72 | + const uint tid = gl_LocalInvocationID.x; |
| 73 | + |
| 74 | + get_offsets(a_offset, b_offset, d_offset); |
| 75 | + a_offset /= QUANT_K; |
| 76 | + b_offset /= QUANT_K_Q8_1; |
| 77 | + |
| 78 | + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; |
| 79 | + |
| 80 | + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |
| 81 | + [[unroll]] for (uint n = 0; n < num_rows; ++n) { |
| 82 | + temp[j][n] = FLOAT_TYPE(0.0f); |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); |
| 87 | + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { |
| 88 | + num_iters++; |
| 89 | + } |
| 90 | + int unroll_count = 4; |
| 91 | + uint unrolled_iters = num_iters & ~(unroll_count - 1); |
| 92 | + |
| 93 | + uint i = 0; |
| 94 | + while (i < unrolled_iters) { |
| 95 | + // Manually partially unroll the loop |
| 96 | + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { |
| 97 | + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); |
| 98 | + i++; |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + unroll_count = 2; |
| 103 | + unrolled_iters = num_iters & ~(unroll_count - 1); |
| 104 | + |
| 105 | +#if K_PER_ITER == 2 |
| 106 | + if ((p.ncols & 1) != 0 && |
| 107 | + unrolled_iters == num_iters && |
| 108 | + unrolled_iters > 0) { |
| 109 | + unrolled_iters -= unroll_count; |
| 110 | + } |
| 111 | +#endif |
| 112 | + |
| 113 | + while (i < unrolled_iters) { |
| 114 | + // Manually partially unroll the loop |
| 115 | + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { |
| 116 | + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); |
| 117 | + i++; |
| 118 | + } |
| 119 | + } |
| 120 | + while (i < num_iters) { |
| 121 | + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); |
| 122 | + i++; |
| 123 | + } |
| 124 | + |
| 125 | + reduce_result(temp, d_offset, first_row, num_rows, tid); |
| 126 | +} |
| 127 | + |
| 128 | +void main() { |
| 129 | + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); |
| 130 | + |
| 131 | + // do NUM_ROWS at a time, unless there aren't enough remaining rows |
| 132 | + if (first_row + NUM_ROWS <= p.stride_d) { |
| 133 | + compute_outputs(first_row, NUM_ROWS); |
| 134 | + } else { |
| 135 | + if (first_row >= p.stride_d) { |
| 136 | + return; |
| 137 | + } |
| 138 | + compute_outputs(first_row, p.stride_d - first_row); |
| 139 | + } |
| 140 | +} |
0 commit comments