From bce54642c8ac6ff41a55140d4f477bee77048e21 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 20 Aug 2024 15:17:24 -0400 Subject: [PATCH 01/25] imatrix : allow processing multiple chunks per batch * perplexity : simplify filling the batch --- examples/imatrix/imatrix.cpp | 94 ++++++++++++++++++++++-------- examples/perplexity/perplexity.cpp | 16 ++--- 2 files changed, 75 insertions(+), 35 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 83b85d72b043a..7b91a7e306f57 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -432,10 +432,9 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { +static bool compute_imatrix(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) { const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); - const int n_ctx = llama_n_ctx(ctx); auto tim1 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); @@ -479,22 +478,28 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { double nll = 0.0; double nll2 = 0.0; - fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch); - std::vector workers(std::thread::hardware_concurrency() - 1); const int num_batches = (n_ctx + n_batch - 1) / n_batch; + const int n_seq = std::max(1, n_batch / n_ctx); + + GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); + GGML_ASSERT(params.n_ctx == n_seq * n_ctx); + + llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); std::vector logits; if (params.compute_ppl && num_batches > 1) { logits.reserve((size_t)n_ctx * n_vocab); } - for (int i = 0; i < n_chunk; ++i) { + fprintf(stderr, "%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + + for (int i = 0; i < n_chunk; i += n_seq) { const int start = i * n_ctx; const int end = start + n_ctx; - std::vector logits; + const int n_seq_batch = std::min(n_seq, n_chunk - i); const auto t_start = std::chrono::high_resolution_clock::now(); @@ -505,35 +510,50 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; + // clear the batch + llama_batch_clear(batch); + + for (int seq = 0; seq < n_seq_batch; seq++) { + int seq_start = batch_start + seq*n_ctx; + + // save original token and restore it after eval + const auto token_org = tokens[seq_start]; - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); + } + + for (int k = 0; k < batch_size; ++k) { + // NOTE: specifying all logits to get activations for the output.weight tensor + // and also for the perplexity calculation. + // TODO: only get outputs when (params.process_output || params.compute_ppl) + // (not possible when this skips FFN computation of the last layer) + llama_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true); + } + + // restore the original token in case it was set to BOS + tokens[seq_start] = token_org; } - // TODO: use batch.logits to save computations instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } - // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; - if (params.compute_ppl && num_batches > 1) { const auto * batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } } - const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { + llama_synchronize(ctx); + const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk); + int total_seconds = (int)(t_total*n_chunk/n_seq); if (total_seconds >= 60*60) { fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); @@ -543,12 +563,21 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { if (params.compute_ppl) { const int first = n_ctx/2; - const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); - count += n_ctx - first - 1; + for (int seq = 0; seq < n_seq_batch; seq++) { + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); + + llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + process_logits(n_vocab, all_logits + first*n_vocab, + tokens_data, n_ctx - 1 - first, + workers, nll, nll2, + logit_history.data() + start + seq*n_ctx + first, + prob_history.data() + start + seq*n_ctx + first); + + count += n_ctx - first - 1; + + printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); + } fflush(stdout); logits.clear(); @@ -584,7 +613,22 @@ int main(int argc, char ** argv) { return 1; } - params.n_batch = std::min(params.n_batch, params.n_ctx); + const int32_t n_ctx = params.n_ctx; + + if (n_ctx <= 0) { + fprintf(stderr, "%s: imatrix tool requires '--ctx-size' > 0\n", __func__); + return 1; + } + + { + const int32_t n_seq = std::max(1, params.n_batch / n_ctx); + const int32_t n_kv = n_seq * n_ctx; + + params.n_parallel = n_seq; + params.n_ctx = n_kv; + + params.n_batch = std::min(params.n_batch, n_kv); + } g_collector.set_params(params); @@ -632,7 +676,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); } - if (!compute_imatrix(ctx, params)) { + if (!compute_imatrix(ctx, params, n_ctx)) { return 1; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 484dd589109c7..0bc0778fc6466 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -583,7 +583,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par int n_outputs = 0; - batch.n_tokens = 0; + // clear the batch + llama_batch_clear(batch); + for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -596,16 +598,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } for (int k = 0; k < batch_size; ++k) { - const int idx = seq*n_ctx + k; - batch.token [idx] = tokens[seq_start + k]; - batch.pos [idx] = j*n_batch + k; - batch.n_seq_id[idx] = 1; - batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; - - n_outputs += batch.logits[idx] != 0; + llama_pos pos = j*n_batch + k; + llama_batch_add(batch, tokens[seq_start + k], pos, { seq }, pos >= first); + n_outputs += (int) (pos >= first); } - batch.n_tokens += batch_size; // restore the original token in case it was set to BOS tokens[seq_start] = token_org; From 347247a24ec0db754216b7d466bac021bef9ae6a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 20 Aug 2024 15:35:56 -0400 Subject: [PATCH 02/25] imatrix : fix segfault when using a single chunk per batch --- examples/imatrix/imatrix.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 7b91a7e306f57..6135f00a7e8c1 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -564,7 +564,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons if (params.compute_ppl) { const int first = n_ctx/2; for (int seq = 0; seq < n_seq_batch; seq++) { - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx); llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; From 3de9300c3786d52fb709596a0c5ac1dc65c9f08d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 6 Sep 2024 17:17:25 -0400 Subject: [PATCH 03/25] imatrix : use GGUF to store imatrix data --- convert_legacy_imatrix_to_gguf.py | 118 +++++++++++++++ examples/imatrix/imatrix.cpp | 241 ++++++++++++++++++------------ examples/quantize/quantize.cpp | 127 ++++++++++------ gguf-py/gguf/constants.py | 7 + 4 files changed, 348 insertions(+), 145 deletions(-) create mode 100644 convert_legacy_imatrix_to_gguf.py diff --git a/convert_legacy_imatrix_to_gguf.py b/convert_legacy_imatrix_to_gguf.py new file mode 100644 index 0000000000000..939d3695b23ce --- /dev/null +++ b/convert_legacy_imatrix_to_gguf.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import os +import sys +import logging +import argparse + +from typing import Any +from pathlib import Path +from dataclasses import dataclass + +import numpy as np +import numpy.typing as npt + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + + +logger = logging.getLogger("imatrix-to-gguf") + + +class IMatrixWriter(gguf.GGUFWriter): + def add_architecture(self) -> None: + # no arch is stored in imatrix files + pass + + +@dataclass +class IMatrixEntry: + values: np.ndarray[Any, np.dtype[np.float32]] + counts: np.ndarray[Any, np.dtype[np.float32]] + + +class IMatrixReader: + chunk_size: int = 512 # guess + offset: int = 0 + data: np.ndarray[Any, np.dtype[np.uint8]] + n_enties: int + entries: dict[str, IMatrixEntry] + chunk_count: int + dataset: str + + def _get(self, dtype: npt.DTypeLike, count: int = 1) -> npt.NDArray[Any]: + count = int(count) + itemsize = int(np.empty([], dtype=dtype).itemsize) + offset = self.offset + self.offset = offset + itemsize * count + return self.data[offset:self.offset].view(dtype=dtype)[:count] + + def __init__(self, imatrix: Path): + self.offset = 0 + self.entries = {} + self.data = np.memmap(imatrix) + n_entries = self._get(np.int32).item() + assert n_entries >= 0 + for _ in range(n_entries): + len = self._get(np.int32).item() + name = self._get(np.uint8, len).tobytes().decode("utf-8") + ncall = self._get(np.int32).item() + nval = self._get(np.int32).item() + data = self._get(np.float32, nval) + assert name not in self.entries, f"duplicated name: {name!r}" + + self.entries[name] = IMatrixEntry(data, np.array([ncall * self.chunk_size], dtype=np.float32)) + + self.chunk_count = self._get(np.int32).item() + self.dataset = self._get(np.uint8, self._get(np.int32).item()).tobytes().decode("utf-8") + + def to_writer(self, outfile: Path) -> IMatrixWriter: + writer = IMatrixWriter(path=outfile, arch="") + + writer.add_type(gguf.GGUFType.IMATRIX) + writer.add_key_value(gguf.Keys.IMatrix.CHUNK_COUNT, self.chunk_count, gguf.GGUFValueType.UINT32) + writer.add_key_value(gguf.Keys.IMatrix.CHUNK_SIZE, self.chunk_size, gguf.GGUFValueType.UINT32) + writer.add_key_value(gguf.Keys.IMatrix.DATASET, self.dataset, gguf.GGUFValueType.STRING) + + for name, entry in self.entries.items(): + writer.add_tensor(name + ".sums", entry.values) + writer.add_tensor(name + ".counts", entry.counts) + + return writer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert an old imatrix.dat file to a GGUF compatible file") + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input.", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + parser.add_argument( + "imatrix", type=Path, + help="path to an imatrix file", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + if args.outfile is None: + input_file: Path = args.imatrix + if input_file.suffix != ".gguf": + args.outfile = input_file.with_suffix(".gguf") + + writer = IMatrixReader(args.imatrix).to_writer(args.outfile) + + writer.write_header_to_file(args.outfile) + writer.write_kv_data_to_file() + writer.write_tensors_to_file() diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 6135f00a7e8c1..2314a035d04fe 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -5,11 +5,9 @@ #include #include #include -#include #include #include #include -#include #include #include @@ -22,16 +20,19 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) { LOG_TEE("\nexample usage:\n"); LOG_TEE("\n %s \\\n" - " -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \\\n" + " -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] [--verbosity 1] \\\n" " [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n" - " [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]); + " [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...]\n" , argv[0]); LOG_TEE("\n"); } +static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; +static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; +static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; + struct Stats { - std::vector values; - std::vector counts; - int ncall = 0; + std::vector values; + std::vector counts; }; class IMatrixCollector { @@ -39,13 +40,13 @@ class IMatrixCollector { IMatrixCollector() = default; void set_params(gpt_params params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); - void save_imatrix(int ncall = -1) const; + void save_imatrix(int32_t n_chunk = -1) const; bool load_imatrix(const char * file_name); private: std::unordered_map m_stats; gpt_params m_params; std::mutex m_mutex; - int m_last_call = 0; + int32_t m_last_chunk = 0; std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id }; @@ -119,18 +120,24 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto & e = m_stats[wname]; - ++e.ncall; - + if (e.counts.size() == 1 && n_as > 1) { + // broadcast, when loading an old imatrix + e.counts.resize(n_as, e.counts[0]); + } if (e.values.empty()) { e.values.resize(src1->ne[0]*n_as, 0); - e.counts.resize(src1->ne[0]*n_as, 0); + e.counts.resize(n_as, 0); } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); exit(1); //GGML_ABORT("fatal error"); } + else if (e.counts.size() != (size_t)n_as) { + fprintf(stderr, "Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), (int)n_as); + exit(1); //GGML_ABORT("fatal error"); + } if (m_params.verbosity > 1) { - printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); + printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); } // loop over all possible experts, regardless if they are used or not in the batch for (int ex = 0; ex < n_as; ++ex) { @@ -148,23 +155,26 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const int64_t i12 = row; const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]); + e.counts[ex]++; + for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[e_start + j] += x[j]*x[j]; - e.counts[e_start + j]++; - if (!std::isfinite(e.values[e_start + j])) { - fprintf(stderr, "%f detected in %s\n", e.values[e_start + j], wname.c_str()); + if (!std::isfinite((float)e.values[e_start + j])) { + fprintf(stderr, "%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); exit(1); } } } } - if (e.ncall > m_last_call) { - m_last_call = e.ncall; - if (m_last_call % m_params.n_out_freq == 0) { + const int32_t n_chunk = e.counts[ex] / (m_params.n_ctx / m_params.n_parallel); + if (n_chunk > m_last_chunk) { + const int32_t chunk_step = n_chunk - m_last_chunk; + m_last_chunk = n_chunk; + if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) { save_imatrix(); } - if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { - save_imatrix(m_last_call); + if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) { + save_imatrix(m_last_chunk); } } } @@ -172,34 +182,40 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto & e = m_stats[wname]; if (e.values.empty()) { e.values.resize(src1->ne[0], 0); - e.counts.resize(src1->ne[0], 0); + e.counts.resize(1, 0); } else if (e.values.size() != (size_t)src1->ne[0]) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); exit(1); //GGML_ABORT("fatal error"); } - ++e.ncall; + else if (e.counts.size() != 1) { + fprintf(stderr, "Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), 1); + exit(1); //GGML_ABORT("fatal error"); + } if (m_params.verbosity > 1) { - printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); + printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); } + // TODO: higher dimensions for (int row = 0; row < (int)src1->ne[1]; ++row) { const float * x = data + row * src1->ne[0]; + e.counts[0]++; for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[j] += x[j]*x[j]; - e.counts[j]++; - if (!std::isfinite(e.values[j])) { - fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str()); + if (!std::isfinite((float)e.values[j])) { + fprintf(stderr, "%f detected in %s\n", (float)e.values[j], wname.c_str()); exit(1); } } } - if (e.ncall > m_last_call) { - m_last_call = e.ncall; - if (m_last_call % m_params.n_out_freq == 0) { + const int32_t n_chunk = e.counts[0] / (m_params.n_ctx / m_params.n_parallel); + if (n_chunk > m_last_chunk) { + const int32_t chunk_step = n_chunk - m_last_chunk; + m_last_chunk = n_chunk; + if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) { save_imatrix(); } - if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { - save_imatrix(m_last_call); + if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) { + save_imatrix(m_last_chunk); } } } @@ -207,15 +223,15 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * return true; } -void IMatrixCollector::save_imatrix(int ncall) const { +void IMatrixCollector::save_imatrix(int32_t n_chunk) const { auto fname = m_params.out_file; if (fname.empty()) { - fname = "imatrix.dat"; + fname = "imatrix.gguf"; } - if (ncall > 0) { + if (n_chunk > 0) { fname += ".at_"; - fname += std::to_string(ncall); + fname += std::to_string(n_chunk); } // avoid writing imatrix entries that do not have full data @@ -223,6 +239,7 @@ void IMatrixCollector::save_imatrix(int ncall) const { int n_entries = 0; std::vector to_store; + size_t data_size = 0; bool is_first = true; // for printing for (const auto & kv : m_stats) { @@ -256,100 +273,132 @@ void IMatrixCollector::save_imatrix(int ncall) const { n_entries++; to_store.push_back(kv.first); + data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN); + data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN); } if (to_store.size() < m_stats.size()) { fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size()); } - std::ofstream out(fname, std::ios::binary); - out.write((const char *) &n_entries, sizeof(n_entries)); + struct ggml_init_params params = { + .mem_size = data_size, + .mem_buffer = NULL, + .no_alloc = false, + }; + struct ggml_context * ctx = ggml_init(params); + struct gguf_context * ctx_gguf = gguf_init_empty(); + + gguf_set_val_str(ctx_gguf, "general.type", "imatrix"); + // Write the input filename to later on specify it in quantize + gguf_set_val_str(ctx_gguf, LLM_KV_IMATRIX_DATASET, m_params.prompt_file.c_str()); + // Write the number of chunks the matrix was computed with + gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT, m_last_chunk); + gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE, m_params.n_ctx / m_params.n_parallel); + for (const auto & name : to_store) { const auto & stat = m_stats.at(name); - int len = name.size(); - out.write((const char *) &len, sizeof(len)); - out.write(name.c_str(), len); - out.write((const char *) &stat.ncall, sizeof(stat.ncall)); - int nval = stat.values.size(); - out.write((const char *) &nval, sizeof(nval)); + const int32_t nval = (int32_t) stat.values.size(); + const int32_t nmat = (int32_t) stat.counts.size(); if (nval > 0) { - std::vector tmp(nval); - for (int i = 0; i < nval; i++) { - tmp[i] = (stat.values[i] / static_cast(stat.counts[i])) * static_cast(stat.ncall); + struct ggml_tensor * sums = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nval / nmat, nmat); + struct ggml_tensor * counts = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, nmat); + ggml_set_name(sums, (name + ".sums").c_str()); + ggml_set_name(counts, (name + ".counts").c_str()); + + for (int32_t j = 0; j < nval; ++j) { + ((float *) sums->data)[j] = (float) stat.values[j]; + } + for (int32_t j = 0; j < nmat; ++j) { + ((float *) counts->data)[j] = (float) stat.counts[j]; } - out.write((const char*)tmp.data(), nval*sizeof(float)); + + gguf_add_tensor(ctx_gguf, sums); + gguf_add_tensor(ctx_gguf, counts); } } - // Write the number of call the matrix was computed with - out.write((const char *) &m_last_call, sizeof(m_last_call)); - - // Write the input filename at the end of the file to later on specify it in quantize - { - int len = m_params.prompt_file.size(); - out.write((const char *) &len, sizeof(len)); - out.write(m_params.prompt_file.c_str(), len); - } + gguf_write_to_file(ctx_gguf, fname.c_str(), false); if (m_params.verbosity > 0) { - fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str()); + fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str()); } + + gguf_free(ctx_gguf); + ggml_free(ctx); } -bool IMatrixCollector::load_imatrix(const char * fname) { - std::ifstream in(fname, std::ios::binary); - if (!in) { - printf("%s: failed to open %s\n",__func__, fname); +bool IMatrixCollector::load_imatrix(const char * file_name) { + struct ggml_context * ctx = nullptr; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ false, // the data is needed + /* .ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(file_name, meta_gguf_params); + if (!ctx_gguf) { return false; } - int n_entries; - in.read((char*)&n_entries, sizeof(n_entries)); - if (in.fail() || n_entries < 1) { - printf("%s: no data in file %s\n", __func__, fname); + const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); + if (n_entries < 2) { + fprintf(stderr, "%s: no data in file %s\n", __func__, file_name); + gguf_free(ctx_gguf); + ggml_free(ctx); return false; } - for (int i = 0; i < n_entries; ++i) { - int len; in.read((char *)&len, sizeof(len)); - std::vector name_as_vec(len+1); - in.read((char *)name_as_vec.data(), len); - if (in.fail()) { - printf("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname); + + const std::string sums_suffix{".sums"}; + const std::string counts_suffix{".counts"}; + + // TODO: allow loading from mis-ordered imatrix files + for (int32_t i = 0; i < n_entries - 1; i += 2) { + std::string sums_name{gguf_get_tensor_name(ctx_gguf, i + 0)}; + std::string counts_name{gguf_get_tensor_name(ctx_gguf, i + 1)}; + + if (sums_name.size() < sums_suffix.size() || + counts_name.size() < counts_suffix.size() || + !std::equal(sums_name.begin(), sums_name.end() - sums_suffix.size(), counts_name.begin()) || + !std::equal(sums_suffix.rbegin(), sums_suffix.rend(), sums_name.rbegin()) || + !std::equal(counts_suffix.rbegin(), counts_suffix.rend(), counts_name.rbegin())) { + fprintf(stderr, "%s: mismatched sums and counts for entry %d\n", __func__, i / 2); + gguf_free(ctx_gguf); + ggml_free(ctx); return false; } - name_as_vec[len] = 0; - std::string name{name_as_vec.data()}; - auto & e = m_stats[std::move(name)]; - int ncall; - in.read((char*)&ncall, sizeof(ncall)); - int nval; - in.read((char *)&nval, sizeof(nval)); - if (in.fail() || nval < 1) { - printf("%s: failed reading number of values for entry %d\n",__func__,i); - m_stats = {}; + + struct ggml_tensor * sums = ggml_get_tensor(ctx, sums_name.c_str()); + struct ggml_tensor * counts = ggml_get_tensor(ctx, counts_name.c_str()); + if (!sums || !counts) { + fprintf(stderr, "%s: failed reading data for entry %d\n", __func__, i / 2); + gguf_free(ctx_gguf); + ggml_free(ctx); return false; } + std::string name = sums_name.substr(0, sums_name.size() - sums_suffix.size()); + auto & e = m_stats[name]; + + int32_t nval = ggml_nelements(sums); if (e.values.empty()) { e.values.resize(nval, 0); - e.counts.resize(nval, 0); } - - std::vector tmp(nval); - in.read((char*)tmp.data(), nval*sizeof(float)); - if (in.fail()) { - printf("%s: failed reading data for entry %d\n",__func__,i); - m_stats = {}; - return false; + int32_t ncounts = ggml_nelements(counts); + if (e.counts.empty()) { + e.counts.resize(ncounts, 0); + } else if (e.counts.size() == 1 && ncounts > 1) { + // broadcast, when loading an old imatrix + e.counts.resize(ncounts, e.counts[0]); } - // Recreate the state as expected by save_imatrix(), and corerct for weighted sum. - for (int i = 0; i < nval; i++) { - e.values[i] += tmp[i]; - e.counts[i] += ncall; + // Recreate the state as expected by save_imatrix() + for (int32_t j = 0; j < nval; j++) { + e.values[j] += ((const float *) sums->data)[j]; + } + for (int32_t j = 0; j < ncounts; j++) { + e.counts[j] += std::lround(((const float *) counts->data)[j]); } - e.ncall += ncall; - } + gguf_free(ctx_gguf); + ggml_free(ctx); return true; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7312309aeef98..2df073d45e8f1 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include struct quant_option { std::string name; @@ -61,6 +59,11 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; +// TODO: share with imatrix.cpp +static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; +static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; +static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; + static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) { std::string ftype_str; @@ -121,66 +124,92 @@ static void usage(const char * executable) { } static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map> & imatrix_data) { - std::ifstream in(imatrix_file.c_str(), std::ios::binary); - if (!in) { - printf("%s: failed to open %s\n",__func__, imatrix_file.c_str()); + + struct ggml_context * ctx = nullptr; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ false, // the data is needed + /* .ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params); + if (!ctx_gguf) { exit(1); } - int n_entries; - in.read((char *)&n_entries, sizeof(n_entries)); - if (in.fail() || n_entries < 1) { - printf("%s: no data in file %s\n", __func__, imatrix_file.c_str()); + const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); + if (n_entries < 2) { + fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str()); + gguf_free(ctx_gguf); + ggml_free(ctx); exit(1); } - for (int i = 0; i < n_entries; ++i) { - int len; in.read((char *)&len, sizeof(len)); - std::vector name_as_vec(len+1); - in.read((char *)name_as_vec.data(), len); - if (in.fail()) { - printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str()); - exit(1); - } - name_as_vec[len] = 0; - std::string name{name_as_vec.data()}; - auto & e = imatrix_data[name]; - int ncall; - in.read((char *)&ncall, sizeof(ncall)); - int nval; - in.read((char *)&nval, sizeof(nval)); - if (in.fail() || nval < 1) { - printf("%s: failed reading number of values for entry %d\n", __func__, i); - imatrix_data = {}; + + const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASET); + const int chunk_count_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT); + const int chunk_size_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE); + if (dataset_idx < 0 || chunk_count_idx < 0 || chunk_size_idx < 0) { + fprintf(stderr, "%s: missing imatrix metadata in file %s\n", __func__, imatrix_file.c_str()); + gguf_free(ctx_gguf); + ggml_free(ctx); + exit(1); + } + + const uint32_t chunk_size = gguf_get_val_u32(ctx_gguf, chunk_size_idx); + + const std::string sums_suffix{".sums"}; + const std::string counts_suffix{".counts"}; + + // TODO: allow loading from mis-ordered imatrix files + for (int32_t i = 0; i < n_entries - 1; i += 2) { + std::string sums_name{gguf_get_tensor_name(ctx_gguf, i + 0)}; + std::string counts_name{gguf_get_tensor_name(ctx_gguf, i + 1)}; + + if (sums_name.size() < sums_suffix.size() || + counts_name.size() < counts_suffix.size() || + !std::equal(sums_name.begin(), sums_name.end() - sums_suffix.size(), counts_name.begin()) || + !std::equal(sums_suffix.rbegin(), sums_suffix.rend(), sums_name.rbegin()) || + !std::equal(counts_suffix.rbegin(), counts_suffix.rend(), counts_name.rbegin())) { + fprintf(stderr, "%s: mismatched sums and counts for entry %d\n", __func__, i / 2); + gguf_free(ctx_gguf); + ggml_free(ctx); exit(1); } - e.resize(nval); - in.read((char *)e.data(), nval*sizeof(float)); - if (in.fail()) { - printf("%s: failed reading data for entry %d\n", __func__, i); - imatrix_data = {}; + + struct ggml_tensor * sums = ggml_get_tensor(ctx, sums_name.c_str()); + struct ggml_tensor * counts = ggml_get_tensor(ctx, counts_name.c_str()); + if (!sums || !counts) { + fprintf(stderr, "%s: failed reading data for entry %d\n", __func__, i / 2); + gguf_free(ctx_gguf); + ggml_free(ctx); exit(1); } - if (ncall > 0) { - for (auto& v : e) v /= ncall; - } + const int64_t ne0 = sums->ne[0]; + const int64_t ne1 = sums->ne[1]; + std::string name = sums_name.substr(0, sums_name.size() - sums_suffix.size()); + auto & e = imatrix_data[name]; + e.resize(ggml_nelements(sums)); + float max_count = 0.0f; + for (int64_t j = 0; j < ne1; ++j) { + const float count = ((const float *) counts->data)[ne1]; + for (int64_t i = 0; i < ne0; ++i) { + e[ne1*ne0 + ne0] = ((const float *) sums->data)[ne1*ne0 + ne0] / count; + } + if (count > max_count) { + max_count = count; + } + } if (getenv("LLAMA_TRACE")) { - printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str()); + printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), int(max_count / chunk_size), name.c_str()); } } + gguf_free(ctx_gguf); + ggml_free(ctx); - // latest imatrix version contains the dataset filename at the end of the file - int m_last_call = 0; - if (in.peek() != EOF) { - in.read((char *)&m_last_call, sizeof(m_last_call)); - int dataset_len; - in.read((char *)&dataset_len, sizeof(dataset_len)); - std::vector dataset_as_vec(dataset_len); - in.read(dataset_as_vec.data(), dataset_len); - imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end()); - printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); - } - printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call); - return m_last_call; + int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx); + imatrix_dataset = gguf_get_val_str(ctx_gguf, dataset_idx); + + printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); + printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk); + return m_last_chunk; } static int prepare_imatrix(const std::string & imatrix_file, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5541972ce52b0..4fdeddb7c6648 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -167,6 +167,12 @@ class Adapter: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" + class IMatrix: + CHUNK_COUNT = "imatrix.chunk_count" + CHUNK_SIZE = "imatrix.chunk_size" + DATASET = "imatrix.dataset" + + # # recommended mapping of model tensor names for storage in gguf # @@ -175,6 +181,7 @@ class Adapter: class GGUFType: MODEL = "model" ADAPTER = "adapter" + IMATRIX = "imatrix" class MODEL_ARCH(IntEnum): From c8ab6a3ba356e902b94499baaf7ab0191c3b6afe Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 8 Sep 2024 10:04:01 -0400 Subject: [PATCH 04/25] imatrix : fix conversion problems --- convert_legacy_imatrix_to_gguf.py | 8 ++++++-- examples/imatrix/imatrix.cpp | 2 +- examples/quantize/quantize.cpp | 11 +++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/convert_legacy_imatrix_to_gguf.py b/convert_legacy_imatrix_to_gguf.py index 939d3695b23ce..bd72655bf2cc7 100644 --- a/convert_legacy_imatrix_to_gguf.py +++ b/convert_legacy_imatrix_to_gguf.py @@ -64,10 +64,11 @@ def __init__(self, imatrix: Path): data = self._get(np.float32, nval) assert name not in self.entries, f"duplicated name: {name!r}" - self.entries[name] = IMatrixEntry(data, np.array([ncall * self.chunk_size], dtype=np.float32)) + self.entries[name] = IMatrixEntry(data * np.float32(self.chunk_size), np.array([ncall * self.chunk_size], dtype=np.float32)) self.chunk_count = self._get(np.int32).item() - self.dataset = self._get(np.uint8, self._get(np.int32).item()).tobytes().decode("utf-8") + dataset_len = self._get(np.int32).item() + self.dataset = self._get(np.uint8, dataset_len).tobytes().decode("utf-8") def to_writer(self, outfile: Path) -> IMatrixWriter: writer = IMatrixWriter(path=outfile, arch="") @@ -110,6 +111,9 @@ def parse_args(): input_file: Path = args.imatrix if input_file.suffix != ".gguf": args.outfile = input_file.with_suffix(".gguf") + if args.outfile.exists(): + logger.error(f"default file exists, specify with --outfile to overwrite: {args.outfile}") + exit(1) writer = IMatrixReader(args.imatrix).to_writer(args.outfile) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 2314a035d04fe..fea97918a682d 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -31,7 +31,7 @@ static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; struct Stats { - std::vector values; + std::vector values; std::vector counts; }; diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 2df073d45e8f1..4f7003194d54b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -132,6 +132,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ }; struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params); if (!ctx_gguf) { + fprintf(stderr, "%s: if this is an older imatrix file, make sure to convert it to the GGUF-based imatrix format\n", __func__); exit(1); } const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); @@ -189,9 +190,9 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ e.resize(ggml_nelements(sums)); float max_count = 0.0f; for (int64_t j = 0; j < ne1; ++j) { - const float count = ((const float *) counts->data)[ne1]; + const float count = ((const float *) counts->data)[j]; for (int64_t i = 0; i < ne0; ++i) { - e[ne1*ne0 + ne0] = ((const float *) sums->data)[ne1*ne0 + ne0] / count; + e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; } if (count > max_count) { max_count = count; @@ -201,14 +202,16 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), int(max_count / chunk_size), name.c_str()); } } - gguf_free(ctx_gguf); - ggml_free(ctx); int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx); imatrix_dataset = gguf_get_val_str(ctx_gguf, dataset_idx); printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk); + + gguf_free(ctx_gguf); + ggml_free(ctx); + return m_last_chunk; } From d19101c9a0e38359a303127bb5ccde47395ee083 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 8 Sep 2024 11:03:59 -0400 Subject: [PATCH 05/25] imatrix : use FMA and sort tensor names --- examples/imatrix/imatrix.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e170714d84871..90ff9280cdda8 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -156,7 +156,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts[ex]++; for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[e_start + j] += x[j]*x[j]; + e.values[e_start + j] = std::fma(x[j], x[j], e.values[e_start + j]); if (!std::isfinite((float)e.values[e_start + j])) { fprintf(stderr, "%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); exit(1); @@ -198,7 +198,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const float * x = data + row * src1->ne[0]; e.counts[0]++; for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[j] += x[j]*x[j]; + e.values[j] = std::fma(x[j], x[j], e.values[j]); if (!std::isfinite((float)e.values[j])) { fprintf(stderr, "%f detected in %s\n", (float)e.values[j], wname.c_str()); exit(1); @@ -279,6 +279,9 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size()); } + // deterministic tensor name order + std::sort(to_store.begin(), to_store.end()); + struct ggml_init_params params = { .mem_size = data_size, .mem_buffer = NULL, From 503630e88a782184ecf42aaeb34cddd6cf6e9107 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 9 Sep 2024 21:56:04 -0400 Subject: [PATCH 06/25] py : add requirements for legacy imatrix convert script --- requirements/requirements-convert_legacy_imatrix_to_gguf.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 requirements/requirements-convert_legacy_imatrix_to_gguf.txt diff --git a/requirements/requirements-convert_legacy_imatrix_to_gguf.txt b/requirements/requirements-convert_legacy_imatrix_to_gguf.txt new file mode 100644 index 0000000000000..afe2747d448d4 --- /dev/null +++ b/requirements/requirements-convert_legacy_imatrix_to_gguf.txt @@ -0,0 +1 @@ +-r ./requirements-convert_legacy_llama.txt From 9e6b0e9419eb9738af12c8425d979831704c0d4b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 9 Sep 2024 22:00:37 -0400 Subject: [PATCH 07/25] perplexity : revert changes --- examples/perplexity/perplexity.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index ab5c51352fdf4..570ee8aeba4ae 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -583,9 +583,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par int n_outputs = 0; - // clear the batch - llama_batch_clear(batch); - + batch.n_tokens = 0; for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -598,10 +596,16 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } for (int k = 0; k < batch_size; ++k) { - llama_pos pos = j*n_batch + k; - llama_batch_add(batch, tokens[seq_start + k], pos, { seq }, pos >= first); - n_outputs += (int) (pos >= first); + const int idx = seq*n_ctx + k; + batch.token [idx] = tokens[seq_start + k]; + batch.pos [idx] = j*n_batch + k; + batch.n_seq_id[idx] = 1; + batch.seq_id [idx][0] = seq; + batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + + n_outputs += batch.logits[idx] != 0; } + batch.n_tokens += batch_size; // restore the original token in case it was set to BOS tokens[seq_start] = token_org; From 894ed8d7b68164852ab1b61600dbc6126d3deb40 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 9 Sep 2024 22:20:18 -0400 Subject: [PATCH 08/25] py : include imatrix converter requirements in toplevel requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9e190ae27de38..98c53db8179e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,6 @@ -r ./requirements/requirements-convert_hf_to_gguf.txt -r ./requirements/requirements-convert_hf_to_gguf_update.txt +-r ./requirements/requirements-convert_legacy_imatrix_to_gguf.txt -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt -r ./requirements/requirements-convert_lora_to_gguf.txt From efa9186dc861bed7f06057480ab3f208e588a99f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 9 Sep 2024 22:33:10 -0400 Subject: [PATCH 09/25] imatrix : avoid using designated initializers in C++ --- examples/imatrix/imatrix.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 90ff9280cdda8..758542905f610 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -283,9 +283,9 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { std::sort(to_store.begin(), to_store.end()); struct ggml_init_params params = { - .mem_size = data_size, - .mem_buffer = NULL, - .no_alloc = false, + /* .mem_size = */ data_size, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ false, }; struct ggml_context * ctx = ggml_init(params); struct gguf_context * ctx_gguf = gguf_init_empty(); From 221724705191554b14f112162414f276bfeb2b17 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 9 Sep 2024 22:35:47 -0400 Subject: [PATCH 10/25] imatrix : remove unused n_entries --- examples/imatrix/imatrix.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 758542905f610..bcdc711533d83 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -235,7 +235,6 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { // avoid writing imatrix entries that do not have full data // this can happen with MoE models where some of the experts end up not being exercised by the provided training data - int n_entries = 0; std::vector to_store; size_t data_size = 0; @@ -269,7 +268,6 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { continue; } - n_entries++; to_store.push_back(kv.first); data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN); data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN); From 8c13e16bb0f6b654ff4774e54fbc3b125ae495a6 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 10 Sep 2024 11:31:49 -0400 Subject: [PATCH 11/25] imatrix : allow loading mis-ordered tensors Sums and counts tensors no longer need to be consecutive. * imatrix : more sanity checks when loading multiple imatrix files * imatrix : use ggml_format_name instead of std::string concatenation Co-authored-by: Xuan Son Nguyen --- examples/imatrix/imatrix.cpp | 75 ++++++++++++++++++++++++---------- examples/quantize/quantize.cpp | 53 ++++++++++++++++-------- 2 files changed, 89 insertions(+), 39 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index bcdc711533d83..0e4cc8e683ec4 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #if defined(_MSC_VER) @@ -24,6 +25,14 @@ static void print_usage(int, char ** argv) { LOG_TEE("\n"); } +static bool str_remove_suffix(std::string & str, const std::string & suffix) { + bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; + if (has_suffix) { + str = str.substr(0, str.size() - suffix.size()); + } + return has_suffix; +} + static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; @@ -302,8 +311,8 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { if (nval > 0) { struct ggml_tensor * sums = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nval / nmat, nmat); struct ggml_tensor * counts = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, nmat); - ggml_set_name(sums, (name + ".sums").c_str()); - ggml_set_name(counts, (name + ".counts").c_str()); + ggml_format_name(sums, "%s.sums", name.c_str()); + ggml_format_name(counts, "%s.counts", name.c_str()); for (int32_t j = 0; j < nval; ++j) { ((float *) sums->data)[j] = (float) stat.values[j]; @@ -338,7 +347,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { return false; } const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); - if (n_entries < 2) { + if (n_entries < 1) { fprintf(stderr, "%s: no data in file %s\n", __func__, file_name); gguf_free(ctx_gguf); ggml_free(ctx); @@ -348,51 +357,73 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { const std::string sums_suffix{".sums"}; const std::string counts_suffix{".counts"}; - // TODO: allow loading from mis-ordered imatrix files - for (int32_t i = 0; i < n_entries - 1; i += 2) { - std::string sums_name{gguf_get_tensor_name(ctx_gguf, i + 0)}; - std::string counts_name{gguf_get_tensor_name(ctx_gguf, i + 1)}; - - if (sums_name.size() < sums_suffix.size() || - counts_name.size() < counts_suffix.size() || - !std::equal(sums_name.begin(), sums_name.end() - sums_suffix.size(), counts_name.begin()) || - !std::equal(sums_suffix.rbegin(), sums_suffix.rend(), sums_name.rbegin()) || - !std::equal(counts_suffix.rbegin(), counts_suffix.rend(), counts_name.rbegin())) { - fprintf(stderr, "%s: mismatched sums and counts for entry %d\n", __func__, i / 2); + // Could re-use m_stats instead, but this allows + // checking for completeness of *each* loaded imatrix file + // and also makes it easier to re-use a similar implementation in quantize.cpp + // Using an ordered map to get a deterministic iteration order. + std::map> sums_counts_for; + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string name = cur->name; + + if (name.empty()) { continue; } + + if (str_remove_suffix(name, sums_suffix)) { + // sums + sums_counts_for[name].first = cur; + } else if (str_remove_suffix(name, counts_suffix)) { + // counts + sums_counts_for[name].second = cur; + } else { + fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); return false; } + } + + for (const auto & sc : sums_counts_for) { + const std::string & name = sc.first; + const struct ggml_tensor * sums = sc.second.first; + const struct ggml_tensor * counts = sc.second.second; - struct ggml_tensor * sums = ggml_get_tensor(ctx, sums_name.c_str()); - struct ggml_tensor * counts = ggml_get_tensor(ctx, counts_name.c_str()); if (!sums || !counts) { - fprintf(stderr, "%s: failed reading data for entry %d\n", __func__, i / 2); + fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); return false; } - std::string name = sums_name.substr(0, sums_name.size() - sums_suffix.size()); auto & e = m_stats[name]; - int32_t nval = ggml_nelements(sums); + int64_t nval = ggml_nelements(sums); if (e.values.empty()) { e.values.resize(nval, 0); + } else if ((size_t) nval != e.values.size()) { + fprintf(stderr, "%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size()); + gguf_free(ctx_gguf); + ggml_free(ctx); + return false; } - int32_t ncounts = ggml_nelements(counts); + + int64_t ncounts = ggml_nelements(counts); if (e.counts.empty()) { e.counts.resize(ncounts, 0); } else if (e.counts.size() == 1 && ncounts > 1) { // broadcast, when loading an old imatrix e.counts.resize(ncounts, e.counts[0]); + } else if ((size_t) ncounts != e.counts.size()) { + fprintf(stderr, "%s: mismatched counts size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) ncounts, e.counts.size()); + gguf_free(ctx_gguf); + ggml_free(ctx); + return false; } // Recreate the state as expected by save_imatrix() - for (int32_t j = 0; j < nval; j++) { + for (int64_t j = 0; j < nval; j++) { e.values[j] += ((const float *) sums->data)[j]; } - for (int32_t j = 0; j < ncounts; j++) { + for (int64_t j = 0; j < ncounts; j++) { e.counts[j] += std::lround(((const float *) counts->data)[j]); } } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 78f55c4dfe556..99887cc7e1590 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -6,6 +6,7 @@ #include #include #include +#include struct quant_option { std::string name; @@ -125,6 +126,15 @@ static void usage(const char * executable) { exit(1); } +// TODO: share with implementation in imatrix.cpp +static bool str_remove_suffix(std::string & str, const std::string & suffix) { + bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; + if (has_suffix) { + str = str.substr(0, str.size() - suffix.size()); + } + return has_suffix; +} + static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map> & imatrix_data) { struct ggml_context * ctx = nullptr; @@ -138,7 +148,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ exit(1); } const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); - if (n_entries < 2) { + if (n_entries < 1) { fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); @@ -160,26 +170,35 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ const std::string sums_suffix{".sums"}; const std::string counts_suffix{".counts"}; - // TODO: allow loading from mis-ordered imatrix files - for (int32_t i = 0; i < n_entries - 1; i += 2) { - std::string sums_name{gguf_get_tensor_name(ctx_gguf, i + 0)}; - std::string counts_name{gguf_get_tensor_name(ctx_gguf, i + 1)}; - - if (sums_name.size() < sums_suffix.size() || - counts_name.size() < counts_suffix.size() || - !std::equal(sums_name.begin(), sums_name.end() - sums_suffix.size(), counts_name.begin()) || - !std::equal(sums_suffix.rbegin(), sums_suffix.rend(), sums_name.rbegin()) || - !std::equal(counts_suffix.rbegin(), counts_suffix.rend(), counts_name.rbegin())) { - fprintf(stderr, "%s: mismatched sums and counts for entry %d\n", __func__, i / 2); + // Using an ordered map to get a deterministic iteration order. + std::map> sums_counts_for; + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string name = cur->name; + + if (name.empty()) { continue; } + + if (str_remove_suffix(name, sums_suffix)) { + // sums + sums_counts_for[name].first = cur; + } else if (str_remove_suffix(name, counts_suffix)) { + // counts + sums_counts_for[name].second = cur; + } else { + fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); exit(1); } + } + + for (const auto & sc : sums_counts_for) { + const std::string & name = sc.first; + const struct ggml_tensor * sums = sc.second.first; + const struct ggml_tensor * counts = sc.second.second; - struct ggml_tensor * sums = ggml_get_tensor(ctx, sums_name.c_str()); - struct ggml_tensor * counts = ggml_get_tensor(ctx, counts_name.c_str()); if (!sums || !counts) { - fprintf(stderr, "%s: failed reading data for entry %d\n", __func__, i / 2); + fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); exit(1); @@ -187,7 +206,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ const int64_t ne0 = sums->ne[0]; const int64_t ne1 = sums->ne[1]; - std::string name = sums_name.substr(0, sums_name.size() - sums_suffix.size()); + auto & e = imatrix_data[name]; e.resize(ggml_nelements(sums)); float max_count = 0.0f; @@ -201,7 +220,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ } } if (getenv("LLAMA_TRACE")) { - printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), int(max_count / chunk_size), name.c_str()); + printf("%s: loaded data (size = %6d, n_tokens = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), name.c_str()); } } From 2d79a7077cb9a7218c1f40d637233658db7349e0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 10 Sep 2024 12:09:17 -0400 Subject: [PATCH 12/25] quantize : use unused imatrix chunk_size with LLAMA_TRACE --- examples/quantize/quantize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 99887cc7e1590..0cde695ed5046 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -220,7 +220,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ } } if (getenv("LLAMA_TRACE")) { - printf("%s: loaded data (size = %6d, n_tokens = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), name.c_str()); + printf("%s: loaded data (size = %6d, n_tokens = %6d, n_chunks = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), int(max_count / chunk_size), name.c_str()); } } From c7a32e761dc96c559d6227b1fe4996ae1445b07a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 30 Jan 2025 19:56:20 -0500 Subject: [PATCH 13/25] common : use GGUF for imatrix output by default --- common/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/common.h b/common/common.h index d7c08f20a124b..d10ec6235ed18 100644 --- a/common/common.h +++ b/common/common.h @@ -275,7 +275,7 @@ struct gpt_params { int32_t i_pos = -1; // position of the passkey in the junk text // imatrix params - std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file + std::string out_file = "imatrix.gguf"; // save the resulting imatrix to this file int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations From a5165a6ca93a16270deda5feea7a1ae3f876b793 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 15 Apr 2025 17:29:57 -0400 Subject: [PATCH 14/25] imatrix : two-way conversion between old format and GGUF --- examples/imatrix/imatrix.cpp | 293 ++++++++++++++++++++++++++++----- examples/quantize/quantize.cpp | 145 ++++++++++++---- 2 files changed, 365 insertions(+), 73 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 1b537407f5fcb..f49bf9ec41e18 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -29,15 +30,19 @@ static void print_usage(int, char ** argv) { LOG("\n"); } +static bool str_has_suffix(const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; +} + static bool str_remove_suffix(std::string & str, const std::string & suffix) { - bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; + bool has_suffix = str_has_suffix(str, suffix); if (has_suffix) { str = str.substr(0, str.size() - suffix.size()); } return has_suffix; } -static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; +static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets"; static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; @@ -51,12 +56,15 @@ class IMatrixCollector { IMatrixCollector() = default; void set_params(common_params params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); + void save_imatrix_legacy(int32_t ncall = -1) const; void save_imatrix(int32_t n_chunk = -1) const; + bool load_imatrix_legacy(const char * fname); bool load_imatrix(const char * file_name); private: std::unordered_map m_stats; common_params m_params; std::mutex m_mutex; + std::vector m_datasets; int32_t m_last_chunk = 0; std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id @@ -88,6 +96,8 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const struct ggml_tensor * src1 = t->src[1]; std::string wname = filter_tensor_name(src0->name); + const int32_t chunk_size = m_params.n_ctx / m_params.n_parallel; + // when ask is true, the scheduler wants to know if we are interested in data from this tensor // if we return true, a follow-up call will be made with ask=false in which we can do the actual collection if (ask) { @@ -175,7 +185,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } } } - const int32_t n_chunk = e.counts[ex] / (m_params.n_ctx / m_params.n_parallel); + const int32_t n_chunk = e.counts[ex] / chunk_size; if (n_chunk > m_last_chunk) { const int32_t chunk_step = n_chunk - m_last_chunk; m_last_chunk = n_chunk; @@ -214,7 +224,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } } } - const int32_t n_chunk = e.counts[0] / (m_params.n_ctx / m_params.n_parallel); + const int32_t n_chunk = e.counts[0] / chunk_size; if (n_chunk > m_last_chunk) { const int32_t chunk_step = n_chunk - m_last_chunk; m_last_chunk = n_chunk; @@ -230,19 +240,19 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * return true; } -void IMatrixCollector::save_imatrix(int32_t n_chunk) const { +void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const { auto fname = m_params.out_file; - if (n_chunk > 0) { + if (ncall > 0) { fname += ".at_"; - fname += std::to_string(n_chunk); + fname += std::to_string(ncall); } // avoid writing imatrix entries that do not have full data // this can happen with MoE models where some of the experts end up not being exercised by the provided training data + int n_entries = 0; std::vector to_store; - size_t data_size = 0; bool is_first = true; // for printing for (const auto & kv : m_stats) { @@ -274,9 +284,8 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { continue; } + n_entries++; to_store.push_back(kv.first); - data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN); - data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN); } if (to_store.size() < m_stats.size()) { @@ -286,6 +295,79 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { // deterministic tensor name order std::sort(to_store.begin(), to_store.end()); + const int32_t chunk_size = m_params.n_ctx / m_params.n_parallel; + + std::ofstream out(fname, std::ios::binary); + out.write((const char *) &n_entries, sizeof(n_entries)); + for (const auto & name : to_store) { + const auto & stat = m_stats.at(name); + const int32_t len = name.size(); + out.write((const char *) &len, sizeof(len)); + out.write(name.c_str(), len); + const int32_t ncall = *std::max_element(stat.counts.begin(), stat.counts.end()) / chunk_size; + out.write((const char *) &ncall, sizeof(ncall)); + const int32_t nval = stat.values.size(); + const int32_t nmat = stat.counts.size(); + out.write((const char *) &nval, sizeof(nval)); + if (nval > 0 && nmat > 0) { + std::vector tmp(nval); + for (int32_t i = 0; i < nval; i++) { + const float counts = static_cast(stat.counts[i / (nval / nmat)]); + tmp[i] = (stat.values[i] / counts) * static_cast(ncall); + } + out.write((const char *) tmp.data(), nval * sizeof(float)); + } + } + + // Write the number of call the matrix was computed with + out.write((const char *) &m_last_chunk, sizeof(m_last_chunk)); + + // Write the input filename at the end of the file to later on specify it in quantize + { + const char * dataset_file = m_params.prompt_file.c_str(); + int32_t len = m_params.prompt_file.size(); + // When there is no prompt but there were other imatrix files loaded, use the last dataset + if (m_params.prompt_file.empty() && !m_datasets.empty()) { + const std::string & dataset_str = m_datasets[m_datasets.size() - 1]; + dataset_file = dataset_str.c_str(); + len = dataset_str.size(); + } + out.write((const char *) &len, sizeof(len)); + out.write(dataset_file, len); + } + + LOGV(1, "\n"); + LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str()); +} + +void IMatrixCollector::save_imatrix(int32_t n_chunk) const { + auto fname = m_params.out_file; + + // TODO: use the new format by default also for .imatrix + if (!str_has_suffix(fname, ".gguf")) { + return this->save_imatrix_legacy(n_chunk); + } + + if (n_chunk > 0) { + fname += ".at_"; + fname += std::to_string(n_chunk); + } + + // write imatrix entries even if they don't have full data. (can be corrected when reading) + // this can happen with MoE models where some of the experts end up not being exercised by the provided training data + + std::vector to_store; + size_t data_size = 0; + + for (const auto & kv : m_stats) { + to_store.push_back(kv.first); + data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN); + data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN); + } + + // deterministic tensor name order + std::sort(to_store.begin(), to_store.end()); + struct ggml_init_params params = { /* .mem_size = */ data_size, /* .mem_buffer = */ NULL, @@ -294,31 +376,42 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { struct ggml_context * ctx = ggml_init(params); struct gguf_context * ctx_gguf = gguf_init_empty(); - gguf_set_val_str(ctx_gguf, "general.type", "imatrix"); - // Write the input filename to later on specify it in quantize - gguf_set_val_str(ctx_gguf, LLM_KV_IMATRIX_DATASET, m_params.prompt_file.c_str()); - // Write the number of chunks the matrix was computed with - gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT, m_last_chunk); - gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE, m_params.n_ctx / m_params.n_parallel); + { + std::vector datasets; + datasets.reserve(m_datasets.size() + 1); + for (size_t i = 0; i < m_datasets.size(); ++i) { + datasets.push_back(m_datasets[i].c_str()); + } + if (!m_params.prompt_file.empty()) { + datasets.push_back(m_params.prompt_file.c_str()); + } + + gguf_set_val_str(ctx_gguf, "general.type", "imatrix"); + // Write the dataset paths + gguf_set_arr_str(ctx_gguf, LLM_KV_IMATRIX_DATASETS, datasets.data(), datasets.size()); + // Write the number of chunks the matrix was computed with + gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT, m_last_chunk); + gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE, m_params.n_ctx / m_params.n_parallel); + } for (const auto & name : to_store) { const auto & stat = m_stats.at(name); const int32_t nval = (int32_t) stat.values.size(); const int32_t nmat = (int32_t) stat.counts.size(); - if (nval > 0) { - struct ggml_tensor * sums = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nval / nmat, nmat); - struct ggml_tensor * counts = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, nmat); - ggml_format_name(sums, "%s.sums", name.c_str()); + if (nval > 0 && nmat > 0) { + struct ggml_tensor * in_sum2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nval / nmat, nmat); + struct ggml_tensor * counts = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, nmat); + ggml_format_name(in_sum2, "%s.in_sum2", name.c_str()); ggml_format_name(counts, "%s.counts", name.c_str()); for (int32_t j = 0; j < nval; ++j) { - ((float *) sums->data)[j] = (float) stat.values[j]; + ((float *) in_sum2->data)[j] = (float) stat.values[j]; } for (int32_t j = 0; j < nmat; ++j) { ((float *) counts->data)[j] = (float) stat.counts[j]; } - gguf_add_tensor(ctx_gguf, sums); + gguf_add_tensor(ctx_gguf, in_sum2); gguf_add_tensor(ctx_gguf, counts); } } @@ -332,6 +425,105 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { ggml_free(ctx); } +bool IMatrixCollector::load_imatrix_legacy(const char * fname) { + std::ifstream in(fname, std::ios::binary); + if (!in) { + LOG_ERR("%s: failed to open %s\n", __func__, fname); + return false; + } + int n_entries; + in.read((char *) &n_entries, sizeof(n_entries)); + if (in.fail() || n_entries < 1) { + LOG_ERR("%s: no data in file %s\n", __func__, fname); + return false; + } + // Guess the chunk size because it's not stored in the file + const int32_t chunk_size = m_params.n_ctx / m_params.n_parallel; + + for (int i = 0; i < n_entries; ++i) { + int32_t len = 0; + in.read((char *) &len, sizeof(len)); + std::vector name_as_vec(len + 1); + in.read((char *) name_as_vec.data(), len); + if (in.fail()) { + LOG_ERR("%s: failed reading name for entry %d from %s\n", __func__, i + 1, fname); + return false; + } + name_as_vec[len] = 0; + std::string name{ name_as_vec.data() }; + auto & e = m_stats[std::move(name)]; + int32_t ncall = 0; + in.read((char *) &ncall, sizeof(ncall)); + int32_t nval = 0; + in.read((char *) &nval, sizeof(nval)); + if (in.fail() || nval < 1) { + LOG_ERR("%s: failed reading number of values for entry %d\n", __func__, i); + m_stats = {}; + return false; + } + + if (e.values.empty()) { + e.values.resize(nval, 0.0f); + e.counts.resize(1, 0); + } + + std::vector tmp(nval); + in.read((char *) tmp.data(), nval * sizeof(float)); + if (in.fail()) { + LOG_ERR("%s: failed reading data for entry %d\n", __func__, i); + m_stats = {}; + return false; + } + + // Recreate the state as expected by save_imatrix(), and correct for weighted sum. + for (int i = 0; i < nval; i++) { + e.values[i] += tmp[i] * chunk_size; + } + // The legacy format doesn't distinguish the counts for different experts + for (size_t j = 0; j < e.counts.size(); ++j) { + e.counts[j] += ncall * chunk_size; + } + } + + { + // TODO: extract into its own method; this is also used by the GGUF-based format + // Calculate the last chunk count + int64_t max_count = 0; + for (const auto & stats : m_stats) { + for (int64_t count : stats.second.counts) { + if (count > max_count) { + max_count = count; + } + } + } + m_last_chunk = max_count / (chunk_size); + } + + { + // Read the number of calls the matrix was computed with + int32_t n_calls; + in.read((char *) &n_calls, sizeof(n_calls)); + // ignore it because it's not important + } + + // Read the dataset path to include it when writing to GGUF + if (!in.fail()){ + int32_t len = 0; + in.read((char *) &len, sizeof(len)); + if (!in.fail()) { + std::vector dataset; + dataset.resize(len + 1, 0); + in.read(dataset.data(), len); + if (!in.fail()) { + m_datasets.push_back(dataset.data()); + } + } + } + + return true; +} + +// Using GGUF as the file format, for greater extensibility bool IMatrixCollector::load_imatrix(const char * file_name) { struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -340,7 +532,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { }; struct gguf_context * ctx_gguf = gguf_init_from_file(file_name, meta_gguf_params); if (!ctx_gguf) { - return false; + return this->load_imatrix_legacy(file_name); } const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); if (n_entries < 1) { @@ -350,8 +542,17 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { return false; } - const std::string sums_suffix{".sums"}; - const std::string counts_suffix{".counts"}; + const int64_t datasets_key = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASETS); + if (datasets_key != -1 && gguf_get_arr_type(ctx_gguf, datasets_key) == GGUF_TYPE_STRING) { + const int64_t n = gguf_get_arr_n(ctx_gguf, datasets_key); + m_datasets.reserve(m_datasets.size() + n); + for (int64_t i = 0; i < n; ++i) { + m_datasets.push_back(gguf_get_arr_str(ctx_gguf, datasets_key, i)); + } + } + + const std::string in_sum2_suffix{ ".in_sum2" }; + const std::string counts_suffix{ ".counts" }; // Could re-use m_stats instead, but this allows // checking for completeness of *each* loaded imatrix file @@ -364,26 +565,23 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { if (name.empty()) { continue; } - if (str_remove_suffix(name, sums_suffix)) { - // sums - sums_counts_for[name].first = cur; + if (str_remove_suffix(name, in_sum2_suffix)) { + // in_sum2 + sums_counts_for[std::move(name)].first = cur; } else if (str_remove_suffix(name, counts_suffix)) { // counts - sums_counts_for[name].second = cur; + sums_counts_for[std::move(name)].second = cur; } else { - LOG_ERR("%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); - return false; + // ignore other tensors } } for (const auto & sc : sums_counts_for) { - const std::string & name = sc.first; - const struct ggml_tensor * sums = sc.second.first; - const struct ggml_tensor * counts = sc.second.second; + const std::string & name = sc.first; + const struct ggml_tensor * in_sum2 = sc.second.first; + const struct ggml_tensor * counts = sc.second.second; - if (!sums || !counts) { + if (!in_sum2 || !counts) { LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); @@ -392,9 +590,9 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { auto & e = m_stats[name]; - int64_t nval = ggml_nelements(sums); + int64_t nval = ggml_nelements(in_sum2); if (e.values.empty()) { - e.values.resize(nval, 0); + e.values.resize(nval, 0.0f); } else if ((size_t) nval != e.values.size()) { LOG_ERR("%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size()); gguf_free(ctx_gguf); @@ -417,12 +615,25 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { // Recreate the state as expected by save_imatrix() for (int64_t j = 0; j < nval; j++) { - e.values[j] += ((const float *) sums->data)[j]; + e.values[j] += ((const float *) in_sum2->data)[j]; } for (int64_t j = 0; j < ncounts; j++) { e.counts[j] += std::lround(((const float *) counts->data)[j]); } } + + // TODO: extract into its own method; this is also used by the legacy format + // Calculate the last chunk count + int64_t max_count = 0; + for (const auto & stats : m_stats) { + for (int64_t count : stats.second.counts) { + if (count > max_count) { + max_count = count; + } + } + } + m_last_chunk = max_count / (m_params.n_ctx / m_params.n_parallel); + gguf_free(ctx_gguf); ggml_free(ctx); return true; @@ -685,7 +896,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c int main(int argc, char ** argv) { common_params params; - params.out_file = "imatrix.dat" ; + params.out_file = "imatrix.gguf" ; params.n_ctx = 512; params.logits_all = true; diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7f2afe6575677..1a37cf316f4de 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -64,7 +64,7 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; // TODO: share with imatrix.cpp -static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; +static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets"; static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; @@ -84,7 +84,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp for (auto ch : ftype_str_in) { ftype_str.push_back(std::toupper(ch)); } - for (auto & it : QUANT_OPTIONS) { + for (const auto & it : QUANT_OPTIONS) { if (striequals(it.name.c_str(), ftype_str.c_str())) { ftype = it.ftype; ftype_str_out = it.name; @@ -93,7 +93,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp } try { int ftype_int = std::stoi(ftype_str); - for (auto & it : QUANT_OPTIONS) { + for (const auto & it : QUANT_OPTIONS) { if (it.ftype == ftype_int) { ftype = it.ftype; ftype_str_out = it.name; @@ -126,7 +126,7 @@ static void usage(const char * executable) { printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); printf("Note: --include-weights and --exclude-weights cannot be used together\n"); printf("\nAllowed quantization types:\n"); - for (auto & it : QUANT_OPTIONS) { + for (const auto & it : QUANT_OPTIONS) { if (it.name != "COPY") { printf(" %2d or ", it.ftype); } else { @@ -146,7 +146,71 @@ static bool str_remove_suffix(std::string & str, const std::string & suffix) { return has_suffix; } -static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map> & imatrix_data) { +static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { + std::ifstream in(imatrix_file.c_str(), std::ios::binary); + if (!in) { + printf("%s: failed to open %s\n",__func__, imatrix_file.c_str()); + exit(1); + } + int n_entries; + in.read((char *)&n_entries, sizeof(n_entries)); + if (in.fail() || n_entries < 1) { + printf("%s: no data in file %s\n", __func__, imatrix_file.c_str()); + exit(1); + } + for (int i = 0; i < n_entries; ++i) { + int len; in.read((char *)&len, sizeof(len)); + std::vector name_as_vec(len+1); + in.read((char *)name_as_vec.data(), len); + if (in.fail()) { + printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str()); + exit(1); + } + name_as_vec[len] = 0; + std::string name{name_as_vec.data()}; + auto & e = imatrix_data[name]; + int ncall; + in.read((char *)&ncall, sizeof(ncall)); + int nval; + in.read((char *)&nval, sizeof(nval)); + if (in.fail() || nval < 1) { + printf("%s: failed reading number of values for entry %d\n", __func__, i); + imatrix_data = {}; + exit(1); + } + e.resize(nval); + in.read((char *)e.data(), nval*sizeof(float)); + if (in.fail()) { + printf("%s: failed reading data for entry %d\n", __func__, i); + imatrix_data = {}; + exit(1); + } + if (ncall > 0) { + for (auto& v : e) v /= ncall; + } + + if (getenv("LLAMA_TRACE")) { + printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str()); + } + } + + // latest imatrix version contains the dataset filename at the end of the file + int m_last_call = 0; + if (in.peek() != EOF) { + in.read((char *)&m_last_call, sizeof(m_last_call)); + int dataset_len; + in.read((char *)&dataset_len, sizeof(dataset_len)); + std::vector dataset_as_vec(dataset_len); + in.read(dataset_as_vec.data(), dataset_len); + imatrix_datasets.resize(1); + imatrix_datasets[0].assign(dataset_as_vec.begin(), dataset_as_vec.end()); + printf("%s: imatrix dataset='%s'\n", __func__, imatrix_datasets[0].c_str()); + } + printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call); + return m_last_call; +} + +static int load_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -155,8 +219,8 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ }; struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params); if (!ctx_gguf) { - fprintf(stderr, "%s: if this is an older imatrix file, make sure to convert it to the GGUF-based imatrix format\n", __func__); - exit(1); + fprintf(stderr, "%s: imatrix file '%s' is using old format\n", __func__, imatrix_file.c_str()); + return load_legacy_imatrix(imatrix_file, imatrix_datasets, imatrix_data); } const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); if (n_entries < 1) { @@ -166,7 +230,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ exit(1); } - const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASET); + const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASETS); const int chunk_count_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT); const int chunk_size_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE); if (dataset_idx < 0 || chunk_count_idx < 0 || chunk_size_idx < 0) { @@ -178,8 +242,8 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ const uint32_t chunk_size = gguf_get_val_u32(ctx_gguf, chunk_size_idx); - const std::string sums_suffix{".sums"}; - const std::string counts_suffix{".counts"}; + const std::string sums_suffix{ ".in_sum2" }; + const std::string counts_suffix{ ".counts" }; // Using an ordered map to get a deterministic iteration order. std::map> sums_counts_for; @@ -190,16 +254,13 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ if (name.empty()) { continue; } if (str_remove_suffix(name, sums_suffix)) { - // sums - sums_counts_for[name].first = cur; + // in_sum2 + sums_counts_for[std::move(name)].first = cur; } else if (str_remove_suffix(name, counts_suffix)) { // counts - sums_counts_for[name].second = cur; + sums_counts_for[std::move(name)].second = cur; } else { - fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); - exit(1); + // ignore other tensors } } @@ -223,8 +284,15 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ float max_count = 0.0f; for (int64_t j = 0; j < ne1; ++j) { const float count = ((const float *) counts->data)[j]; - for (int64_t i = 0; i < ne0; ++i) { - e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; + if (count > 0.0f) { + for (int64_t i = 0; i < ne0; ++i) { + e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; + } + } else { + // Partial imatrix data, this tensor never got any input during calibration + for (int64_t i = 0; i < ne0; ++i) { + e[j*ne0 + i] = 1; + } } if (count > max_count) { max_count = count; @@ -236,9 +304,18 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ } int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx); - imatrix_dataset = gguf_get_val_str(ctx_gguf, dataset_idx); - printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); + int64_t n_datasets = gguf_get_arr_n(ctx_gguf, dataset_idx); + imatrix_datasets.resize(n_datasets); + for (int64_t i = 0; i < n_datasets; ++i) { + imatrix_datasets.push_back(gguf_get_val_str(ctx_gguf, dataset_idx)); + } + printf("%s: imatrix datasets=['%s'", __func__, imatrix_datasets[0].c_str()); + for (size_t i = 1; i < imatrix_datasets.size(); ++i) { + printf(", '%s'", imatrix_datasets[i].c_str()); + } + printf("]\n"); + printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk); gguf_free(ctx_gguf); @@ -248,7 +325,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_ } static int prepare_imatrix(const std::string & imatrix_file, - std::string & imatrix_dataset, + std::vector & imatrix_dataset, const std::vector & included_weights, const std::vector & excluded_weights, std::unordered_map> & imatrix_data) { @@ -260,18 +337,21 @@ static int prepare_imatrix(const std::string & imatrix_file, return m_last_call; } if (!excluded_weights.empty()) { - for (auto& name : excluded_weights) { - for (auto it = imatrix_data.begin(); it != imatrix_data.end(); ) { + for (const auto & name : excluded_weights) { + for (auto it = imatrix_data.begin(); it != imatrix_data.end();) { auto pos = it->first.find(name); - if (pos != std::string::npos) it = imatrix_data.erase(it); - else ++it; + if (pos != std::string::npos) { + it = imatrix_data.erase(it); + } else { + ++it; + } } } } if (!included_weights.empty()) { std::unordered_map> tmp; - for (auto& name : included_weights) { - for (auto& e : imatrix_data) { + for (const auto & name : included_weights) { + for (auto & e : imatrix_data) { auto pos = e.first.find(name); if (pos != std::string::npos) { tmp.emplace(std::move(e)); @@ -372,9 +452,9 @@ int main(int argc, char ** argv) { usage(argv[0]); } - std::string imatrix_dataset; + std::vector imatrix_datasets; std::unordered_map> imatrix_data; - int m_last_call = prepare_imatrix(imatrix_file, imatrix_dataset, included_weights, excluded_weights, imatrix_data); + int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data); if (!imatrix_data.empty()) { params.imatrix = &imatrix_data; { @@ -385,11 +465,12 @@ int main(int argc, char ** argv) { kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); } - if (!imatrix_dataset.empty()) { + if (!imatrix_datasets.empty()) { llama_model_kv_override kvo; + // TODO: list multiple datasets when there are more than one std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + strncpy(kvo.val_str, imatrix_datasets[0].c_str(), 127); kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); } From 635f945ed12ae95d55c733f81e8d96e4802a2a93 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 15 Apr 2025 17:42:26 -0400 Subject: [PATCH 15/25] convert : remove imatrix to gguf python script --- convert_legacy_imatrix_to_gguf.py | 122 ------------------ gguf-py/gguf/constants.py | 2 +- requirements.txt | 1 - ...rements-convert_legacy_imatrix_to_gguf.txt | 1 - 4 files changed, 1 insertion(+), 125 deletions(-) delete mode 100644 convert_legacy_imatrix_to_gguf.py delete mode 100644 requirements/requirements-convert_legacy_imatrix_to_gguf.txt diff --git a/convert_legacy_imatrix_to_gguf.py b/convert_legacy_imatrix_to_gguf.py deleted file mode 100644 index bd72655bf2cc7..0000000000000 --- a/convert_legacy_imatrix_to_gguf.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python3 - -from __future__ import annotations - -import os -import sys -import logging -import argparse - -from typing import Any -from pathlib import Path -from dataclasses import dataclass - -import numpy as np -import numpy.typing as npt - -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) -import gguf - - -logger = logging.getLogger("imatrix-to-gguf") - - -class IMatrixWriter(gguf.GGUFWriter): - def add_architecture(self) -> None: - # no arch is stored in imatrix files - pass - - -@dataclass -class IMatrixEntry: - values: np.ndarray[Any, np.dtype[np.float32]] - counts: np.ndarray[Any, np.dtype[np.float32]] - - -class IMatrixReader: - chunk_size: int = 512 # guess - offset: int = 0 - data: np.ndarray[Any, np.dtype[np.uint8]] - n_enties: int - entries: dict[str, IMatrixEntry] - chunk_count: int - dataset: str - - def _get(self, dtype: npt.DTypeLike, count: int = 1) -> npt.NDArray[Any]: - count = int(count) - itemsize = int(np.empty([], dtype=dtype).itemsize) - offset = self.offset - self.offset = offset + itemsize * count - return self.data[offset:self.offset].view(dtype=dtype)[:count] - - def __init__(self, imatrix: Path): - self.offset = 0 - self.entries = {} - self.data = np.memmap(imatrix) - n_entries = self._get(np.int32).item() - assert n_entries >= 0 - for _ in range(n_entries): - len = self._get(np.int32).item() - name = self._get(np.uint8, len).tobytes().decode("utf-8") - ncall = self._get(np.int32).item() - nval = self._get(np.int32).item() - data = self._get(np.float32, nval) - assert name not in self.entries, f"duplicated name: {name!r}" - - self.entries[name] = IMatrixEntry(data * np.float32(self.chunk_size), np.array([ncall * self.chunk_size], dtype=np.float32)) - - self.chunk_count = self._get(np.int32).item() - dataset_len = self._get(np.int32).item() - self.dataset = self._get(np.uint8, dataset_len).tobytes().decode("utf-8") - - def to_writer(self, outfile: Path) -> IMatrixWriter: - writer = IMatrixWriter(path=outfile, arch="") - - writer.add_type(gguf.GGUFType.IMATRIX) - writer.add_key_value(gguf.Keys.IMatrix.CHUNK_COUNT, self.chunk_count, gguf.GGUFValueType.UINT32) - writer.add_key_value(gguf.Keys.IMatrix.CHUNK_SIZE, self.chunk_size, gguf.GGUFValueType.UINT32) - writer.add_key_value(gguf.Keys.IMatrix.DATASET, self.dataset, gguf.GGUFValueType.STRING) - - for name, entry in self.entries.items(): - writer.add_tensor(name + ".sums", entry.values) - writer.add_tensor(name + ".counts", entry.counts) - - return writer - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Convert an old imatrix.dat file to a GGUF compatible file") - parser.add_argument( - "--outfile", type=Path, - help="path to write to; default: based on input.", - ) - parser.add_argument( - "--verbose", action="store_true", - help="increase output verbosity", - ) - parser.add_argument( - "imatrix", type=Path, - help="path to an imatrix file", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - - if args.outfile is None: - input_file: Path = args.imatrix - if input_file.suffix != ".gguf": - args.outfile = input_file.with_suffix(".gguf") - if args.outfile.exists(): - logger.error(f"default file exists, specify with --outfile to overwrite: {args.outfile}") - exit(1) - - writer = IMatrixReader(args.imatrix).to_writer(args.outfile) - - writer.write_header_to_file(args.outfile) - writer.write_kv_data_to_file() - writer.write_tensors_to_file() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5b4b0e9edb48a..c2dbf7b643dc5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -219,7 +219,7 @@ class Adapter: class IMatrix: CHUNK_COUNT = "imatrix.chunk_count" CHUNK_SIZE = "imatrix.chunk_size" - DATASET = "imatrix.dataset" + DATASETS = "imatrix.datasets" # diff --git a/requirements.txt b/requirements.txt index cf3116c6cb6ac..f2a18d62879b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ -r ./requirements/requirements-convert_hf_to_gguf.txt -r ./requirements/requirements-convert_hf_to_gguf_update.txt --r ./requirements/requirements-convert_legacy_imatrix_to_gguf.txt -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt -r ./requirements/requirements-convert_lora_to_gguf.txt -r ./requirements/requirements-tool_bench.txt diff --git a/requirements/requirements-convert_legacy_imatrix_to_gguf.txt b/requirements/requirements-convert_legacy_imatrix_to_gguf.txt deleted file mode 100644 index afe2747d448d4..0000000000000 --- a/requirements/requirements-convert_legacy_imatrix_to_gguf.txt +++ /dev/null @@ -1 +0,0 @@ --r ./requirements-convert_legacy_llama.txt From 1d19025909ae3abbc26c50bb8795c2f351fe4ba1 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 15 Apr 2025 17:48:06 -0400 Subject: [PATCH 16/25] imatrix : use the function name in more error messages --- examples/imatrix/imatrix.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index f49bf9ec41e18..fdbc97a5e1513 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -154,7 +154,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * exit(1); //GGML_ABORT("fatal error"); } else if (e.counts.size() != (size_t)n_as) { - LOG_ERR("Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), (int)n_as); + LOG_ERR("%s: inconsistent expert count for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.counts.size(), (int)n_as); exit(1); //GGML_ABORT("fatal error"); } LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); @@ -208,7 +208,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * exit(1); //GGML_ABORT("fatal error"); } else if (e.counts.size() != 1) { - LOG_ERR("Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), 1); + LOG_ERR("%s: inconsistent expert count for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.counts.size(), 1); exit(1); //GGML_ABORT("fatal error"); } LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); @@ -819,7 +819,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c // (not possible when this skips FFN computation of the last layer) common_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true); } - + // restore the original token in case it was set to BOS tokens[seq_start] = token_org; } @@ -896,7 +896,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c int main(int argc, char ** argv) { common_params params; - params.out_file = "imatrix.gguf" ; + params.out_file = "imatrix.gguf"; params.n_ctx = 512; params.logits_all = true; From ba6f6be6ce9bfde9ff763811a462f75b146cb825 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 18 Jun 2025 16:33:37 -0400 Subject: [PATCH 17/25] imatrix : don't use FMA explicitly This should make comparisons between the formats easier because this matches the behavior of the previous version. --- tools/imatrix/imatrix.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 4250f507ce802..1e640027aa251 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -180,7 +180,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts[ex]++; for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[e_start + j] = std::fma(x[j], x[j], e.values[e_start + j]); + e.values[e_start + j] += x[j] * x[j]; if (!std::isfinite((float)e.values[e_start + j])) { LOG_ERR("%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); exit(1); @@ -220,7 +220,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const float * x = (const float *) (data + row * src1->nb[1]); e.counts[0]++; for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[j] = std::fma(x[j], x[j], e.values[j]); + e.values[j] += x[j] * x[j]; if (!std::isfinite((float)e.values[j])) { LOG_ERR("%f detected in %s\n", (float)e.values[j], wname.c_str()); exit(1); From 1a9454a3d23564e59b411372d863de899387b70c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 18 Jun 2025 16:44:41 -0400 Subject: [PATCH 18/25] imatrix : avoid returning from void function save_imatrix --- tools/imatrix/imatrix.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 1e640027aa251..540687370e06f 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -348,7 +348,8 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { // TODO: use the new format by default also for .imatrix if (!str_has_suffix(fname, ".gguf")) { - return this->save_imatrix_legacy(n_chunk); + this->save_imatrix_legacy(n_chunk); + return; } if (n_chunk > 0) { From 43cd2b3eb58f0eb832472579dea1a097eb12fd7d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 23 Jun 2025 11:50:54 -0400 Subject: [PATCH 19/25] imatrix : support 3d tensors with MUL_MAT --- tools/imatrix/imatrix.cpp | 85 ++++++++++++++++++++++--------------- tools/quantize/quantize.cpp | 4 +- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 540687370e06f..37152070d887c 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -4,6 +4,7 @@ #include "llama.h" #include "gguf.h" +#include #include #include #include @@ -15,7 +16,6 @@ #include #include #include -#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -124,14 +124,21 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const char * data = is_host ? (const char *) src1->data : m_src1_data.data(); GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + // TODO: 4d? (is that even used in practice?) + // the extra dimension would need to be stored somewhere to be reflected in the imatrix file + if (ggml_nrows(src1) != src1->ne[1] * src1->ne[2]) { + LOG_ERR("%s: tensor has more than 3 dimensions: %s", __func__, wname.c_str()); + GGML_ASSERT(false); + } + // this has been adapted to the new format of storing merged experts in a single 3d tensor // ref: https://github.com/ggml-org/llama.cpp/pull/6387 if (t->op == GGML_OP_MUL_MAT_ID) { // ids -> [n_experts_used, n_tokens] // src1 -> [cols, n_expert_used, n_tokens] const ggml_tensor * ids = t->src[2]; - const int n_as = src0->ne[2]; - const int n_ids = ids->ne[0]; + const int64_t n_as = src0->ne[2]; + const int64_t n_ids = ids->ne[0]; // the top-k selected expert ids are stored in the ids tensor // for simplicity, always copy ids to host, because it is small @@ -153,7 +160,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts.resize(n_as, 0); } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { - LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); + LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)(src1->ne[0]*n_as)); exit(1); //GGML_ABORT("fatal error"); } else if (e.counts.size() != (size_t)n_as) { @@ -162,11 +169,11 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); // loop over all possible experts, regardless if they are used or not in the batch - for (int ex = 0; ex < n_as; ++ex) { + for (int64_t ex = 0; ex < n_as; ++ex) { size_t e_start = ex*src1->ne[0]; - for (int idx = 0; idx < n_ids; ++idx) { - for (int row = 0; row < (int)src1->ne[2]; ++row) { + for (int64_t idx = 0; idx < n_ids; ++idx) { + for (int64_t row = 0; row < src1->ne[2]; ++row) { const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]); GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check @@ -179,7 +186,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts[ex]++; - for (int j = 0; j < (int)src1->ne[0]; ++j) { + for (int64_t j = 0; j < src1->ne[0]; ++j) { e.values[e_start + j] += x[j] * x[j]; if (!std::isfinite((float)e.values[e_start + j])) { LOG_ERR("%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); @@ -202,40 +209,48 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } } else { auto & e = m_stats[wname]; + const int64_t n_mat = src1->ne[2] * src1->ne[3]; + if (e.values.empty()) { - e.values.resize(src1->ne[0], 0); - e.counts.resize(1, 0); + e.values.resize(src1->ne[0] * n_mat, 0); + e.counts.resize(n_mat, 0); } - else if (e.values.size() != (size_t)src1->ne[0]) { - LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); + else if (e.values.size() != (size_t)(src1->ne[0] * n_mat)) { + LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)(src1->ne[0] * n_mat)); exit(1); //GGML_ABORT("fatal error"); } - else if (e.counts.size() != 1) { - LOG_ERR("%s: inconsistent expert count for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.counts.size(), 1); + else if (e.counts.size() != (size_t)n_mat) { + LOG_ERR("%s: inconsistent expert count for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.counts.size(), (int)n_mat); exit(1); //GGML_ABORT("fatal error"); } - LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); - // TODO: higher dimensions - for (int row = 0; row < (int)src1->ne[1]; ++row) { - const float * x = (const float *) (data + row * src1->nb[1]); - e.counts[0]++; - for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[j] += x[j] * x[j]; - if (!std::isfinite((float)e.values[j])) { - LOG_ERR("%f detected in %s\n", (float)e.values[j], wname.c_str()); - exit(1); + LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->ne[2], (int)src1->type); + for (int64_t i3 = 0; i3 < src1->ne[3]; ++i3) { + for (int64_t i2 = 0; i2 < src1->ne[2]; ++i2) { + const int64_t mat_id = i3 * src1->ne[2] + i2; + const int64_t mat_start = mat_id * src1->ne[0]; + + for (int64_t row = 0; row < src1->ne[1]; ++row) { + const float * x = (const float *) (data + row * src1->nb[1] + i2 * src1->nb[2] + i3 * src1->ne[3]); + e.counts[mat_id]++; + for (int64_t j = 0; j < src1->ne[0]; ++j) { + e.values[mat_start + j] += x[j] * x[j]; + if (!std::isfinite((float)e.values[j])) { + LOG_ERR("%f detected in %s\n", (float)e.values[j], wname.c_str()); + exit(1); + } + } + } + const int32_t n_chunk = e.counts[mat_id] / chunk_size; + if (n_chunk > m_last_chunk) { + const int32_t chunk_step = n_chunk - m_last_chunk; + m_last_chunk = n_chunk; + if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) { + save_imatrix(); + } + if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) { + save_imatrix(m_last_chunk); + } } - } - } - const int32_t n_chunk = e.counts[0] / chunk_size; - if (n_chunk > m_last_chunk) { - const int32_t chunk_step = n_chunk - m_last_chunk; - m_last_chunk = n_chunk; - if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) { - save_imatrix(); - } - if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) { - save_imatrix(m_last_chunk); } } } diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index c7eaf892e063a..46a2cff96c30d 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -196,7 +196,9 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector 0) { - for (auto& v : e) v /= ncall; + for (auto & v : e) { + v /= ncall; + } } if (getenv("LLAMA_TRACE")) { From 0e7935507587a4025cbd67e3bc3ef16dfa92bbe1 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 23 Jun 2025 12:43:25 -0400 Subject: [PATCH 20/25] quantize : fix dataset name loading from gguf imatrix --- tools/quantize/quantize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 46a2cff96c30d..e974978ef3f97 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -206,7 +206,7 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector Date: Mon, 23 Jun 2025 16:22:27 -0400 Subject: [PATCH 21/25] common : move string_remove_suffix from quantize and imatrix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- common/common.cpp | 9 +++++++++ common/common.h | 1 + tools/imatrix/imatrix.cpp | 18 +++--------------- tools/quantize/quantize.cpp | 13 ++----------- 4 files changed, 15 insertions(+), 26 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index e4e71ad13fb59..66462f5d93063 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -448,6 +448,15 @@ void string_replace_all(std::string & s, const std::string & search, const std:: bool string_ends_with(const std::string_view & str, const std::string_view & suffix) { return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; } + +bool string_remove_suffix(std::string & str, const std::string_view & suffix) { + bool has_suffix = string_ends_with(str, suffix); + if (has_suffix) { + str = str.substr(0, str.size() - suffix.size()); + } + return has_suffix; +} + size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) { if (!str.empty() && !stop.empty()) { const char text_last_char = str.back(); diff --git a/common/common.h b/common/common.h index e08a59eae7543..d6e5411e107b6 100644 --- a/common/common.h +++ b/common/common.h @@ -518,6 +518,7 @@ static bool string_starts_with(const std::string & str, // While we wait for C++20's std::string::ends_with... bool string_ends_with(const std::string_view & str, const std::string_view & suffix); +bool string_remove_suffix(std::string & str, const std::string_view & suffix); size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop); bool string_parse_kv_override(const char * data, std::vector & overrides); diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 37152070d887c..1a16ae2c080a1 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -31,18 +31,6 @@ static void print_usage(int, char ** argv) { LOG("\n"); } -static bool str_has_suffix(const std::string & str, const std::string & suffix) { - return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; -} - -static bool str_remove_suffix(std::string & str, const std::string & suffix) { - bool has_suffix = str_has_suffix(str, suffix); - if (has_suffix) { - str = str.substr(0, str.size() - suffix.size()); - } - return has_suffix; -} - static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets"; static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; @@ -362,7 +350,7 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { auto fname = m_params.out_file; // TODO: use the new format by default also for .imatrix - if (!str_has_suffix(fname, ".gguf")) { + if (!string_ends_with(fname, ".gguf")) { this->save_imatrix_legacy(n_chunk); return; } @@ -584,10 +572,10 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { if (name.empty()) { continue; } - if (str_remove_suffix(name, in_sum2_suffix)) { + if (string_remove_suffix(name, in_sum2_suffix)) { // in_sum2 sums_counts_for[std::move(name)].first = cur; - } else if (str_remove_suffix(name, counts_suffix)) { + } else if (string_remove_suffix(name, counts_suffix)) { // counts sums_counts_for[std::move(name)].second = cur; } else { diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 69b03f504a4a8..45c59ecb6fffe 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -147,15 +147,6 @@ static void usage(const char * executable) { exit(1); } -// TODO: share with implementation in imatrix.cpp -static bool str_remove_suffix(std::string & str, const std::string & suffix) { - bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; - if (has_suffix) { - str = str.substr(0, str.size() - suffix.size()); - } - return has_suffix; -} - static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { std::ifstream in(imatrix_file.c_str(), std::ios::binary); if (!in) { @@ -265,10 +256,10 @@ static int load_imatrix(const std::string & imatrix_file, std::vector Date: Sat, 12 Jul 2025 13:42:35 -0400 Subject: [PATCH 22/25] imatrix : add warning when legacy format is written --- tools/imatrix/imatrix.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 1a16ae2c080a1..d98ddce2f6483 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -349,8 +349,9 @@ void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const { void IMatrixCollector::save_imatrix(int32_t n_chunk) const { auto fname = m_params.out_file; - // TODO: use the new format by default also for .imatrix + // TODO: use the new format in more cases if (!string_ends_with(fname, ".gguf")) { + LOG_WRN("\n%s: saving to legacy imatrix format because output suffix is not .gguf\n", __func__); this->save_imatrix_legacy(n_chunk); return; } From 50f53b3e400490b16c6b3c4178c823689b129c0f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Jul 2025 14:09:28 -0400 Subject: [PATCH 23/25] imatrix : warn when writing partial data, to help guess dataset coverage Also make the legacy format store partial data by using neutral values for missing data. This matches what is done at read-time for the new format, and so should get the same quality in case the old format is still used. --- tools/imatrix/imatrix.cpp | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index d98ddce2f6483..b5bc19a169e08 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -254,7 +254,7 @@ void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const { fname += std::to_string(ncall); } - // avoid writing imatrix entries that do not have full data + // warn when writing imatrix entries that do not have full data // this can happen with MoE models where some of the experts end up not being exercised by the provided training data int n_entries = 0; @@ -286,8 +286,7 @@ void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const { } if (n_zeros > 0) { - LOG_WRN("%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); - continue; + LOG_WRN("%s: entry '%40s' has partial data (%.2f%%)\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); } n_entries++; @@ -310,7 +309,8 @@ void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const { const int32_t len = name.size(); out.write((const char *) &len, sizeof(len)); out.write(name.c_str(), len); - const int32_t ncall = *std::max_element(stat.counts.begin(), stat.counts.end()) / chunk_size; + // ceiling division to avoid accidental zeros + const int32_t ncall = (*std::max_element(stat.counts.begin(), stat.counts.end()) + (chunk_size - 1)) / chunk_size; out.write((const char *) &ncall, sizeof(ncall)); const int32_t nval = stat.values.size(); const int32_t nmat = stat.counts.size(); @@ -318,8 +318,14 @@ void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const { if (nval > 0 && nmat > 0) { std::vector tmp(nval); for (int32_t i = 0; i < nval; i++) { - const float counts = static_cast(stat.counts[i / (nval / nmat)]); - tmp[i] = (stat.values[i] / counts) * static_cast(ncall); + float count = static_cast(stat.counts[i / (nval / nmat)]); + float value = stat.values[i]; + if (count == 0.0f) { + // store 1 for partial data + value = 1.0f; + count = 1.0f; + } + tmp[i] = (value / count) * static_cast(ncall); } out.write((const char *) tmp.data(), nval * sizeof(float)); } @@ -367,7 +373,26 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { std::vector to_store; size_t data_size = 0; + bool is_first = true; // for printing for (const auto & kv : m_stats) { + const int n_all = kv.second.counts.size(); + + int n_zeros = 0; + for (const auto c : kv.second.counts) { + if (c == 0) { + n_zeros++; + } + } + + if (n_zeros != 0 && is_first) { + LOG_INF("\n"); + is_first = false; + } + + if (n_zeros > 0) { + LOG_WRN("%s: entry '%40s' has partial data (%.2f%%)\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); + } + to_store.push_back(kv.first); data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN); data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN); From 183eeb55187ac523500814ea34f787df16d08a1e Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Jul 2025 14:54:33 -0400 Subject: [PATCH 24/25] imatrix : avoid loading model to convert or combine imatrix --- tools/imatrix/imatrix.cpp | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index b5bc19a169e08..a1f21d7ee56d1 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -967,9 +967,23 @@ int main(int argc, char ** argv) { } } - if (params.in_files.size() > 1) { - LOG_INF("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str()); + if (params.prompt.empty()) { + LOG_INF("No prompt provided; combining precomputed matrices only.\n"); + + if (params.in_files.empty()) { + LOG_ERR("Error: No prompt provided and no precomputed matrices (--in-file) to combine.\n"); + return 1; + } + + if (params.in_files.size() == 1) { + LOG_INF("%s : saving imatrix to '%s'\n", __func__, params.out_file.c_str()); + } else if (params.in_files.size() > 1) { + LOG_INF("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str()); + } + g_collector.save_imatrix(); + + return 0; } llama_backend_init(); @@ -1004,19 +1018,10 @@ int main(int argc, char ** argv) { LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - if (params.prompt.empty()) { - if (params.in_files.empty()) { - LOG_ERR("Error: No prompt provided and no precomputed matrices (--in-file) to combine.\n"); - return 1; - } - LOG_INF("No prompt provided; combining precomputed matrices only.\n"); - } else { - if (!compute_imatrix(ctx, params, n_ctx)) { - return 1; - } + if (!compute_imatrix(ctx, params, n_ctx)) { + return 1; } - g_collector.save_imatrix(); LOG("\n"); From 942c55cd57ee695d9600584c114cf07de8554031 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Jul 2025 14:56:18 -0400 Subject: [PATCH 25/25] imatrix : avoid using imatrix.dat in README --- tools/imatrix/README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tools/imatrix/README.md b/tools/imatrix/README.md index 6d8897d98bb61..4ce5ca0ca42fb 100644 --- a/tools/imatrix/README.md +++ b/tools/imatrix/README.md @@ -7,14 +7,15 @@ More information is available here: https://github.com/ggml-org/llama.cpp/pull/4 ``` ./llama-imatrix \ - -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \ + -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] \ [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \ - [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...] + [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...] \ + [--parse-special] ``` Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory. The parameters in square brackets are optional and have the following meaning: -* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used. +* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.gguf` is used. * `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`. * `--output-frequency` specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks) * `--save-frequency` specifies how often to save a copy of the imatrix in a separate file. Default is 0 (i.e., never) @@ -25,9 +26,9 @@ For faster computation, make sure to use GPU offloading via the `-ngl` argument ## Example ```bash -# generate importance matrix (imatrix.dat) +# generate importance matrix (imatrix.gguf) ./llama-imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99 # use the imatrix to perform a Q4_K_M quantization -./llama-quantize --imatrix imatrix.dat ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m +./llama-quantize --imatrix imatrix.gguf ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m ```