Skip to content

get_rows & dequantize function implementation for repacked weights of type q4_0 (q4_0x8) #3223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 154 additions & 1 deletion ggml/src/ggml-cpu/repack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,11 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR

size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);

return true;
}
case GGML_OP_GET_ROWS:
{
size = 0; // GET_ROWS (standard and repacked) doesn't need a work buffer
return true;
}
default:
Expand All @@ -1197,6 +1202,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
case GGML_OP_MUL_MAT_ID:
forward_mul_mat_id(params, op);
return true;
case GGML_OP_GET_ROWS:
forward_get_rows(params, op);
return true;
default:
// GGML_ABORT("fatal error");
break;
Expand Down Expand Up @@ -1405,6 +1413,140 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
#undef MMID_MATRIX_ROW
}

void forward_get_rows(const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_Q4_0: {
if (ggml_cpu_has_avx2()) {
if (src0->ne[1] % 8 == 0) {
ggml_compute_forward_get_rows_q4_0<block_q4_0x8>(params, dst, 8);
}
} else {
GGML_ABORT("Unsupported block interleaved size for get_rows function");
}

} break;
default:
GGML_ABORT("fatal error");
break;
}
}

template<typename BLOCK_TYPE>
static void ggml_compute_forward_get_rows_q4_0(
const ggml_compute_params * params,
ggml_tensor * dst,
int nrows_interleaved) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_TENSOR_BINARY_OP_LOCALS

const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);

assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == ggml_type_size(src0->type));
assert(ggml_nrows(dst) == nr);

const int ith = params->ith;
const int nth = params->nth;

// rows per thread
const int dr = (nr + nth - 1) / nth;

// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);

const size_t sizeof_one_repacked_block = sizeof(BLOCK_TYPE);

const int num_repacked_blocks_per_row_width = nc / QK4_0;

const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;

for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i / (ne11 * ne10);
const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row

GGML_ASSERT(i01 >= 0 && i01 < ne01);

int row_group_idx = i01 / nrows_interleaved;
const int row_idx_in_group = i01 % nrows_interleaved;

const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;

// Pointer to the first <BLOCK_TYPE> of the identified row_group_idx
const BLOCK_TYPE * p_first_repacked_block_of_group_block_type = (const BLOCK_TYPE *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);

dequantize_row_q4_0(
p_first_repacked_block_of_group_block_type,
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
}
}

/**
* Dequantizes a single logical row from data repacked with quant interleaving for repacked block_q4_0x8
*
* @param p_repacked_group_column_blocks Pointer to the start of 'block_q4_0x8' for the row group.
* @param y Output buffer for the dequantized float values.
* @param k Total number of elements (columns) in the logical row.
* @param row_idx_in_group Index (0-7) of the logical row to dequantize.
*/
static void dequantize_row_q4_0(
const block_q4_0x8 * GGML_RESTRICT p_repacked_group_column_blocks,
float * GGML_RESTRICT y,
int64_t k,
int row_idx_in_group) {
const int GGML_Q4_0_X8_INTERLEAVE_SIZE = 8;
assert(k % QK4_0 == 0);
assert(row_idx_in_group >= 0 && row_idx_in_group < GGML_Q4_0_X8_INTERLEAVE_SIZE);

const int nb = k / QK4_0;
const int bytes_for_half_elements = (QK4_0 / 2) / 2;

const int offset_to_second_half_data = bytes_for_half_elements * GGML_Q4_0_X8_INTERLEAVE_SIZE;
const uint64_t xor_mask = 0x8888888888888888ULL;
const int qk4_0_half_elements = QK4_0 / 2;

for (int i = 0; i < nb; ++i) {
const block_q4_0x8 * current_column_repacked_block = &p_repacked_group_column_blocks[i];
const float d_val = GGML_FP16_TO_FP32(current_column_repacked_block->d[row_idx_in_group]);
float * y_curr = y + i * QK4_0;

const int8_t * qs_first_half_repacked_ptr = &(current_column_repacked_block->qs[row_idx_in_group * bytes_for_half_elements]);

uint64_t first_half_chunk_u64;
memcpy(&first_half_chunk_u64, qs_first_half_repacked_ptr, sizeof(uint64_t));
first_half_chunk_u64 ^= xor_mask; // Reverse the XOR
const uint8_t * original_qs_first_half_bytes = (const uint8_t *)&first_half_chunk_u64;

const int8_t * qs_second_half_repacked_ptr = &(current_column_repacked_block->qs[offset_to_second_half_data + (row_idx_in_group * bytes_for_half_elements)]);

uint64_t second_half_chunk_u64;
memcpy(&second_half_chunk_u64, qs_second_half_repacked_ptr, sizeof(uint64_t));
second_half_chunk_u64 ^= xor_mask; // Reverse the XOR
const uint8_t * original_qs_second_half_bytes = (const uint8_t *)&second_half_chunk_u64;

// dequantizing all QK4_0's for this block.
for (int j = 0; j < bytes_for_half_elements; ++j) {
const uint8_t quant_byte_first = original_qs_first_half_bytes[j];
y_curr[j] = ((quant_byte_first & 0x0F) - 8) * d_val;
y_curr[j + qk4_0_half_elements] = ((quant_byte_first >> 4) - 8) * d_val;

const uint8_t quant_byte_second = original_qs_second_half_bytes[j];
const int out_idx_base_second_half = j + bytes_for_half_elements; // Offset for the second set of low nibbles
y_curr[out_idx_base_second_half] = ((quant_byte_second & 0x0F) - 8) * d_val;
y_curr[out_idx_base_second_half + qk4_0_half_elements] = ((quant_byte_second >> 4) - 8) * d_val;
}
}
}

int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
(int) NB_COLS, (int) INTER_SIZE);
Expand Down Expand Up @@ -1538,12 +1680,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
} else if (op->op == GGML_OP_GET_ROWS
&& op->src[0]->buffer
&& (ggml_n_dims(op->src[0]) == 2)
&& op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
&& ggml_repack_get_optimal_repack_type(op->src[0])) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[0]->type == GGML_TYPE_Q4_0) {
return true;
}
}
return false;
}

ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_GET_ROWS) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
Expand Down