Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class vk_memory_logger;
class vk_perf_logger;
static void ggml_vk_destroy_buffer(vk_buffer& buf);

static constexpr uint32_t mul_mat_vec_max_cols = 8;
static constexpr uint32_t mul_mat_vec_max_cols = 16;
static constexpr uint32_t p021_max_gqa_ratio = 8;

enum vk_device_architecture {
Expand Down Expand Up @@ -557,6 +557,7 @@ struct vk_device_struct {
vk_pipeline pipeline_out_prod_f16_f32;
vk_pipeline pipeline_out_prod_q4_0;
vk_pipeline pipeline_out_prod_q8_0;
vk_pipeline pipeline_out_prod_tq2_0;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
Expand Down Expand Up @@ -2798,6 +2799,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);

CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
Expand Down Expand Up @@ -3063,6 +3065,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ2_0][i], ("mul_mat_vec_tq2_0_f32_f32_"+std::to_string(i+1)).c_str(), mul_mat_vec_tq2_0_f32_f32_len, mul_mat_vec_tq2_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
Expand Down Expand Up @@ -3149,6 +3152,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ2_0], "dequant_tq2_0", dequant_tq2_0_len, dequant_tq2_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
Expand Down Expand Up @@ -3429,6 +3433,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f32, "out_prod_f32", out_prod_f32_len, out_prod_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_out_prod_q4_0, "out_prod_q4_0", out_prod_q4_0_len, out_prod_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_out_prod_q8_0, "out_prod_q8_0", out_prod_q8_0_len, out_prod_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_out_prod_tq2_0, "out_prod_tq2_0", out_prod_tq2_0_len, out_prod_tq2_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1, true);

// TODO: should we have device->subgroup_size here or 0?
ggml_vk_create_pipeline(device, device->pipeline_out_prod_f16_f32, "out_prod_f16_f32", out_prod_f16_f32_len, out_prod_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, { 0 }, 1);
Expand Down Expand Up @@ -4668,6 +4673,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -4739,6 +4745,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -4796,6 +4803,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -4875,6 +4883,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -4921,6 +4930,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -7797,6 +7807,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F32) return ctx->device->pipeline_out_prod_f32;
if (src0->type == GGML_TYPE_Q4_0) return ctx->device->pipeline_out_prod_q4_0;
if (src0->type == GGML_TYPE_Q8_0) return ctx->device->pipeline_out_prod_q8_0;
if (src0->type == GGML_TYPE_TQ2_0) return ctx->device->pipeline_out_prod_tq2_0;
}
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_out_prod_f16_f32;
Expand Down Expand Up @@ -12308,6 +12319,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
Expand Down Expand Up @@ -12502,8 +12514,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}

if (
src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32 ||
src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32
(src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) ||
(src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32)
) {
return true;
}
Expand Down Expand Up @@ -12547,6 +12559,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
return true;
default:
return false;
Expand Down
26 changes: 25 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,30 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif

#if defined(DATA_A_TQ2_0)
// TQ2_0 ternary dequantization: {0,1,2} -> {-1,0,+1} via (q-1) mapping
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
const uint c0 = (vui >> 0) & 3;
const uint c1 = (vui >> 2) & 3;
const float q0 = float(c0) - 1.0f;
const float q1 = float(c1) - 1.0f;
return vec2(q0, q1);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
const uint c0 = (vui >> 0) & 3;
const uint c1 = (vui >> 2) & 3;
const uint c2 = (vui >> 4) & 3;
const uint c3 = (vui >> 6) & 3;
const float q0 = float(c0) - 1.0f;
const float q1 = float(c1) - 1.0f;
const float q2 = float(c2) - 1.0f;
const float q3 = float(c3) - 1.0f;
return vec4(q0, q1, q2, q3);
}
#endif

#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
Expand Down Expand Up @@ -461,7 +485,7 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif

#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}
Expand Down
21 changes: 21 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
}
#endif

#if defined(DATA_A_TQ2_0)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2_0 {
block_tq2_0 block;
};

float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx / 4;
const uint iqs_offset = idx % 4;
const uint vui = uint(bl.block.qs[iqs]);
const uint c = (vui >> (2 * iqs_offset)) & 3;
const float q = float(c) - 1.0f;
float16_t ret = d * float16_t(q);
return ret;
}
#endif

#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
Expand Down Expand Up @@ -715,6 +734,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_TQ2_0)
#define dequantFuncA dequantFuncTQ2_0
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#endif
36 changes: 36 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require

#include "types.comp"

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};

layout (push_constant) uniform parameter {
uint ne;
} p;

layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint i = gl_GlobalInvocationID.x * 4;

if (i >= p.ne) {
return;
}

const uint ib = i / QUANT_K; // block index
const uint iqs = (i % QUANT_K) / 4; // quant index within block (byte index)
const uint bit_pos_base = (i % 4) * 2; // bit position within byte

const float d = float(data_a[ib].d);

for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) {
const uint local_iqs = ((i + j) % QUANT_K) / 4; // byte index for this element
const uint bit_pos = ((i + j) % 4) * 2; // bit position for this element
const uint vui = uint(data_a[ib].qs[local_iqs]);
const uint q = (vui >> bit_pos) & 3;
data_b[i + j] = D_TYPE(d * (float(q) - 1.0f));
}
}
66 changes: 66 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require

#include "mul_mat_vec_base.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

const uint num_blocks_per_row = p.ncols / QUANT_K;

const uint tid = gl_LocalInvocationID.x;

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}

[[unroll]] for (uint i = tid; i < num_blocks_per_row; i += gl_WorkGroupSize.x) {

[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row;
const float d = float(data_a[ib0 + i].d);

[[unroll]] for (uint j = 0; j < 64; j += 32) {
[[unroll]] for (uint l = 0; l < 4; ++l) {
[[unroll]] for (uint k = 0; k < 32; ++k) {
// Extract quantized value: ((x[i].qs[j + k] >> (l*2)) & 3) - 1
const uint q_byte = uint(data_a[ib0 + i].qs[j + k]);
const uint shift = l * 2;
const uint q = (q_byte >> shift) & 3;
const FLOAT_TYPE dequant_val = FLOAT_TYPE(d * (float(q) - 1.0f)); // CPU kernel: (q-1)*d

// y-data access pattern: y[i].qs[j*4 + l*32 + k]
const uint b_idx = i * QUANT_K + j * 4 + l * 32 + k;
if (b_idx < p.ncols) {
[[unroll]] for (uint jcol = 0; jcol < NUM_COLS; ++jcol) {
temp[jcol][n] += dequant_val * FLOAT_TYPE(data_b[jcol * p.batch_stride_b + b_offset + b_idx]);
}
}
}
}
}
}
}

reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
16 changes: 16 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,22 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_TQ2_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

const uint ib = idx / 128; // 2 values per idx (like Q2_K)
const uint iqs = idx % 128; // 0..127
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // Q2_K indexing pattern
const uint qsshift = ((iqs % 64) / 16) * 2; // Q2_K shift: 0,2,4,6

const float d = float(data_a[ib].d);

const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
const vec2 v = d * (vec2((qs >> qsshift) & 3) - 1.0f); // (q-1)*d

buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
Expand Down
Loading
Loading