From 271104c65c9b99d5b5aca4489d7bec103cd60db9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 3 Apr 2024 11:07:16 -0400 Subject: [PATCH 001/117] wip: llama : separate recurrent states from the KV cache This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states. --- llama.cpp | 1386 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 960 insertions(+), 426 deletions(-) diff --git a/llama.cpp b/llama.cpp index 267ac4cc022a1..9ca8ca0f41320 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1793,14 +1793,14 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + uint32_t n_embd_r() const { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s() const { // dimension of the recurrent state embeddings // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -1904,7 +1904,6 @@ struct llama_layer { struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; - int32_t src = 0; // used by recurrent state models to copy states std::set seq_id; @@ -1925,9 +1924,6 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; - bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -1947,9 +1943,365 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * k : k_l) { + size += ggml_nrows(k) * ggml_row_size(k->type, k->ne[0]); + } + for (struct ggml_tensor * v : v_l) { + size += ggml_nrows(v) * ggml_row_size(v->type, v->ne[0]); + } + return size; + } +}; + +// for recurrent models, use a tree of sequences to simplify +// quickly finding the tail cell of each sequence +// TODO: drop the _rs_ infix +struct llama_rs_seq_node { + llama_seq_id seq_id = -1; + int32_t next_cell = -1; + + // needed for automatic typecasting with .find() + llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} + + bool operator<(const llama_rs_seq_node & other) const { + return seq_id < other.seq_id; + } + + bool is_tail() const { + return next_cell < 0; + } +}; + +struct llama_rs_cell { + llama_pos pos = -1; + int32_t src = -1; // copy source id (cleared next when -1) + + // Link to previous cell in this sequence. + // Sequences can only diverge, never converge, + // so this works when there are multiple seq_ids per cell too. + int32_t prev = -1; + + // ref count of tails (should match the number of next_cell == -1 in seq_nodes) + uint32_t tail_rc = 0; + + // seq_ids by insertion order, to simplify updating n_cells compared to a set + std::vector seq_nodes; + + llama_rs_seq_node * get_node(const llama_seq_id & id) { + for (size_t i = 0; i < seq_nodes.size(); ++i) { + if (seq_nodes[i].seq_id == id) { + return &seq_nodes[i]; + } + } + return nullptr; + } + + void insert_node(const llama_rs_seq_node & node) { + llama_rs_seq_node * node_dest = get_node(node.seq_id); + if (node_dest == nullptr) { + seq_nodes.push_back(node); + } else { + *node_dest = node; + } + } + + bool remove_node(llama_rs_seq_node * node_ptr) { + if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) { + size_t offset = node_ptr - seq_nodes.data(); + if (offset % sizeof(llama_rs_seq_node) == 0) { + offset /= sizeof(llama_rs_seq_node); + if (offset < seq_nodes.size()) { + for (size_t i = offset + 1; i < seq_nodes.size(); ++i) { + seq_nodes[i - 1] = seq_nodes[i]; + } + seq_nodes.resize(seq_nodes.size() - 1); + return true; + } + } + } + return false; + } + + bool has_seq_id(const llama_seq_id & id) const { + for (size_t i = 0; i < seq_nodes.size(); ++i) { + if (seq_nodes[i].seq_id == id) { + return true; + } + } + return false; + } + + bool is_empty() const { + return seq_nodes.empty(); + } +}; + + +struct llama_rs_seq_meta { + // cell id of the latest state of this seq_id + int32_t tail = -1; + // number of cells for which this seq_id is the first + // (useful to know if cells in this sequence should be pruned) + int32_t n_cells = 0; + // whether the tail is a cell part of multiple sequences + bool shared = false; +}; + +// ring-buffer of cached recurrent state data +struct llama_rs_cache { + bool do_copy = false; + + uint32_t head = 0; // first state used for the last slot + uint32_t size = 0; + uint32_t used = 0; + + // computed when finding a slot + uint32_t n = 0; // range of states used for the last slot + + // useful to know the minimum reserved cell count per seq_id + // only counts sequences with n_cells > 0 + uint32_t n_seqs = 0; + + // with state models, a cell can hold the state for more than one past token + // TODO: it's probably not possible to always use contiguous cells + std::vector cells; + + // find tail cells faster + std::vector seq_tails; // map seq_ids to cell ids + + // per layer + // NOTE: the naming of r and s is arbitrary + std::vector r_l; // rolling/shift states + std::vector s_l; // ssm (recurrent) states + + // returns whether or not a cell was freed + bool clear_cell(uint32_t i) { + if (i < size) { + llama_rs_cell & rs_cell = cells[i]; + if (!rs_cell.is_empty()) { + // update sequence tree links + bool first = true; + for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: if all next cells are the same cell, this should still work + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + // update tail + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = new_tail.seq_nodes.size() > 1; + } else { + seq.shared = false; + } + } + // cell counts + if (first) { + seq.n_cells -= 1; + if (seq.n_cells == 0) { + GGML_ASSERT(seq.tail < 0); + n_seqs -= 1; + } + first = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + used -= 1; + return true; + } + } + return false; + } + + // TODO: maybe use a simpler data structure than a tree + // returns whether or not a cell was freed + bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < size) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto * node_ptr = rs_cell.get_node(id); // search once + if (node_ptr != nullptr) { + if (rs_cell.seq_nodes.size() == 1) { + return clear_cell(i_cell); + } else { + // update tree + llama_rs_seq_node node = *node_ptr; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = cells[seq.tail].seq_nodes.size() > 1; + } else { + seq.shared = false; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } + if (node_ptr == rs_cell.seq_nodes.data()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + if (seq.n_cells == 0) { + n_seqs -= 1; + } + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = node_ptr[1]; + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + if (next_seq.n_cells == 1) { + n_seqs += 1; + } + if (other_no_longer_shared) { + next_seq.shared = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else if (other_no_longer_shared) { + llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; + if ((uint32_t) first_node.seq_id < seq_tails.size()) { + seq_tails[first_node.seq_id].shared = false; + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + const bool removed = rs_cell.remove_node(node_ptr); + GGML_ASSERT(removed); + } + } + } + return false; + } + + bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < seq_tails.size()) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto & seq = seq_tails[id]; + int32_t prev = rs_cell.prev; + if ((uint32_t) seq.tail == i_cell) { + // the cell is already the tail of this seq_id + return false; + } + if (rs_cell.is_empty()) { + prev = seq.tail; + } + // ensure the new tail won't mess up the tree + GGML_ASSERT(seq.tail == -1 || seq.tail == prev); + if (prev >= 0 && (uint32_t) prev < size) { + // the targeted cell has a previous cell + llama_rs_cell & prev_cell = cells[prev]; + llama_rs_seq_node * prev_node = prev_cell.get_node(id); + GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing + GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken + if (rs_cell.pos < 0) { + GGML_ASSERT(rs_cell.is_empty()); + rs_cell.pos = prev_cell.pos + 1; + rs_cell.src = prev_cell.src; + } + prev_cell.tail_rc -= 1; + prev_node->next_cell = i_cell; + } + if (rs_cell.is_empty()) { + // only add after potential failures above + if (seq.n_cells == 0) { + n_seqs += 1; + } + seq.n_cells += 1; + // set pos if still unset + if (rs_cell.pos < 0) { + rs_cell.pos = 0; + rs_cell.src = -1; + } + } + // the target cell was not already a tail of this seq_id + rs_cell.insert_node(id); // next_cell == -1 by default + rs_cell.tail_rc += 1; + seq.tail = i_cell; + seq.shared = rs_cell.seq_nodes.size() > 1; + return true; + } + return false; + } + + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + // (but this over-prunes when the system prompt doesn't take lots of cells) + // Hmm. The system prompt does not need checkpoints... + size_t min_cells_per_seq() const { + return size / (n_seqs > 0 ? n_seqs : 1); + } + + // each seq_id can have at most this many cells + // (ignoring seqs which behave as a shared prompt) + // TODO: avoid recalculating system seq_ids + // (to use when pruning (to avoid over-pruning)) + // NOTE: this also limits the shared prompt to at most half the cells + // (but the shared prompt technically needs only one cell...) + // (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence) + size_t max_cells_per_seq() const { + int32_t n_system_seqs = 0; + int32_t n_system_cells = 0; + for (size_t i = 0; i < seq_tails.size(); ++i) { + auto & seq = seq_tails[i]; + if (seq.tail >= 0 && (size_t) seq.tail < size) { + if (seq.shared && seq.n_cells > 0) { + n_system_seqs += 1; + n_system_cells += seq.n_cells; + } + } + } + int32_t n_other_seqs = n_seqs - n_system_seqs; + return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1); + } + + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * r : r_l) { + size += ggml_nrows(r) * ggml_row_size(r->type, r->ne[0]); + } + for (struct ggml_tensor * s : s_l) { + size += ggml_nrows(s) * ggml_row_size(s->type, s->ne[0]); + } + return size; + } +}; + +struct llama_cache { + // key + value cache for self attention + llama_kv_cache kv; + + // recurrent state cache for state space models + llama_rs_cache rs; + std::vector ctxs; std::vector bufs; + // NOTE: padding may make this bigger than kv.total_size() + rs.total_size() size_t total_size() const { size_t size = 0; for (ggml_backend_buffer_t buf : bufs) { @@ -1958,7 +2310,7 @@ struct llama_kv_cache { return size; } - ~llama_kv_cache() { + ~llama_cache() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -2146,8 +2498,8 @@ struct llama_context { const llama_model & model; - // key + value cache for the self attention - struct llama_kv_cache kv_self; + // key + value cache for self-attention, and/or recurrent state cache + struct llama_cache cache; std::mt19937 rng; @@ -2205,9 +2557,9 @@ struct llama_context { struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_s_copy; // I32 [n_rs] + struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] + struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch] // control vectors struct llama_control_vector cvec; @@ -2221,47 +2573,45 @@ struct llama_context { // kv cache helpers // -static bool llama_kv_cache_init( - struct llama_kv_cache & cache, +static bool llama_cache_init( + struct llama_cache & cache, const llama_model & model, ggml_type type_k, ggml_type type_v, - uint32_t kv_size, + uint32_t n_ctx, + uint32_t n_seq_max, bool offload) { const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); - const int64_t n_layer = hparams.n_layer; - cache.has_shift = false; + // TODO: per layer n_embd_* + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_r = hparams.n_embd_r(); + const uint32_t n_embd_s = hparams.n_embd_s(); + const bool has_kv = hparams.n_head != 0 && hparams.causal_attn; + const bool has_r = n_embd_r != 0; + const bool has_s = n_embd_s != 0; + const bool has_rs = has_r || has_s; + const uint32_t kv_size = has_kv ? n_ctx : 0; + const uint32_t rs_size = has_rs ? n_seq_max : 0; + // TODO: per cache type layer count + const int64_t n_layer = hparams.n_layer; - // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.kv.size = kv_size; - // TODO: support mixed reccurent Transformer architectues - // NOTE: (!a || b) is a logical implication (a -> b) - GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); - GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); - GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); - GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); + cache.kv.type_k = type_k; + cache.kv.type_v = type_v; - cache.head = 0; - cache.size = kv_size; - cache.used = 0; + cache.kv.cells.clear(); + cache.kv.cells.resize(kv_size); - cache.type_k = type_k; - cache.type_v = type_v; + cache.rs.size = rs_size; - cache.cells.clear(); - cache.cells.resize(kv_size); - - if (cache.recurrent) { - // init state copy sources - for (uint32_t i = 0; i < cache.size; ++i) { - cache.cells[i].src = i; - } - } + cache.rs.cells.clear(); + cache.rs.cells.resize(rs_size); + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(rs_size); #ifdef GGML_USE_CLBLAST offload = false; @@ -2282,7 +2632,7 @@ static bool llama_kv_cache_init( for (auto & it : buft_layer_count) { int n_layers = it.second; struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -2295,17 +2645,37 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + if (has_kv) { + cache.kv.k_l.reserve(n_layer); + cache.kv.v_l.reserve(n_layer); + } + if (has_r) { + cache.rs.r_l.reserve(n_layer); + } + if (has_s) { + cache.rs.s_l.reserve(n_layer); + } for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); + if (has_kv) { + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.kv.k_l.push_back(k); + cache.kv.v_l.push_back(v); + } + if (has_r) { + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size); + ggml_format_name(r, "cache_r_l%d", i); + cache.rs.r_l.push_back(r); + } + if (has_s) { + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size); + ggml_format_name(s, "cache_s_l%d", i); + cache.rs.s_l.push_back(s); + } } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -2330,23 +2700,30 @@ static bool llama_kv_cache_init( // Note: On success, it's important that cache.head points // to the first cell of the slot. static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_batch & batch) { - const uint32_t n_ctx = cache.size; + struct llama_cache & cache, + const struct llama_batch & batch) { + const uint32_t kv_size = cache.kv.size; + const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; - if (cache.recurrent) { + if (rs_size > 0) { // For recurrent state architectures (like Mamba), - // each KV cache cell can store the state for a whole sequence. + // each cache cell can store the state for a whole sequence. + // TODO: real ring-buffer of states + // TODO: state chekpoints (multiple cells per sequence) + // TODO: find a way to always make the rs slot contiguous + + // Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size + - llama_seq_id min = cache.size - 1; + llama_seq_id min = cache.rs.size - 1; llama_seq_id max = 0; for (uint32_t i = 0; i < n_tokens; ++i) { for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; // make sure it's a valid seq_id - if ((uint32_t) seq_id < cache.size) { + if ((uint32_t) seq_id < rs_size) { if (seq_id > max) { max = seq_id; } @@ -2354,83 +2731,93 @@ static bool llama_kv_cache_find_slot( min = seq_id; } // Assuming the tokens are in-order - if (batch.pos[i] != cache.cells[seq_id].pos + 1) { + if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); + __func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id); } - if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.used += 1; + if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { + cache.rs.used += 1; } - cache.cells[seq_id].pos = batch.pos[i]; - // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set + cache.rs.cells[seq_id].pos = batch.pos[i]; + cache.rs.cells[seq_id].seq_nodes.insert(seq_id); } else { // too big seq_id // TODO: would it be possible to resize the KV cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } } // allow getting the range of used cells, from head to head + n - cache.head = min; - cache.n = max - min + 1; + cache.rs.head = min; + cache.rs.n = max - min + 1; // sanity check - return max >= min; - } - // otherwise, one cell per token. - - if (n_tokens > n_ctx) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); - return false; + if (max < min) { + return false; + } } - uint32_t n_tested = 0; + if (kv_size > 0) { + // one KV cell per token + if (n_tokens > kv_size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, kv_size); + return false; + } - while (true) { - if (cache.head + n_tokens > n_ctx) { - n_tested += n_ctx - cache.head; - cache.head = 0; - continue; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (cache.kv.head > cache.kv.used + 2*n_tokens) { + cache.kv.head = 0; } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { - found = false; - cache.head += i + 1; - n_tested += i + 1; - break; + uint32_t n_tested = 0; + + while (true) { + if (cache.kv.head + n_tokens > kv_size) { + n_tested += kv_size - cache.kv.head; + cache.kv.head = 0; + continue; } - } - if (found) { - break; - } + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.kv.cells[cache.kv.head + i].pos >= 0) { + found = false; + cache.kv.head += i + 1; + n_tested += i + 1; + break; + } + } - if (n_tested >= n_ctx) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + if (found) { + break; + } + + if (n_tested >= kv_size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } } - } - for (uint32_t i = 0; i < n_tokens; i++) { - cache.cells[cache.head + i].pos = batch.pos[i]; + for (uint32_t i = 0; i < n_tokens; i++) { + cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.kv.cells[cache.kv.head + i].seq_id.insert(batch.seq_id[i][j]); + } } - } - cache.used += n_tokens; + cache.kv.used += n_tokens; + } return true; } -// find how many cells are currently in use +// find how many KV cells are currently in use static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_kv_cell & cell = cache.cells[i - 1]; @@ -2443,214 +2830,381 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_clear(struct llama_kv_cache & cache) { - for (int32_t i = 0; i < (int32_t) cache.size; ++i) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); +static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { + for (uint32_t i = cache.size; i > 0; --i) { + const llama_rs_cell & cell = cache.cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +static void llama_cache_clear(struct llama_cache & cache) { + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + kv_cell.pos = -1; + kv_cell.delta = 0; + kv_cell.seq_id.clear(); + } + cache.kv.has_shift = false; + cache.kv.do_defrag = false; + cache.kv.head = 0; + cache.kv.used = 0; + } + if (cache.rs.size > 0) { + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.seq_nodes.clear(); + } + cache.rs.do_copy = false; + cache.rs.head = 0; + cache.rs.used = 0; + cache.rs.n_seqs = 0; + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(cache.rs.size); } - cache.head = 0; - cache.used = 0; } -static bool llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - uint32_t new_head = cache.size; +static llama_pos llama_cache_seq_rm( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - // models like Mamba can't have a state partially erased - if (cache.recurrent) { - if (seq_id >= (int64_t) cache.size) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { + if (seq_id >= (int64_t) cache.rs.size) { // could be fatal - return false; - } - if (0 <= seq_id) { - // partial intersection is invalid - if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { - return false; - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; + return n_past; + } + uint32_t new_head = cache.rs.size; + // adjust p0 and p1 according to the states found + llama_pos new_p0 = 0; + llama_pos new_p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos < p0) { + // move forward the new p0 further + if (rs_cell.pos >= new_p0) { + new_p0 = rs_cell.pos + 1; + } + } else if (rs_cell.pos >= p1) { + // move back the new p1 further + if (rs_cell.pos < new_p1) { + new_p1 = rs_cell.pos; + } + if (rs_cell.pos >= n_past) { + n_past = rs_cell.pos + 1; + } + } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) + if (seq_id < 0) { + cache.rs.clear_cell(i); + } else { // (rs_cell.has_seq_id(seq_id)) + cache.rs.remove_seq_from_cell(i, seq_id); + } + if (rs_cell.is_empty() && new_head == cache.rs.size) { + new_head = i; + } + } } } + p0 = new_p0; + p1 = new_p1; + // correctly set n_past when there's nothing after p1 + if (n_past < p0) { n_past = p0; } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - if (seq_id < 0) { - cache.cells[i].seq_id.clear(); - } else if (cache.cells[i].has_seq_id(seq_id)) { - cache.cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cache.cells[i].is_empty()) { - // keep count of the number of used cells - if (cache.cells[i].pos >= 0) cache.used--; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; - cache.cells[i].pos = -1; - if (new_head == cache.size) new_head = i; + if (seq_id < 0 || kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + if (seq_id < 0) { + kv_cell.seq_id.clear(); + } else { // (kv_cell.has_seq_id(seq_id)) + kv_cell.seq_id.erase(seq_id); + } + if (kv_cell.is_empty()) { + // keep count of the number of used cells + if (kv_cell.pos >= 0) { cache.kv.used--; } + + kv_cell.pos = -1; + if (new_head == cache.kv.size) { new_head = i; } + } + } else { + if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } + } } } - } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; + } + } - return true; + return n_past; } -static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { +static llama_pos llama_cache_seq_cp( + struct llama_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { - if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - seq_id_src = cache.cells[seq_id_src].src; - GGML_ASSERT((uint32_t) seq_id_src < cache.size); - // intent to "copy from" - // supports copy chains thanks to taking the source of the source - cache.cells[seq_id_dst].src = seq_id_src; - - // preserve the "keep or clear" status of the copied sequence - if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { - cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); - } else { - cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); + // TODO: in practice this seems to be only used on whole sequences; + // should partial sequence copy be removed? + + llama_pos n_past = 0; + + if (cache.rs.size > 0) { + // have to start from beginning for recurrent models + p0 = 0; + if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { + auto seq_src = cache.rs.seq_tails[seq_id_src]; + int32_t src_tail = seq_src.tail; + // find the last tail of src in the pos range + while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) { + llama_rs_cell & tail_cell = cache.rs.cells[src_tail]; + if (tail_cell.pos < p1) { + break; + } + src_tail = tail_cell.prev; } - cache.do_copy = true; + uint32_t new_head = cache.rs.size; - cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { + if (i == (uint32_t) src_tail) { + // need to be inserted in order, but there's only one + cache.rs.insert_seq_tail_to_cell(i, seq_id_dst); + } else { + // keep only the tail cell of the source + // assuming a copy means no rollback will be attempted afterwards + cache.rs.remove_seq_from_cell(i, seq_id_src); + if (new_head == cache.rs.size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } } - return; + p1 = n_past; } - // otherwise, this is the KV cache of a Transformer-like model - - cache.head = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.insert(seq_id_dst); + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { + kv_cell.seq_id.insert(seq_id_dst); + if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } + } } } + + return n_past; } -static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { - uint32_t new_head = cache.size; +static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (!kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= 0) cache.kv.used--; + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) new_head = i; + } else { + kv_cell.seq_id.clear(); + kv_cell.seq_id.insert(seq_id); + } + } - for (uint32_t i = 0; i < cache.size; ++i) { - if (!cache.cells[i].has_seq_id(seq_id)) { - if (cache.cells[i].pos >= 0) cache.used--; - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; - } else { - cache.cells[i].seq_id.clear(); - cache.cells[i].seq_id.insert(seq_id); + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; } } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; } -static void llama_kv_cache_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - uint32_t new_head = cache.size; +static llama_pos llama_cache_seq_add( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; + auto & seq = cache.rs.seq_tails[seq_id]; + // follow the sequence from its tail + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + int32_t i = cell_id; + cell_id = rs_cell.prev; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos += delta; + if (rs_cell.pos < 0) { + // NOTE: this affects the other sequences which share the cell + cache.rs.clear_cell(i); + // TODO: update cache.rs.head + } + } + if (n_past <= rs_cell.pos) { + n_past = rs_cell.pos + 1; } } - return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; - cache.cells[i].pos += delta; - cache.cells[i].delta += delta; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; + kv_cell.pos += delta; + kv_cell.delta += delta; - if (cache.cells[i].pos < 0) { - if (!cache.cells[i].is_empty()) { - cache.used--; + if (kv_cell.pos < 0) { + if (!kv_cell.is_empty()) { + cache.kv.used--; + } + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) { + new_head = i; + } + } } - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) { - new_head = i; + if (n_past <= kv_cell.pos) { + n_past = kv_cell.pos + 1; } } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.kv.head = new_head != cache.kv.size ? new_head : 0; } - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.head = new_head != cache.size ? new_head : 0; + return n_past; } -static void llama_kv_cache_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { +static llama_pos llama_cache_seq_div( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; + auto & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos /= d; + } + cell_id = rs_cell.prev; + if (n_past <= rs_cell.pos) { + n_past = rs_cell.pos + 1; } } - return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; - { - llama_pos p_old = cache.cells[i].pos; - cache.cells[i].pos /= d; - cache.cells[i].delta += cache.cells[i].pos - p_old; + { + llama_pos p_old = kv_cell.pos; + kv_cell.pos /= d; + kv_cell.delta += kv_cell.pos - p_old; + } + } + if (n_past <= kv_cell.pos) { + n_past = kv_cell.pos + 1; + } } } } + + return n_past; } -static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { llama_pos result = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id)) { - result = std::max(result, cache.cells[i].pos); + if (cache.rs.size > 0) { + int32_t cell_id = cache.rs.seq_tails[seq_id].tail; + if (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + result = rs_cell.pos; + } + // exit early + return result; + } + + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + result = std::max(result, kv_cell.pos); + } } } @@ -6009,6 +6563,7 @@ struct llm_build_context { const llama_cparams & cparams; const llama_batch & batch; const llama_kv_cache & kv_self; + const llama_rs_cache & rs_self; const int64_t n_embd; const int64_t n_layer; @@ -6034,8 +6589,10 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_rs; const int32_t n_outputs; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_head; const int32_t n_orig_ctx; const enum llama_pooling_type pooling_type; @@ -6058,7 +6615,8 @@ struct llm_build_context { hparams (model.hparams), cparams (lctx.cparams), batch (batch), - kv_self (lctx.kv_self), + kv_self (lctx.cache.kv), + rs_self (lctx.cache.rs), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -6081,8 +6639,10 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - n_outputs (worst_case ? n_tokens : lctx.n_outputs), - kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + n_rs (worst_case ? rs_self.size : rs_self.n), + n_outputs (worst_case ? n_tokens : lctx.n_outputs), + kv_head (worst_case ? kv_self.size - n_tokens : kv_self.head), + rs_head (worst_case ? 0 : rs_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -6148,29 +6708,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - - GGML_ASSERT(kv_self.recurrent); - - struct ggml_tensor * state_copy = build_inp_s_copy(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - - // TODO: name the intermediate tensors with cb() - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); - } - - return gf; - } - struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6267,21 +6804,21 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; } struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); + lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_rs); cb(lctx.inp_s_mask, "inp_s_mask", -1); ggml_set_input(lctx.inp_s_mask); return lctx.inp_s_mask; } struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens); cb(lctx.inp_s_seq, "inp_s_seq", -1); ggml_set_input(lctx.inp_s_seq); return lctx.inp_s_seq; @@ -9269,26 +9806,31 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - // (ab)using the KV cache to store the states - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size); + + // copy states + { + // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows + // NOTE: assuming the copy destinations are ALL contained in the current batch + // this shrinks the tensors's ne[1] to n_rs + conv_states = ggml_get_rows(ctx0, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); + } // clear states of sequences which are starting at the beginning of this batch { - conv_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), - state_mask); - ssm_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), - state_mask); + conv_states = ggml_mul(ctx0, conv_states, state_mask); + ssm_states = ggml_mul(ctx0, ssm_states, state_mask); } - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); + conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_rs); + ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_rs); // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -9321,8 +9863,8 @@ struct llm_build_context { // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx0, rs_self.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); // extract x from x_conv x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); @@ -9348,15 +9890,15 @@ struct llm_build_context { // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, + // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, // because only a single tensor can be returned. struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); // store last states (the second part of y_ssm_states) ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states)))); + ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), + ggml_view_1d(ctx0, rs_self.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); @@ -9558,23 +10100,6 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { return result; } -static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; - - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - - struct llm_build_context llm(lctx, dummy, cb, false); - - llm.init(); - - struct ggml_cgraph * result = llm.build_s_copy(); - - llm.free(); - - return result; -} - static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch, @@ -9729,26 +10254,14 @@ static struct ggml_cgraph * llama_build_graph( } static void llama_set_k_shift(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; + const int64_t kv_size = lctx.cache.kv.size; assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); int32_t * data = (int32_t *) lctx.inp_K_shift->data; for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].delta; - } -} - -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; + data[i] = lctx.cache.kv.cells[i].delta; } } @@ -9759,7 +10272,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; + const auto & kv_self = lctx.cache.kv; + const auto & rs_self = lctx.cache.rs; if (batch.token) { const int64_t n_tokens = batch.n_tokens; @@ -9835,7 +10349,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { f = -INFINITY; } else { f = 0.0f; @@ -9886,7 +10400,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_pos->data; for (int i = 0; i < n_kv; ++i) { - data[i] = float(lctx.kv_self.cells[i].pos); + data[i] = float(kv_self.cells[i].pos); } } @@ -9943,29 +10457,54 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; + if (rs_self.size > 0) { + const int64_t n_rs = rs_self.n; if (lctx.inp_s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched - for (int i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + // clear unused states + for (int i = 0; i < n_rs; ++i) { + uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) has_self_seq; + data[i] = (float) rs_cell.src >= 0; - // ensure current sequences will be kept - if (!has_self_seq && kv_cell.pos >= 0) { - kv_cell.seq_id.insert(seq_id); + // only clear once + if (rs_cell.src < 0) { + rs_cell.src = cell_id; } } } + + // checkpoints require copies between cells + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + const uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; + + // prevent out-of-bound sources + if (rs_cell.src < 0 || (uint32_t) rs_cell.src >= rs_self.size) { + rs_cell.src = cell_id; + } + + data[i] = rs_cell.src; + + // ensure copy only happens once + if (rs_cell.src != (int32_t) cell_id) { + rs_cell.src = cell_id; + } + } + } + // For Mamba (and other recurrent architectures), // update the correct state(s)/sequence(s) for each token of the batch. + // Each row contains relative cell ids of the sequences for the associated token. // Like with the KQ_mask, if a token in the batch has multiple sequences, // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). if (lctx.inp_s_seq) { @@ -9978,12 +10517,20 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int32_t n_seq = batch.n_seq_id[j]; GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence - for (int i = 0; i < n_kv; ++i) { + for (int i = 0; i < n_rs; ++i) { if (i < n_seq) { - // for this type of model, the head is the minimum seq_id of the batch - data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; + llama_seq_id seq_id = batch.seq_id[j][i]; + GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); + const auto & seq = rs_self.seq_tails[seq_id]; + // all sequences of this batch should already be initialized + GGML_ASSERT(seq.tail >= 0); + // ensure the relative cell id will be positive but not too big + GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); + GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); + + data[j*n_rs + i] = seq.tail - rs_self.head; } else { - data[j*n_kv + i] = -1; + data[j*n_rs + i] = -1; } } } @@ -10129,7 +10676,8 @@ static int llama_decode_internal( //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; + auto & rs_self = lctx.cache.rs; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -10245,17 +10793,11 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_kv_cache_find_slot(lctx.cache, u_batch)) { return 1; } - if (!kv_self.recurrent) { + if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -10329,11 +10871,15 @@ static int llama_decode_internal( // update the kv ring buffer { kv_self.head += n_tokens; + rs_self.head += rs_self.n; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) { kv_self.head = 0; } + if (rs_self.head >= rs_self.size) { + rs_self.head = 0; + } } #ifdef GGML_PERF @@ -10430,7 +10976,7 @@ static int llama_decode_internal( // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; const auto & hparams = lctx.model.hparams; @@ -10651,7 +11197,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.cache.kv.has_shift) { { ggml_backend_sched_reset(lctx.sched); @@ -10667,7 +11213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; kv_self.has_shift = false; @@ -10677,39 +11223,13 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { - { - ggml_backend_sched_reset(lctx.sched); - - ggml_cgraph * gf = llama_build_graph_s_copy(lctx); - - ggml_backend_sched_alloc_graph(lctx.sched, gf); - - llama_set_s_copy(lctx); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); - - need_reserve = true; - } - - { - auto & kv_self = lctx.kv_self; - - kv_self.do_copy = false; - - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].src = i; - } - } - } - // defragment the KV cache if needed - if (lctx.kv_self.do_defrag) { + if (lctx.cache.kv.do_defrag) { llama_kv_cache_defrag_internal(lctx); need_reserve = true; - lctx.kv_self.do_defrag = false; + lctx.cache.kv.do_defrag = false; } // reserve a worst case graph again @@ -14258,18 +14778,8 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - // Mamba only needs a constant number of KV cache cells per sequence - if (model->arch == LLM_ARCH_MAMBA) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); @@ -14377,25 +14887,42 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_cache_init(ctx->cache, ctx->model, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } - { + if (ctx->cache.rs.size > 0) { + size_t memory_size_r = 0; + size_t memory_size_s = 0; + + for (auto & r : ctx->cache.rs.r_l) { + memory_size_r += ggml_nbytes(r); + } + + for (auto & s : ctx->cache.rs.s_l) { + memory_size_s += ggml_nbytes(s); + } + + LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); + } + if (ctx->cache.kv.size > 0) { size_t memory_size_k = 0; size_t memory_size_v = 0; - for (auto & k : ctx->kv_self.k_l) { + for (auto & k : ctx->cache.kv.k_l) { memory_size_k += ggml_nbytes(k); } - for (auto & v : ctx->kv_self.v_l) { + for (auto & v : ctx->cache.kv.v_l) { memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -14513,7 +15040,11 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) { } uint32_t llama_n_seq_max(const struct llama_context * ctx) { - return ctx->kv_self.size; + if (ctx->cache.rs.size > 0) { + return ctx->cache.rs.size; + } else { + return ctx->cache.kv.size; + } } enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { @@ -14799,8 +15330,9 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { - if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { - view->n_cells = int32_t(ctx->kv_self.size); + const llama_kv_cache & kv_self = ctx->cache.kv; + if (uint32_t(view->n_cells) < kv_self.size || view->cells == nullptr) { + view->n_cells = int32_t(kv_self.size); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); view->cells = (struct llama_kv_cache_view_cell *)p; @@ -14809,7 +15341,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = ctx->kv_self.cells; + const std::vector & kv_cells = kv_self.cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; @@ -14818,7 +15350,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k uint32_t max_contig = 0; int32_t max_contig_idx = -1; - for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { + for (int32_t i = 0; i < int32_t(kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { const size_t curr_size = kv_cells[i].seq_id.size(); token_count += curr_size; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; @@ -14856,67 +15388,77 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->max_contiguous_idx = max_contig_idx; view->token_count = token_count; view->used_cells = used_cells; - if (uint32_t(used_cells) != ctx->kv_self.used) { + if (uint32_t(used_cells) != kv_self.used) { LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, ctx->kv_self.used, used_cells); + __func__, kv_self.used, used_cells); } } int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; - for (uint32_t i = 0; i < ctx->kv_self.size; i++) { - result += ctx->kv_self.cells[i].seq_id.size(); + for (uint32_t i = 0; i < ctx->cache.kv.size; i++) { + result += ctx->cache.kv.cells[i].seq_id.size(); } return result; } int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { - return ctx->kv_self.used; + return ctx->cache.kv.used; } void llama_kv_cache_clear(struct llama_context * ctx) { - llama_kv_cache_clear(ctx->kv_self); + llama_cache_clear(ctx->cache); } bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return false; } + llama_pos n_past = llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + return n_past >= p0; } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + uint32_t n_seq_max = llama_n_seq_max(ctx); + if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { + return; + } if (seq_id_src == seq_id_dst) { return; } - llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); + llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_seq_keep(ctx->kv_self, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + llama_cache_seq_keep(ctx->cache, seq_id); } void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (delta == 0) { return; } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta); + llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (d == 1) { return; } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); + llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_cache_seq_pos_max(ctx->cache, seq_id); } void llama_kv_cache_defrag(struct llama_context * ctx) { - llama_kv_cache_defrag(ctx->kv_self); + llama_kv_cache_defrag(ctx->cache.kv); } void llama_kv_cache_update(struct llama_context * ctx) { @@ -14944,9 +15486,10 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); - const size_t s_kv = ctx->kv_self.total_size(); + const size_t s_kv = ctx->cache.kv.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); - const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; + const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell; + // TODO: rs cache cells const size_t s_total = ( + s_rng_size @@ -15241,14 +15784,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } } + // FIXME: set rs cache too // set kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); size_t kv_buf_size; uint32_t kv_head; @@ -15279,16 +15823,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v - const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); - - ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size); - inp += v_size; - continue; - } - // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size); @@ -15303,8 +15837,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; - ctx->kv_self.used = kv_used; + ctx->cache.kv.head = kv_head; + ctx->cache.kv.used = kv_used; for (uint32_t i = 0; i < kv_head; ++i) { llama_pos pos; @@ -15313,13 +15847,13 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos); memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); - ctx->kv_self.cells[i].pos = pos; + ctx->cache.kv.cells[i].pos = pos; llama_seq_id seq_id; for (size_t j = 0; j < seq_id_size; ++j) { memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); - ctx->kv_self.cells[i].seq_id.insert(seq_id); + ctx->cache.kv.cells[i].seq_id.insert(seq_id); } } } From 8db1e4d45fb27a5e76ac55559a008a425e00fbac Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 4 Apr 2024 10:46:43 -0400 Subject: [PATCH 002/117] llama : use std::find for seq_nodes in llama_rs_cache --- llama.cpp | 153 ++++++++++++++++++++++-------------------------------- 1 file changed, 61 insertions(+), 92 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9ca8ca0f41320..6dc310bf94c6c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1962,11 +1962,12 @@ struct llama_rs_seq_node { llama_seq_id seq_id = -1; int32_t next_cell = -1; - // needed for automatic typecasting with .find() + // needed for automatic typecasting from a llama_seq_id llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} - bool operator<(const llama_rs_seq_node & other) const { - return seq_id < other.seq_id; + // needed for more convenient std::find + bool operator==(const llama_rs_seq_node & other) const { + return seq_id == other.seq_id; } bool is_tail() const { @@ -1989,48 +1990,18 @@ struct llama_rs_cell { // seq_ids by insertion order, to simplify updating n_cells compared to a set std::vector seq_nodes; - llama_rs_seq_node * get_node(const llama_seq_id & id) { - for (size_t i = 0; i < seq_nodes.size(); ++i) { - if (seq_nodes[i].seq_id == id) { - return &seq_nodes[i]; - } - } - return nullptr; - } - void insert_node(const llama_rs_seq_node & node) { - llama_rs_seq_node * node_dest = get_node(node.seq_id); - if (node_dest == nullptr) { + auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node); + if (node_dest == seq_nodes.end()) { seq_nodes.push_back(node); } else { + // overwrite the pre-existing node with the same seq_id if it exists *node_dest = node; } } - bool remove_node(llama_rs_seq_node * node_ptr) { - if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) { - size_t offset = node_ptr - seq_nodes.data(); - if (offset % sizeof(llama_rs_seq_node) == 0) { - offset /= sizeof(llama_rs_seq_node); - if (offset < seq_nodes.size()) { - for (size_t i = offset + 1; i < seq_nodes.size(); ++i) { - seq_nodes[i - 1] = seq_nodes[i]; - } - seq_nodes.resize(seq_nodes.size() - 1); - return true; - } - } - } - return false; - } - bool has_seq_id(const llama_seq_id & id) const { - for (size_t i = 0; i < seq_nodes.size(); ++i) { - if (seq_nodes[i].seq_id == id) { - return true; - } - } - return false; + return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end(); } bool is_empty() const { @@ -2132,67 +2103,65 @@ struct llama_rs_cache { bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < size) { llama_rs_cell & rs_cell = cells[i_cell]; - auto * node_ptr = rs_cell.get_node(id); // search once - if (node_ptr != nullptr) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once + if (node_iter != rs_cell.seq_nodes.end()) { if (rs_cell.seq_nodes.size() == 1) { return clear_cell(i_cell); - } else { - // update tree - llama_rs_seq_node node = *node_ptr; - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - cells[node.next_cell].prev = rs_cell.prev; + } + // else update tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = cells[seq.tail].seq_nodes.size() > 1; + } else { + seq.shared = false; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = cells[seq.tail].seq_nodes.size() > 1; - } else { - seq.shared = false; - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + if (seq.n_cells == 0) { + n_seqs -= 1; } - if (node_ptr == rs_cell.seq_nodes.data()) { - // this seq_id was the first in the list - seq.n_cells -= 1; - if (seq.n_cells == 0) { - n_seqs -= 1; - } - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = node_ptr[1]; - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - if (next_seq.n_cells == 1) { - n_seqs += 1; - } - if (other_no_longer_shared) { - next_seq.shared = false; - } - } else { - GGML_ASSERT(false && "invalid seq_id"); + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = *(std::next(node_iter)); + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + if (next_seq.n_cells == 1) { + n_seqs += 1; } - } else if (other_no_longer_shared) { - llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; - if ((uint32_t) first_node.seq_id < seq_tails.size()) { - seq_tails[first_node.seq_id].shared = false; - } else { - GGML_ASSERT(false && "invalid seq_id"); + if (other_no_longer_shared) { + next_seq.shared = false; } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else if (other_no_longer_shared) { + llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; + if ((uint32_t) first_node.seq_id < seq_tails.size()) { + seq_tails[first_node.seq_id].shared = false; + } else { + GGML_ASSERT(false && "invalid seq_id"); } - } else { - GGML_ASSERT(false && "invalid seq_id"); } - const bool removed = rs_cell.remove_node(node_ptr); - GGML_ASSERT(removed); + } else { + GGML_ASSERT(false && "invalid seq_id"); } + rs_cell.seq_nodes.erase(node_iter); } } return false; @@ -2215,8 +2184,8 @@ struct llama_rs_cache { if (prev >= 0 && (uint32_t) prev < size) { // the targeted cell has a previous cell llama_rs_cell & prev_cell = cells[prev]; - llama_rs_seq_node * prev_node = prev_cell.get_node(id); - GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken if (rs_cell.pos < 0) { GGML_ASSERT(rs_cell.is_empty()); @@ -2267,7 +2236,7 @@ struct llama_rs_cache { int32_t n_system_seqs = 0; int32_t n_system_cells = 0; for (size_t i = 0; i < seq_tails.size(); ++i) { - auto & seq = seq_tails[i]; + const auto & seq = seq_tails[i]; if (seq.tail >= 0 && (size_t) seq.tail < size) { if (seq.shared && seq.n_cells > 0) { n_system_seqs += 1; From 0028010d01447c079f98bc33f06fca691fc99905 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 8 Apr 2024 09:54:35 -0400 Subject: [PATCH 003/117] llama : state checkpoints for recurrent models --- ggml.c | 96 +++---- llama.cpp | 751 +++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 585 insertions(+), 262 deletions(-) diff --git a/ggml.c b/ggml.c index c9b0a6a0ef776..7a3f1b7a2f882 100644 --- a/ggml.c +++ b/ggml.c @@ -6335,19 +6335,18 @@ struct ggml_tensor * ggml_ssm_conv( GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(ggml_is_matrix(x)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_matrix(sq)); + GGML_ASSERT(ggml_is_vector(sq)); GGML_ASSERT(sq->type == GGML_TYPE_I32); const int64_t d_conv = c->ne[0]; const int64_t d_inner = c->ne[1]; const int64_t n_tokens = x->ne[1]; - const int64_t n_kv = s->ne[2]; + const int64_t n_rs = s->ne[2]; GGML_ASSERT( s->ne[0] == d_conv - 1); GGML_ASSERT( s->ne[1] == d_inner); GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_kv); - GGML_ASSERT(sq->ne[1] == n_tokens); + GGML_ASSERT(sq->ne[0] == n_tokens); bool is_node = false; @@ -6356,8 +6355,8 @@ struct ggml_tensor * ggml_ssm_conv( is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv)); + // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs} + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs)); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6410,7 +6409,7 @@ struct ggml_tensor * ggml_ssm_scan( is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv} + // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs} struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; @@ -15087,9 +15086,9 @@ static void ggml_compute_forward_ssm_conv_f32( const int nc = src2->ne[0]; // d_conv const int nr = src0->ne[1]; // d_inner const int n_t = src1->ne[1]; // n_tokens - const int n_kv = src0->ne[2]; // max number of sequences in the batch + const int n_rs = src0->ne[2]; // max number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst)); + GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -15106,10 +15105,12 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src3->data; // {n_tokens} + + if (n_rs > 1) { // multiple sequences means it's hard to know when it's the first time a state is read, // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); // can't use memcpy because of d_conv vs d_conv - 1 @@ -15123,19 +15124,19 @@ static void ggml_compute_forward_ssm_conv_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens} - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv} - float * s0; // {d_conv - 1, d_inner, n_kv} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} + int32_t sq_i = sq[i2]; + float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs} + float * s0; // {d_conv - 1, d_inner, n_rs} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} int ne0s0; - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs} ne0s0 = src0->ne[0]; } else { // the source is the last (d_conv - 1) columns of the destination @@ -15153,18 +15154,6 @@ static void ggml_compute_forward_ssm_conv_f32( s[(nc - 1) + i1*nc] = x0[i1]; } - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product @@ -15216,7 +15205,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nc = src0->ne[0]; // d_state const int64_t nr = src0->ne[1]; // d_inner const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch + const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15225,6 +15214,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // required for the dot product between s and C, and when copying the states GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // required for per-sequence offsets for states @@ -15240,10 +15230,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - if (n_kv > 1) { + const int32_t * sq = src6->data; // {n_tokens} + + if (n_rs > 1) { // it's hard to know if the source states have already been copied // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_kv; ++i3) { + for (int i3 = 0; i3 < n_rs; ++i3) { float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); memcpy(s, s0, nc*ir*sizeof(float)); @@ -15251,21 +15243,21 @@ static void ggml_compute_forward_ssm_scan_f32( } for (int i2 = 0; i2 < n_t; ++i2) { - int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - - GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv); + int32_t sq_i = sq[i2]; + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs} + float * s0; + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} + + GGML_ASSERT(0 <= sq_i && sq_i < n_rs); // avoid needing to copy the state for the first token if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv} + s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs} } else { // otherwise the source is the same as the destination s0 = s; @@ -15288,18 +15280,6 @@ static void ggml_compute_forward_ssm_scan_f32( } y[i1] = sumf; } - - // handle copies when there are multiple output states - for (int i3 = 1; i3 < n_kv; ++i3) { - int32_t seq = sq[i3]; - if (0 <= seq && seq < n_kv) { - float * s1 = s + (seq - sq[0])*nc*nr; - memcpy(s1, s, nc*ir*sizeof(float)); - } else { - // stop at negative or too big seq_ids - break; - } - } } } diff --git a/llama.cpp b/llama.cpp index 6dc310bf94c6c..d561f80f62b6d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2016,11 +2016,13 @@ struct llama_rs_seq_meta { // number of cells for which this seq_id is the first // (useful to know if cells in this sequence should be pruned) int32_t n_cells = 0; - // whether the tail is a cell part of multiple sequences - bool shared = false; + // changing the tail cell of a sequence can only be done at batch boundary, + // this guards against changing the cell when it shouldn't be; + // should be cleared when done finding a slot + bool in_ubatch = false; }; -// ring-buffer of cached recurrent state data +// ring-buffered tree of cached recurrent state data struct llama_rs_cache { bool do_copy = false; @@ -2032,8 +2034,10 @@ struct llama_rs_cache { uint32_t n = 0; // range of states used for the last slot // useful to know the minimum reserved cell count per seq_id - // only counts sequences with n_cells > 0 + // only counts sequences with n_cells > 0 AND which have a non-shared tail uint32_t n_seqs = 0; + // cells part of multiple sequences AND which have at least one tail + uint32_t n_shared_tail_cells = 0; // with state models, a cell can hold the state for more than one past token // TODO: it's probably not possible to always use contiguous cells @@ -2047,127 +2051,332 @@ struct llama_rs_cache { std::vector r_l; // rolling/shift states std::vector s_l; // ssm (recurrent) states - // returns whether or not a cell was freed - bool clear_cell(uint32_t i) { - if (i < size) { - llama_rs_cell & rs_cell = cells[i]; - if (!rs_cell.is_empty()) { - // update sequence tree links - bool first = true; - for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: if all next cells are the same cell, this should still work - cells[node.next_cell].prev = rs_cell.prev; + // TODO: maybe use a simpler data structure than a tree + + // Inefficient, but thorough verification and rebuilding of the rs cache + // from only the cells list with `pos` and seq_ids. + // Should not be called in a hot loop except when desperate and/or debugging. + bool rebuild(bool debug) { + bool was_valid = true; + // the source of truth is the cells list + // buffer sizes + if (size != cells.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", + __func__, cells.size(), size); + } + cells.resize(size); + was_valid = false; + } + if (size != seq_tails.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", + __func__, seq_tails.size(), size); + } + seq_tails.resize(size); + was_valid = false; + } + // cells consistency + uint32_t used_verif = 0; + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.seq_nodes.empty()) { + if (cell.pos >= 0) { + cell.pos = -1; + was_valid = false; + } + } + if (cell.pos < 0) { + if (cell.pos != -1) { + cell.pos = -1; + was_valid = false; + } + if (!cell.seq_nodes.empty()) { + cell.seq_nodes.clear(); + was_valid = false; + } + cell.src = -1; + if (cell.prev != -1) { + cell.prev = -1; + was_valid = false; + } + } else if (!debug) { + // Assuming the cache should be actually rebuilt when not debugging + cell.src = cell_id; + } + if (!cell.seq_nodes.empty()) { + used_verif += 1; + } + } + if (used != used_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid used cell count (%u instead of %u)\n", + __func__, used, used_verif); + } + used = used_verif; + was_valid = false; + } + // tail verification + std::vector> seq_cells; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + seq_cells.clear(); + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.has_seq_id(seq_id)) { + seq_cells.push_back({cell.pos, cell_id}); + } + } + // sort by pos and then by cell_id + std::sort(seq_cells.begin(), seq_cells.end()); + int32_t tail = seq_cells.empty() ? -1 : seq_cells[seq_cells.size() - 1].second; + if (tail != seq.tail) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.tail, tail); + } + seq.tail = tail; + was_valid = false; + } + int32_t prev = -1; + for (size_t i = 0; i < seq_cells.size(); ++i) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + if (cell.prev != prev) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, cell.prev, prev); } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - // update tail - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = new_tail.seq_nodes.size() > 1; - } else { - seq.shared = false; - } - } - // cell counts - if (first) { - seq.n_cells -= 1; - if (seq.n_cells == 0) { - GGML_ASSERT(seq.tail < 0); - n_seqs -= 1; - } - first = false; - } + cell.prev = prev; + was_valid = false; + } + prev = cell_id; + } + int32_t n_cells = 0; + int32_t next = -1; + for (size_t i = seq_cells.size(); i-- > 0;) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + // assuming it's always found, because how else would it end up in the list of cells for this seq_id? + auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id); + if (seq_node == cell.seq_nodes.begin()) { + n_cells += 1; + } + if (seq_node->next_cell != next) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, seq_node->next_cell, next); + } + seq_node->next_cell = next; + was_valid = false; + } + next = cell_id; + } + if (seq.n_cells != n_cells) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.n_cells, n_cells); + } + seq.n_cells = n_cells; + } + // in_batch should only be true when in the process of finding a slot + if (seq.in_ubatch != false) { + if (debug) { + LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n", + __func__, seq_id); + } + seq.in_ubatch = false; + was_valid = false; + } + } + // tail_rc + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + uint32_t tail_rc = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) { + tail_rc += 1; + } + } + if (cell.tail_rc != tail_rc) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n", + __func__, cell_id, cell.tail_rc, tail_rc); + } + cell.tail_rc = tail_rc; + was_valid = false; + } + } + // n_seqs + uint32_t n_seqs_verif = 0; + uint32_t n_shared_tail_cells_verif = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0) { + llama_rs_cell & tail_cell = cells[seq.tail]; + // NOTE: could also have checked if n_cells > 0 + if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) { + if (tail_cell.seq_nodes.size() > 1) { + n_shared_tail_cells_verif += 1; } else { - GGML_ASSERT(false && "invalid seq_id"); + n_seqs_verif += 1; } } - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); - used -= 1; - return true; } } - return false; + if (n_seqs != n_seqs_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n", + __func__, n_seqs, n_seqs_verif); + } + n_seqs = n_seqs_verif; + was_valid = false; + } + if (n_shared_tail_cells != n_shared_tail_cells_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n", + __func__, n_shared_tail_cells, n_shared_tail_cells_verif); + } + n_shared_tail_cells = n_shared_tail_cells_verif; + was_valid = false; + } + return was_valid; } - // TODO: maybe use a simpler data structure than a tree // returns whether or not a cell was freed - bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { - if (i_cell < size && (size_t) id < size) { - llama_rs_cell & rs_cell = cells[i_cell]; - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once - if (node_iter != rs_cell.seq_nodes.end()) { - if (rs_cell.seq_nodes.size() == 1) { - return clear_cell(i_cell); - } - // else update tree - llama_rs_seq_node node = *node_iter; + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + if (!rs_cell.is_empty()) { + // update sequence tree links + bool first = true; + for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: if all next cells are the same cell, this should still work cells[node.next_cell].prev = rs_cell.prev; } + // next_cell of the nodes of the previous cell + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + prev_cell.tail_rc += 1; + } + } if ((uint32_t) node.seq_id < seq_tails.size()) { auto & seq = seq_tails[node.seq_id]; - bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + // update tail if (node.is_tail()) { seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = cells[seq.tail].seq_nodes.size() > 1; - } else { - seq.shared = false; - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; } - if (node_iter == rs_cell.seq_nodes.begin()) { - // this seq_id was the first in the list + // cell counts + if (first) { seq.n_cells -= 1; - if (seq.n_cells == 0) { - n_seqs -= 1; - } - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = *(std::next(node_iter)); - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - if (next_seq.n_cells == 1) { - n_seqs += 1; - } - if (other_no_longer_shared) { - next_seq.shared = false; + if (rs_cell.tail_rc > 0 && seq.tail < 0) { + // last tail cell + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; + } else { + n_seqs -= 1; } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } else if (other_no_longer_shared) { - llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; - if ((uint32_t) first_node.seq_id < seq_tails.size()) { - seq_tails[first_node.seq_id].shared = false; - } else { - GGML_ASSERT(false && "invalid seq_id"); } + first = false; } } else { GGML_ASSERT(false && "invalid seq_id"); } - rs_cell.seq_nodes.erase(node_iter); } + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + used -= 1; + } + } + + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + // TODO: assert the iterator points inside the correct vector + if (node_iter != rs_cell.seq_nodes.end()) { + if (rs_cell.seq_nodes.size() == 1) { + clear_cell(rs_cell); + return rs_cell.seq_nodes.end(); + } + // else update tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + prev_cell.tail_rc += 1; + } + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail < 0 && rs_cell.tail_rc == 1) { + // assuming the previous cell of a shared cell is also shared, + // (no need to update the shared tail cells count elsewhere, then) + // this was a shared tail cell, but will no longer be a tail cell + n_shared_tail_cells -= 1; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = *(std::next(node_iter)); + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + // only the tail ref count from the other seq_ids are left in tail_rc + if (rs_cell.tail_rc > 0) { + // will become a non-shared cell + if (rs_cell.seq_nodes.size() == 2) { + n_seqs += 1; + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + return rs_cell.seq_nodes.erase(node_iter); + } + return node_iter; + } + + // returns whether or not the seq_id was removed + bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < size) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once + return node_iter != remove_seq_node_from_cell(rs_cell, node_iter); } return false; } - bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) { + bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < seq_tails.size()) { llama_rs_cell & rs_cell = cells[i_cell]; auto & seq = seq_tails[id]; @@ -2194,10 +2403,11 @@ struct llama_rs_cache { } prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; + rs_cell.prev = prev; } if (rs_cell.is_empty()) { - // only add after potential failures above - if (seq.n_cells == 0) { + // either the sequence didn't own any cells or had a shared tail cell + if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) { n_seqs += 1; } seq.n_cells += 1; @@ -2206,12 +2416,40 @@ struct llama_rs_cache { rs_cell.pos = 0; rs_cell.src = -1; } + used += 1; + } else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) { + // don't count shared-cell tails + // FIXME: make this saner + n_seqs -= 1; + n_shared_tail_cells += 1; + } else if (rs_cell.tail_rc == 0) { + // shared cell without a tail gets a tail; + // FIXME: don't prune, in case this is used in llama_cache_seq_cp + GGML_ASSERT(false); // make sure we don't get here by accident + // prune the other sequences out of this cell + // NOTE: have to inline the removal because the state tree is partially invalid + bool first = true; + for (auto & node : rs_cell.seq_nodes) { + GGML_ASSERT(node.seq_id != id); + GGML_ASSERT(node.next_cell >= 0); + // easy removal, none of the nodes are tails + llama_rs_cell & next_cell = cells[node.next_cell]; + next_cell.prev = rs_cell.prev; + if (first) { + auto & first_seq = seq_tails[node.seq_id]; + first_seq.n_cells -= 1; + first = false; + } + } + rs_cell.seq_nodes.clear(); + } else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) { + // this is correct as long as this isn't called when trying to find a slot + // TODO: find a way to assert this } // the target cell was not already a tail of this seq_id rs_cell.insert_node(id); // next_cell == -1 by default rs_cell.tail_rc += 1; seq.tail = i_cell; - seq.shared = rs_cell.seq_nodes.size() > 1; return true; } return false; @@ -2219,33 +2457,12 @@ struct llama_rs_cache { // each seq_id should have access to at least this many cells // (to use when pruning (to avoid over-pruning)) - // (but this over-prunes when the system prompt doesn't take lots of cells) - // Hmm. The system prompt does not need checkpoints... - size_t min_cells_per_seq() const { - return size / (n_seqs > 0 ? n_seqs : 1); - } - - // each seq_id can have at most this many cells - // (ignoring seqs which behave as a shared prompt) - // TODO: avoid recalculating system seq_ids - // (to use when pruning (to avoid over-pruning)) - // NOTE: this also limits the shared prompt to at most half the cells - // (but the shared prompt technically needs only one cell...) - // (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence) - size_t max_cells_per_seq() const { - int32_t n_system_seqs = 0; - int32_t n_system_cells = 0; - for (size_t i = 0; i < seq_tails.size(); ++i) { - const auto & seq = seq_tails[i]; - if (seq.tail >= 0 && (size_t) seq.tail < size) { - if (seq.shared && seq.n_cells > 0) { - n_system_seqs += 1; - n_system_cells += seq.n_cells; - } - } + size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const { + uint32_t seqs = n_seqs; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; } - int32_t n_other_seqs = n_seqs - n_system_seqs; - return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1); + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); } size_t total_size() const { @@ -2528,7 +2745,7 @@ struct llama_context { struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [n_rs] struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] - struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch] + struct ggml_tensor * inp_s_seq; // I32 [n_batch] // control vectors struct llama_control_vector cvec; @@ -2657,7 +2874,7 @@ static bool llama_cache_init( return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -2678,54 +2895,170 @@ static bool llama_kv_cache_find_slot( if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. - // TODO: real ring-buffer of states - // TODO: state chekpoints (multiple cells per sequence) // TODO: find a way to always make the rs slot contiguous - // Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size - - - llama_seq_id min = cache.rs.size - 1; - llama_seq_id max = 0; + llama_seq_id min_seq = cache.rs.size - 1; + llama_seq_id max_seq = 0; + uint32_t min_cell = cache.rs.size - 1; + uint32_t max_cell = 0; for (uint32_t i = 0; i < n_tokens; ++i) { - for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + int32_t target_cell = -1; // ensure all the sequences of a token get the same cell + int32_t n_seq_ids = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_ids; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; - // make sure it's a valid seq_id + bool need_new_cell = false; + // Everything should fit assuming the biggest seq_id < rs_size if ((uint32_t) seq_id < rs_size) { - if (seq_id > max) { - max = seq_id; + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + if (seq_id > max_seq) { max_seq = seq_id; } + if (seq_id < min_seq) { min_seq = seq_id; } + + if (!seq.in_ubatch && target_cell >= 0) { + // never saw this seq_id before, + // but there's already a cell reserved for this token, use it + cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); + } else if (seq.tail < 0) { + need_new_cell = true; + } else { + llama_rs_cell & tail = cache.rs.cells[seq.tail]; + if (seq.in_ubatch) { + // this seq_id was already seen before in the batch + // assuming the tail cell already "has" this seq_id + tail.pos += 1; + target_cell = seq.tail; + } else { + // first time this sequence is seen, + // there's no reserved cell yet; + // if it's not the first sequence of the token, how could it even get here? + GGML_ASSERT(j == 0); + + bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; + if (has_same_seqs) { + // the tail cell of a seq_id is assumed to already be part of the seq_id, + // hence the skip of the first seq_id + for (int32_t k = 1; k < n_seq_ids; ++k) { + if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { + has_same_seqs = false; + } + } + } + + // TODO: make the checkpoint interval configurable + if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { + // a checkpoint should be saved + need_new_cell = true; + } else { + // re-use last tail + tail.pos += 1; + target_cell = seq.tail; + } + } } - if (seq_id < min) { - min = seq_id; + + if (need_new_cell && target_cell < 0) { + const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + + uint32_t cell_id = cache.rs.size; + bool looped_once = false; + + while (true) { + if (cache.rs.head >= cache.rs.size) { + cache.rs.head = 0; + if (looped_once) { + // avoid infinite loop + // NOTE: this should not happen, but gracefully fail anyway + LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__); + return false; + } + looped_once = true; + } + cell_id = cache.rs.head; + llama_rs_cell & candidate = cache.rs.cells[cell_id]; + if (candidate.is_empty()) { break; } + if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + if (candidate.seq_nodes.size() > 1) { + // prune out the other seq_ids, because they diverge + // TODO(maybe): hande this in insert_seq_tail_to_cell_id + // (hopefully doesn't happen too often) + for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { + if (node_iter->seq_id == seq_id) { + node_iter = std::next(node_iter); + } else { + node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); + } + } + } + // re-use the tail cell to avoid not finding anything + candidate.pos += 1; + break; + } + if (candidate.tail_rc > 0) { + // skip tails of other sequences + cache.rs.head += 1; + continue; + } + if (candidate.seq_nodes.size() > 1) { + // shared prompts are not usually backtracked, so they can be pruned + cache.rs.clear_cell(candidate); + break; + } + + // prune too-long sequences + llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; + if (seq_id_to_prune == seq_id) { + // TODO: selectively skip some cells to keep older states + cache.rs.clear_cell(candidate); + break; + } + GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); + auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; + if (seq_to_prune.n_cells > min_cells_per_seq) { + cache.rs.clear_cell(candidate); + break; + } + cache.rs.head += 1; + } + if (cell_id < cache.rs.size) { + cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); + target_cell = cell_id; + } } + + if (seq.tail >= 0) { + if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } + if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } + seq.in_ubatch = true; + } + // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) { + if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id); - } - if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.rs.used += 1; + __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); } - cache.rs.cells[seq_id].pos = batch.pos[i]; - cache.rs.cells[seq_id].seq_nodes.insert(seq_id); } else { // too big seq_id - // TODO: would it be possible to resize the KV cache size instead? + // TODO: would it be possible to resize the rs cache size instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } + cache.rs.head = target_cell + 1; + } + + for (llama_seq_id i = min_seq; i <= max_seq; ++i) { + // make sure it's cleared for next time + cache.rs.seq_tails[i].in_ubatch = false; } // allow getting the range of used cells, from head to head + n - cache.rs.head = min; - cache.rs.n = max - min + 1; + cache.rs.head = min_cell; + cache.rs.n = max_cell - min_cell + 1; // sanity check - if (max < min) { + if (max_seq < min_seq || max_cell < min_cell) { return false; } } @@ -2799,6 +3132,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } +// find how many recurrent state cells are currently in use static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_rs_cell & cell = cache.cells[i - 1]; @@ -2829,12 +3163,15 @@ static void llama_cache_clear(struct llama_cache & cache) { llama_rs_cell & rs_cell = cache.rs.cells[i]; rs_cell.pos = -1; rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; rs_cell.seq_nodes.clear(); } cache.rs.do_copy = false; cache.rs.head = 0; cache.rs.used = 0; cache.rs.n_seqs = 0; + cache.rs.n_shared_tail_cells = 0; cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(cache.rs.size); } @@ -2846,8 +3183,8 @@ static llama_pos llama_cache_seq_rm( llama_pos p0, llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -2863,7 +3200,9 @@ static llama_pos llama_cache_seq_rm( for (uint32_t i = 0; i < cache.rs.size; ++i) { llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + + if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) { if (rs_cell.pos < p0) { // move forward the new p0 further if (rs_cell.pos >= new_p0) { @@ -2879,9 +3218,9 @@ static llama_pos llama_cache_seq_rm( } } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) if (seq_id < 0) { - cache.rs.clear_cell(i); + cache.rs.clear_cell(rs_cell); } else { // (rs_cell.has_seq_id(seq_id)) - cache.rs.remove_seq_from_cell(i, seq_id); + cache.rs.remove_seq_node_from_cell(rs_cell, seq_node); } if (rs_cell.is_empty() && new_head == cache.rs.size) { new_head = i; @@ -2943,11 +3282,12 @@ static llama_pos llama_cache_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } // TODO: in practice this seems to be only used on whole sequences; - // should partial sequence copy be removed? + // should partial sequence copy support be removed? + // TODO: What if the destination sequence is not empty? llama_pos n_past = 0; @@ -2973,11 +3313,11 @@ static llama_pos llama_cache_seq_cp( if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { if (i == (uint32_t) src_tail) { // need to be inserted in order, but there's only one - cache.rs.insert_seq_tail_to_cell(i, seq_id_dst); + cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst); } else { // keep only the tail cell of the source // assuming a copy means no rollback will be attempted afterwards - cache.rs.remove_seq_from_cell(i, seq_id_src); + cache.rs.remove_seq_from_cell_id(i, seq_id_src); if (new_head == cache.rs.size) { new_head = i; } @@ -3009,16 +3349,41 @@ static llama_pos llama_cache_seq_cp( } static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { + if (cache.rs.size > 0) { + uint32_t new_head = cache.rs.size; + + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (!rs_cell.seq_nodes.empty()) { + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + if (node_iter->seq_id != seq_id) { + node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + } else { + node_iter = std::next(node_iter); + } + } + if (new_head == cache.rs.size && rs_cell.is_empty()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } + } + if (cache.kv.size > 0) { uint32_t new_head = cache.kv.size; for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; if (!kv_cell.has_seq_id(seq_id)) { - if (kv_cell.pos >= 0) cache.kv.used--; + if (kv_cell.pos >= 0) { cache.kv.used--; } kv_cell.pos = -1; kv_cell.seq_id.clear(); - if (new_head == cache.kv.size) new_head = i; + if (new_head == cache.kv.size) { new_head = i; } } else { kv_cell.seq_id.clear(); kv_cell.seq_id.insert(seq_id); @@ -3052,13 +3417,12 @@ static llama_pos llama_cache_seq_add( while (cell_id >= 0) { GGML_ASSERT((uint32_t) cell_id < cache.rs.size); llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - int32_t i = cell_id; cell_id = rs_cell.prev; if (rs_cell.pos >= p0 && rs_cell.pos < p1) { rs_cell.pos += delta; if (rs_cell.pos < 0) { // NOTE: this affects the other sequences which share the cell - cache.rs.clear_cell(i); + cache.rs.clear_cell(rs_cell); // TODO: update cache.rs.head } } @@ -6787,7 +7151,7 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens); + lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(lctx.inp_s_seq, "inp_s_seq", -1); ggml_set_input(lctx.inp_s_seq); return lctx.inp_s_seq; @@ -10482,26 +10846,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); int32_t * data = (int32_t *) lctx.inp_s_seq->data; - for (int j = 0; j < n_tokens; ++j) { - const int32_t n_seq = batch.n_seq_id[j]; - GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence - - for (int i = 0; i < n_rs; ++i) { - if (i < n_seq) { - llama_seq_id seq_id = batch.seq_id[j][i]; - GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); - const auto & seq = rs_self.seq_tails[seq_id]; - // all sequences of this batch should already be initialized - GGML_ASSERT(seq.tail >= 0); - // ensure the relative cell id will be positive but not too big - GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); - GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - - data[j*n_rs + i] = seq.tail - rs_self.head; - } else { - data[j*n_rs + i] = -1; - } - } + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); + const auto & seq = rs_self.seq_tails[seq_id]; + // ensure the relative cell id will be positive but not too big + GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); + GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); + + data[i] = seq.tail - rs_self.head; } } } @@ -14874,7 +15227,7 @@ struct llama_context * llama_new_context_with_model( memory_size_s += ggml_nbytes(s); } - LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); @@ -14891,7 +15244,7 @@ struct llama_context * llama_new_context_with_model( memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -15458,7 +15811,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv = ctx->cache.kv.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell; - // TODO: rs cache cells + // FIXME: rs cache cells const size_t s_total = ( + s_rng_size @@ -15606,14 +15959,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat } } + // FIXME: copy rs cache // copy kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); @@ -15637,17 +15991,6 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v - const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); - - tmp_buf.resize(v_size); - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size()); - data_ctx->write(tmp_buf.data(), tmp_buf.size()); - continue; - } - // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); @@ -15753,7 +16096,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } } - // FIXME: set rs cache too + // FIXME: set rs cache // set kv cache { const auto & kv_self = ctx->cache.kv; From 0c8b3b20956521acc8f1f297cb58ab3172b3c3e7 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 9 Apr 2024 17:35:22 -0400 Subject: [PATCH 004/117] llama : correctly handle more edge cases for the rs cache --- llama.cpp | 407 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 211 insertions(+), 196 deletions(-) diff --git a/llama.cpp b/llama.cpp index d561f80f62b6d..5433bde86796a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2034,7 +2034,7 @@ struct llama_rs_cache { uint32_t n = 0; // range of states used for the last slot // useful to know the minimum reserved cell count per seq_id - // only counts sequences with n_cells > 0 AND which have a non-shared tail + // only counts sequences which have a non-shared tail uint32_t n_seqs = 0; // cells part of multiple sequences AND which have at least one tail uint32_t n_shared_tail_cells = 0; @@ -2082,21 +2082,37 @@ struct llama_rs_cache { llama_rs_cell & cell = cells[cell_id]; if (cell.seq_nodes.empty()) { if (cell.pos >= 0) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } cell.pos = -1; was_valid = false; } } if (cell.pos < 0) { if (cell.pos != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } cell.pos = -1; was_valid = false; } if (!cell.seq_nodes.empty()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n", + __func__, cell_id, cell.seq_nodes.size()); + } cell.seq_nodes.clear(); was_valid = false; } cell.src = -1; if (cell.prev != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.prev); + } cell.prev = -1; was_valid = false; } @@ -2213,17 +2229,15 @@ struct llama_rs_cache { // n_seqs uint32_t n_seqs_verif = 0; uint32_t n_shared_tail_cells_verif = 0; - for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { - auto & seq = seq_tails[seq_id]; - if (seq.tail >= 0) { - llama_rs_cell & tail_cell = cells[seq.tail]; - // NOTE: could also have checked if n_cells > 0 - if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) { - if (tail_cell.seq_nodes.size() > 1) { - n_shared_tail_cells_verif += 1; - } else { + for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { + llama_rs_cell & rs_cell = cells[cell_id]; + if (!rs_cell.seq_nodes.empty()) { + if (rs_cell.seq_nodes.size() == 1) { + if (rs_cell.tail_rc == 1) { n_seqs_verif += 1; } + } else if (rs_cell.tail_rc > 0) { + n_shared_tail_cells_verif += 1; } } } @@ -2246,72 +2260,15 @@ struct llama_rs_cache { return was_valid; } - // returns whether or not a cell was freed - void clear_cell(llama_rs_cell & rs_cell) { - GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - if (!rs_cell.is_empty()) { - // update sequence tree links - bool first = true; - for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: if all next cells are the same cell, this should still work - cells[node.next_cell].prev = rs_cell.prev; - } - // next_cell of the nodes of the previous cell - if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { - llama_rs_cell & prev_cell = cells[rs_cell.prev]; - auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); - // assuming the previous node is always found - GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); - prev_node->next_cell = node.next_cell; - if (node.is_tail()) { - prev_cell.tail_rc += 1; - } - } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - // update tail - if (node.is_tail()) { - seq.tail = rs_cell.prev; - } - // cell counts - if (first) { - seq.n_cells -= 1; - if (rs_cell.tail_rc > 0 && seq.tail < 0) { - // last tail cell - if (rs_cell.seq_nodes.size() > 1) { - n_shared_tail_cells -= 1; - } else { - n_seqs -= 1; - } - } - first = false; - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); - used -= 1; - } - } - // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); // TODO: assert the iterator points inside the correct vector if (node_iter != rs_cell.seq_nodes.end()) { - if (rs_cell.seq_nodes.size() == 1) { - clear_cell(rs_cell); - return rs_cell.seq_nodes.end(); - } - // else update tree + // update the tree llama_rs_seq_node node = *node_iter; if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: because of this, partially removing seq_ids from cells should only be done from the tail cells[node.next_cell].prev = rs_cell.prev; } if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { @@ -2321,6 +2278,14 @@ struct llama_rs_cache { GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); prev_node->next_cell = node.next_cell; if (node.is_tail()) { + if (prev_cell.seq_nodes.size() > 1) { + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells += 1; + } + if (rs_cell.seq_nodes.size() == 1) { + n_seqs -= 1; + } + } prev_cell.tail_rc += 1; } } @@ -2328,11 +2293,15 @@ struct llama_rs_cache { auto & seq = seq_tails[node.seq_id]; if (node.is_tail()) { seq.tail = rs_cell.prev; - if (seq.tail < 0 && rs_cell.tail_rc == 1) { - // assuming the previous cell of a shared cell is also shared, - // (no need to update the shared tail cells count elsewhere, then) - // this was a shared tail cell, but will no longer be a tail cell - n_shared_tail_cells -= 1; + if (rs_cell.tail_rc == 1) { + if (rs_cell.seq_nodes.size() > 1) { + // assuming the previous cell of a shared cell is also shared, + // this was a shared tail cell, but will no longer be a tail cell + n_shared_tail_cells -= 1; + } else if (seq.tail < 0) { + // no more tail, no more sequence + n_seqs -= 1; + } } GGML_ASSERT(rs_cell.tail_rc > 0); rs_cell.tail_rc -= 1; @@ -2341,21 +2310,30 @@ struct llama_rs_cache { // this seq_id was the first in the list seq.n_cells -= 1; - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = *(std::next(node_iter)); - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - // only the tail ref count from the other seq_ids are left in tail_rc - if (rs_cell.tail_rc > 0) { - // will become a non-shared cell - if (rs_cell.seq_nodes.size() == 2) { - n_seqs += 1; + auto next_node = std::next(node_iter); + if (next_node != rs_cell.seq_nodes.end()) { + // the next node is the new first one, so update its n_cells + if ((uint32_t) next_node->seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node->seq_id]; + next_seq.n_cells += 1; + // only the tail ref count from the other seq_ids are left in tail_rc + if (rs_cell.tail_rc > 0) { + // will become a non-shared cell + if (rs_cell.seq_nodes.size() == 2) { + n_shared_tail_cells -= 1; + n_seqs += 1; + } } + } else { + GGML_ASSERT(false && "invalid seq_id"); } } else { - GGML_ASSERT(false && "invalid seq_id"); + // this was the last seq_id of the cell + used -= 1; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + // the other fields *should* have already been updated elsewhere } } } else { @@ -2366,6 +2344,13 @@ struct llama_rs_cache { return node_iter; } + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + node_iter = remove_seq_node_from_cell(rs_cell, node_iter); + } + } + // returns whether or not the seq_id was removed bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < size) { @@ -2404,47 +2389,63 @@ struct llama_rs_cache { prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; rs_cell.prev = prev; + if (seq.tail == prev) { + // What to do when the tail moves... + // from unique to shared (n_seqs--) + // if the new cell has one seq_id or has no tails (n_shared_tail_cells++) + // if the new cell has one seq_id and a tail (n_seqs-- (yes, another time)) + // from unique to unique (seq.n_cells++) + // from empty to unique (seq.n_cells++, n_seqs++) + // from empty to shared + // if the new cell only has one seq_id or has no tail (n_shared_tail_cells++) + // if the new cell only has one seq_id and has one tail (n_seqs--) + // from shared to shared + // if the last cell has no tails (n_shared_tail_cells--) + // if the new cell has no tails or has one seq_id (n_shared_tail_cells++) + // if the new cell only has one seq_id and has one tail (n_seqs--) + // from shared to unique (seq.n_cells++) + // if this seq_id was not the first of the last cell (n_seqs++) + // if the last cell has no tails (n_shared_tail_cells--) + if (prev_cell.seq_nodes.size() > 1) { + // from shared + if (rs_cell.is_empty()) { + // to unique + if (prev_cell.seq_nodes[0].seq_id != id) { + n_seqs += 1; + } + } + // the previous cell is no longer a shared tail + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells -= 1; + } + } else if (!rs_cell.is_empty()) { + // from unique to shared + n_seqs -= 1; + } + } } if (rs_cell.is_empty()) { - // either the sequence didn't own any cells or had a shared tail cell - if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) { - n_seqs += 1; - } + // to unique seq.n_cells += 1; - // set pos if still unset - if (rs_cell.pos < 0) { + if (seq.tail < 0) { + // from empty to unique + n_seqs += 1; + // pos was not yet set rs_cell.pos = 0; rs_cell.src = -1; } used += 1; - } else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) { - // don't count shared-cell tails - // FIXME: make this saner - n_seqs -= 1; - n_shared_tail_cells += 1; - } else if (rs_cell.tail_rc == 0) { - // shared cell without a tail gets a tail; - // FIXME: don't prune, in case this is used in llama_cache_seq_cp - GGML_ASSERT(false); // make sure we don't get here by accident - // prune the other sequences out of this cell - // NOTE: have to inline the removal because the state tree is partially invalid - bool first = true; - for (auto & node : rs_cell.seq_nodes) { - GGML_ASSERT(node.seq_id != id); - GGML_ASSERT(node.next_cell >= 0); - // easy removal, none of the nodes are tails - llama_rs_cell & next_cell = cells[node.next_cell]; - next_cell.prev = rs_cell.prev; - if (first) { - auto & first_seq = seq_tails[node.seq_id]; - first_seq.n_cells -= 1; - first = false; + } else { + // to shared + if (rs_cell.seq_nodes.size() == 1) { + // a lone tail becomes a shared cell + if (rs_cell.tail_rc > 0) { + n_seqs -= 1; } + n_shared_tail_cells += 1; + } else if (rs_cell.tail_rc == 0) { + n_shared_tail_cells += 1; } - rs_cell.seq_nodes.clear(); - } else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) { - // this is correct as long as this isn't called when trying to find a slot - // TODO: find a way to assert this } // the target cell was not already a tail of this seq_id rs_cell.insert_node(id); // next_cell == -1 by default @@ -2977,6 +2978,7 @@ static bool llama_kv_cache_find_slot( llama_rs_cell & candidate = cache.rs.cells[cell_id]; if (candidate.is_empty()) { break; } if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + // the candidate is the old tail if (candidate.seq_nodes.size() > 1) { // prune out the other seq_ids, because they diverge // TODO(maybe): hande this in insert_seq_tail_to_cell_id @@ -3198,40 +3200,42 @@ static llama_pos llama_cache_seq_rm( llama_pos new_p0 = 0; llama_pos new_p1 = std::numeric_limits::max(); - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // partial seq_id removal has to happen from the tail + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + // copy before the cell is potentially changed + int32_t prev_id = rs_cell.prev; + if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) { + // non-tail removal for shared cells can only be done when clearing a cell + // (i.e. when the next cell's link to the previous cell can be safely changed) + p1 = rs_cell.pos + 1; + } + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // if the node isn't found, the sequence tree is malformed + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + // get the smallest removed cell id + if (new_head > (uint32_t) cell_id) { new_head = cell_id; } + } else { + // one more than the biggest non-removed cell of this sequence + if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; } - if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) { if (rs_cell.pos < p0) { - // move forward the new p0 further - if (rs_cell.pos >= new_p0) { - new_p0 = rs_cell.pos + 1; - } - } else if (rs_cell.pos >= p1) { - // move back the new p1 further - if (rs_cell.pos < new_p1) { - new_p1 = rs_cell.pos; - } - if (rs_cell.pos >= n_past) { - n_past = rs_cell.pos + 1; - } - } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) - if (seq_id < 0) { - cache.rs.clear_cell(rs_cell); - } else { // (rs_cell.has_seq_id(seq_id)) - cache.rs.remove_seq_node_from_cell(rs_cell, seq_node); - } - if (rs_cell.is_empty() && new_head == cache.rs.size) { - new_head = i; - } + // new_p0 should be right after the max pos in the states before p0 + if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; } + } else { // (rs_cell.pos >= p1) + // new_p1 should be the min pos in the states after p1 + if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; } } } + cell_id = prev_id; } p0 = new_p0; p1 = new_p1; - // correctly set n_past when there's nothing after p1 - if (n_past < p0) { n_past = p0; } // If we freed up a slot, set head to it so searching can start there. if (new_head != cache.rs.size && new_head < cache.rs.head) { @@ -3259,10 +3263,8 @@ static llama_pos llama_cache_seq_rm( kv_cell.pos = -1; if (new_head == cache.kv.size) { new_head = i; } } - } else { - if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; - } + } else if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; } } } @@ -3292,42 +3294,37 @@ static llama_pos llama_cache_seq_cp( llama_pos n_past = 0; if (cache.rs.size > 0) { - // have to start from beginning for recurrent models + // have to start from the beginning for recurrent models p0 = 0; if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { - auto seq_src = cache.rs.seq_tails[seq_id_src]; - int32_t src_tail = seq_src.tail; - // find the last tail of src in the pos range - while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) { - llama_rs_cell & tail_cell = cache.rs.cells[src_tail]; - if (tail_cell.pos < p1) { - break; - } - src_tail = tail_cell.prev; - } - - uint32_t new_head = cache.rs.size; - + int32_t src_head = -1; + int32_t head_pos = p1; + int32_t src_next = -1; + // find the start of the sequence for (uint32_t i = 0; i < cache.rs.size; ++i) { llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { - if (i == (uint32_t) src_tail) { - // need to be inserted in order, but there's only one - cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst); - } else { - // keep only the tail cell of the source - // assuming a copy means no rollback will be attempted afterwards - cache.rs.remove_seq_from_cell_id(i, seq_id_src); - if (new_head == cache.rs.size) { - new_head = i; - } + if (!rs_cell.is_empty() && rs_cell.prev < 0) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + if (seq_node != rs_cell.seq_nodes.end()) { + src_head = i; + head_pos = rs_cell.pos; + src_next = seq_node->next_cell; + break; } } } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.rs.size && new_head < cache.rs.head) { - cache.rs.head = new_head; + while (src_head >= 0 && head_pos < p1) { + cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst); + src_head = src_next; + if (head_pos >= n_past) { n_past = head_pos + 1; } + if (src_next >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[src_next]; + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + head_pos = rs_cell.pos; + // it should always be found if the seq tree is valid + GGML_ASSERT(seq_node != rs_cell.seq_nodes.end()); + src_next = seq_node->next_cell; + } } } p1 = n_past; @@ -3338,9 +3335,7 @@ static llama_pos llama_cache_seq_cp( llama_kv_cell & kv_cell = cache.kv.cells[i]; if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { kv_cell.seq_id.insert(seq_id_dst); - if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; - } + if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; } } } } @@ -3352,18 +3347,19 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (!rs_cell.seq_nodes.empty()) { - for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { - if (node_iter->seq_id != seq_id) { - node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); - } else { - node_iter = std::next(node_iter); - } - } - if (new_head == cache.rs.size && rs_cell.is_empty()) { - new_head = i; + // partial seq_id removal has to happen from the tail(s) + for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { + if (i == (uint32_t) seq_id) { continue; } + llama_rs_seq_meta & seq = cache.rs.seq_tails[i]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i); + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + cell_id = rs_cell.prev; + if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) { + new_head = cell_id; } } } @@ -3414,6 +3410,7 @@ static llama_pos llama_cache_seq_add( auto & seq = cache.rs.seq_tails[seq_id]; // follow the sequence from its tail int32_t cell_id = seq.tail; + uint32_t new_head = cache.rs.size; while (cell_id >= 0) { GGML_ASSERT((uint32_t) cell_id < cache.rs.size); llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; @@ -3423,13 +3420,19 @@ static llama_pos llama_cache_seq_add( if (rs_cell.pos < 0) { // NOTE: this affects the other sequences which share the cell cache.rs.clear_cell(rs_cell); - // TODO: update cache.rs.head + if (new_head > (uint32_t) cell_id) { + new_head = cell_id; + } } } if (n_past <= rs_cell.pos) { n_past = rs_cell.pos + 1; } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.rs.head = new_head != cache.rs.size ? new_head : 0; } if (cache.kv.size > 0) { @@ -3474,8 +3477,8 @@ static llama_pos llama_cache_seq_div( llama_pos p0, llama_pos p1, int d) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -11275,6 +11278,10 @@ static int llama_decode_internal( } } n_outputs_prev += lctx.n_outputs; + +#ifndef NDEBUG + GGML_ASSERT(lctx.cache.rs.rebuild(true)); +#endif } // wait for the computation to finish (automatically done when obtaining the model output) @@ -16332,11 +16339,19 @@ void llama_batch_free(struct llama_batch batch) { int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + return ret; } From a09db95eabb5f75a5534f804882cf82e1bb5cadd Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 29 Apr 2024 10:24:45 -0400 Subject: [PATCH 005/117] llama : rename many llama_kv_cache_* functions --- llama.cpp | 97 +++++++++++++++++++++++++++++++++++++++---------------- llama.h | 72 ++++++++++++++++++++++++++++++++++------- 2 files changed, 131 insertions(+), 38 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9d887c6dbfe29..f972c3472a278 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2032,7 +2032,6 @@ struct llama_rs_seq_meta { // ring-buffered tree of cached recurrent state data struct llama_rs_cache { - bool do_copy = false; uint32_t head = 0; // first state used for the last slot uint32_t size = 0; @@ -2769,7 +2768,7 @@ struct llama_context { }; // -// kv cache helpers +// kv and rs cache helpers // static bool llama_cache_init( @@ -2898,7 +2897,7 @@ static bool llama_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( +static bool llama_cache_find_slot( struct llama_cache & cache, const struct llama_batch & batch) { const uint32_t kv_size = cache.kv.size; @@ -3181,7 +3180,6 @@ static void llama_cache_clear(struct llama_cache & cache) { rs_cell.tail_rc = 0; rs_cell.seq_nodes.clear(); } - cache.rs.do_copy = false; cache.rs.head = 0; cache.rs.used = 0; cache.rs.n_seqs = 0; @@ -3412,8 +3410,8 @@ static llama_pos llama_cache_seq_add( llama_pos p1, llama_pos delta) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -3535,7 +3533,7 @@ static llama_pos llama_cache_seq_div( } static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { - llama_pos result = 0; + llama_pos result = -1; if (cache.rs.size > 0) { int32_t cell_id = cache.rs.seq_tails[seq_id].tail; @@ -11174,7 +11172,7 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - if (!llama_kv_cache_find_slot(lctx.cache, u_batch)) { + if (!llama_cache_find_slot(lctx.cache, u_batch)) { return 1; } @@ -15790,6 +15788,10 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k } } +bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug) { + return ctx->cache.rs.rebuild(debug); +} + int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; @@ -15804,55 +15806,96 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { return ctx->cache.kv.used; } -void llama_kv_cache_clear(struct llama_context * ctx) { +int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { + return ctx->cache.rs.used; +} + +void llama_cache_clear(struct llama_context * ctx) { llama_cache_clear(ctx->cache); } +// deprecated +void llama_kv_cache_clear(struct llama_context * ctx) { + llama_cache_clear(ctx); +} + +llama_pos llama_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); +} + +// deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return false; } - llama_pos n_past = llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + llama_pos n_past = llama_cache_seq_rm(ctx, seq_id, p0, p1); return n_past >= p0; } -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + +llama_pos llama_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { uint32_t n_seq_max = llama_n_seq_max(ctx); if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { - return; + return 0; } if (seq_id_src == seq_id_dst) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id_dst) + 1; } - llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); + return llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +// deprecated +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } llama_cache_seq_keep(ctx->cache, seq_id); } -void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +// deprecated +void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + llama_cache_seq_keep(ctx, seq_id); +} + +llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } if (delta == 0) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; } - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); + return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } -void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +// deprecated +void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_cache_seq_add(ctx, seq_id, p0, p1, delta); +} + +llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } if (d == 1) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; } - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); + return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } -llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } +// deprecated +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + llama_cache_seq_div(ctx, seq_id, p0, p1, d); +} + +llama_pos llama_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } return llama_cache_seq_pos_max(ctx->cache, seq_id); } +// deprecated +llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + llama_pos max_pos = llama_cache_seq_pos_max(ctx, seq_id); + return max_pos < 0 ? 0 : max_pos; +} + void llama_kv_cache_defrag(struct llama_context * ctx) { llama_kv_cache_defrag(ctx->cache.kv); } @@ -16597,7 +16640,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, batch.n_seq_id[i] = 1; batch.seq_id[i][0] = dest_seq_id; } - if (!llama_kv_cache_find_slot(cache, batch)) { + if (!llama_cache_find_slot(cache, batch)) { llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; diff --git a/llama.h b/llama.h index b770a275ff02f..c211ca592a5df 100644 --- a/llama.h +++ b/llama.h @@ -515,6 +515,12 @@ extern "C" { // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + // Rebuild and check the validity of the recurrent state cache's tree of sequences. + // (slow, use only for debugging purposes) + // Returns whether or not the rs cache was valid. + // The errors are always corrected, but only logged when debug is true. + LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug); + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -522,36 +528,60 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache - LLAMA_API void llama_kv_cache_clear( + // Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them) + LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); + + // Clear the KV and recurrent state caches + LLAMA_API void llama_cache_clear( struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_cache_clear( + struct llama_context * ctx), + "use llama_cache_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_cache_seq_rm( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_cp( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep( + LLAMA_API void llama_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -559,12 +589,20 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_add( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_add( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta), + "use llama_cache_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -572,17 +610,29 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_div( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d), + "use llama_cache_seq_div instead"); - // Returns the largest position present in the KV cache for the specified sequence - LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + // Returns the largest position present in the KV and/or RS cache for the specified sequence + LLAMA_API llama_pos llama_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: From b6fafd174721c930e89b27df7de6ee776ace9ade Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 29 Apr 2024 12:59:43 -0400 Subject: [PATCH 006/117] llama : remove useless return value for some llama_cache_* functions --- llama.cpp | 47 ++++++++++++----------------------------------- llama.h | 14 +++++++------- 2 files changed, 19 insertions(+), 42 deletions(-) diff --git a/llama.cpp b/llama.cpp index 92bff6b907b8f..15f7ca43a6dc8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2887,7 +2887,6 @@ static bool llama_cache_init( bool offload) { const struct llama_hparams & hparams = model.hparams; - // TODO: per layer n_embd_* const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -3010,6 +3009,8 @@ static bool llama_cache_find_slot( const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; + // FIXME: on failure, leave all caches in a consistent state. + if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. @@ -3509,7 +3510,7 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id } } -static llama_pos llama_cache_seq_add( +static void llama_cache_seq_add( struct llama_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -3519,8 +3520,6 @@ static llama_pos llama_cache_seq_add( if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be shifted auto & seq = cache.rs.seq_tails[seq_id]; @@ -3541,9 +3540,6 @@ static llama_pos llama_cache_seq_add( } } } - if (n_past <= rs_cell.pos) { - n_past = rs_cell.pos + 1; - } } // If we freed up a slot, set head to it so searching can start there. @@ -3573,9 +3569,6 @@ static llama_pos llama_cache_seq_add( } } } - if (n_past <= kv_cell.pos) { - n_past = kv_cell.pos + 1; - } } } @@ -3583,11 +3576,9 @@ static llama_pos llama_cache_seq_add( // Otherwise we just start the next search from the beginning. cache.kv.head = new_head != cache.kv.size ? new_head : 0; } - - return n_past; } -static llama_pos llama_cache_seq_div( +static void llama_cache_seq_div( struct llama_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -3596,8 +3587,6 @@ static llama_pos llama_cache_seq_div( if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be changed auto & seq = cache.rs.seq_tails[seq_id]; @@ -3609,9 +3598,6 @@ static llama_pos llama_cache_seq_div( rs_cell.pos /= d; } cell_id = rs_cell.prev; - if (n_past <= rs_cell.pos) { - n_past = rs_cell.pos + 1; - } } } @@ -3628,14 +3614,9 @@ static llama_pos llama_cache_seq_div( kv_cell.delta += kv_cell.pos - p_old; } } - if (n_past <= kv_cell.pos) { - n_past = kv_cell.pos + 1; - } } } } - - return n_past; } static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { @@ -16935,13 +16916,11 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { llama_cache_seq_keep(ctx, seq_id); } -llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - if (delta == 0) { - return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; - } +void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (delta == 0) { return; } - return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); + llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } // deprecated @@ -16949,13 +16928,11 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla llama_cache_seq_add(ctx, seq_id, p0, p1, delta); } -llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - if (d == 1) { - return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; - } +void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (d == 1) { return; } - return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); + llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } // deprecated diff --git a/llama.h b/llama.h index fa6d0b58625be..bf0f4a9e140d6 100644 --- a/llama.h +++ b/llama.h @@ -562,7 +562,8 @@ extern "C" { // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past + // Returns n_past (one more than the largest remaining pos in the seq_id) + // which is only meaningful to handle for partial removals. LLAMA_API llama_pos llama_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, @@ -579,7 +580,8 @@ extern "C" { // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past + // Returns n_past (one more than the largest remaining pos in the destination seq_id) + // which is only meaningful to handle when partially copying. LLAMA_API llama_pos llama_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, @@ -609,8 +611,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past - LLAMA_API llama_pos llama_cache_seq_add( + LLAMA_API void llama_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -630,8 +631,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past - LLAMA_API llama_pos llama_cache_seq_div( + LLAMA_API void llama_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -652,7 +652,7 @@ extern "C" { LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells"); + "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: From 7e13f19fb527b62ca87930841608b7369d86173a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 24 May 2024 16:19:25 -0400 Subject: [PATCH 007/117] llama : rethink recurrent state cell counts * llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot --- llama.cpp | 586 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 307 insertions(+), 279 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3501163ba2542..969249126c186 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1753,6 +1753,9 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + // TODO: find a more compact way to add more per-layer hyper-parameters + std::vector n_head_kv_vec; + float f_norm_eps; float f_norm_rms_eps; @@ -1793,6 +1796,8 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_head_kv_vec != other.n_head_kv_vec) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1812,29 +1817,46 @@ struct llama_hparams { return false; } - uint32_t n_gqa() const { + uint32_t n_head_kv_l(uint32_t layer) const { + if (layer < n_head_kv_vec.size()) { + int32_t n_hkv_l = n_head_kv_vec[layer]; + // TODO: what should happen when it's negative? + GGML_ASSERT(n_hkv_l >= 0); + return n_hkv_l; + } + return n_head_kv; + } + + uint32_t n_gqa(uint32_t layer = 0) const { + uint32_t n_head_kv = n_head_kv_l(layer); if (n_head_kv == 0) { return 0; } return n_head/n_head_kv; } - uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t layer = 0) const { // dimension of key embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_k * n_head_kv; } - uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t layer = 0) const { // dimension of value embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_v * n_head_kv; } - uint32_t n_embd_r() const { // dimension of the rolling state embeddings + uint32_t n_embd_r(uint32_t layer) const { // dimension of the rolling state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s(uint32_t layer) const { // dimension of the recurrent state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -2078,10 +2100,12 @@ struct llama_rs_cache { // computed when finding a slot uint32_t n = 0; // range of states used for the last slot - // useful to know the minimum reserved cell count per seq_id - // only counts sequences which have a non-shared tail + // only counts cells which are tails of all of their sequences. + // useful to know the minimum reserved cell count per seq_id. uint32_t n_seqs = 0; - // cells part of multiple sequences AND which have at least one tail + // cells part of multiple sequences, + // but which are only the tail of some of them. + // useful to dismiss sequences used as a shared prompt uint32_t n_shared_tail_cells = 0; // with state models, a cell can hold the state for more than one past token @@ -2279,10 +2303,8 @@ struct llama_rs_cache { for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { llama_rs_cell & rs_cell = cells[cell_id]; if (!rs_cell.seq_nodes.empty()) { - if (rs_cell.seq_nodes.size() == 1) { - if (rs_cell.tail_rc == 1) { - n_seqs_verif += 1; - } + if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + n_seqs_verif += 1; } else if (rs_cell.tail_rc > 0) { n_shared_tail_cells_verif += 1; } @@ -2308,9 +2330,11 @@ struct llama_rs_cache { } // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + // Why an iterator? Because it allows using std::vector::erase. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - // TODO: assert the iterator points inside the correct vector + // The iterator needs to point inside the correct vector + GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); if (node_iter != rs_cell.seq_nodes.end()) { // update the tree llama_rs_seq_node node = *node_iter; @@ -2325,12 +2349,20 @@ struct llama_rs_cache { GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); prev_node->next_cell = node.next_cell; if (node.is_tail()) { + // move the tail back to the previous cell if (prev_cell.seq_nodes.size() > 1) { - if (prev_cell.tail_rc == 0) { - n_shared_tail_cells += 1; - } - if (rs_cell.seq_nodes.size() == 1) { - n_seqs -= 1; + if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells += 1; + } + + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf with a single tail + if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + n_seqs -= 1; + } } } prev_cell.tail_rc += 1; @@ -2341,17 +2373,22 @@ struct llama_rs_cache { if (node.is_tail()) { seq.tail = rs_cell.prev; if (rs_cell.tail_rc == 1) { - if (rs_cell.seq_nodes.size() > 1) { - // assuming the previous cell of a shared cell is also shared, - // this was a shared tail cell, but will no longer be a tail cell - n_shared_tail_cells -= 1; - } else if (seq.tail < 0) { + if (seq.tail < 0) { // no more tail, no more sequence - n_seqs -= 1; + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; + } else { + n_seqs -= 1; + } } } GGML_ASSERT(rs_cell.tail_rc > 0); rs_cell.tail_rc -= 1; + } else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) { + // will fully become a tail cell + if (rs_cell.tail_rc > 0) { + n_seqs += 1; + } } if (node_iter == rs_cell.seq_nodes.begin()) { // this seq_id was the first in the list @@ -2363,14 +2400,6 @@ struct llama_rs_cache { if ((uint32_t) next_node->seq_id < seq_tails.size()) { auto & next_seq = seq_tails[next_node->seq_id]; next_seq.n_cells += 1; - // only the tail ref count from the other seq_ids are left in tail_rc - if (rs_cell.tail_rc > 0) { - // will become a non-shared cell - if (rs_cell.seq_nodes.size() == 2) { - n_shared_tail_cells -= 1; - n_seqs += 1; - } - } } else { GGML_ASSERT(false && "invalid seq_id"); } @@ -2433,43 +2462,41 @@ struct llama_rs_cache { rs_cell.pos = prev_cell.pos + 1; rs_cell.src = prev_cell.src; } - prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; rs_cell.prev = prev; if (seq.tail == prev) { // What to do when the tail moves... - // from unique to shared (n_seqs--) - // if the new cell has one seq_id or has no tails (n_shared_tail_cells++) - // if the new cell has one seq_id and a tail (n_seqs-- (yes, another time)) - // from unique to unique (seq.n_cells++) - // from empty to unique (seq.n_cells++, n_seqs++) - // from empty to shared - // if the new cell only has one seq_id or has no tail (n_shared_tail_cells++) - // if the new cell only has one seq_id and has one tail (n_seqs--) - // from shared to shared - // if the last cell has no tails (n_shared_tail_cells--) - // if the new cell has no tails or has one seq_id (n_shared_tail_cells++) - // if the new cell only has one seq_id and has one tail (n_seqs--) - // from shared to unique (seq.n_cells++) - // if this seq_id was not the first of the last cell (n_seqs++) - // if the last cell has no tails (n_shared_tail_cells--) - if (prev_cell.seq_nodes.size() > 1) { - // from shared - if (rs_cell.is_empty()) { - // to unique - if (prev_cell.seq_nodes[0].seq_id != id) { - n_seqs += 1; - } + // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _) + // O -> oO (n_seqs--, n_shared_tail_cells++) + // O -> O (seq.n_cells++) + // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2) + // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+)) + // _ -> oO (n_shared_tail_cells++) + // _ -> O (seq.n_cells++, n_seqs++) + // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--) + // Oo -> OO+ (n_shared_tail_cell--) + // OOo -> O (seq.n_cells++, n_seqs++) + if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) { + // from fully tail + if (prev_cell.tail_rc > 1) { + // the previous tail becomes shared with a non-tail + n_shared_tail_cells += 1; + } + if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) { + // the new tail cell was previously a fully non-tail cell + n_shared_tail_cells += 1; + n_seqs -= 1; } - // the previous cell is no longer a shared tail - if (prev_cell.tail_rc == 0) { + } else if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + if (prev_cell.tail_rc == 1) { + // it was the last tail of the previous cell n_shared_tail_cells -= 1; } - } else if (!rs_cell.is_empty()) { - // from unique to shared - n_seqs -= 1; } } + prev_cell.tail_rc -= 1; } if (rs_cell.is_empty()) { // to unique @@ -2482,15 +2509,10 @@ struct llama_rs_cache { rs_cell.src = -1; } used += 1; - } else { + } else if (rs_cell.tail_rc == 0) { // to shared - if (rs_cell.seq_nodes.size() == 1) { - // a lone tail becomes a shared cell - if (rs_cell.tail_rc > 0) { - n_seqs -= 1; - } - n_shared_tail_cells += 1; - } else if (rs_cell.tail_rc == 0) { + if (seq.tail < 0) { + // from empty to shared n_shared_tail_cells += 1; } } @@ -2910,26 +2932,18 @@ static bool llama_cache_init( const llama_context * ctx, ggml_type type_k, ggml_type type_v, - uint32_t n_ctx, - uint32_t n_seq_max, bool offload) { const llama_model & model = ctx->model; const llama_cparams & cparams = ctx->cparams; const struct llama_hparams & hparams = model.hparams; - // TODO: per layer n_embd_* - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const uint32_t n_embd_r = hparams.n_embd_r(); - const uint32_t n_embd_s = hparams.n_embd_s(); - const bool has_kv = hparams.n_head != 0 && hparams.causal_attn; - const bool has_r = n_embd_r != 0; - const bool has_s = n_embd_s != 0; + const bool has_kv = hparams.n_head_kv != 0 && hparams.causal_attn; + const bool has_r = hparams.ssm_d_conv != 0 && hparams.ssm_d_inner != 0; + const bool has_s = hparams.ssm_d_state != 0 && hparams.ssm_d_inner != 0; const bool has_rs = has_r || has_s; - const uint32_t kv_size = has_kv ? n_ctx : 0; - const uint32_t rs_size = has_rs ? n_seq_max : 0; - // TODO: per cache type layer count + const uint32_t kv_size = has_kv ? cparams.n_ctx : 0; + const uint32_t rs_size = has_rs ? cparams.n_seq_max : 0; const int64_t n_layer = hparams.n_layer; cache.kv.size = kv_size; @@ -2967,6 +2981,7 @@ static bool llama_cache_init( std::map ctx_map; for (auto & it : buft_layer_count) { int n_layers = it.second; + // TODO: for mixed architectures, avoid allocating empty recurrent state or kv cache tensors struct ggml_init_params params = { /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, @@ -2995,20 +3010,20 @@ static bool llama_cache_init( for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); if (has_kv) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.kv.k_l.push_back(k); cache.kv.v_l.push_back(v); } if (has_r) { - ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size); + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size); ggml_format_name(r, "cache_r_l%d", i); cache.rs.r_l.push_back(r); } if (has_s) { - ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size); + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size); ggml_format_name(s, "cache_s_l%d", i); cache.rs.s_l.push_back(s); } @@ -3024,7 +3039,7 @@ static bool llama_cache_init( return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -3042,177 +3057,21 @@ static bool llama_cache_find_slot( const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; - // FIXME: on failure, leave all caches in a consistent state. - + // only check first, to allow failing gracefully if (rs_size > 0) { - // For recurrent state architectures (like Mamba), - // each cache cell can store the state for a whole sequence. - // TODO: find a way to always make the rs slot contiguous - - llama_seq_id min_seq = cache.rs.size - 1; - llama_seq_id max_seq = 0; - uint32_t min_cell = cache.rs.size - 1; - uint32_t max_cell = 0; - + // everything should fit if all seq_ids are smaller than the max for (uint32_t i = 0; i < n_tokens; ++i) { - int32_t target_cell = -1; // ensure all the sequences of a token get the same cell - int32_t n_seq_ids = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_ids; ++j) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; - bool need_new_cell = false; - // Everything should fit assuming the biggest seq_id < rs_size - if ((uint32_t) seq_id < rs_size) { - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - if (seq_id > max_seq) { max_seq = seq_id; } - if (seq_id < min_seq) { min_seq = seq_id; } - - if (!seq.in_ubatch && target_cell >= 0) { - // never saw this seq_id before, - // but there's already a cell reserved for this token, use it - cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); - } else if (seq.tail < 0) { - need_new_cell = true; - } else { - llama_rs_cell & tail = cache.rs.cells[seq.tail]; - if (seq.in_ubatch) { - // this seq_id was already seen before in the batch - // assuming the tail cell already "has" this seq_id - tail.pos += 1; - target_cell = seq.tail; - } else { - // first time this sequence is seen, - // there's no reserved cell yet; - // if it's not the first sequence of the token, how could it even get here? - GGML_ASSERT(j == 0); - - bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; - if (has_same_seqs) { - // the tail cell of a seq_id is assumed to already be part of the seq_id, - // hence the skip of the first seq_id - for (int32_t k = 1; k < n_seq_ids; ++k) { - if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { - has_same_seqs = false; - } - } - } - - // TODO: make the checkpoint interval configurable - if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { - // a checkpoint should be saved - need_new_cell = true; - } else { - // re-use last tail - tail.pos += 1; - target_cell = seq.tail; - } - } - } - if (need_new_cell && target_cell < 0) { - const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); - - uint32_t cell_id = cache.rs.size; - bool looped_once = false; - - while (true) { - if (cache.rs.head >= cache.rs.size) { - cache.rs.head = 0; - if (looped_once) { - // avoid infinite loop - // NOTE: this should not happen, but gracefully fail anyway - LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__); - return false; - } - looped_once = true; - } - cell_id = cache.rs.head; - llama_rs_cell & candidate = cache.rs.cells[cell_id]; - if (candidate.is_empty()) { break; } - if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { - // the candidate is the old tail - if (candidate.seq_nodes.size() > 1) { - // prune out the other seq_ids, because they diverge - // TODO(maybe): hande this in insert_seq_tail_to_cell_id - // (hopefully doesn't happen too often) - for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { - if (node_iter->seq_id == seq_id) { - node_iter = std::next(node_iter); - } else { - node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); - } - } - } - // re-use the tail cell to avoid not finding anything - candidate.pos += 1; - break; - } - if (candidate.tail_rc > 0) { - // skip tails of other sequences - cache.rs.head += 1; - continue; - } - if (candidate.seq_nodes.size() > 1) { - // shared prompts are not usually backtracked, so they can be pruned - cache.rs.clear_cell(candidate); - break; - } - - // prune too-long sequences - llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; - if (seq_id_to_prune == seq_id) { - // TODO: selectively skip some cells to keep older states - cache.rs.clear_cell(candidate); - break; - } - GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); - auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; - if (seq_to_prune.n_cells > min_cells_per_seq) { - cache.rs.clear_cell(candidate); - break; - } - cache.rs.head += 1; - } - if (cell_id < cache.rs.size) { - cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); - target_cell = cell_id; - } - } - - if (seq.tail >= 0) { - if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } - if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } - seq.in_ubatch = true; - } - - // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); - } - } else { + if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { // too big seq_id // TODO: would it be possible to resize the rs cache size instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } - cache.rs.head = target_cell + 1; - } - - for (llama_seq_id i = min_seq; i <= max_seq; ++i) { - // make sure it's cleared for next time - cache.rs.seq_tails[i].in_ubatch = false; - } - - // allow getting the range of used cells, from head to head + n - cache.rs.head = min_cell; - cache.rs.n = max_cell - min_cell + 1; - - // sanity check - if (max_seq < min_seq || max_cell < min_cell) { - return false; } } @@ -3257,7 +3116,174 @@ static bool llama_cache_find_slot( return false; } } + } + + // now modification can be done, and should NOT fail + + if (rs_size > 0) { + // For recurrent state architectures (like Mamba), + // each cache cell can store the state for a whole sequence. + // TODO: find a way to always make the rs slot contiguous + + llama_seq_id min_seq = cache.rs.size - 1; + llama_seq_id max_seq = 0; + uint32_t min_cell = cache.rs.size - 1; + uint32_t max_cell = 0; + + for (uint32_t i = 0; i < n_tokens; ++i) { + int32_t target_cell = -1; // ensure all the sequences of a token get the same cell + int32_t n_seq_ids = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_ids; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + bool need_new_cell = false; + // Everything should fit assuming the biggest seq_id < rs_size + GGML_ASSERT((uint32_t) seq_id < rs_size); + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + if (seq_id > max_seq) { max_seq = seq_id; } + if (seq_id < min_seq) { min_seq = seq_id; } + + if (!seq.in_ubatch && target_cell >= 0) { + // never saw this seq_id before, + // but there's already a cell reserved for this token, use it + cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); + } else if (seq.tail < 0) { + // this seq_id has no tail (and is empty) + need_new_cell = true; + } else { + llama_rs_cell & tail = cache.rs.cells[seq.tail]; + if (seq.in_ubatch) { + // this seq_id was already seen before in the batch + // assuming the tail cell already "has" this seq_id + tail.pos += 1; + target_cell = seq.tail; + } else { + // first time this sequence is seen, + // there's no reserved cell yet; + // if it's not the first sequence of the token, how could it even get here? + GGML_ASSERT(j == 0); + + bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; + if (has_same_seqs) { + // the tail cell of a seq_id is assumed to already be part of the seq_id, + // hence the skip of the first seq_id + for (int32_t k = 1; k < n_seq_ids; ++k) { + if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { + has_same_seqs = false; + } + } + } + + // TODO: make the checkpoint interval configurable + if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { + // a checkpoint should be saved + need_new_cell = true; + } else { + // re-use last tail + tail.pos += 1; + target_cell = seq.tail; + } + } + } + + // reserve a cell for this seq_id + if (need_new_cell && target_cell < 0) { + const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + uint32_t cell_id = cache.rs.size; + bool looped_once = false; + + while (true) { + if (cache.rs.head >= cache.rs.size) { + cache.rs.head = 0; + // avoid infinite loop + // NOTE: this should not fail; if it does, it's a bug. + GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not."); + looped_once = true; + } + cell_id = cache.rs.head; + llama_rs_cell & candidate = cache.rs.cells[cell_id]; + if (candidate.is_empty()) { break; } + if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + // the candidate is the old tail + if (candidate.seq_nodes.size() > 1) { + // prune out the other seq_ids, because they diverge + // TODO(maybe): hande this in insert_seq_tail_to_cell_id + // (hopefully doesn't happen too often) + for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { + if (node_iter->seq_id == seq_id) { + node_iter = std::next(node_iter); + } else { + node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); + } + } + } + // re-use the tail cell to avoid not finding anything + candidate.pos += 1; + break; + } + if (candidate.tail_rc > 0) { + // skip tails of other sequences + cache.rs.head += 1; + continue; + } + if (candidate.seq_nodes.size() > 1) { + // shared prompts are not usually backtracked, so they can be pruned + cache.rs.clear_cell(candidate); + break; + } + + // prune too-long sequences + llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; + if (seq_id_to_prune == seq_id) { + // TODO: selectively skip some cells to keep older states + cache.rs.clear_cell(candidate); + break; + } + GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); + auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; + if (seq_to_prune.n_cells > min_cells_per_seq) { + cache.rs.clear_cell(candidate); + break; + } + cache.rs.head += 1; + } + if (cell_id < cache.rs.size) { + cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); + target_cell = cell_id; + } + } + + if (seq.tail >= 0) { + if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } + if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } + seq.in_ubatch = true; + } + + // Assuming the tokens are in-order + if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); + } + } + cache.rs.head = target_cell + 1; + } + + for (llama_seq_id i = min_seq; i <= max_seq; ++i) { + // make sure it's cleared for next time + cache.rs.seq_tails[i].in_ubatch = false; + } + + // allow getting the range of used cells, from head to head + n + cache.rs.head = min_cell; + cache.rs.n = max_cell - min_cell + 1; + + // sanity check + GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell); + } + + if (kv_size > 0) { for (uint32_t i = 0; i < n_tokens; i++) { cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; @@ -4194,9 +4220,9 @@ struct llama_model_loader { bool get_arr(const std::string & key, std::vector & result, const bool required = true) { const int kid = gguf_find_key(meta, key.c_str()); - if (kid < 0) { + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { if (required) { - throw std::runtime_error(format("key not found in model: %s", key.c_str())); + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } return false; } @@ -4204,16 +4230,17 @@ struct llama_model_loader { struct GGUFMeta::ArrayInfo arr_info = GGUFMeta::GKV::get_kv(meta, kid); - if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { - throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); - } - + // TODO: allow ANY lossless cast // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } - result.resize(arr_info.length); - result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + result.reserve(arr_info.length); + result.assign((const T *)arr_info.data, (const T *)arr_info.data + arr_info.length); return true; } @@ -4750,7 +4777,12 @@ static void llm_load_hparams( // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + + // per-layer n_head_kv + if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_vec, false)) { + // global/fallback n_head_kv + ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + } bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -6704,10 +6736,7 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - const int64_t n_ff = hparams.n_ff; const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7198,8 +7227,8 @@ static void llm_build_kv_store( int64_t il) { const int64_t n_ctx = cparams.n_ctx; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(kv.size == n_ctx); @@ -7465,9 +7494,9 @@ static struct ggml_tensor * llm_build_kqv( int il) { const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_head_kv = hparams.n_head_kv_l(il); const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -7619,9 +7648,7 @@ struct llm_build_context { const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; - const int64_t n_embd_k_gqa; const int64_t n_embd_head_v; - const int64_t n_embd_v_gqa; const int64_t n_expert; const int64_t n_expert_used; @@ -7673,9 +7700,7 @@ struct llm_build_context { n_head (hparams.n_head), n_head_kv (hparams.n_head_kv), n_embd_head_k (hparams.n_embd_head_k), - n_embd_k_gqa (hparams.n_embd_k_gqa()), n_embd_head_v (hparams.n_embd_head_v), - n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (hparams.n_expert_used), freq_base (cparams.rope_freq_base), @@ -7746,9 +7771,9 @@ struct llm_build_context { // we rotate only the first n_rot dimensions ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_head_kv, n_ctx, + n_embd_head_k, hparams.n_head_kv_l(il), n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)), 0), lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -7777,6 +7802,9 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, nm, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), @@ -11014,8 +11042,8 @@ struct llm_build_context { struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size); + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size); // copy states { @@ -16452,7 +16480,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) { + if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -17282,7 +17310,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks @@ -17434,7 +17462,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); size_t kv_buf_size; @@ -17627,7 +17655,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < kv_self.size; ++i) { @@ -17713,7 +17741,7 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // Write the layer count @@ -17843,7 +17871,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // Sanity check model compatibility const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); if (n_layer != n_layer_ref) { LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); From cbc743e6006349dde61fe214d56c2d6efa34828d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 24 May 2024 19:27:27 -0400 Subject: [PATCH 008/117] llama : support Jamba --- convert-hf-to-gguf.py | 103 ++++++- gguf-py/gguf/constants.py | 36 +++ gguf-py/gguf/gguf_writer.py | 7 +- gguf-py/gguf/tensor_mapping.py | 52 +++- llama.cpp | 521 ++++++++++++++++++++++++++------- 5 files changed, 601 insertions(+), 118 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index daad1c4fc7255..83d9b0638f856 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2300,7 +2300,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_state_size(d_state) @@ -2346,6 +2346,107 @@ def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i ) +@Model.register("JambaForCausalLM") +class JambaModel(Model): + model_arch = gguf.MODEL_ARCH.JAMBA + + def get_vocab_base_pre(self, tokenizer) -> str: + del tokenizer # unused + + return "gpt-2" + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) + d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 + d_inner = self.hparams["mamba_expand"] * d_model + d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6 + n_kv_head = self.hparams["num_key_value_heads"] + attn_offset = self.hparams["attn_layer_offset"] + attn_period = self.hparams["attn_layer_period"] + n_kv_vec = [0 for _ in range(attn_offset)] + [ + n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) + ] + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(n_kv_vec) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_file_type(self.ftype) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # process the experts separately + if ".feed_forward.experts." in name: + n_experts = self.hparams["num_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + + # merge the experts into a single 3d tensor + for wid in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + # using the same merged name as qwen2moe + merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + yield new_name, data_torch + return + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield new_name, data_torch + + # same as Mamba + def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + del n_dims # unused + + return bid is not None and new_name in ( + self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ + gguf.MODEL_TENSOR.SSM_CONV1D, + gguf.MODEL_TENSOR.SSM_X, + gguf.MODEL_TENSOR.SSM_DT, + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ] + ) + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 42df2e4d00604..3668778be0af1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum): GEMMA = auto() STARCODER2 = auto() MAMBA = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -180,7 +181,10 @@ class MODEL_TENSOR(IntEnum): SSM_CONV1D = auto() SSM_X = auto() SSM_DT = auto() + SSM_DT_NORM = auto() SSM_A = auto() + SSM_B_NORM = auto() + SSM_C_NORM = auto() SSM_D = auto() SSM_OUT = auto() @@ -214,6 +218,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -259,7 +264,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm", + MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", } @@ -678,6 +686,34 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8b41b54eaa5a6..272ef4a8071cd 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -385,8 +385,11 @@ def add_parallel_residual(self, use: bool) -> None: def add_head_count(self, count: int) -> None: self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) - def add_head_count_kv(self, count: int) -> None: - self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + def add_head_count_kv(self, count: int | Sequence[int]) -> None: + if isinstance(count, int): + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + else: + self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) def add_key_length(self, length: int) -> None: self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8e1cac9152f55..eb60bb8ac01d4 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -206,6 +206,7 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "model.layers.{bid}.pre_ff_layernorm", # jamba ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -214,6 +215,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate", # qwen2moe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.feed_forward.router", # jamba ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -244,6 +246,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc11", # nomic-bert "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 + "model.layers.{bid}.feed_forward.up_proj", # jamba ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -272,6 +275,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc12", # nomic-bert "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "transformer.h.{bid}.mlp.linear_1", # refact + "model.layers.{bid}.feed_forward.gate_proj", # jamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -306,6 +310,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc2", # nomic-bert "model.layers.{bid}.mlp.c_proj", # starcoder2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 + "model.layers.{bid}.feed_forward.down_proj", # jamba ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -347,38 +352,57 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", - "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba ), MODEL_TENSOR.SSM_CONV1D: ( - "model.layers.{bid}.conv1d", - "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.conv1d", # mamba-hf + "backbone.layers.{bid}.mixer.conv1d", # mamba + "model.layers.{bid}.mamba.conv1d", # jamba ), MODEL_TENSOR.SSM_X: ( - "model.layers.{bid}.x_proj", - "backbone.layers.{bid}.mixer.x_proj", + "model.layers.{bid}.x_proj", # mamba-hf + "backbone.layers.{bid}.mixer.x_proj", # mamba + "model.layers.{bid}.mamba.x_proj", # jamba ), MODEL_TENSOR.SSM_DT: ( - "model.layers.{bid}.dt_proj", - "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.dt_proj", # mamba-hf + "backbone.layers.{bid}.mixer.dt_proj", # mamba + "model.layers.{bid}.mamba.dt_proj", # jamba + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.{bid}.mamba.dt_layernorm", # jamba ), MODEL_TENSOR.SSM_A: ( - "model.layers.{bid}.A_log", - "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.A_log", # mamba-hf + "backbone.layers.{bid}.mixer.A_log", # mamba + "model.layers.{bid}.mamba.A_log", # jamba + ), + + MODEL_TENSOR.SSM_B_NORM: ( + "model.layers.{bid}.mamba.b_layernorm", # jamba + ), + + MODEL_TENSOR.SSM_C_NORM: ( + "model.layers.{bid}.mamba.c_layernorm", # jamba ), MODEL_TENSOR.SSM_D: ( - "model.layers.{bid}.D", - "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.D", # mamba-hf + "backbone.layers.{bid}.mixer.D", # mamba + "model.layers.{bid}.mamba.D", # jamba ), MODEL_TENSOR.SSM_OUT: ( - "model.layers.{bid}.out_proj", - "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.out_proj", # mamba-hf + "backbone.layers.{bid}.mixer.out_proj", # mamba + "model.layers.{bid}.mamba.out_proj", # jamba ), } diff --git a/llama.cpp b/llama.cpp index 969249126c186..3176c8d0d5d64 100644 --- a/llama.cpp +++ b/llama.cpp @@ -221,6 +221,7 @@ enum llm_arch { LLM_ARCH_GEMMA, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_JAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -257,6 +258,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_JAMBA, "jamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -472,7 +474,10 @@ enum llm_tensor { LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, }; @@ -970,6 +975,37 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_JAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1905,6 +1941,9 @@ struct llama_layer { struct ggml_tensor * attn_k_norm_b; struct ggml_tensor * attn_out_norm; struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * ssm_dt_norm; + struct ggml_tensor * ssm_b_norm; + struct ggml_tensor * ssm_c_norm; // attention struct ggml_tensor * wq; @@ -5150,6 +5189,22 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_JAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6854,6 +6909,118 @@ static bool llm_load_tensors( layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } } break; + case LLM_ARCH_JAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + GGML_ASSERT((int64_t) hparams.n_head_kv_vec.size() == n_layer); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv_vec[i]; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + + layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + + layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + + layer.ssm_dt_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}); + + layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + + layer.ssm_b_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}); + layer.ssm_c_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + + layer.wq = nullptr; + layer.wk = nullptr; + layer.wv = nullptr; + layer.wo = nullptr; + + } else { + // Attention layers + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ssm_in = nullptr; + layer.ssm_conv1d = nullptr; + layer.ssm_conv1d_b = nullptr; + layer.ssm_x = nullptr; + layer.ssm_dt_norm = nullptr; + layer.ssm_dt = nullptr; + layer.ssm_dt_b = nullptr; + layer.ssm_b_norm = nullptr; + layer.ssm_c_norm = nullptr; + layer.ssm_a = nullptr; + layer.ssm_d = nullptr; + layer.ssm_out = nullptr; + } + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + + layer.ffn_gate = nullptr; + layer.ffn_down = nullptr; + layer.ffn_up = nullptr; + } else { + // FFN (no MoE) + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + layer.ffn_gate_exps = nullptr; + layer.ffn_down_exps = nullptr; + layer.ffn_up_exps = nullptr; + } + } + } break; case LLM_ARCH_XVERSE: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -7632,6 +7799,132 @@ static struct ggml_tensor * llm_build_kv( return cur; } +// TODO: split +static struct ggml_tensor * llm_build_mamba( + struct ggml_context * ctx, + const llama_model & model, + const llama_hparams & hparams, + const llama_rs_cache & rs, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + struct ggml_tensor * state_seq, + struct ggml_tensor * w_dt_norm, + struct ggml_tensor * w_b_norm, + struct ggml_tensor * w_c_norm, + int32_t n_tokens, + int32_t rs_head, + int32_t n_rs, + const llm_build_cb & cb, + int il) { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, rs.r_l[il], hparams.n_embd_r(il), rs.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, rs.s_l[il], hparams.n_embd_s(il), rs.size); + + // copy states + { + // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows + // NOTE: assuming the copy destinations are ALL contained in the current batch + // this shrinks the tensors's ne[1] to n_rs + conv_states = ggml_get_rows(ctx, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx, ssm_states, state_copy); + } + + // clear states of sequences which are starting at the beginning of this batch + { + conv_states = ggml_mul(ctx, conv_states, state_mask); + ssm_states = ggml_mul(ctx, ssm_states, state_mask); + } + + conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); + ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); + + // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} + struct ggml_tensor * xz = ggml_mul_mat(ctx, model.layers[il].ssm_in, cur); + // split the above in two + // => {d_inner, n_tokens} + struct ggml_tensor * x = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], 0); + struct ggml_tensor * z = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + + // conv + { + // Custom operator which is needed only to ease simultaneous sequence processing. + // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weigth, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // The new conv_states is the last (d_conv - 1) columns + // of the last 3rd dimensional "layer" of the self-overlapping view. + // For simultaneous sequences, it's more complicated. + struct ggml_tensor * x_conv = ggml_ssm_conv(ctx, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + + // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + + // extract x from x_conv + x = ggml_view_2d(ctx, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + + // bias + x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); + + x = ggml_silu(ctx, x); + } + + // ssm + { + // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} + struct ggml_tensor * x_db = ggml_mul_mat(ctx, model.layers[il].ssm_x, x); + // split + struct ggml_tensor * dt = ggml_view_2d(ctx, x_db, dt_rank, n_tokens, x_db->nb[1], 0); + struct ggml_tensor * B = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + + if (w_dt_norm) { dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); } + if (w_b_norm) { B = llm_build_norm(ctx, B, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } + if (w_c_norm) { C = llm_build_norm(ctx, C, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } + + // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} + dt = ggml_mul_mat(ctx, model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, + // because only a single tensor can be returned. + struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); + + // store last states (the second part of y_ssm_states) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), + ggml_view_1d(ctx, rs.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); + + struct ggml_tensor * y = ggml_view_2d(ctx, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + + // TODO: skip computing output for unused tokens + + // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx, y, ggml_silu(ctx, z)); + + // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} + cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); + } + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -11024,13 +11317,6 @@ struct llm_build_context { struct ggml_cgraph * build_mamba() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - const int64_t d_model = n_embd; - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - GGML_ASSERT(2 * d_model == d_inner); - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -11042,116 +11328,144 @@ struct llm_build_context { struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size); + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); - // copy states - { - // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows - // NOTE: assuming the copy destinations are ALL contained in the current batch - // this shrinks the tensors's ne[1] to n_rs - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - } + cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, + state_copy, state_mask, state_seq, NULL, NULL, NULL, + n_tokens, rs_head, n_rs, cb, il); - // clear states of sequences which are starting at the beginning of this batch - { - conv_states = ggml_mul(ctx0, conv_states, state_mask); - ssm_states = ggml_mul(ctx0, ssm_states, state_mask); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_rs); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_rs); + // residual + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_jamba() { + + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_mask = build_inp_s_mask(); + struct ggml_tensor * state_seq = build_inp_s_seq(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv_l(il); - // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} - struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); - // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + if (n_head_kv == 0) { + // Mamba + cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, + state_copy, state_mask, state_seq, + model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, + n_tokens, rs_head, n_rs, cb, il); + } else { + // Attention - // conv - { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq); - - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, rs_self.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); - - // extract x from x_conv - x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); - - // bias - x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); - - x = ggml_silu(ctx0, x); - } - - // ssm - { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} - struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); - // split - struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); - - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} - dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); - - // store last states (the second part of y_ssm_states) - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, rs_self.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); - - struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - x = ggml_get_rows(ctx0, x, inp_out_ids); - y = ggml_get_rows(ctx0, y, inp_out_ids); - z = ggml_get_rows(ctx0, z, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} - cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); + // No RoPE :) + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } // residual - cur = ggml_add(ctx0, cur, inpL); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur); + cb(cur, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // FFN + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + cb, il); + cb(cur, "ffn_moe_out", il); + } + + // residual + cur = ggml_add(ctx0, ffn_inp, cur); cb(cur, "l_out", il); // input for next layer @@ -11630,6 +11944,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_mamba(); } break; + case LLM_ARCH_JAMBA: + { + result = llm.build_jamba(); + } break; case LLM_ARCH_XVERSE: { result = llm.build_xverse(); @@ -16644,6 +16962,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_JAMBA: case LLM_ARCH_JINA_BERT_V2: return LLAMA_ROPE_TYPE_NONE; From 61a88a1da399be2207c8aa0a8a280dffc3f64887 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 24 May 2024 22:41:38 -0400 Subject: [PATCH 009/117] llama : fix BERT inference without KV cache --- llama.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llama.cpp b/llama.cpp index 6bc5167be6f60..678c49094b22e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3105,6 +3105,10 @@ static bool llama_cache_init( ggml_context * ctx = it.second; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { + if (!has_kv && !has_rs) { + // no buffer was needed, so this is fine + return true; + } LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); return false; } From ea2e63e9d2b4d9e60587083b9fc824d9ca342af1 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 25 May 2024 12:54:30 -0400 Subject: [PATCH 010/117] convert-hf : check for unprocessed Jamba experts --- convert-hf-to-gguf.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 971875069dcc3..28a43c54f70f7 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2470,6 +2470,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield new_name, data_torch + def write_tensors(self): + super().write_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + # same as Mamba def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: del n_dims # unused From fc59407efea1d49a3d8338fd20fa38afbe06fdb5 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 25 May 2024 13:55:11 -0400 Subject: [PATCH 011/117] convert-hf : support Mini-Jamba conversion --- convert-hf-to-gguf.py | 21 ++++++++++++++++++++- gguf-py/gguf/tensor_mapping.py | 3 +++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 28a43c54f70f7..a42458e63d23f 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2393,6 +2393,16 @@ def get_vocab_base_pre(self, tokenizer) -> str: return "gpt-2" + def set_vocab(self): + if (self.dir_model / "tokenizer.model").is_file(): + # Using Jamba's tokenizer.json causes errors on model load + # (something about "byte not found in vocab"), + # but there's a working tokenizer.model + self._set_vocab_sentencepiece() + else: + # Some Jamba models only have a tokenizer.json, which works. + self._set_vocab_gpt2() + def set_gguf_parameters(self): d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 @@ -2412,7 +2422,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_block_count(self.block_count) - self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) @@ -2430,6 +2440,15 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Mini-Jamba + name = name.replace(".moe.", ".feed_forward.") + if bid is not None: + moe_offset = self.hparams["expert_layer_offset"] + moe_period = self.hparams["expert_layer_period"] + + if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0): + name = name.replace(".experts.0.", ".") + # process the experts separately if ".feed_forward.experts." in name: n_experts = self.hparams["num_experts"] diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b71bf1ecdd4d4..c81600151b142 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -207,6 +207,7 @@ class TensorNameMap: "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "model.layers.{bid}.pre_ff_layernorm", # jamba + "model.layers.{bid}.pre_moe_layernorm", # mini-jamba ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -390,10 +391,12 @@ class TensorNameMap: MODEL_TENSOR.SSM_B_NORM: ( "model.layers.{bid}.mamba.b_layernorm", # jamba + "model.layers.{bid}.mamba.B_layernorm", # mini-jamba ), MODEL_TENSOR.SSM_C_NORM: ( "model.layers.{bid}.mamba.c_layernorm", # jamba + "model.layers.{bid}.mamba.C_layernorm", # mini-jamba ), MODEL_TENSOR.SSM_D: ( From 181dadf294d9495b54a86a23299fc15b282dac1d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 28 May 2024 12:23:05 -0400 Subject: [PATCH 012/117] llama : fix Jamba quantization sanity checks --- llama.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 678c49094b22e..4c9ecf018e67f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16290,11 +16290,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks - // - // - qs.n_attention_wv == 0 for Mamba models - // - qs.n_attention_wv == model.hparams.n_layer for Transformer models - // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); + { + const auto & n_head_kv_vec = model.hparams.n_head_kv_vec; + int n_attn_layer; + if (model.hparams.n_head_kv == 0) { + // Mamba models don't have attention layers + n_attn_layer = 0; + } else { + // Transformers and hybrid models (like Jamba) have attention layers + n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_vec.begin(), n_head_kv_vec.end(), 0); + } + GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + } size_t total_size_org = 0; size_t total_size_new = 0; From 3a414b0be242be52f8c186acb368510975eb0d15 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 28 May 2024 12:21:52 -0400 Subject: [PATCH 013/117] llama : sequence-length-aware batch splitting --- llama.cpp | 443 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 355 insertions(+), 88 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4c9ecf018e67f..209d3063cb5ec 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2807,6 +2807,321 @@ struct llama_model { } }; +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + + int32_t n_tokens; + int32_t n_seqs; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * output; +}; + +struct llama_sbatch_seq { + int32_t n_seq_id; + llama_seq_id * seq_id; + size_t offset; + size_t length; + + // helper for smoother batch API transition -- can be deprecated in the future + llama_seq_id all_seq_id; // used if seq_id == NULL +}; + +// sequence-length-aware batch splitting +struct llama_sbatch { + // tokens left in this batch + size_t n_tokens; + + size_t n_embd; + + bool logits_all; // TODO: remove once lctx.logits_all is removed too + + // sorted indices into the batch + std::vector ids; + // batch indices of the output + std::vector out_ids; + std::vector seq; + const llama_batch * batch = nullptr; + + // buffers for the ubatch + std::vector ubatch_token; + std::vector ubatch_embd; + std::vector ubatch_pos; + std::vector ubatch_n_seq_id; + std::vector ubatch_seq_id; + std::vector ubatch_output; + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) { + // clear empty sequences + // the previous ubatch is assumed to be gone, + // so nothing should refer to values in these sequences anymore. + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; + } + } + ubatch_token.resize(!has_embd ? n_ubatch : 0); + ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); + ubatch_pos.resize(n_ubatch); + ubatch_n_seq_id.resize(n_ubatch); + ubatch_seq_id.resize(n_ubatch); + ubatch_output.resize(n_ubatch); + llama_ubatch ubatch = { + true, + 0, + 0, + !has_embd ? ubatch_token.data() : nullptr, + has_embd ? ubatch_embd.data() : nullptr, + ubatch_pos.data(), + ubatch_n_seq_id.data(), + ubatch_seq_id.data(), + ubatch_output.data(), + }; + return ubatch; + } + + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { + GGML_ASSERT(batch != nullptr); + GGML_ASSERT(length <= seq.length); + if (ubatch.equal_seqs) { + // is the new sequence of a different size than expected? + if (ubatch.n_seqs > 0 && length != (size_t) ubatch.n_tokens / ubatch.n_seqs) { + ubatch.equal_seqs = false; + } + } + // NOTE: loops are separated for cache-friendliness + if (batch->token) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + ubatch.token = nullptr; + } + if (batch->embd) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float) + ); + } + } else { + ubatch.embd = nullptr; + } + // from here on, the else branches are deprecated; + // they are helpers for smoother batch API transition + if (batch->pos) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + llama_pos bi = ids[seq.offset + i]; + ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); + } + } + if (batch->n_seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_tokens + i] = batch->n_seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_tokens + i] = 1; + } + } + if (batch->seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_tokens + i] = batch->seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_tokens + i] = &seq.all_seq_id; + } + } + if (batch->logits) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else { + // only get last output + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { out_ids.push_back(id); } + } + } + ubatch.n_tokens += length; + ubatch.n_seqs += seq.n_seq_id != 0; // don't count seq_ids for legacy splits + seq.offset += length; + seq.length -= length; + n_tokens -= length; + } + + // legacy split, unknown number of sequences of unequal lengths + llama_ubatch split_slice(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + ubatch.equal_seqs = false; + if (!seq.empty()) { + llama_sbatch_seq & s = seq[0]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits + // TODO: reduce copies + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + // make batches of equal-length sequences + llama_ubatch split_equal(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + size_t length = 0; + size_t n_tokens_in_ubatch = 0; + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits + // smallest first, because it's easier to split this way; + // starting from the end to pop in constant time. + for (size_t i = seq.size(); i-- > 0;) { + llama_sbatch_seq & s = seq[i]; + GGML_ASSERT(s.length > 0); + if (length == 0) { + length = s.length < n_ubatch ? s.length : n_ubatch; + } + add_seq_to_ubatch(ubatch, s, length); + n_tokens_in_ubatch += length; + // shared prompts can't be mixed with any of their sequences, + // so it's safer to compute them in their own ubatch + if (s.n_seq_id > 1) { break; } + // stop when there isn't enough space for another sequence + if (length + n_tokens_in_ubatch > n_ubatch) { break; } + } + } + return ubatch; + } + + // sequence-wise split + llama_ubatch split_seq(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + llama_sbatch_seq & s = seq[seq.size() - 1]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; + } + + void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) { + GGML_ASSERT(batch.n_tokens >= 0); + this->batch = &batch; + this->n_embd = n_embd; + this->logits_all = logits_all; + + n_tokens = batch.n_tokens; + ids.resize(n_tokens); + out_ids.clear(); + // TODO: reserve out_ids and seq + + for (size_t i = 0; i < n_tokens; ++i) { + ids[i] = i; + } + if (legacy_split) { + seq.resize(1); + llama_sbatch_seq & s = seq[0]; + s.n_seq_id = 0; + s.seq_id = nullptr; + s.offset = 0; + s.length = n_tokens; + s.all_seq_id = batch.all_seq_id; + return; + } + std::sort(ids.begin(), ids.end(), + [batch](size_t a, size_t b) { + int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; + int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; + // sort by seq_id, then by pos + if (n_seq_a == n_seq_b) { + if (batch.seq_id) { + for (int32_t i = 0; i < n_seq_a; ++i) { + llama_seq_id seq_id_a = batch.seq_id[a][i]; + llama_seq_id seq_id_b = batch.seq_id[b][i]; + // smaller seq_ids go first + if (seq_id_a != seq_id_b) { + return seq_id_a < seq_id_b; + } + } + } + // when all else is equal, sort by pos + if (batch.pos) { + return batch.pos[a] < batch.pos[b]; + } + // no pos, sort by id (assuming batch.all_pos_1 is positive) + return a < b; + } + // shared prompts go first + return n_seq_a > n_seq_b; + } + ); + // init seq + llama_sbatch_seq * last_seq = nullptr; + + if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const size_t s_len = seq.size(); + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; + } + } + if (same) { + last_seq->length += 1; + continue; + } + } + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; + seq.push_back(new_seq); + last_seq = &seq[s_len]; + } + } else { + llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + seq.push_back(new_seq); + } + // keep shared prompts first at the end, then sort by length descending. + std::sort(seq.begin(), seq.end(), + [](llama_sbatch_seq & a, llama_sbatch_seq & b) { + if (a.n_seq_id == b.n_seq_id) { + return a.length > b.length; + } + return a.n_seq_id < b.n_seq_id; + } + ); + } +}; + struct llama_context { llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { @@ -2832,6 +3147,9 @@ struct llama_context { // key + value cache for self-attention, and/or recurrent state cache struct llama_cache cache; + // sequence-length-aware batch splitting + llama_sbatch sbatch; + std::mt19937 rng; bool has_evaluated_once = false; @@ -3126,7 +3444,7 @@ static bool llama_cache_init( // to the first cell of the slot. static bool llama_cache_find_slot( struct llama_cache & cache, - const struct llama_batch & batch) { + const struct llama_ubatch & batch) { const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; @@ -7533,7 +7851,7 @@ static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, struct llama_context & lctx, const llama_hparams & hparams, - const llama_batch & batch, + const llama_ubatch & batch, struct ggml_tensor * tok_embd, const llm_build_cb & cb) { const int64_t n_embd = hparams.n_embd; @@ -8107,7 +8425,7 @@ struct llm_build_context { llama_context & lctx; const llama_hparams & hparams; const llama_cparams & cparams; - const llama_batch & batch; + const llama_ubatch & batch; const llama_kv_cache & kv_self; const llama_rs_cache & rs_self; @@ -8153,7 +8471,7 @@ struct llm_build_context { // TODO: consider making the entire interface noexcept llm_build_context( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, const llm_build_cb & cb, bool worst_case) : model (lctx.model), @@ -12215,8 +12533,8 @@ struct llm_build_context { }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { - llama_batch dummy; - dummy.n_tokens = 0; + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -12232,8 +12550,8 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const } static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; + llama_ubatch dummy = {}; + dummy.equal_seqs = true; llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; @@ -12250,7 +12568,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_batch & batch, + const llama_ubatch & batch, bool worst_case) { const auto & model = lctx.model; @@ -12438,7 +12756,7 @@ static void llama_set_k_shift(llama_context & lctx) { } } -static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { +static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // // set input data // @@ -12478,10 +12796,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_tokens; ++i) { data[i] = i; } - } else if (batch.logits) { + } else if (batch.output) { int32_t n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - if (batch.logits[i]) { + if (batch.output[i]) { data[n_outputs++] = i; } } @@ -12835,11 +13153,6 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; - // count outputs if (batch_all.logits) { for (uint32_t i = 0; i < n_tokens_all; ++i) { @@ -12852,55 +13165,29 @@ static int llama_decode_internal( n_outputs = 1; } + lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all); + // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); return -2; }; - // set output mappings - if (batch_all.logits) { - int32_t i_logits = 0; - for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch_all.logits[i]) { - lctx.output_ids[i] = i_logits++; - } - } - } else { - for (uint32_t i = 0; i < n_outputs; ++i) { - lctx.output_ids[i] = i; - } - } - - for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - llama_batch u_batch = { - /* .n_tokens = */ (int32_t) n_tokens, - /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, - /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, - /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, - /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, - /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, - /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, - /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, - /* .all_pos_1 = */ batch_all.all_pos_1, - /* .all_seq_id = */ batch_all.all_seq_id, - }; + while (lctx.sbatch.n_tokens > 0) { + // TODO: deprecate slice splits in favor of equal splits + llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); + const uint32_t n_tokens = u_batch.n_tokens; // count the outputs in this u_batch { int32_t n_outputs_new = 0; - if (u_batch.logits) { - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.logits[i] != 0; - } - } else if (n_outputs == n_tokens_all) { + if (n_outputs == n_tokens_all) { n_outputs_new = n_tokens; } else { - // keep last output only - if (cur_token + n_tokens >= n_tokens_all) { - n_outputs_new = 1; + GGML_ASSERT(u_batch.output); + for (uint32_t i = 0; i < n_tokens; i++) { + n_outputs_new += u_batch.output[i] != 0; } } @@ -12911,32 +13198,6 @@ static int llama_decode_internal( int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT(n_threads > 0); - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - if (u_batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1; - } - - u_batch.pos = pos.data(); - } - - if (u_batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = u_batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - - u_batch.n_seq_id = n_seq_id.data(); - u_batch.seq_id = seq_id_arr.data(); - } - // non-causal masks do not use the KV cache if (hparams.causal_attn) { llama_kv_cache_update(&lctx); @@ -12945,6 +13206,7 @@ static int llama_decode_internal( return 1; } + // TODO: move into llama_cache_find_slot if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -13108,6 +13370,12 @@ static int llama_decode_internal( #endif } + // set output mappings + GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs); + for (size_t i = 0; i < n_outputs; ++i) { + lctx.output_ids[lctx.sbatch.out_ids[i]] = i; + } + // set to total number of outputs in the batch, for use in llama_get_logits_ith lctx.n_outputs = n_outputs; @@ -13398,10 +13666,11 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { if (need_reserve) { // TODO: extract to a function // build worst-case graph + int n_seqs = 1; // TODO: worst-case number of sequences int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); - int n_past = lctx.cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(lctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph ggml_backend_sched_reset(lctx.sched); @@ -17345,10 +17614,11 @@ struct llama_context * llama_new_context_with_model( } // build worst-case graph + int n_seqs = 1; // TODO: worst-case number of sequences int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); - int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true); + llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); // initialize scheduler with the worst-case graph if (!ggml_backend_sched_reserve(ctx->sched, gf)) { @@ -18662,8 +18932,9 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // Allocate the new cells for the slot if (cell_count) { - llama_batch batch = llama_batch_init(cell_count, 0, 1); + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; + batch.n_seqs = 1; for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; memcpy(&pos, inp, sizeof(pos)); @@ -18674,7 +18945,6 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, batch.seq_id[i][0] = dest_seq_id; } if (!llama_cache_find_slot(cache, batch)) { - llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; } @@ -18686,9 +18956,6 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - - // Cleanup - llama_batch_free(batch); } const uint32_t kv_size = kv_self.size; From 3587a9498773203f10f66814f67568797f1ce7a0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 11:37:14 -0400 Subject: [PATCH 014/117] llama : use equal-sequence-length sub-batches for recurrent models * ggml : simplify SSM-related operators * llama : make recurrent state slot allocation contiguous * llama : adapt internal uses of batches to llama_ubatch --- ggml.c | 250 ++++++--------- ggml.h | 6 +- llama.cpp | 946 ++++++++++++++++++++++++++++++++++-------------------- 3 files changed, 699 insertions(+), 503 deletions(-) diff --git a/ggml.c b/ggml.c index 58ac9702694c7..7a3a5fa9468ff 100644 --- a/ggml.c +++ b/ggml.c @@ -7103,40 +7103,35 @@ struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, struct ggml_tensor * s, struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq) { + struct ggml_tensor * c) { GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_matrix(x)); + GGML_ASSERT(ggml_is_3d(x)); GGML_ASSERT(ggml_is_matrix(c)); - GGML_ASSERT(ggml_is_vector(sq)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); - const int64_t d_conv = c->ne[0]; - const int64_t d_inner = c->ne[1]; - const int64_t n_tokens = x->ne[1]; - const int64_t n_rs = s->ne[2]; + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = x->ne[1]; // tokens per sequence + const int64_t n_s = s->ne[2]; - GGML_ASSERT( s->ne[0] == d_conv - 1); - GGML_ASSERT( s->ne[1] == d_inner); - GGML_ASSERT( x->ne[0] == d_inner); - GGML_ASSERT(sq->ne[0] == n_tokens); + GGML_ASSERT(s->ne[0] == d_conv - 1); + GGML_ASSERT(s->ne[1] == d_inner); + GGML_ASSERT(x->ne[0] == d_inner); + GGML_ASSERT(x->ne[2] == n_s); bool is_node = false; - if (s->grad || x->grad || c->grad || sq->grad) { + if (s->grad || x->grad || c->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } - // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_rs} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_rs)); + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = s; result->src[1] = x; result->src[2] = c; - result->src[3] = sq; return result; } @@ -7150,40 +7145,43 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq) { + struct ggml_tensor * C) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(sq->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(A)); + GGML_ASSERT(ggml_is_3d(B)); + GGML_ASSERT(ggml_is_3d(s)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(ggml_are_same_shape(B, C)); { - const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_tokens = x->ne[1]; + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_seq_tokens = x->ne[1]; + const int64_t n_seqs = x->ne[2]; + GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_tokens); - GGML_ASSERT(C->ne[0] == d_state); - GGML_ASSERT(C->ne[1] == n_tokens); + GGML_ASSERT(B->ne[1] == n_seq_tokens); + GGML_ASSERT(B->ne[2] == n_seqs); } bool is_node = false; - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { + if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } - // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_rs} - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + // y + struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7193,7 +7191,6 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = sq; return result; } @@ -16249,24 +16246,20 @@ static void ggml_compute_forward_ssm_conv_f32( const struct ggml_tensor * src0 = dst->src[0]; // conv_state const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight - const struct ggml_tensor * src3 = dst->src[3]; // state_seq const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv - const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // n_tokens - const int n_rs = src0->ne[2]; // max number of sequences in the batch + const int nc = src2->ne[0]; // d_conv + const int nr = src0->ne[1]; // d_inner + const int n_t = src1->ne[1]; // tokens per sequence + const int n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT((nr*n_t) + (nc*nr*n_rs) == ggml_nelements(dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); - GGML_ASSERT(src3->nb[0] == sizeof(int32_t)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // for use with the destination state offset between sequences - GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16276,64 +16269,53 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - const int32_t * sq = src3->data; // {n_tokens} + // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)? + // This would avoid having to copy into an intermediate buffer, but the state would be bigger. + float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith; - if (n_rs > 1) { - // multiple sequences means it's hard to know when it's the first time a state is read, - // so copy them all over to the destination, just to be sure. - for (int i3 = 0; i3 < n_rs; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float)); - // can't use memcpy because of d_conv vs d_conv - 1 - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - // copy s0 to last (d_conv - 1) columns of s - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } + for (int i3 = 0; i3 < n_s; ++i3) { + float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + + // copy the state into working memory + // can't use memcpy because (d_conv) != (d_conv - 1) + for (int i1 = 0; i1 < ir; ++i1) { + for (int i0 = 0; i0 < nc - 1; ++i0) { + s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; } } - } - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t sq_i = sq[i2]; - float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq_i*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_rs} - float * s0; // {d_conv - 1, d_inner, n_rs} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - int ne0s0; + for (int i2 = 0; i2 < n_t; ++i2) { + float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} + float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - GGML_ASSERT(0 <= sq_i && sq_i < n_rs); + // shift state left + memmove(s, s + 1, (nc*ir - 1) * sizeof(float)); - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_conv - 1, d_inner, n_rs} - ne0s0 = src0->ne[0]; - } else { - // the source is the last (d_conv - 1) columns of the destination - s0 = s + 1; - ne0s0 = nc; - } + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // insert x on the last column + s[(nc - 1) + i1*nc] = x0[i1]; + } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // shift state left - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[i0 + i1*nc] = s0[i0 + i1*ne0s0]; + // it seems a little faster when this is separate from the state shift + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + sumf += s[i] * c[i]; + } + x[i1] = sumf; } - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; } - // it seems a little faster when this is separate from the state shift + // copy the state out of it for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - float sumf = 0.0f; - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + for (int i0 = 0; i0 < nc - 1; ++i0) { + s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc]; } - x[i1] = sumf; } } } @@ -16368,30 +16350,24 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - const struct ggml_tensor * src6 = dst->src[6]; // sq const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens in the batch - const int64_t n_rs = src0->ne[2]; // max number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); - // required for the dot product between s and C, and when copying the states + // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[2]) - GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16401,55 +16377,33 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - const int32_t * sq = src6->data; // {n_tokens} - - if (n_rs > 1) { - // it's hard to know if the source states have already been copied - // when there are multiple, so copy them already. - for (int i3 = 0; i3 < n_rs; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]); - memcpy(s, s0, nc*ir*sizeof(float)); - } - } - - for (int i2 = 0; i2 < n_t; ++i2) { - int32_t sq_i = sq[i2]; - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_rs} - float * s0; - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens} - - GGML_ASSERT(0 <= sq_i && sq_i < n_rs); - - // avoid needing to copy the state for the first token - if (i2 == 0) { - s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq_i*(src0->nb[2])); // {d_state, d_inner, n_rs} - } else { - // otherwise the source is the same as the destination - s0 = s; - } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; } - y[i1] = sumf; } } } @@ -19614,7 +19568,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_SSM_CONV: + { + const int64_t d_conv = node->src[2]->ne[0]; + const int64_t d_inner = node->src[0]->ne[1]; + cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1); + } break; case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/ggml.h b/ggml.h index 4e6bcb30fd931..bdf05a31139e5 100644 --- a/ggml.h +++ b/ggml.h @@ -1793,8 +1793,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * s, struct ggml_tensor * x, - struct ggml_tensor * c, - struct ggml_tensor * sq); + struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, @@ -1803,8 +1802,7 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C, - struct ggml_tensor * sq); + struct ggml_tensor * C); // partition into non-overlapping windows with padding if needed // example: diff --git a/llama.cpp b/llama.cpp index 27374c18506c9..ca64b7e29df7a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2114,6 +2114,24 @@ struct llama_layer { struct ggml_tensor * rope_short = nullptr; }; +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + // FIXME: make all uses of this use n_seqs + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; @@ -2223,17 +2241,15 @@ struct llama_rs_cell { } }; - struct llama_rs_seq_meta { // cell id of the latest state of this seq_id int32_t tail = -1; // number of cells for which this seq_id is the first // (useful to know if cells in this sequence should be pruned) int32_t n_cells = 0; - // changing the tail cell of a sequence can only be done at batch boundary, - // this guards against changing the cell when it shouldn't be; - // should be cleared when done finding a slot - bool in_ubatch = false; + // the last pos of this sequence if it is in the current ubatch, + // only set and used when finding a slot. + llama_pos ubatch_end_pos = -1; }; // ring-buffered tree of cached recurrent state data @@ -2261,6 +2277,10 @@ struct llama_rs_cache { // find tail cells faster std::vector seq_tails; // map seq_ids to cell ids + // freeable cell ids, computed when finding a slot + // useful to find the smallest range to defrag + std::vector freeable; + // per layer // NOTE: the naming of r and s is arbitrary std::vector r_l; // rolling/shift states @@ -2399,8 +2419,8 @@ struct llama_rs_cache { if (seq_node->next_cell != next) { // TODO: relax the error when multiple cells have the same pos if (debug) { - LLAMA_LOG_ERROR("%s: invalid next cell for cells[%u] (%d instead of %d)\n", - __func__, cell_id, seq_node->next_cell, next); + LLAMA_LOG_ERROR("%s: invalid next cell for seq_id %d in cells[%u] (%d instead of %d)\n", + __func__, seq_id, cell_id, seq_node->next_cell, next); } seq_node->next_cell = next; was_valid = false; @@ -2414,15 +2434,6 @@ struct llama_rs_cache { } seq.n_cells = n_cells; } - // in_batch should only be true when in the process of finding a slot - if (seq.in_ubatch != false) { - if (debug) { - LLAMA_LOG_ERROR("%s: in_ubatch was true while it should have been false for seq_id %d\n", - __func__, seq_id); - } - seq.in_ubatch = false; - was_valid = false; - } } // tail_rc for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { @@ -2475,6 +2486,88 @@ struct llama_rs_cache { return was_valid; } + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + uint32_t min_cells_per_seq(const llama_ubatch & batch) const { + uint32_t seqs = n_seqs; + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_rs_seq_meta & new_seq = seq_tails[seq_id]; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; + } + } + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); + } + + void freeable_for_batch(const llama_ubatch & batch, llama_pos checkpoint_interval) { + GGML_ASSERT(batch.equal_seqs); + int32_t min_cells = min_cells_per_seq(batch); + + // TODO: minimize work required to find freeable cells + // currently, this finds freeable cells by excluding non-freeable cells, + // because some conditions are more easily expressed this way. + + freeable.assign(size, 1); + + for (llama_rs_seq_meta & seq : seq_tails) { + seq.ubatch_end_pos = -1; + } + + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_tails.size()); + llama_rs_seq_meta & seq = seq_tails[seq_id]; + seq.ubatch_end_pos = batch.pos[i * batch.n_seq_tokens + batch.n_seq_tokens - 1]; + } + } + + for (llama_rs_seq_meta & seq : seq_tails) { + if (seq.tail >= 0 && freeable[seq.tail] != 0) { + llama_pos end_pos = seq.ubatch_end_pos; + // When is a tail cell not freeable? + if (end_pos < 0) { + // when any of its tails are not in the batch + freeable[seq.tail] = 0; + } else if (min_cells > 1) { + // TODO: fallback to this less often + llama_rs_cell & tail = cells[seq.tail]; + GGML_ASSERT(tail.pos < end_pos); + if (tail.prev < 0 || tail.pos + checkpoint_interval <= end_pos) { + // make a checkpoint before prompt processing + // TODO: should it always be done after instead? + freeable[seq.tail] = 0; + } else { + llama_rs_cell & prev = cells[tail.prev]; + if (prev.pos + checkpoint_interval <= end_pos) { + // make a checkpoint during text generation + freeable[seq.tail] = 0; + } + } + } + } + } + + for (uint32_t i = 0; i < size; ++i) { + llama_rs_cell & cell = cells[i]; + if (!cell.is_empty() && cell.tail_rc == 0) { + // TODO: reduce indirection here + llama_rs_seq_node & seq_node = cell.seq_nodes[0]; + llama_rs_seq_meta & seq = seq_tails[seq_node.seq_id]; + bool keep_tail = freeable[seq.tail] == 0; + // kept tails use an additional cell, so make them allow freeing a checkpoint + int32_t really_min_cells = keep_tail ? min_cells - 1 : min_cells; + // A checkpoint is kept if there's enough alloted space for this sequence + // or if it's the state right before the tail + if (seq.n_cells <= really_min_cells || (really_min_cells > 1 && seq_node.next_cell == seq.tail)) { + freeable[i] = 0; + } + } + } + } + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. // Why an iterator? Because it allows using std::vector::erase. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { @@ -2496,22 +2589,30 @@ struct llama_rs_cache { prev_node->next_cell = node.next_cell; if (node.is_tail()) { // move the tail back to the previous cell + prev_cell.tail_rc += 1; if (prev_cell.seq_nodes.size() > 1) { if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { - if (prev_cell.tail_rc == 0) { + if (prev_cell.tail_rc == 1) { n_shared_tail_cells += 1; } - // o oo oo - // |/ -> o/ - // | | - // e.g. when removing the leaf with a single tail - if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) { - n_seqs -= 1; + if (rs_cell.tail_rc == 1) { + if (prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf of a split tree + n_seqs -= 1; + } else { + // o + // o -> oo + // | | + // e.g. when merging back with a previous tail + n_shared_tail_cells -= 1; + } } } } - prev_cell.tail_rc += 1; } } if ((uint32_t) node.seq_id < seq_tails.size()) { @@ -2534,6 +2635,7 @@ struct llama_rs_cache { // will fully become a tail cell if (rs_cell.tail_rc > 0) { n_seqs += 1; + n_shared_tail_cells -= 1; } } if (node_iter == rs_cell.seq_nodes.begin()) { @@ -2583,14 +2685,107 @@ struct llama_rs_cache { return false; } - bool insert_seq_tail_to_cell_id(uint32_t i_cell, const llama_seq_id & id) { + bool swap_cells(uint32_t i_src, uint32_t i_dst) { + if (i_src < size && i_dst < size && i_src != i_dst) { + llama_rs_cell & src = cells[i_src]; + llama_rs_cell & dst = cells[i_dst]; + + for (llama_rs_seq_node & seq_node : src.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_dst; + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_dst; + } + } + for (llama_rs_seq_node & seq_node : dst.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_src; + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_src; + } + } + + if (src.prev == dst.prev) { + // avoid swapping them twice + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } else if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } else { + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } + } + if (dst.prev >= 0) { + llama_rs_cell & prev = cells[dst.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } + + std::swap(src.pos, dst.pos); + std::swap(src.src, dst.src); + std::swap(src.prev, dst.prev); + std::swap(src.tail_rc, dst.tail_rc); + std::swap(src.seq_nodes, dst.seq_nodes); + + return true; + } + return false; + } + + bool insert_seq_tail_to_cell_id(uint32_t i_cell, llama_seq_id id, llama_pos end_pos = -1) { if (i_cell < size && (size_t) id < seq_tails.size()) { llama_rs_cell & rs_cell = cells[i_cell]; auto & seq = seq_tails[id]; int32_t prev = rs_cell.prev; + if (end_pos >= 0) { + if (end_pos <= rs_cell.pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, end_pos, rs_cell.pos, id); + } + rs_cell.pos = end_pos; + } else { + // if no pos was specified, then the target cell should already have a valid one. + GGML_ASSERT(!rs_cell.is_empty()); + } if ((uint32_t) seq.tail == i_cell) { // the cell is already the tail of this seq_id - return false; + if (rs_cell.tail_rc != rs_cell.seq_nodes.size()) { + GGML_ASSERT(end_pos >= 0); // make sure this is the first re-added seq_id + // remove non-tail seq_ids (branch off them) + for (size_t i = rs_cell.seq_nodes.size(); i-- > 0;) { + if (!rs_cell.seq_nodes[i].is_tail()) { + remove_seq_node_from_cell(rs_cell, rs_cell.seq_nodes.begin() + i); + } + } + } + return true; } if (rs_cell.is_empty()) { prev = seq.tail; @@ -2603,9 +2798,7 @@ struct llama_rs_cache { auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken - if (rs_cell.pos < 0) { - GGML_ASSERT(rs_cell.is_empty()); - rs_cell.pos = prev_cell.pos + 1; + if (rs_cell.is_empty()) { rs_cell.src = prev_cell.src; } prev_node->next_cell = i_cell; @@ -2650,8 +2843,7 @@ struct llama_rs_cache { if (seq.tail < 0) { // from empty to unique n_seqs += 1; - // pos was not yet set - rs_cell.pos = 0; + // make sure it's cleared rs_cell.src = -1; } used += 1; @@ -2671,16 +2863,6 @@ struct llama_rs_cache { return false; } - // each seq_id should have access to at least this many cells - // (to use when pruning (to avoid over-pruning)) - size_t min_cells_per_seq(const llama_rs_seq_meta & new_seq) const { - uint32_t seqs = n_seqs; - if (new_seq.tail < 0 || new_seq.n_cells == 0) { - seqs += 1; - } - return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); - } - size_t total_size() const { size_t size = 0; for (struct ggml_tensor * r : r_l) { @@ -2883,22 +3065,6 @@ struct llama_model { } }; -// very similar to llama_batch, -// but has more metadata about sequences -struct llama_ubatch { - bool equal_seqs; - - int32_t n_tokens; - int32_t n_seqs; - - llama_token * token; - float * embd; - llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * output; -}; - struct llama_sbatch_seq { int32_t n_seq_id; llama_seq_id * seq_id; @@ -2954,6 +3120,7 @@ struct llama_sbatch { true, 0, 0, + 0, !has_embd ? ubatch_token.data() : nullptr, has_embd ? ubatch_embd.data() : nullptr, ubatch_pos.data(), @@ -2967,16 +3134,14 @@ struct llama_sbatch { void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { GGML_ASSERT(batch != nullptr); GGML_ASSERT(length <= seq.length); - if (ubatch.equal_seqs) { - // is the new sequence of a different size than expected? - if (ubatch.n_seqs > 0 && length != (size_t) ubatch.n_tokens / ubatch.n_seqs) { - ubatch.equal_seqs = false; - } - } + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); // NOTE: loops are separated for cache-friendliness if (batch->token) { for (size_t i = 0; i < length; ++i) { - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; } } else { ubatch.token = nullptr; @@ -3004,22 +3169,32 @@ struct llama_sbatch { ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); } } - if (batch->n_seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_tokens + i] = batch->n_seq_id[ids[seq.offset + i]]; + if (seq.n_seq_id > 0) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } else { + GGML_ASSERT(seq.n_seq_id == 1); + ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { - for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_tokens + i] = 1; - } - } - if (batch->seq_id) { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_tokens + i] = batch->seq_id[ids[seq.offset + i]]; + if (batch->n_seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } } - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_tokens + i] = &seq.all_seq_id; + if (batch->seq_id) { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; + } } } if (batch->logits) { @@ -3043,11 +3218,15 @@ struct llama_sbatch { if (is_last) { out_ids.push_back(id); } } } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1; + } ubatch.n_tokens += length; - ubatch.n_seqs += seq.n_seq_id != 0; // don't count seq_ids for legacy splits + ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits seq.offset += length; seq.length -= length; n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); } // legacy split, unknown number of sequences of unequal lengths @@ -3283,7 +3462,6 @@ struct llama_context { struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [n_rs] struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] - struct ggml_tensor * inp_s_seq; // I32 [n_batch] // control vectors struct llama_control_vector cvec; @@ -3426,6 +3604,7 @@ static bool llama_cache_init( cache.rs.cells.resize(rs_size); cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(rs_size); + cache.rs.freeable.reserve(rs_size); #ifdef GGML_USE_CLBLAST offload = false; @@ -3524,11 +3703,13 @@ static bool llama_cache_find_slot( const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; + const uint32_t n_seqs = batch.n_seqs; + const uint32_t n_seq_tokens = batch.n_seq_tokens; // only check first, to allow failing gracefully if (rs_size > 0) { // everything should fit if all seq_ids are smaller than the max - for (uint32_t i = 0; i < n_tokens; ++i) { + for (uint32_t i = 0; i < n_seqs; ++i) { int32_t n_seq_id = batch.n_seq_id[i]; for (int32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; @@ -3541,6 +3722,23 @@ static bool llama_cache_find_slot( } } } + // TODO: configurable checkpoint interval + cache.rs.freeable_for_batch(batch, 8); + { + uint32_t freeable_rs_cell_count = 0; + for (uint32_t is_freeable : cache.rs.freeable) { + freeable_rs_cell_count += (uint32_t) (is_freeable != 0); + if (freeable_rs_cell_count >= n_seqs) { + // there's enough, no need to count them all + break; + } + } + if (n_seqs > freeable_rs_cell_count) { + // This should not happen + LLAMA_LOG_ERROR("%s: n_seqs=%d > freeable_rs_cell_count=%d\n", __func__, n_seqs, freeable_rs_cell_count); + return false; + } + } } if (kv_size > 0) { @@ -3591,172 +3789,146 @@ static bool llama_cache_find_slot( if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. - // TODO: find a way to always make the rs slot contiguous + // A slot should be always be contiguous. - llama_seq_id min_seq = cache.rs.size - 1; - llama_seq_id max_seq = 0; - uint32_t min_cell = cache.rs.size - 1; - uint32_t max_cell = 0; + uint32_t min_head = 0; + uint32_t min_n = cache.rs.size; + uint32_t min_free = 0; - for (uint32_t i = 0; i < n_tokens; ++i) { - int32_t target_cell = -1; // ensure all the sequences of a token get the same cell - int32_t n_seq_ids = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_ids; ++j) { - llama_seq_id seq_id = batch.seq_id[i][j]; - bool need_new_cell = false; - // Everything should fit assuming the biggest seq_id < rs_size - GGML_ASSERT((uint32_t) seq_id < rs_size); - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - if (seq_id > max_seq) { max_seq = seq_id; } - if (seq_id < min_seq) { min_seq = seq_id; } - - if (!seq.in_ubatch && target_cell >= 0) { - // never saw this seq_id before, - // but there's already a cell reserved for this token, use it - cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); - } else if (seq.tail < 0) { - // this seq_id has no tail (and is empty) - need_new_cell = true; - } else { - llama_rs_cell & tail = cache.rs.cells[seq.tail]; - if (seq.in_ubatch) { - // this seq_id was already seen before in the batch - // assuming the tail cell already "has" this seq_id - tail.pos += 1; - target_cell = seq.tail; - } else { - // first time this sequence is seen, - // there's no reserved cell yet; - // if it's not the first sequence of the token, how could it even get here? - GGML_ASSERT(j == 0); - - bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; - if (has_same_seqs) { - // the tail cell of a seq_id is assumed to already be part of the seq_id, - // hence the skip of the first seq_id - for (int32_t k = 1; k < n_seq_ids; ++k) { - if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { - has_same_seqs = false; - } + // compact the freeable cell list + // e.g. 0,1,0,0,1,1,0,1,0,1 -> 1,4,5,7,9 + // while also finding the smallest cell range for the slot + { + uint32_t next_free = 0; + for (size_t i = 0; i < cache.rs.freeable.size(); ++i) { + if (cache.rs.freeable[i]) { + cache.rs.freeable[next_free] = i; + next_free += 1; + + if (next_free >= n_seqs) { + uint32_t head = cache.rs.freeable[next_free - n_seqs]; + // i is the last seen freeable cell id + uint32_t n = i - head + 1; + // keep the first smallest big enough slot + if (n < min_n) { + min_free = next_free - n_seqs; + min_head = head; + min_n = n; + if (n == n_seqs) { + // it's the smallest it can be + break; } } - - // TODO: make the checkpoint interval configurable - if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { - // a checkpoint should be saved - need_new_cell = true; - } else { - // re-use last tail - tail.pos += 1; - target_cell = seq.tail; - } } } + } + } - // reserve a cell for this seq_id - if (need_new_cell && target_cell < 0) { - const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + // sanity check + GGML_ASSERT(min_head + min_n <= cache.rs.size); - uint32_t cell_id = cache.rs.size; - bool looped_once = false; + // keep only the necessary range + cache.rs.freeable.resize(min_free + n_seqs); + cache.rs.freeable.erase(cache.rs.freeable.begin(), cache.rs.freeable.begin() + min_free); + GGML_ASSERT(cache.rs.freeable.size() == n_seqs); + GGML_ASSERT(min_n >= n_seqs); + cache.rs.freeable.resize(min_n); - while (true) { - if (cache.rs.head >= cache.rs.size) { - cache.rs.head = 0; - // avoid infinite loop - // NOTE: this should not fail; if it does, it's a bug. - GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not."); - looped_once = true; - } - cell_id = cache.rs.head; - llama_rs_cell & candidate = cache.rs.cells[cell_id]; - if (candidate.is_empty()) { break; } - if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { - // the candidate is the old tail - if (candidate.seq_nodes.size() > 1) { - // prune out the other seq_ids, because they diverge - // TODO(maybe): hande this in insert_seq_tail_to_cell_id - // (hopefully doesn't happen too often) - for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { - if (node_iter->seq_id == seq_id) { - node_iter = std::next(node_iter); - } else { - node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); - } + // expand the free list + // e.g. 2,4,5,8 -> 1,0,1,1,0,0,1 + for (uint32_t i = n_seqs; i-- > 0;) { + uint32_t dst = cache.rs.freeable[i] - min_head; + if (dst != i) { + cache.rs.freeable[i] = 0; + } + GGML_ASSERT(dst >= i); + cache.rs.freeable[dst] = 1; + } + + // coalesce the free cells together + // e.g. 1,0,1,1,0,0,1 -> 1,1,1,1,0,0,0 + // or 1,0,1,1,1,1 -> 1,1,1,1,1,0 + { + uint32_t top_free = min_n - 1; + for (uint32_t i = min_n; i-- > 1;) { + uint32_t is_free = cache.rs.freeable[i]; + if (!is_free) { + GGML_ASSERT(top_free > i); + cache.rs.swap_cells(min_head + i, min_head + top_free); + std::swap(cache.rs.freeable[i], cache.rs.freeable[top_free]); + // the previous one has to be free, + // otherwise it would already have been swapped. + top_free -= 1; + } + // stop early if all freeable cells have already been put at the beginning + if (top_free < n_seqs) { break; } + } + } + + // order the re-used cells identically to their batch order + // (and clear the non-reused cells) + { + for (uint32_t i = 0; i < n_seqs; ++i) { + // ignore the already-swapped cells + if (cache.rs.freeable[i]) { + llama_rs_cell & cell = cache.rs.cells[min_head + i]; + if (!cell.is_empty()) { + if (cell.tail_rc == 0) { + cache.rs.clear_cell(cell); + } else { + // TODO: does this always work correctly + // even if there are more than one seq_node in this cell? + + // Which seq_id of the batch is it? + llama_seq_id seq_id = cell.seq_nodes[0].seq_id; + int32_t nth_seq_id = -1; + for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { + if (seq_id == batch.seq_id[s][0]) { + nth_seq_id = s; + break; } } - // re-use the tail cell to avoid not finding anything - candidate.pos += 1; - break; - } - if (candidate.tail_rc > 0) { - // skip tails of other sequences - cache.rs.head += 1; - continue; - } - if (candidate.seq_nodes.size() > 1) { - // shared prompts are not usually backtracked, so they can be pruned - cache.rs.clear_cell(candidate); - break; - } + GGML_ASSERT(nth_seq_id != -1); - // prune too-long sequences - llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; - if (seq_id_to_prune == seq_id) { - // TODO: selectively skip some cells to keep older states - cache.rs.clear_cell(candidate); - break; - } - GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); - auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; - if (seq_to_prune.n_cells > min_cells_per_seq) { - cache.rs.clear_cell(candidate); - break; + cache.rs.swap_cells(min_head + i, min_head + nth_seq_id); + cache.rs.freeable[i] = 0; + std::swap(cache.rs.freeable[i], cache.rs.freeable[nth_seq_id]); + i -= 1; // check this cell again, now that it was swapped } - cache.rs.head += 1; } - if (cell_id < cache.rs.size) { - cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); - target_cell = cell_id; - } - } - - if (seq.tail >= 0) { - if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } - if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } - seq.in_ubatch = true; - } - - // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); } } - cache.rs.head = target_cell + 1; } - for (llama_seq_id i = min_seq; i <= max_seq; ++i) { - // make sure it's cleared for next time - cache.rs.seq_tails[i].in_ubatch = false; + // reserve + { + for (uint32_t i = 0; i < n_seqs; ++i) { + uint32_t i_cell = min_head + i; + int32_t n_seq_id = batch.n_seq_id[i]; + llama_pos end_pos = batch.pos[(i * n_seq_tokens) + n_seq_tokens - 1]; + // set the pos with the first seq_id + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][0], end_pos); + // insert the rest of the seq_ids by re-using the cell's pos + for (int j = 1; j < n_seq_id; ++j) { + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][j]); + } + } } // allow getting the range of used cells, from head to head + n - cache.rs.head = min_cell; - cache.rs.n = max_cell - min_cell + 1; - - // sanity check - GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell); + cache.rs.head = min_head; + cache.rs.n = min_n; } if (kv_size > 0) { - for (uint32_t i = 0; i < n_tokens; i++) { - cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.kv.cells[cache.kv.head + k].pos = batch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.kv.cells[cache.kv.head + i].seq_id.insert(batch.seq_id[i][j]); + for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { + cache.kv.cells[cache.kv.head + k].seq_id.insert(batch.seq_id[s][j]); + } } } @@ -8492,16 +8664,15 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_ubatch & batch, const llama_rs_cache & rs, struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, struct ggml_tensor * state_mask, - struct ggml_tensor * state_seq, struct ggml_tensor * w_dt_norm, struct ggml_tensor * w_b_norm, struct ggml_tensor * w_c_norm, - int32_t n_tokens, int32_t rs_head, int32_t n_rs, const llm_build_cb & cb, @@ -8510,14 +8681,23 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_seqs = batch.n_seqs; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, rs.r_l[il], hparams.n_embd_r(il), rs.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, rs.s_l[il], hparams.n_embd_s(il), rs.size); + struct ggml_tensor * conv_states_all = rs.r_l[il]; + struct ggml_tensor * ssm_states_all = rs.s_l[il]; + + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx, conv_states_all, hparams.n_embd_r(il), rs.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx, ssm_states_all, hparams.n_embd_s(il), rs.size); // copy states { - // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows - // NOTE: assuming the copy destinations are ALL contained in the current batch + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs // this shrinks the tensors's ne[1] to n_rs conv_states = ggml_get_rows(ctx, conv_states, state_copy); ssm_states = ggml_get_rows(ctx, ssm_states, state_copy); @@ -8532,17 +8712,24 @@ static struct ggml_tensor * llm_build_mamba( conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); - // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} + struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0); + struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} struct ggml_tensor * xz = ggml_mul_mat(ctx, model.layers[il].ssm_in, cur); // split the above in two - // => {d_inner, n_tokens} - struct ggml_tensor * x = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], 0); - struct ggml_tensor * z = ggml_view_2d(ctx, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); + struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz)); // conv { - // Custom operator which is needed only to ease simultaneous sequence processing. - // For a single sequence, the equivalent is to concatenate the columns of conv_states and x, + // Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml. + // And also because this uses much less memory for large batches (4 times less when d_conv is 4). + // The equivalent is to concatenate the columns of conv_states and x, // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, // then element-wise multiply that with the conv1d weigth, // then sum the elements of each row, @@ -8551,17 +8738,17 @@ static struct ggml_tensor * llm_build_mamba( // and then you're left with the resulting x tensor. // The new conv_states is the last (d_conv - 1) columns // of the last 3rd dimensional "layer" of the self-overlapping view. - // For simultaneous sequences, it's more complicated. - struct ggml_tensor * x_conv = ggml_ssm_conv(ctx, conv_states, x, model.layers[il].ssm_conv1d, state_seq); + // For simultaneous sequences, all sequences need to have the same length. + x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d); - // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache - ggml_build_forward_expand(graph, - ggml_cpy(ctx, - ggml_view_2d(ctx, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx, rs.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + // ensure conv is updated before copying into the recurrent state cache + ggml_build_forward_expand(graph, x); - // extract x from x_conv - x = ggml_view_2d(ctx, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); + ggml_build_forward_expand(graph, + ggml_cpy(ctx, conv_states, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_rs), + rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); @@ -8571,45 +8758,47 @@ static struct ggml_tensor * llm_build_mamba( // ssm { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} + // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} struct ggml_tensor * x_db = ggml_mul_mat(ctx, model.layers[il].ssm_x, x); // split - struct ggml_tensor * dt = ggml_view_2d(ctx, x_db, dt_rank, n_tokens, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); + struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); if (w_dt_norm) { dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); } if (w_b_norm) { B = llm_build_norm(ctx, B, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } if (w_c_norm) { C = llm_build_norm(ctx, C, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); } - // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} + // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} dt = ggml_mul_mat(ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, - // because only a single tensor can be returned. - struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); + // => {d_inner, n_seq_tokens, n_seqs} + struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); - // store last states (the second part of y_ssm_states) - ggml_build_forward_expand(graph, - ggml_cpy(ctx, - ggml_view_1d(ctx, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx, rs.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); + // The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache + ggml_build_forward_expand(graph, y); - struct ggml_tensor * y = ggml_view_2d(ctx, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, ssm_states, + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - // TODO: skip computing output for unused tokens + // TODO: skip computing output earlier for unused tokens - // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, z)); - // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); } + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + return cur; } @@ -8642,6 +8831,8 @@ struct llm_build_context { const float norm_eps; const float norm_rms_eps; + const int32_t n_seqs; + const int32_t n_seq_tokens; const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_rs; @@ -8692,6 +8883,8 @@ struct llm_build_context { beta_slow (cparams.yarn_beta_slow), norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), + n_seqs (batch.n_seqs), + n_seq_tokens (batch.n_seq_tokens), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_rs (worst_case ? rs_self.size : rs_self.n), @@ -8726,7 +8919,6 @@ struct llm_build_context { lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; } void free() { @@ -8898,13 +9090,6 @@ struct llm_build_context { return lctx.inp_s_mask; } - struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(lctx.inp_s_seq, "inp_s_seq", -1); - ggml_set_input(lctx.inp_s_seq); - return lctx.inp_s_seq; - } - struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -12017,7 +12202,6 @@ struct llm_build_context { struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { // norm @@ -12026,9 +12210,9 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, - state_copy, state_mask, state_seq, NULL, NULL, NULL, - n_tokens, rs_head, n_rs, cb, il); + cur = llm_build_mamba(ctx0, model, hparams, batch, rs_self, gf, cur, + state_copy, state_mask, NULL, NULL, NULL, + rs_head, n_rs, cb, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -12074,7 +12258,6 @@ struct llm_build_context { struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); - struct ggml_tensor * state_seq = build_inp_s_seq(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -12089,10 +12272,9 @@ struct llm_build_context { if (n_head_kv == 0) { // Mamba - cur = llm_build_mamba(ctx0, model, hparams, rs_self, gf, cur, - state_copy, state_mask, state_seq, + cur = llm_build_mamba(ctx0, model, hparams, batch, rs_self, gf, cur, state_copy, state_mask, model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, - n_tokens, rs_head, n_rs, cb, il); + rs_head, n_rs, cb, il); } else { // Attention @@ -12152,6 +12334,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, false, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); } @@ -13234,8 +13417,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (lctx.inp_KQ_mask) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -13245,22 +13430,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; - - for (int i = 0; i < n_kv; ++i) { - float f; - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -fabs(kv_self.cells[i].pos - pos); + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = batch.pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + f = -INFINITY; } else { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } } + data[h*(n_kv*n_seq_tokens*n_seqs) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } @@ -13271,8 +13459,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } } else { + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; // when using kv cache, the mask needs to match the kv cache size - const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -13280,27 +13470,35 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; - - for (int i = 0; i < n_tokens; ++i) { - float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - if (hparams.use_alibi) { - f = -fabs(batch.pos[i] - batch.pos[j]); - } else { - f = 0.0f; + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = batch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < batch.n_seq_id[s0]; ++s) { + if (batch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -fabs(batch.pos[ti] - batch.pos[tj]); + } else { + f = 0.0f; + } + break; + } } - break; + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; } } - data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } } } } @@ -13308,7 +13506,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_mean); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -13317,12 +13517,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); std::vector sum(n_tokens, 0); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - sum[seq_id] += 1; + sum[seq_id] += batch.n_seq_tokens; } std::vector div(n_tokens, 0.0f); @@ -13333,14 +13535,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - data[seq_id*n_tokens + i] = div[seq_id]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } } } if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - const int64_t n_tokens = batch.n_tokens; + const int64_t n_tokens = batch.n_tokens; + const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seqs = batch.n_seqs; GGML_ASSERT(lctx.inp_cls); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -13348,14 +13555,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { uint32_t * data = (uint32_t *) lctx.inp_cls->data; memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_pos pos = batch.pos[i]; + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + // TODO: adapt limits to n_seqs when batch.equal_seqs is true GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS"); - if (pos == 0) { - data[seq_id] = i; + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = batch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } } } } @@ -13372,7 +13583,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { uint32_t cell_id = i + rs_self.head; llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) rs_cell.src >= 0; + data[i] = (float) (rs_cell.src >= 0); // only clear once if (rs_cell.src < 0) { @@ -13404,29 +13615,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } } - - // For Mamba (and other recurrent architectures), - // update the correct state(s)/sequence(s) for each token of the batch. - // Each row contains relative cell ids of the sequences for the associated token. - // Like with the KQ_mask, if a token in the batch has multiple sequences, - // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). - if (lctx.inp_s_seq) { - const int64_t n_tokens = batch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer)); - int32_t * data = (int32_t *) lctx.inp_s_seq->data; - - for (int i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = batch.seq_id[i][0]; - GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); - const auto & seq = rs_self.seq_tails[seq_id]; - // ensure the relative cell id will be positive but not too big - GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); - GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); - - data[i] = seq.tail - rs_self.head; - } - } } } @@ -13598,7 +13786,7 @@ static int llama_decode_internal( } else { GGML_ASSERT(u_batch.output); for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += u_batch.output[i] != 0; + n_outputs_new += (int32_t) (u_batch.output[i] != 0); } } @@ -14077,10 +14265,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { if (need_reserve) { // TODO: extract to a function // build worst-case graph - int n_seqs = 1; // TODO: worst-case number of sequences - int n_tokens = (int)std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -18026,10 +18214,10 @@ struct llama_context * llama_new_context_with_model( } // build worst-case graph - int n_seqs = 1; // TODO: worst-case number of sequences - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -19347,6 +19535,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, if (cell_count) { llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; batch.n_seqs = 1; for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; @@ -19354,9 +19543,9 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(pos); batch.pos[i] = pos; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = dest_seq_id; } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; if (!llama_cache_find_slot(cache, batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; @@ -19680,9 +19869,54 @@ void llama_synchronize(struct llama_context * ctx) { ctx->t_compute_start_us = 0; } +// make the outputs have the same order they had in the user-provided batch +static void llama_reorder_outputs(struct llama_context * ctx) { + std::vector & out_ids = ctx->sbatch.out_ids; + if (!out_ids.empty()) { + std::vector logits_tmp; + std::vector embd_tmp; + uint32_t n_vocab = ctx->model.hparams.n_vocab; + uint32_t n_embd = ctx->model.hparams.n_embd; + int32_t n_outputs = ctx->n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + // insertion sort (from https://en.wikipedia.org/wiki/Insertion_sort, but using memmove) + for (int32_t i = 1; i < n_outputs; ++i) { + int32_t j = i; + size_t out_id_tmp = out_ids[i]; + while (j > 0 && out_ids[j - 1] > out_id_tmp) { j -= 1; } + if (i - j == 0) { continue; } + memmove(out_ids.data() + j + 1, out_ids.data() + j, (i - j)*sizeof(out_ids[0])); + out_ids[j] = out_id_tmp; + if (ctx->logits_size > 0) { + // only allocate once something needs to be moved + if (logits_tmp.empty()) { logits_tmp.resize(n_vocab); } + memcpy(logits_tmp.data(), ctx->logits + i*n_vocab, n_vocab*sizeof(float)); + memmove(ctx->logits + (j + 1)*n_vocab, ctx->logits + j*n_vocab, (i - j)*n_vocab*sizeof(float)); + memcpy(ctx->logits + j*n_vocab, logits_tmp.data(), n_vocab*sizeof(float)); + } + if (ctx->embd_size > 0) { + // only allocate once something needs to be moved + if (embd_tmp.empty()) { embd_tmp.resize(n_embd); } + memcpy(embd_tmp.data(), ctx->embd + i*n_embd, n_embd*sizeof(float)); + memmove(ctx->embd + (j + 1)*n_embd, ctx->embd + j*n_embd, (i - j)*n_embd*sizeof(float)); + memcpy(ctx->embd + j*n_embd, embd_tmp.data(), n_embd*sizeof(float)); + } + } + std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx->output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} + float * llama_get_logits(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder logits for backward compatibility + // TODO: maybe deprecate this + llama_reorder_outputs(ctx); + return ctx->logits; } @@ -19727,6 +19961,10 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings(struct llama_context * ctx) { llama_synchronize(ctx); + // reorder embeddings for backward compatibility + // TODO: maybe deprecate this + llama_reorder_outputs(ctx); + return ctx->embd; } From 72eea49224e5b90263de08f8cddc6010353841eb Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 12:24:19 -0400 Subject: [PATCH 015/117] llama : fix batch split output count for embeddings --- llama.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 6878bc8936046..7c6afa7d1fbe8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13730,7 +13730,9 @@ static int llama_decode_internal( n_outputs = 1; } - lctx.sbatch.from_batch(batch_all, n_embd, /* legacy_split */ rs_self.size == 0, lctx.logits_all); + lctx.sbatch.from_batch(batch_all, n_embd, + /* legacy_split */ rs_self.size == 0, + /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { @@ -13740,6 +13742,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // TODO: deprecate slice splits in favor of equal splits + // For now, only use equal splits for recurrent or hybrid model architectures llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens; From 18d1c140471da9443db9e0b67f61ccf540e113c0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 15:01:34 -0400 Subject: [PATCH 016/117] llama : minimize swaps when reordering logits This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models. --- llama.cpp | 50 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7c6afa7d1fbe8..d44dfe7b20933 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19828,33 +19828,43 @@ void llama_synchronize(struct llama_context * ctx) { static void llama_reorder_outputs(struct llama_context * ctx) { std::vector & out_ids = ctx->sbatch.out_ids; if (!out_ids.empty()) { - std::vector logits_tmp; - std::vector embd_tmp; uint32_t n_vocab = ctx->model.hparams.n_vocab; uint32_t n_embd = ctx->model.hparams.n_embd; int32_t n_outputs = ctx->n_outputs; GGML_ASSERT((size_t) n_outputs == out_ids.size()); - // insertion sort (from https://en.wikipedia.org/wiki/Insertion_sort, but using memmove) - for (int32_t i = 1; i < n_outputs; ++i) { - int32_t j = i; - size_t out_id_tmp = out_ids[i]; - while (j > 0 && out_ids[j - 1] > out_id_tmp) { j -= 1; } - if (i - j == 0) { continue; } - memmove(out_ids.data() + j + 1, out_ids.data() + j, (i - j)*sizeof(out_ids[0])); - out_ids[j] = out_id_tmp; + { + bool is_already_sorted = true; + for (int32_t i = 0; i < n_outputs - 1; ++i) { + if (out_ids[i] > out_ids[i + 1]) { + is_already_sorted = false; + break; + } + } + if (is_already_sorted) { + out_ids.clear(); + return; + } + } + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); if (ctx->logits_size > 0) { - // only allocate once something needs to be moved - if (logits_tmp.empty()) { logits_tmp.resize(n_vocab); } - memcpy(logits_tmp.data(), ctx->logits + i*n_vocab, n_vocab*sizeof(float)); - memmove(ctx->logits + (j + 1)*n_vocab, ctx->logits + j*n_vocab, (i - j)*n_vocab*sizeof(float)); - memcpy(ctx->logits + j*n_vocab, logits_tmp.data(), n_vocab*sizeof(float)); + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); + } } if (ctx->embd_size > 0) { - // only allocate once something needs to be moved - if (embd_tmp.empty()) { embd_tmp.resize(n_embd); } - memcpy(embd_tmp.data(), ctx->embd + i*n_embd, n_embd*sizeof(float)); - memmove(ctx->embd + (j + 1)*n_embd, ctx->embd + j*n_embd, (i - j)*n_embd*sizeof(float)); - memcpy(ctx->embd + j*n_embd, embd_tmp.data(), n_embd*sizeof(float)); + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); + } } } std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); From 61200ef29fc0e76f264ada583b77e9228120779f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 16:41:22 -0400 Subject: [PATCH 017/117] llama : fix edge case finding batch seq_id of split recurrent cell This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash. --- llama.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index d44dfe7b20933..62d66c2bc2831 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3879,11 +3879,17 @@ static bool llama_cache_find_slot( if (cell.tail_rc == 0) { cache.rs.clear_cell(cell); } else { - // TODO: does this always work correctly - // even if there are more than one seq_node in this cell? + // Find the seq_id of the first tail of this cell + llama_seq_id seq_id = -1; + for (llama_rs_seq_node & seq_node : cell.seq_nodes) { + if (seq_node.is_tail()) { + seq_id = seq_node.seq_id; + break; + } + } + GGML_ASSERT(seq_id != -1); // Which seq_id of the batch is it? - llama_seq_id seq_id = cell.seq_nodes[0].seq_id; int32_t nth_seq_id = -1; for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { if (seq_id == batch.seq_id[s][0]) { From eb589d5e3664b784aef5326aa14dd21889eb1948 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 1 Jun 2024 23:05:13 -0400 Subject: [PATCH 018/117] llama : avoid copies for simple batch splits --- llama.cpp | 81 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/llama.cpp b/llama.cpp index 62d66c2bc2831..ce96d7b5503d2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3143,19 +3143,29 @@ struct llama_sbatch { GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); // NOTE: loops are separated for cache-friendliness if (batch->token) { - for (size_t i = 0; i < length; ++i) { - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; } } else { ubatch.token = nullptr; } if (batch->embd) { - for (size_t i = 0; i < length; ++i) { - memcpy( - ubatch.embd + n_embd * (ubatch.n_tokens + i), - batch->embd + n_embd * ids[seq.offset + i], - n_embd * sizeof(float) - ); + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + n_embd * (ubatch.n_tokens + i), + batch->embd + n_embd * ids[seq.offset + i], + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + seq.offset; } } else { ubatch.embd = nullptr; @@ -3163,8 +3173,13 @@ struct llama_sbatch { // from here on, the else branches are deprecated; // they are helpers for smoother batch API transition if (batch->pos) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3172,7 +3187,7 @@ struct llama_sbatch { ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); } } - if (seq.n_seq_id > 0) { + if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; @@ -3181,9 +3196,10 @@ struct llama_sbatch { ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { + // simple split if (batch->n_seq_id) { for (size_t i = 0; i < length; ++i) { - ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]]; + ubatch.n_seq_id = batch->n_seq_id + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3192,7 +3208,7 @@ struct llama_sbatch { } if (batch->seq_id) { for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]]; + ubatch.seq_id = batch->seq_id + seq.offset; } } else { for (size_t i = 0; i < length; ++i) { @@ -3201,11 +3217,19 @@ struct llama_sbatch { } } if (batch->logits) { - for (size_t i = 0; i < length; ++i) { - size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; - ubatch.output[ubatch.n_tokens + i] = is_output; - if (is_output) { out_ids.push_back(id); } + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } } } else if (logits_all) { for (size_t i = 0; i < length; ++i) { @@ -3222,18 +3246,18 @@ struct llama_sbatch { } } if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { - ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1; + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; } ubatch.n_tokens += length; - ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits seq.offset += length; seq.length -= length; n_tokens -= length; GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); } - // legacy split, unknown number of sequences of unequal lengths - llama_ubatch split_slice(size_t n_ubatch) { + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); ubatch.equal_seqs = false; @@ -3241,7 +3265,6 @@ struct llama_sbatch { llama_sbatch_seq & s = seq[0]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits - // TODO: reduce copies add_seq_to_ubatch(ubatch, s, length); } return ubatch; @@ -3254,7 +3277,7 @@ struct llama_sbatch { if (!seq.empty()) { size_t length = 0; size_t n_tokens_in_ubatch = 0; - GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits // smallest first, because it's easier to split this way; // starting from the end to pop in constant time. for (size_t i = seq.size(); i-- > 0;) { @@ -3282,13 +3305,13 @@ struct llama_sbatch { if (!seq.empty()) { llama_sbatch_seq & s = seq[seq.size() - 1]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; - GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits add_seq_to_ubatch(ubatch, s, length); } return ubatch; } - void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) { + void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -3302,7 +3325,7 @@ struct llama_sbatch { for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } - if (legacy_split) { + if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; s.n_seq_id = 0; @@ -13737,7 +13760,7 @@ static int llama_decode_internal( } lctx.sbatch.from_batch(batch_all, n_embd, - /* legacy_split */ rs_self.size == 0, + /* simple_split */ rs_self.size == 0, /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -13749,7 +13772,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { // TODO: deprecate slice splits in favor of equal splits // For now, only use equal splits for recurrent or hybrid model architectures - llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch); + llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch); const uint32_t n_tokens = u_batch.n_tokens; // count the outputs in this u_batch From 8fb57ac0fbf21d09abd21f3c167ee2cec8bb7094 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 2 Jun 2024 22:49:24 -0400 Subject: [PATCH 019/117] llama : use im2col and mul_mat to perform convolution for Mamba This removes the need for ggml_ssm_conv!!! But performance seems slighly worse on my system, especially for prompt processing. Maybe ggml_mul_mat isn't optimized for small row sizes? More performance testing is necessary until GGML_OP_SSM_CONV is removed. * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model. --- ggml.c | 121 +++++++++++++++++++++--------------------------------- ggml.h | 3 +- llama.cpp | 83 +++++++++++++++++++++++++------------ 3 files changed, 103 insertions(+), 104 deletions(-) diff --git a/ggml.c b/ggml.c index 426501015bbe5..253b3fa416e93 100644 --- a/ggml.c +++ b/ggml.c @@ -7124,26 +7124,24 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, + struct ggml_tensor * sx, struct ggml_tensor * c) { - GGML_ASSERT(ggml_is_3d(s)); - GGML_ASSERT(ggml_is_3d(x)); + GGML_ASSERT(ggml_is_3d(sx)); GGML_ASSERT(ggml_is_matrix(c)); const int64_t d_conv = c->ne[0]; const int64_t d_inner = c->ne[1]; - const int64_t n_t = x->ne[1]; // tokens per sequence - const int64_t n_s = s->ne[2]; + const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence + const int64_t n_s = sx->ne[2]; - GGML_ASSERT(s->ne[0] == d_conv - 1); - GGML_ASSERT(s->ne[1] == d_inner); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(x->ne[2] == n_s); + // TODO: maybe support other strides than 1? + GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(sx->ne[1] == d_inner); + GGML_ASSERT(n_t >= 0); bool is_node = false; - if (s->grad || x->grad || c->grad) { + if (sx->grad || c->grad) { GGML_ASSERT(false); // TODO: implement is_node = true; } @@ -7152,9 +7150,8 @@ struct ggml_tensor * ggml_ssm_conv( result->op = GGML_OP_SSM_CONV; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = s; - result->src[1] = x; - result->src[2] = c; + result->src[0] = sx; + result->src[1] = c; return result; } @@ -7203,8 +7200,8 @@ struct ggml_tensor * ggml_ssm_scan( is_node = true; } - // y - struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]); + // concatenated y + ssm_states + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -16252,22 +16249,21 @@ static void ggml_compute_forward_ssm_conv_f32( return; } - const struct ggml_tensor * src0 = dst->src[0]; // conv_state - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight const int ith = params->ith; const int nth = params->nth; - const int nc = src2->ne[0]; // d_conv + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t const int nr = src0->ne[1]; // d_inner - const int n_t = src1->ne[1]; // tokens per sequence - const int n_s = src0->ne[2]; // number of sequences in the batch + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_are_same_shape(src1, dst)); + GGML_ASSERT( dst->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // rows per thread @@ -16278,54 +16274,28 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)? - // This would avoid having to copy into an intermediate buffer, but the state would be bigger. - float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith; - for (int i3 = 0; i3 < n_s; ++i3) { - float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - - // copy the state into working memory - // can't use memcpy because (d_conv) != (d_conv - 1) - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)]; - } - } - for (int i2 = 0; i2 < n_t; ++i2) { - float * x = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} - float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner} - - // shift state left - memmove(s, s + 1, (nc*ir - 1) * sizeof(float)); + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + // TODO: transpose the output for smaller strides for big batches? // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // insert x on the last column - s[(nc - 1) + i1*nc] = x0[i1]; - } - - // it seems a little faster when this is separate from the state shift for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision float sumf = 0.0f; + + // d_conv for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - sumf += s[i] * c[i]; + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; } x[i1] = sumf; } } - - // copy the state out of it - for (int i1 = 0; i1 < ir; ++i1) { - for (int i0 = 0; i0 < nc - 1; ++i0) { - s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc]; - } - } } } @@ -16368,7 +16338,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t n_s = src0->ne[2]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16377,6 +16347,10 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src5->nb[0] == sizeof(float)); // required for the dot product between s and C GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -16388,13 +16362,17 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i3 = 0; i3 < n_s; ++i3) { for (int i2 = 0; i2 < n_t; ++i2) { - float * y = (float *) ((char *) dst->data + ir0*( dst->nb[0]) + i2*( dst->nb[1]) + i3*( dst->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - float * C = (float *) ((char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } // d_inner for (int i1 = 0; i1 < ir; ++i1) { @@ -16406,7 +16384,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i0 = 0; i0 < nc; ++i0) { int i = i0 + i1*nc; // state = prev_state * dA + dB * x - float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[i0]; s[i] = state; @@ -19577,13 +19555,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; - case GGML_OP_SSM_CONV: - { - const int64_t d_conv = node->src[2]->ne[0]; - const int64_t d_inner = node->src[0]->ne[1]; - - cur += sizeof(float)*d_conv*(d_inner + n_tasks - 1); - } break; case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/ggml.h b/ggml.h index 9df601e2cd826..c772febf0aafa 100644 --- a/ggml.h +++ b/ggml.h @@ -1803,8 +1803,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, - struct ggml_tensor * s, - struct ggml_tensor * x, + struct ggml_tensor * sx, struct ggml_tensor * c); GGML_API struct ggml_tensor * ggml_ssm_scan( diff --git a/llama.cpp b/llama.cpp index ce96d7b5503d2..ecdcf3a4e7096 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2827,11 +2827,13 @@ struct llama_rs_cache { n_shared_tail_cells += 1; n_seqs -= 1; } - } else if (rs_cell.is_empty()) { - // from shared to unique - n_seqs += 1; - if (prev_cell.tail_rc == 1) { - // it was the last tail of the previous cell + } else { + if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + } + if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + // from last shared to fully tail n_shared_tail_cells -= 1; } } @@ -8683,6 +8685,18 @@ static struct ggml_tensor * llm_build_mamba( conv_states = ggml_reshape_3d(ctx, conv_states, d_conv - 1, d_inner, n_rs); ssm_states = ggml_reshape_3d(ctx, ssm_states, d_state, d_inner, n_rs); + // copy states which won't be changed further (between n_seqs and n_rs) + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, conv_states, (d_conv - 1)*d_inner*(n_rs - n_seqs), n_seqs*(conv_states->nb[2])), + ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*(d_conv - 1)*d_inner*ggml_element_size(conv_states_all)))); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, ssm_states, d_state*d_inner*(n_rs - n_seqs), n_seqs*(ssm_states->nb[2])), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*(n_rs - n_seqs), (rs_head + n_seqs)*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + // the part of the states that will be used and modified struct ggml_tensor * conv = ggml_view_3d(ctx, conv_states, d_conv - 1, d_inner, n_seqs, conv_states->nb[1], conv_states->nb[2], 0); struct ggml_tensor * ssm = ggml_view_3d(ctx, ssm_states, d_state, d_inner, n_seqs, ssm_states->nb[1], ssm_states->nb[2], 0); @@ -8698,28 +8712,43 @@ static struct ggml_tensor * llm_build_mamba( // conv { - // Custom operator, which is needed because self-overlapping views aren't yet well supported by ggml. - // And also because this uses much less memory for large batches (4 times less when d_conv is 4). - // The equivalent is to concatenate the columns of conv_states and x, - // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weigth, + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_cont(ctx, ggml_transpose(ctx, x)), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, // then sum the elements of each row, // (the last two steps are a dot product over rows (also doable with mul_mat)) // then permute away the ne[0] dimension, // and then you're left with the resulting x tensor. - // The new conv_states is the last (d_conv - 1) columns - // of the last 3rd dimensional "layer" of the self-overlapping view. // For simultaneous sequences, all sequences need to have the same length. - x = ggml_ssm_conv(ctx, conv, x, model.layers[il].ssm_conv1d); - // ensure conv is updated before copying into the recurrent state cache - ggml_build_forward_expand(graph, x); + // For some reason, im2col expects a F16 kernel, but doesn't even read from it. + // TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it. + // => { d_conv * d_inner, n_seq_tokens, n_seqs} + x = ggml_im2col(ctx, + ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner), + conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32); - ggml_build_forward_expand(graph, - ggml_cpy(ctx, conv_states, - ggml_view_1d(ctx, conv_states_all, - (d_conv - 1)*(d_inner)*(n_rs), - rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs); + + // => {1, 1, d_inner, n_seq_tokens * n_seqs} + x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x); + x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); + + // Alternatively, this does the same as the above + // x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); @@ -8746,16 +8775,16 @@ static struct ggml_tensor * llm_build_mamba( // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} - struct ggml_tensor * y = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); - - // The ssm scan also changes the state, ensure it's done before copying to the recurrent state cache - ggml_build_forward_expand(graph, y); + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); // store last states ggml_build_forward_expand(graph, - ggml_cpy(ctx, ssm_states, - ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); // TODO: skip computing output earlier for unused tokens From 17f6c1ef3bdb8332393ea8da008023134291b0c3 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 3 Jun 2024 00:41:15 -0400 Subject: [PATCH 020/117] llama : fix .base() compilation error on Windows --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index ecdcf3a4e7096..d4736473169a1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2574,7 +2574,7 @@ struct llama_rs_cache { std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); // The iterator needs to point inside the correct vector - GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); + GGML_ASSERT(&(*node_iter) >= rs_cell.seq_nodes.data() && &(*node_iter) < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); if (node_iter != rs_cell.seq_nodes.end()) { // update the tree llama_rs_seq_node node = *node_iter; From fee3c1d740c0e027c81e2f2f3fb48d619857175f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 3 Jun 2024 13:49:56 -0400 Subject: [PATCH 021/117] llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL * ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors The implementation already supported it, and this makes Mamba's conv step slightly faster. --- ggml.c | 5 ----- llama.cpp | 20 ++++++++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ggml.c b/ggml.c index 253b3fa416e93..1a37ff2f070be 100644 --- a/ggml.c +++ b/ggml.c @@ -10992,11 +10992,6 @@ static void ggml_compute_forward_concat_f32( GGML_TENSOR_BINARY_OP_LOCALS - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - const int32_t dim = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(dim >= 0 && dim < 4); diff --git a/llama.cpp b/llama.cpp index d4736473169a1..36b824d566b90 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8713,7 +8713,7 @@ static struct ggml_tensor * llm_build_mamba( // conv { // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} - struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_cont(ctx, ggml_transpose(ctx, x)), 0); + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0); // copy last (d_conv - 1) columns back into the state cache struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); @@ -8734,6 +8734,8 @@ static struct ggml_tensor * llm_build_mamba( // and then you're left with the resulting x tensor. // For simultaneous sequences, all sequences need to have the same length. + // TODO: remove unused implementations +#if 0 // For some reason, im2col expects a F16 kernel, but doesn't even read from it. // TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it. // => { d_conv * d_inner, n_seq_tokens, n_seqs} @@ -8741,14 +8743,24 @@ static struct ggml_tensor * llm_build_mamba( ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner), conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32); + #if 0 + // TODO: CUDA, SYCL, and Vulkan don't (yet) support broadcasting the ne[3] dimension on MUL_MAT x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs); // => {1, 1, d_inner, n_seq_tokens * n_seqs} x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x); - x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); + #else + x = ggml_reshape_4d(ctx, x, d_conv, d_inner, n_seq_tokens, n_seqs); - // Alternatively, this does the same as the above - // x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + // NOTE: it seems this is very slighly more performant than MUL_MAT on CPU for small row sizes + // => {1, d_inner, n_seq_tokens, n_seqs} + x = ggml_sum_rows(ctx, ggml_mul(ctx, x, model.layers[il].ssm_conv1d)); + #endif + x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); +#else + // Alternatively, this does the same as the above, but faster + x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); +#endif // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); From 372482dffeecc25b8eec24ad672ec66bd9baa55c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 8 Jun 2024 17:58:40 -0400 Subject: [PATCH 022/117] llama : rename llama_cache to llama_past This can be changed back later if the name change is wrong. I was renaming the functions anyway to generalize kv-cache-related functions to hybrid and recurrent model architectures. I think llama_past is a better name than llama_cache for a combined kv cache and recurrent state cache, because the states it contains pretty much always come before the newly-added ones for any particular sequence. Also 'llama_past_clear' sounds more obvious in what it does than 'llama_kv_cache_clear'. The future is what the models generate. (For embeddings, the kv cache isn't really used anyway) Still, I'm open to better suggestions. --- llama.cpp | 104 +++++++++++++++++++++++++++--------------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4a2fb3a92f452..4b84313cf8d37 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2877,7 +2877,7 @@ struct llama_rs_cache { } }; -struct llama_cache { +struct llama_past { // key + value cache for self attention llama_kv_cache kv; @@ -2896,7 +2896,7 @@ struct llama_cache { return size; } - ~llama_cache() { + ~llama_past() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -3426,7 +3426,7 @@ struct llama_context { const llama_model & model; // key + value cache for self-attention, and/or recurrent state cache - struct llama_cache cache; + struct llama_past cache; // sequence-length-aware batch splitting llama_sbatch sbatch; @@ -3604,8 +3604,8 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { // kv and rs cache helpers // -static bool llama_cache_init( - struct llama_cache & cache, +static bool llama_past_init( + struct llama_past & cache, const llama_context * ctx, ggml_type type_k, ggml_type type_v, @@ -3713,11 +3713,11 @@ static bool llama_cache_init( // no buffer was needed, so this is fine return true; } - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to allocate buffer for past cache\n", __func__); return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s past cache size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -3728,8 +3728,8 @@ static bool llama_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_cache_find_slot( - struct llama_cache & cache, +static bool llama_past_find_slot( + struct llama_past & cache, const struct llama_ubatch & batch) { const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; @@ -4001,7 +4001,7 @@ static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { return 0; } -static void llama_cache_clear(struct llama_cache & cache) { +static void llama_past_clear(struct llama_past & cache) { if (cache.kv.size > 0) { for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; @@ -4035,8 +4035,8 @@ static void llama_cache_clear(struct llama_cache & cache) { } } -static llama_pos llama_cache_seq_rm( - struct llama_cache & cache, +static llama_pos llama_past_seq_rm( + struct llama_past & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { @@ -4134,8 +4134,8 @@ static llama_pos llama_cache_seq_rm( return n_past; } -static llama_pos llama_cache_seq_cp( - struct llama_cache & cache, +static llama_pos llama_past_seq_cp( + struct llama_past & cache, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, @@ -4199,7 +4199,7 @@ static llama_pos llama_cache_seq_cp( return n_past; } -static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { +static void llama_past_seq_keep(struct llama_past & cache, llama_seq_id seq_id) { if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; @@ -4249,8 +4249,8 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id } } -static void llama_cache_seq_add( - struct llama_cache & cache, +static void llama_past_seq_add( + struct llama_past & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1, @@ -4317,8 +4317,8 @@ static void llama_cache_seq_add( } } -static void llama_cache_seq_div( - struct llama_cache & cache, +static void llama_past_seq_div( + struct llama_past & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1, @@ -4358,7 +4358,7 @@ static void llama_cache_seq_div( } } -static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { +static llama_pos llama_past_seq_pos_max(struct llama_past & cache, llama_seq_id seq_id) { llama_pos result = -1; if (cache.rs.size > 0) { @@ -13911,7 +13911,7 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - if (!llama_cache_find_slot(lctx.cache, u_batch)) { + if (!llama_past_find_slot(lctx.cache, u_batch)) { return 1; } @@ -17981,7 +17981,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { + if (!llama_past_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -18515,85 +18515,85 @@ int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { return ctx->cache.rs.used; } -void llama_cache_clear(struct llama_context * ctx) { - llama_cache_clear(ctx->cache); +void llama_past_clear(struct llama_context * ctx) { + llama_past_clear(ctx->cache); } // deprecated void llama_kv_cache_clear(struct llama_context * ctx) { - llama_cache_clear(ctx); + llama_past_clear(ctx); } -llama_pos llama_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +llama_pos llama_past_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - return llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + return llama_past_seq_rm(ctx->cache, seq_id, p0, p1); } // deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_pos n_past = llama_cache_seq_rm(ctx, seq_id, p0, p1); + llama_pos n_past = llama_past_seq_rm(ctx, seq_id, p0, p1); return n_past >= p0; } -llama_pos llama_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +llama_pos llama_past_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { uint32_t n_seq_max = llama_n_seq_max(ctx); if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { return 0; } if (seq_id_src == seq_id_dst) { - return llama_cache_seq_pos_max(ctx->cache, seq_id_dst) + 1; + return llama_past_seq_pos_max(ctx->cache, seq_id_dst) + 1; } - return llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); + return llama_past_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } // deprecated void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - llama_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); + llama_past_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); } -void llama_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +void llama_past_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_cache_seq_keep(ctx->cache, seq_id); + llama_past_seq_keep(ctx->cache, seq_id); } // deprecated void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_cache_seq_keep(ctx, seq_id); + llama_past_seq_keep(ctx, seq_id); } -void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +void llama_past_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } if (delta == 0) { return; } - llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); + llama_past_seq_add(ctx->cache, seq_id, p0, p1, delta); } // deprecated void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - llama_cache_seq_add(ctx, seq_id, p0, p1, delta); + llama_past_seq_add(ctx, seq_id, p0, p1, delta); } -void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_past_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } if (d == 1) { return; } - llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); + llama_past_seq_div(ctx->cache, seq_id, p0, p1, d); } // deprecated void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - llama_cache_seq_div(ctx, seq_id, p0, p1, d); + llama_past_seq_div(ctx, seq_id, p0, p1, d); } -llama_pos llama_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { +llama_pos llama_past_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } - return llama_cache_seq_pos_max(ctx->cache, seq_id); + return llama_past_seq_pos_max(ctx->cache, seq_id); } // deprecated llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - llama_pos max_pos = llama_cache_seq_pos_max(ctx, seq_id); + llama_pos max_pos = llama_past_seq_pos_max(ctx, seq_id); return max_pos < 0 ? 0 : max_pos; } @@ -19345,7 +19345,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, GGML_ASSERT(cache.rs.size == 0); // not implemented // Wipe the slot - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); const uint8_t * inp = src; @@ -19402,7 +19402,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } batch.n_seq_id[0] = 1; batch.seq_id[0] = &dest_seq_id; - if (!llama_cache_find_slot(cache, batch)) { + if (!llama_past_find_slot(cache, batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; } @@ -19427,7 +19427,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(k_type_i_ref); const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; if (k_type_i != k_type_i_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return 0; } @@ -19438,7 +19438,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(k_size_row_ref); const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il); return 0; } @@ -19459,7 +19459,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } @@ -19470,7 +19470,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_size_row_ref); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); return 0; } @@ -19490,7 +19490,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_type_i_ref); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return 0; } @@ -19501,7 +19501,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, inp += sizeof(v_size_el_ref); const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); if (v_size_el != v_size_el_ref) { - llama_cache_seq_rm(cache, dest_seq_id, -1, -1); + llama_past_seq_rm(cache, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); return 0; } From 43d8d4bf9e88df10203f7d8d4a1107b84bebbcfd Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 10 Jun 2024 14:44:42 -0400 Subject: [PATCH 023/117] examples : replace llama_kv_cache_seq_* with llama_past_seq_* --- common/common.cpp | 2 +- examples/batched-bench/batched-bench.cpp | 4 +- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/gritlm/gritlm.cpp | 4 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 4 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 8 +-- .../llama.cpp.swift/LibLlama.swift | 8 +-- examples/lookahead/lookahead.cpp | 13 ++--- examples/lookup/lookup.cpp | 3 +- examples/main/main.cpp | 21 +++++--- examples/parallel/parallel.cpp | 10 ++-- examples/passkey/passkey.cpp | 28 +++++------ examples/perplexity/perplexity.cpp | 12 ++--- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 49 ++++++++++--------- examples/speculative/speculative.cpp | 28 ++++++----- llama.cpp | 3 +- llama.h | 28 +++++------ 23 files changed, 127 insertions(+), 114 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 1591790e6df4c..d04e047410778 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2366,7 +2366,7 @@ std::tuple llama_init_from_gpt_par std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_clear(lctx); + llama_past_clear(lctx); llama_synchronize(lctx); llama_reset_timings(lctx); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 718f0a61a1878..114dd811ee3f9 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_TEE("%s: llama_decode() failed\n", __func__); @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index dbbd06da58183..443a03d575ea4 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) + llama_past_seq_cp(context, 0, Int32(i), -1, -1) } if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 62d9b144d3340..888cf9e8e8c34 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -112,7 +112,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (int32_t i = 1; i < n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } if (n_parallel > 1) { diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 244751e003d9e..9a7c32d6b8ca2 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -25,7 +25,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2135157916c97..dd389ac004383 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -43,7 +43,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_causal_attn(ctx, false); // run model @@ -97,7 +97,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo const llama_model * mdl = llama_get_model(ctx); llama_token eos_token = llama_token_eos(mdl); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_causal_attn(ctx, true); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e18f495630616..c81590a3f8c88 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -455,7 +455,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 0e4ec79c693fa..0a74b93abd698 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -380,8 +380,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 5c31548a6c25c..d48eb245daa80 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1360,7 +1360,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // warmup run if (t.n_prompt > 0) { @@ -1372,7 +1372,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); uint64_t t_start = get_time_ns(); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 874158ef0f98f..57ee5a650893a 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_cache_clear(reinterpret_cast(context)); + llama_past_clear(reinterpret_cast(context)); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 737f882fb2d2e..50fcaa12d6145 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -214,7 +214,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp_start = ggml_time_us() @@ -227,7 +227,7 @@ actor LlamaContext { // bench text generation - llama_kv_cache_clear(context) + llama_past_clear(context) let t_tg_start = ggml_time_us() @@ -246,7 +246,7 @@ actor LlamaContext { let t_tg_end = ggml_time_us() - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -296,7 +296,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_kv_cache_clear(context) + llama_past_clear(context) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index fb20ad93f9c1d..7f6e42e8d2810 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -96,7 +96,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -438,17 +438,18 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation - llama_kv_cache_seq_rm(ctx, -1, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_kv_cache_seq_keep(ctx, seq_id_best); - llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); + llama_past_seq_keep(ctx, seq_id_best); + llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1); + llama_past_seq_rm (ctx, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 80ecd925d5962..db861d6ad99f0 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -195,7 +195,8 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, 0, n_past, -1); llama_batch_clear(batch_tgt); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b97b7b7937f02..446fe035c3d25 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -299,6 +299,10 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } + + // remove any "future" tokens that we might have inherited from the previous session + n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1); + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { LOG_TEE("%s: using full prompt from session file\n", __func__); } else if (n_matching_session_tokens >= embd_inp.size()) { @@ -310,9 +314,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", __func__, n_matching_session_tokens, embd_inp.size()); } - - // remove any "future" tokens that we might have inherited from the previous session - llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOGLN( @@ -325,6 +326,8 @@ int main(int argc, char ** argv) { LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); + } else { + session_tokens.resize(n_matching_session_tokens); } // number of tokens to keep when resetting context @@ -535,8 +538,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; @@ -563,9 +566,9 @@ int main(int argc, char ** argv) { LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd); + llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; @@ -579,6 +582,8 @@ int main(int argc, char ** argv) { if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { + // TODO: are the session tokens guaranteed to all be matching here? + // Should n_matching_session_tokens be re-used instead? if (embd[i] != session_tokens[n_session_consumed]) { session_tokens.resize(n_session_consumed); break; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7faeaec975ae3..f684788043450 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -200,7 +200,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("\n"); @@ -232,9 +232,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, -1, -1); + llama_past_seq_rm(ctx, i, -1, -1); // but keep the system prompt - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("%s: clearing the KV cache\n", __func__); @@ -371,8 +371,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); - llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_past_seq_rm(ctx, client.id + 1, -1, -1); + llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1e0a9..c6564c5cfd4c7 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,11 +126,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update (ctx); + llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } llama_batch_clear(batch); @@ -160,12 +160,12 @@ int main(int argc, char ** argv) { LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; llama_batch_clear(batch); @@ -191,12 +191,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 0bd78c21a86a1..ad03b3bb5552b 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { } // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 55b7b2f70ae2a..bd7d06d371c2e 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -81,7 +81,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 00c2277ac2827..974dc3c3ed5f5 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -192,7 +192,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_kv_cache_clear(ctx3); + llama_past_clear(ctx3); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6ffaa8d9fe637..a04c47bae21e0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1107,7 +1107,7 @@ struct server_context { LOG_VERBOSE("clearing KV cache", {}); // clear the entire KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); clean_kv_cache = false; } @@ -1151,7 +1151,7 @@ struct server_context { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } @@ -1824,7 +1824,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + llama_past_seq_rm(ctx, slot->id + 1, -1, -1); slot->cache_tokens.clear(); server_task_result result; @@ -1939,8 +1939,8 @@ struct server_context { {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); + llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -2155,23 +2155,28 @@ struct server_context { } // keep only the common part - int p0 = (int) system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); - - p0 = (int) system_tokens.size(); - if (p0 != 0) { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); + llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past; + + // for recurrent and hybrid models, sometimes it goes back further than asked + llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1); + + if (new_p0 < p0) { + GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size()); + + slot.n_past -= p0 - new_p0; + if (slot.ga_i > 0) { + // TODO: test with an hybrid model (e.g. Jamba) + slot.n_past_se -= p0 - new_p0; } - // there is no common part left (except for the system prompt) - slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? + // TODO: find a way to avoid rolling back the sampling context twice llama_sampling_reset(slot.ctx_sampling); + // push the prompt into the sampling context (do not apply grammar) + for (int i = 0; i < slot.n_past; ++i) { + llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + } + + p0 = new_p0; } // remove the non-common part from the cache @@ -2273,9 +2278,9 @@ struct server_context { LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); + llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0939a1a6a7a38..3a1ef06a5e6b4 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -394,14 +394,15 @@ int main(int argc, char ** argv) { { LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_kv_cache_seq_keep(ctx_dft, s_keep); - llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_dft, 0); - - llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_cache_seq_keep(ctx_tgt, s_keep); - llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_past_seq_keep(ctx_dft, s_keep); + llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_dft, 0); + + // FIXME: recurrent and hybrid models + llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_past_seq_keep(ctx_tgt, s_keep); + llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -418,7 +419,8 @@ int main(int argc, char ** argv) { llama_batch_clear(batch_dft); llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); @@ -474,8 +476,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1); + llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -553,9 +555,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_past_seq_keep(ctx_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_past_seq_cp(ctx_tgt, 0, s, -1, -1); } // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/llama.cpp b/llama.cpp index 4b84313cf8d37..2233161d8f938 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2126,7 +2126,6 @@ struct llama_ubatch { llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] - // FIXME: make all uses of this use n_seqs int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] @@ -18992,7 +18991,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; diff --git a/llama.h b/llama.h index 0d9d522569632..4ecfc5f3e0a91 100644 --- a/llama.h +++ b/llama.h @@ -583,11 +583,11 @@ extern "C" { LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed - LLAMA_API void llama_cache_clear( + LLAMA_API void llama_past_clear( struct llama_context * ctx); LLAMA_API DEPRECATED(void llama_kv_cache_clear( struct llama_context * ctx), - "use llama_cache_clear instead"); + "use llama_past_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // seq_id < 0 : match any sequence @@ -595,7 +595,7 @@ extern "C" { // p1 < 0 : [p0, inf) // Returns n_past (one more than the largest remaining pos in the seq_id) // which is only meaningful to handle for partial removals. - LLAMA_API llama_pos llama_cache_seq_rm( + LLAMA_API llama_pos llama_past_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -605,7 +605,7 @@ extern "C" { llama_seq_id seq_id, llama_pos p0, llama_pos p1), - "use llama_cache_seq_rm instead, and handle its return value for partial removals"); + "use llama_past_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence @@ -613,7 +613,7 @@ extern "C" { // p1 < 0 : [p0, inf) // Returns n_past (one more than the largest remaining pos in the destination seq_id) // which is only meaningful to handle when partially copying. - LLAMA_API llama_pos llama_cache_seq_cp( + LLAMA_API llama_pos llama_past_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, @@ -625,16 +625,16 @@ extern "C" { llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1), - "use llama_cache_seq_cp instead, and handle its return value for partial copies"); + "use llama_past_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_cache_seq_keep( + LLAMA_API void llama_past_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_keep instead"); + "use llama_past_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -642,7 +642,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_cache_seq_add( + LLAMA_API void llama_past_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -654,7 +654,7 @@ extern "C" { llama_pos p0, llama_pos p1, llama_pos delta), - "use llama_cache_seq_add instead"); + "use llama_past_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -662,7 +662,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_cache_seq_div( + LLAMA_API void llama_past_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -674,16 +674,16 @@ extern "C" { llama_pos p0, llama_pos p1, int d), - "use llama_cache_seq_div instead"); + "use llama_past_seq_div instead"); // Returns the largest position present in the KV and/or RS cache for the specified sequence - LLAMA_API llama_pos llama_cache_seq_pos_max( + LLAMA_API llama_pos llama_past_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); + "use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: From 33425a7e1ed366082a2dbf64f2485531471515e0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 12 Jun 2024 12:57:02 -0400 Subject: [PATCH 024/117] mamba : fix non-contiguous usage of ggml_silu --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 2233161d8f938..37190bf1c48b0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8867,7 +8867,7 @@ static struct ggml_tensor * llm_build_mamba( // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx, y, ggml_silu(ctx, z)); + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_mul_mat(ctx, model.layers[il].ssm_out, y); From 1f0fea70fb761d10e2264cbdcf4852ed32706c89 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 Aug 2024 10:43:42 -0400 Subject: [PATCH 025/117] llama : initial Mamba-2 support --- convert_hf_to_gguf.py | 67 ++++++++ ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 193 ++++++++++++++-------- gguf-py/gguf/constants.py | 19 +++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 6 +- src/llama.cpp | 291 +++++++++++++++++++++++++++++++-- 7 files changed, 495 insertions(+), 87 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 108c822cff5d2..0ac64574a3043 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2788,6 +2788,73 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@Model.register("Mamba2ForCausalLM") +class Mamba2Model(Model): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + elif (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b8a21a2ccc3f0..59e0022dd4286 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1787,7 +1787,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * D); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d63c917a5705a..6668209081b6c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * D) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; - - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; + + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); + GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(D->ne[0] == n_head); + GGML_ASSERT(ggml_is_vector(D)); + + if (ggml_is_vector(A)) { + // Mamba-2 + GGML_ASSERT(A->ne[0] == n_head); + } else { + // Mamba-1 + GGML_ASSERT(A->ne[0] == d_state); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + } } bool is_node = false; @@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = D; return result; } @@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // dt - const struct ggml_tensor * src3 = dst->src[3]; // A - const struct ggml_tensor * src4 = dst->src[4]; // B - const struct ggml_tensor * src5 = dst->src[5]; // C + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src0->ne[3]; // number of sequences in the batch + + const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations + GGML_ASSERT(src6->nb[0] == sizeof(float)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); + + // heads per thread + const int dh = (nh + nth - 1)/nth; + + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); + + for (int i3 = 0; i3 < ns; ++i3) { + for (int i2 = 0; i2 < nt; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + const float * D = (const float *) ((const char *) src6->data); // {nh} + float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} + float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + + // use the output as the source when it's not the first token-wise iteration if (i2 > 0) { s0 = s; } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + if (ggml_is_vector(src3)) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); + + // TODO: SIMD implementation + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } + } + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } } - y[i1] = sumf; } } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b55effa9907b1..32a2fb20f84b9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -130,6 +130,7 @@ class SSM: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class Tokenizer: @@ -208,6 +209,7 @@ class MODEL_ARCH(IntEnum): GEMMA2 = auto() STARCODER2 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -269,6 +271,7 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() ATTN_Q_A = auto() ATTN_Q_B = auto() @@ -338,6 +341,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -399,6 +403,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", @@ -869,6 +874,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1373,6 +1391,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af3b98c679b0b..ea788918dbf2c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a4f185c0658a3..8593a80a5ab8f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -396,7 +396,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.SSM_IN: ( @@ -429,6 +429,10 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.D", ), + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 + ), + MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", diff --git a/src/llama.cpp b/src/llama.cpp index bd7f1508b2644..5be0ef7a2ac7a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -198,6 +198,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -245,6 +246,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -328,6 +330,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_TOKENIZER_MODEL, @@ -427,7 +430,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -517,6 +521,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, @@ -1068,6 +1073,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -2239,6 +2260,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -2289,6 +2311,7 @@ struct llama_hparams { if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->ssm_n_group != other.ssm_n_group) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true; @@ -2357,7 +2380,7 @@ struct llama_hparams { // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings @@ -2419,6 +2442,7 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm; struct ggml_tensor * attn_norm_cross; struct ggml_tensor * attn_norm_enc; + struct ggml_tensor * ssm_norm; // attention struct ggml_tensor * wq; @@ -5573,6 +5597,38 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: model.type = e_model::MODEL_MEDIUM; break; + case 1536: model.type = e_model::MODEL_LARGE; break; + case 2048: model.type = e_model::MODEL_XL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6404,6 +6460,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } @@ -7639,7 +7696,7 @@ static bool llm_load_tensors( layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); - layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); @@ -7648,9 +7705,61 @@ static bool llm_load_tensors( layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); // no "weight" suffix for these - layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = n_embd / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}); + + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}); + + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {n_head}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {n_head}); + + layer.ssm_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}); + // out_proj layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } @@ -9041,6 +9150,8 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = batch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -9064,7 +9175,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, graph, ssm_states_all, state_copy, state_mask, hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9113,8 +9224,8 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x); // split struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * B = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -9127,23 +9238,23 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); // store last states ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} - y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -9157,6 +9268,136 @@ static struct ggml_tensor * llm_build_mamba( return cur; } +static struct ggml_tensor * llm_build_mamba2( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_ubatch & batch, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t kv_head, + int32_t n_kv, + const llm_build_cb & cb, + int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = model.hparams; + const llama_kv_cache & kv = lctx.kv_self; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = batch.n_seqs; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * conv_states_all = kv.k_l[il]; + struct ggml_tensor * ssm_states_all = kv.v_l[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, + graph, conv_states_all, state_copy, state_mask, + hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, + graph, ssm_states_all, state_copy, state_mask, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); + + // split the above in three + struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + struct ggml_tensor * x = ggml_view_4d(ctx, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + struct ggml_tensor * B = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + struct ggml_tensor * C = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); + + // grouped RMS norm + y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = llm_build_norm(ctx, y, hparams, + model.layers[il].ssm_norm, NULL, + LLM_NORM_RMS, cb, il); + y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -12788,7 +13029,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_mamba() { + struct ggml_cgraph * build_mamba(int32_t version = 1) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_tensor * cur; @@ -12807,9 +13048,19 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + switch (version) { + case 2: + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + case 1: + default: + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + } if (il == n_layer - 1) { // skip computing output for unused tokens @@ -14858,7 +15109,11 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_MAMBA: { - result = llm.build_mamba(); + result = llm.build_mamba(/* version */ 1); + } break; + case LLM_ARCH_MAMBA2: + { + result = llm.build_mamba(/* version */ 2); } break; case LLM_ARCH_XVERSE: { @@ -17954,6 +18209,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -18125,6 +18381,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { + case LLM_ARCH_MAMBA2: case LLM_ARCH_MAMBA: return true; default: return false; } From dceff23faec99945d3161d24ea209a0c433546db Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 21:49:39 -0400 Subject: [PATCH 026/117] ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states --- ggml/src/ggml.c | 95 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6668209081b6c..f8e708088b357 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (ne00 > 1 && ne10 == 1) { + // fast broadcast path + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + const float scale = src1_ptr[0]; + + if (scale == 0.0f) { + // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, + // but it is useful when resetting the state of recurrent models. + memset((char *)dst->data + ir*nb1, 0, nb1); + } else { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + } + if (scale != 1.0f) { + ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); + } + } + } + } else if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); @@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32( const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); - // TODO: SIMD implementation // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; +#if defined(GGML_SIMD) + const int np = (nc & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); +#else + const int np = 0; +#endif // d_state - for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * dA) + (B[ig] * x_dt); + const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } else { @@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32( // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. // d_state for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } From 2bfe9de6d3a3598d4b778f9b144bb8ac33c2797b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 22:43:39 -0400 Subject: [PATCH 027/117] llama : support running Mamba-Codestral-7B-v0.1 --- convert_hf_to_gguf.py | 4 ++++ src/llama.cpp | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0ac64574a3043..a5bdd5def2029 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2843,6 +2843,10 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + if name.endswith(".dt_bias"): name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" diff --git a/src/llama.cpp b/src/llama.cpp index 5be0ef7a2ac7a..fd80361bd7605 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9383,7 +9383,7 @@ static struct ggml_tensor * llm_build_mamba2( // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); y = llm_build_norm(ctx, y, hparams, - model.layers[il].ssm_norm, NULL, + ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); From aff96920f972d8e042dfdef6dc08644cd8df0234 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 16:28:07 -0400 Subject: [PATCH 028/117] llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted --- ggml/src/ggml.c | 4 ++-- src/llama.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f8e708088b357..415fa6901304a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10226,11 +10226,11 @@ static void ggml_compute_forward_mul_f32( if (scale == 0.0f) { // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, // but it is useful when resetting the state of recurrent models. - memset((char *)dst->data + ir*nb1, 0, nb1); + memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); } else { if (dst->data != src0->data) { // src0 is same shape as dst => same indices - memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); } if (scale != 1.0f) { ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); diff --git a/src/llama.cpp b/src/llama.cpp index fd80361bd7605..03f93164a89e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9335,7 +9335,7 @@ static struct ggml_tensor * llm_build_mamba2( ggml_cpy(ctx, last_conv, ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); // 1D convolution // The equivalent is to make a self-overlapping view of conv_x From e04910dc48966f1cbc7309d12b8e1b55bdd33df2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 23:06:22 -0400 Subject: [PATCH 029/117] llama : remove unused variable --- src/llama.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 03f93164a89e8..dda3d51b017d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7718,7 +7718,6 @@ static bool llm_load_tensors( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = n_embd / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; @@ -9287,7 +9286,7 @@ static struct ggml_tensor * llm_build_mamba2( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t head_dim = d_inner / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = batch.n_seqs; From fa358e707132ace9012cb90880abe86fd32464a6 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 01:13:43 -0400 Subject: [PATCH 030/117] llama : add missing break --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index dda3d51b017d6..5b6b6707a1c95 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5628,7 +5628,7 @@ static void llm_load_hparams( } break; default: model.type = e_model::MODEL_UNKNOWN; } - } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); From 38913dc8ddd1e119df0e0cfcacfb260b9b1f5c02 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 14:31:12 -0400 Subject: [PATCH 031/117] convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a5bdd5def2029..4851926b7b98f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2801,13 +2801,13 @@ def set_vocab(self): vocab_size = -(vocab_size // -pad_vocab) * pad_vocab self.hparams["vocab_size"] = vocab_size - if (self.dir_model / "tokenizer.json").is_file(): - self._set_vocab_gpt2() - elif (self.dir_model / "tokenizer.model").is_file(): + if (self.dir_model / "tokenizer.model").is_file(): self._set_vocab_sentencepiece() elif (self.dir_model / "tokenizer.model.v3").is_file(): # mamba-codestral raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() else: # Use the GPT-NeoX tokenizer when no tokenizer files are present self._set_vocab_builtin("gpt-neox", vocab_size) From fcb889cf7fb6588a6565f4cc6373be3f53ff25ca Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 20:31:30 -0400 Subject: [PATCH 032/117] llama : session saving and reloading for hybrid models --- include/llama.h | 4 +- src/llama.cpp | 519 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 390 insertions(+), 133 deletions(-) diff --git a/include/llama.h b/include/llama.h index 59f38936fbed7..6f6e73c901091 100644 --- a/include/llama.h +++ b/include/llama.h @@ -38,10 +38,10 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 9 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ -#define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_STATE_SEQ_VERSION 3 #ifdef __cplusplus extern "C" { diff --git a/src/llama.cpp b/src/llama.cpp index 213a27cc8e2db..0f55196cf8edb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19839,8 +19839,28 @@ struct llama_data_write { } } + void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = rs_self.cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0; + + write(&pos, sizeof(pos)); + write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_node : cell.seq_nodes) { + write(&seq_node.seq_id, sizeof(seq_node.seq_id)); + } + } + } + } + } + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { - const struct llama_kv_cache & kv_self = ctx->kv_self; + const struct llama_kv_cache & kv_self = ctx->cache.kv; const struct llama_hparams & hparams = ctx->model.hparams; const uint32_t v_trans = kv_self.v_trans ? 1 : 0; @@ -19849,12 +19869,10 @@ struct llama_data_write { write(&v_trans, sizeof(v_trans)); write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Write key type const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; @@ -19874,7 +19892,7 @@ struct llama_data_write { if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19895,7 +19913,7 @@ struct llama_data_write { // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19922,43 +19940,151 @@ struct llama_data_write { } } - void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { - const struct llama_kv_cache & kv_self = ctx->kv_self; - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; + void write_rs_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { + const struct llama_rs_cache & rs_self = ctx->cache.rs; + const struct llama_hparams & hparams = ctx->model.hparams; - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = kv_self.size; - for (uint32_t i = 0; i < kv_self.size; ++i) { - const auto & cell = kv_self.cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == kv_self.size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = kv_self.size; + const uint32_t n_layer = hparams.n_layer; + + write(&n_layer, sizeof(n_layer)); + + // Iterate and write all recurrent states, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Write type + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + write(&r_type_i, sizeof(r_type_i)); + + // Write row size + const uint64_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + write(&r_size_row, sizeof(r_size_row)); + + // Read each range of cells of r_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * r_size_row; + write_tensor_data(rs_self.r_l[il], range.first * r_size_row, buf_size); + } + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Write type + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + write(&s_type_i, sizeof(s_type_i)); + + // Write row size + const uint64_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + write(&s_size_row, sizeof(s_size_row)); + + // Read each range of cells of s_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * s_size_row; + write_tensor_data(rs_self.s_l[il], range.first * s_size_row, buf_size); + } + } + } + + void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { + const struct llama_kv_cache & kv_self = ctx->cache.kv; + const struct llama_rs_cache & rs_self = ctx->cache.rs; + std::vector> kv_cell_ranges; // ranges, from inclusive, to exclusive + std::vector> rs_cell_ranges; // ranges, from inclusive, to exclusive + uint32_t kv_cell_count = 0; + uint32_t rs_cell_count = 0; + // Transformer KV cache + { + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++kv_cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = kv_self.size; + } } } + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, kv_self.size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : kv_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(kv_cell_count == cell_count_check); } - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, kv_self.size); + // Recurrent state cache + if (seq_id == -1) { + // Find all the ranges of cells + uint32_t cell_range_begin = rs_self.size; + for (uint32_t i = 0; i < rs_self.size; ++i) { + const auto & cell = rs_self.cells[i]; + if (!cell.is_empty()) { + ++rs_cell_count; + if (cell_range_begin == rs_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = rs_self.size; + } + } + } + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, rs_self.size); + } + + } else { + // Find the cell ranges of the specified seq_id + if ((size_t) seq_id < rs_self.seq_tails.size()) { + int32_t tail_cell_id = rs_self.seq_tails[seq_id].tail; + if (tail_cell_id >= 0) { + ++rs_cell_count; + rs_cell_ranges.emplace_back(tail_cell_id, tail_cell_id + 1); + } + } } - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; + { + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : rs_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(rs_cell_count == cell_count_check); } - GGML_ASSERT(cell_count == cell_count_check); - write(&cell_count, sizeof(cell_count)); + write(&kv_cell_count, sizeof(kv_cell_count)); + write(&rs_cell_count, sizeof(rs_cell_count)); - write_kv_cache_meta(kv_self, cell_ranges, seq_id); - write_kv_cache_data(ctx, cell_ranges); + if (seq_id == -1) { + // write metadata for both when the whole cache needs to be saved + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } else if (kv_cell_count > 0) { + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + } else { + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } + if (kv_cell_count > 0) { + write_kv_cache_data(ctx, kv_cell_ranges); + } + if (rs_cell_count > 0) { + write_rs_cache_data(ctx, rs_cell_ranges); + } } }; @@ -20050,108 +20176,98 @@ struct llama_data_read { } } - bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { - struct llama_kv_cache & kv_self = ctx->kv_self; + bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_kv_cache & kv_self = cache.kv; + + // whole KV cache restore + + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } - if (dest_seq_id != -1) { - // single sequence + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = kv_self.cells[i]; - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + llama_pos pos; + uint32_t n_seq_id; - llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; + cell.pos = pos; - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); return false; } - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!llama_kv_cache_find_slot(kv_self, batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; + cell.seq_id.insert(seq_id); } + } - // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); - GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore + kv_self.head = 0; + kv_self.used = cell_count; - if (cell_count > kv_self.size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } + return true; + } - llama_kv_cache_clear(kv_self); + bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_rs_cache & rs_self = cache.rs; - for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = kv_self.cells[i]; + // whole RS cache restore - llama_pos pos; - uint32_t n_seq_id; + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache\n", __func__); + return false; + } - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + for (uint32_t i = 0; i < cell_count; ++i) { + llama_rs_cell & cell = rs_self.cells[i]; - cell.pos = pos; + llama_pos pos; + uint32_t n_seq_id; - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - read_to(&seq_id, sizeof(seq_id)); + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - return false; - } + cell.pos = pos; + cell.src = i; - cell.seq_id.insert(seq_id); + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); - if (kv_self.recurrent) { - int32_t & tail = kv_self.cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + return false; } - } - kv_self.head = 0; - kv_self.used = cell_count; - } + cell.insert_node(seq_id); - if (kv_self.recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = kv_self.head + i; - // make sure the recurrent states will keep their restored state - kv_self.cells[cell_id].src = cell_id; } } + rs_self.head = 0; + rs_self.used = cell_count; + + rs_self.rebuild(/* debug */ false); + return true; } bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } const struct llama_hparams & hparams = ctx->model.hparams; - struct llama_kv_cache & kv_self = ctx->kv_self; + struct llama_kv_cache & kv_self = ctx->cache.kv; uint32_t v_trans; uint32_t n_layer; read_to(&v_trans, sizeof(v_trans)); @@ -20172,7 +20288,7 @@ struct llama_data_read { // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Read type of key int32_t k_type_i_ref; @@ -20192,15 +20308,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); - } + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); } if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -20220,15 +20334,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); - } + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); } } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -20256,29 +20368,174 @@ struct llama_data_read { return false; } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } return true; } - void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { - uint32_t cell_count; - read_to(&cell_count, sizeof(cell_count)); + bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + const struct llama_hparams & hparams = ctx->model.hparams; + struct llama_rs_cache & rs_self = ctx->cache.rs; + uint32_t n_layer; + read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache to restore state (%u > %u)\n", __func__, cell_count, rs_self.size); + return false; + } + + // For each layer, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Read type of key + int32_t r_type_i_ref; + read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + if (r_type_i != r_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t r_size_row_ref; + read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + const size_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + if (r_size_row != r_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.r_l[il], read(cell_count * r_size_row), rs_self.head * r_size_row, cell_count * r_size_row); + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Read type of key + int32_t s_type_i_ref; + read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + if (s_type_i != s_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t s_size_row_ref; + read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + const size_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + if (s_size_row != s_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.s_l[il], read(cell_count * s_size_row), rs_self.head * s_size_row, cell_count * s_size_row); + } + + return true; + } + + bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { + + if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id); + return false; + } + + // single sequence + + llama_past & cache = ctx->cache; + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &seq_id; + if (!llama_past_find_slot(cache, batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + if (cache.kv.size > 0) { + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(cache.kv.head + cell_count <= cache.kv.size); + GGML_ASSERT(cache.kv.cells[cache.kv.head].pos == batch.pos[0]); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.kv.cells[cache.kv.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].has_seq_id(seq_id)); + } + if (cache.rs.size > 0) { + GGML_ASSERT(cache.rs.head + cache.rs.n <= cache.rs.size); + GGML_ASSERT(cache.rs.n == 1); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.rs.cells[cache.rs.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].has_seq_id(seq_id)); + // Prevent cells from being cleared + for (uint32_t i = cache.rs.head; i < cache.rs.head + cache.rs.n; ++i) { + cache.rs.cells[i].src = i; + } + } + + return true; + } + + void read_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + uint32_t kv_cell_count; + read_to(&kv_cell_count, sizeof(kv_cell_count)); + uint32_t rs_cell_count; + read_to(&rs_cell_count, sizeof(rs_cell_count)); + + bool res = true; + + if (seq_id == -1) { + llama_past_clear(ctx); + res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count); + } else { + llama_past_seq_rm(ctx, seq_id, -1, -1); + // Only a single recurrent cell at most, + // because otherwise the cells can be shuffled when a slot is allocated + if (rs_cell_count > 1) { + LLAMA_LOG_ERROR("%s: too many recurrent state cells for single-sequence session\n", __func__); + res = false; + } + res = res && read_cache_seq_meta(ctx, std::max(kv_cell_count, rs_cell_count), seq_id); + } - bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + res = res && read_kv_cache_data(ctx, kv_cell_count) && read_rs_cache_data(ctx, rs_cell_count); if (!res) { if (seq_id == -1) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); } else { - llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); + llama_past_seq_rm(ctx, seq_id, -1, -1); } throw std::runtime_error("failed to restore kv cache"); } @@ -20433,7 +20690,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da data_ctx.write_logits(ctx); data_ctx.write_embeddings(ctx); - data_ctx.write_kv_cache(ctx); + data_ctx.write_cache(ctx); return data_ctx.get_size_written(); } @@ -20473,7 +20730,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da data_ctx.read_logits(ctx); data_ctx.read_embeddings(ctx); - data_ctx.read_kv_cache(ctx); + data_ctx.read_cache(ctx); return data_ctx.get_size_read(); } @@ -20569,7 +20826,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { llama_synchronize(ctx); - data_ctx.write_kv_cache(ctx, seq_id); + data_ctx.write_cache(ctx, seq_id); return data_ctx.get_size_written(); } @@ -20592,7 +20849,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { llama_synchronize(ctx); - data_ctx.read_kv_cache(ctx, dest_seq_id); + data_ctx.read_cache(ctx, dest_seq_id); return data_ctx.get_size_read(); } From 9d3f44dad426acc26d35e3b6cf1462d3a3f43113 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 21:46:27 -0400 Subject: [PATCH 033/117] convert_hf : fix Jamba conversion --- convert_hf_to_gguf.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 00059bd01afca..e9bb4b20bd6d3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2910,7 +2910,6 @@ def set_gguf_parameters(self): n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) ] - self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) self.gguf_writer.add_embedding_length(d_model) @@ -2979,8 +2978,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield new_name, data_torch - def write_tensors(self): - super().write_tensors() + def prepare_tensors(self): + super().prepare_tensors() if self._experts is not None: # flatten `list[dict[str, Tensor]]` into `list[str]` @@ -2988,20 +2987,6 @@ def write_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") - # same as Mamba - def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del n_dims # unused - - return bid is not None and new_name in ( - self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ - gguf.MODEL_TENSOR.SSM_CONV1D, - gguf.MODEL_TENSOR.SSM_X, - gguf.MODEL_TENSOR.SSM_DT, - gguf.MODEL_TENSOR.SSM_A, - gguf.MODEL_TENSOR.SSM_D, - ] - ) - @Model.register("CohereForCausalLM") class CommandR2Model(Model): From 5f62db790b8e548eb7db0f69a9fadb7f809f6c96 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 21:50:27 -0400 Subject: [PATCH 034/117] llama : fix mixed signedness comparison --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 842de9118876c..cf7dccb384f2b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20963,7 +20963,7 @@ struct llama_data_read { bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { - if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0 || seq_id >= (llama_seq_id) llama_n_seq_max(ctx)) { LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id); return false; } From 375de5b1f8c07b5bfdef7f00b738eb176f8431ba Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 1 Sep 2024 21:59:24 -0400 Subject: [PATCH 035/117] llama : use unused n_embd_k_gqa in k_shift This also slightly reduces the diff from the master branch --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index cf7dccb384f2b..043f3d7ec7853 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10806,7 +10806,7 @@ struct llm_build_context { ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_head_kv, n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)), + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); From 4bb4b22a58b06da1fef8193c97ead7a9099c37bf Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 14 Sep 2024 15:00:07 -0400 Subject: [PATCH 036/117] llama : begin renaming llama_past back to llama_kv_cache --- src/llama.cpp | 106 ++++++++++++++++++++++++-------------------------- 1 file changed, 50 insertions(+), 56 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 043f3d7ec7853..2dc413da903b6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2719,7 +2719,7 @@ struct llama_kv_cell { }; // ring-buffer of cached KV data -struct llama_kv_cache { +struct llama_kv_self_cache { bool has_shift = false; bool do_defrag = false; bool v_trans = true; // the value tensor is transposed @@ -2820,7 +2820,7 @@ struct llama_rs_seq_meta { }; // ring-buffered tree of cached recurrent state data -struct llama_rs_cache { +struct llama_rs_self_cache { uint32_t head = 0; // first state used for the last slot uint32_t size = 0; @@ -3444,12 +3444,12 @@ struct llama_rs_cache { } }; -struct llama_past { +struct llama_kv_cache { // key + value cache for self attention - llama_kv_cache kv; + llama_kv_self_cache kv; // recurrent state cache for state space models - llama_rs_cache rs; + llama_rs_self_cache rs; std::vector ctxs; std::vector bufs; @@ -3463,7 +3463,7 @@ struct llama_past { return size; } - ~llama_past() { + ~llama_kv_cache() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -3949,7 +3949,7 @@ struct llama_context { struct llama_cparams cparams; struct llama_sampling sampling; struct llama_sbatch sbatch; - struct llama_past cache; + struct llama_kv_cache cache; struct llama_control_vector cvec; std::unordered_map lora_adapters; @@ -4195,8 +4195,8 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { // kv and rs cache helpers // -static bool llama_past_init( - struct llama_past & cache, +static bool llama_kv_cache_init( + struct llama_kv_cache & cache, const llama_context * ctx, ggml_type type_k, ggml_type type_v, @@ -4300,11 +4300,11 @@ static bool llama_past_init( // no buffer was needed, so this is fine return true; } - LLAMA_LOG_ERROR("%s: failed to allocate buffer for past cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s past cache size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -4315,9 +4315,9 @@ static bool llama_past_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_past_find_slot( - struct llama_past & cache, - const struct llama_ubatch & batch) { +static bool llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_ubatch & batch) { const uint32_t kv_size = cache.kv.size; const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; @@ -4563,7 +4563,7 @@ static bool llama_past_find_slot( } // find how many KV cells are currently in use -static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { +static uint32_t llama_kv_cache_cell_max(const struct llama_kv_self_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_kv_cell & cell = cache.cells[i - 1]; @@ -4576,7 +4576,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { } // find how many recurrent state cells are currently in use -static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { +static uint32_t llama_rs_cache_cell_max(const struct llama_rs_self_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_rs_cell & cell = cache.cells[i - 1]; @@ -4588,7 +4588,7 @@ static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { return 0; } -static void llama_past_clear(struct llama_past & cache) { +static void llama_past_clear(struct llama_kv_cache & cache) { if (cache.kv.size > 0) { for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; @@ -4623,7 +4623,7 @@ static void llama_past_clear(struct llama_past & cache) { } static llama_pos llama_past_seq_rm( - struct llama_past & cache, + struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { @@ -4722,7 +4722,7 @@ static llama_pos llama_past_seq_rm( } static llama_pos llama_past_seq_cp( - struct llama_past & cache, + struct llama_kv_cache & cache, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, @@ -4786,7 +4786,7 @@ static llama_pos llama_past_seq_cp( return n_past; } -static void llama_past_seq_keep(struct llama_past & cache, llama_seq_id seq_id) { +static void llama_past_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; @@ -4837,7 +4837,7 @@ static void llama_past_seq_keep(struct llama_past & cache, llama_seq_id seq_id) } static void llama_past_seq_add( - struct llama_past & cache, + struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1, @@ -4905,7 +4905,7 @@ static void llama_past_seq_add( } static void llama_past_seq_div( - struct llama_past & cache, + struct llama_kv_cache & cache, llama_seq_id seq_id, llama_pos p0, llama_pos p1, @@ -4945,7 +4945,7 @@ static void llama_past_seq_div( } } -static llama_pos llama_past_seq_pos_max(struct llama_past & cache, llama_seq_id seq_id) { +static llama_pos llama_past_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { llama_pos result = -1; if (cache.rs.size > 0) { @@ -4970,7 +4970,7 @@ static llama_pos llama_past_seq_pos_max(struct llama_past & cache, llama_seq_id return result; } -static void llama_kv_cache_defrag(struct llama_kv_cache & cache) { +static void llama_kv_cache_defrag(struct llama_kv_self_cache & cache) { cache.do_defrag = true; } @@ -9772,7 +9772,7 @@ static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache & kv, + const llama_kv_self_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, @@ -10129,7 +10129,7 @@ static struct ggml_tensor * llm_build_moe_ffn( static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, struct llama_context & lctx, - const llama_kv_cache & kv, + const llama_kv_self_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, struct ggml_tensor * wo_b, @@ -10260,7 +10260,7 @@ static struct ggml_tensor * llm_build_kqv( static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, struct llama_context & lctx, - const llama_kv_cache & kv, + const llama_kv_self_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, struct ggml_tensor * wo_b, @@ -10344,7 +10344,7 @@ static struct ggml_tensor * llm_build_mamba( int il) { const llama_model & model = lctx.model; const llama_hparams & hparams = model.hparams; - const llama_rs_cache & rs = lctx.cache.rs; + const llama_rs_self_cache & rs = lctx.cache.rs; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; @@ -10661,8 +10661,8 @@ struct llm_build_context { const llama_hparams & hparams; const llama_cparams & cparams; const llama_ubatch & batch; - const llama_kv_cache & kv_self; - const llama_rs_cache & rs_self; + const llama_kv_self_cache & kv_self; + const llama_rs_self_cache & rs_self; const int64_t n_embd; const int64_t n_layer; @@ -17367,17 +17367,11 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - if (!llama_past_find_slot(lctx.cache, ubatch)) { + if (!llama_kv_cache_find_slot(lctx.cache, ubatch)) { return 1; } - // TODO: move into llama_past_find_slot + // TODO: move into llama_kv_cache_find_slot if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -19557,7 +19551,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_past_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -19575,7 +19569,7 @@ struct llama_context * llama_new_context_with_model( memory_size_s += ggml_nbytes(s); } - LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: RS self size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); @@ -19592,7 +19586,7 @@ struct llama_context * llama_new_context_with_model( memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV self size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -20052,7 +20046,7 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { - const llama_kv_cache & kv_self = ctx->cache.kv; + const llama_kv_self_cache & kv_self = ctx->cache.kv; if (uint32_t(view->n_cells) < kv_self.size || view->cells == nullptr) { view->n_cells = int32_t(kv_self.size); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); @@ -20333,7 +20327,7 @@ struct llama_data_write { } } - void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + void write_kv_cache_meta(const llama_kv_self_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { @@ -20353,7 +20347,7 @@ struct llama_data_write { } } - void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + void write_rs_cache_meta(const llama_rs_self_cache & rs_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { @@ -20374,7 +20368,7 @@ struct llama_data_write { } void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { - const struct llama_kv_cache & kv_self = ctx->cache.kv; + const struct llama_kv_self_cache & kv_self = ctx->cache.kv; const struct llama_hparams & hparams = ctx->model.hparams; const uint32_t v_trans = kv_self.v_trans ? 1 : 0; @@ -20455,7 +20449,7 @@ struct llama_data_write { } void write_rs_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { - const struct llama_rs_cache & rs_self = ctx->cache.rs; + const struct llama_rs_self_cache & rs_self = ctx->cache.rs; const struct llama_hparams & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; @@ -20503,8 +20497,8 @@ struct llama_data_write { } void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { - const struct llama_kv_cache & kv_self = ctx->cache.kv; - const struct llama_rs_cache & rs_self = ctx->cache.rs; + const struct llama_kv_self_cache & kv_self = ctx->cache.kv; + const struct llama_rs_self_cache & rs_self = ctx->cache.rs; std::vector> kv_cell_ranges; // ranges, from inclusive, to exclusive std::vector> rs_cell_ranges; // ranges, from inclusive, to exclusive uint32_t kv_cell_count = 0; @@ -20692,8 +20686,8 @@ struct llama_data_read { bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) { if (cell_count == 0) { return true; } - struct llama_past & cache = ctx->cache; - struct llama_kv_cache & kv_self = cache.kv; + struct llama_kv_cache & cache = ctx->cache; + struct llama_kv_self_cache & kv_self = cache.kv; // whole KV cache restore @@ -20734,8 +20728,8 @@ struct llama_data_read { bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) { if (cell_count == 0) { return true; } - struct llama_past & cache = ctx->cache; - struct llama_rs_cache & rs_self = cache.rs; + struct llama_kv_cache & cache = ctx->cache; + struct llama_rs_self_cache & rs_self = cache.rs; // whole RS cache restore @@ -20781,7 +20775,7 @@ struct llama_data_read { bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { if (cell_count == 0) { return true; } const struct llama_hparams & hparams = ctx->model.hparams; - struct llama_kv_cache & kv_self = ctx->cache.kv; + struct llama_kv_self_cache & kv_self = ctx->cache.kv; uint32_t v_trans; uint32_t n_layer; read_to(&v_trans, sizeof(v_trans)); @@ -20895,7 +20889,7 @@ struct llama_data_read { bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) { if (cell_count == 0) { return true; } const struct llama_hparams & hparams = ctx->model.hparams; - struct llama_rs_cache & rs_self = ctx->cache.rs; + struct llama_rs_self_cache & rs_self = ctx->cache.rs; uint32_t n_layer; read_to(&n_layer, sizeof(n_layer)); @@ -20970,7 +20964,7 @@ struct llama_data_read { // single sequence - llama_past & cache = ctx->cache; + llama_kv_cache & cache = ctx->cache; llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; batch.n_seq_tokens = cell_count; @@ -20992,7 +20986,7 @@ struct llama_data_read { } batch.n_seq_id[0] = 1; batch.seq_id[0] = &seq_id; - if (!llama_past_find_slot(cache, batch)) { + if (!llama_kv_cache_find_slot(cache, batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } From 273e7a495ad8c93bb9ba8123c1a3de3c68f93cf9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 30 Sep 2024 15:52:42 -0400 Subject: [PATCH 037/117] llama : avoid redundant state copy for Mamba 1 and 2 --- ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 50 ++++++------ src/llama.cpp | 154 +++++++++++++++++-------------------- tests/test-backend-ops.cpp | 54 ++++++++++--- 4 files changed, 142 insertions(+), 119 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fec6798ff6d06..1fc53bebebf30 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1833,7 +1833,8 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D); + struct ggml_tensor * D, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 12e4f26942f86..1c4c393e55d06 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D) { + struct ggml_tensor * D, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); @@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; @@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(ggml_is_3d(dt)); GGML_ASSERT(s->ne[1] == head_dim); GGML_ASSERT(s->ne[2] == n_head); - GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); GGML_ASSERT(D->ne[0] == n_head); GGML_ASSERT(ggml_is_vector(D)); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); - if (ggml_is_vector(A)) { - // Mamba-2 - GGML_ASSERT(A->ne[0] == n_head); - } else { - // Mamba-1 + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == n_head); - GGML_ASSERT(ggml_is_matrix(A)); } } @@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan( } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[4] = B; result->src[5] = C; result->src[6] = D; + result->src[7] = ids; return result; } @@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} - const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} + const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nh = src1->ne[1]; // n_head const int64_t ng = src4->ne[1]; const int64_t nt = src1->ne[2]; // number of tokens per sequence - const int64_t ns = src0->ne[3]; // number of sequences in the batch + const int64_t ns = src1->ne[3]; // number of sequences in the batch - const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(float)); + GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); + const int32_t * ids = (const int32_t *) src7->data; + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + for (int i2 = 0; i2 < nt; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} - const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} - float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} - - // use the output as the source when it's not the first token-wise iteration - if (i2 > 0) { s0 = s; } - if (ggml_is_vector(src3)) { + if (src3->ne[0] == 1) { // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop // n_head @@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32( } } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } } } diff --git a/src/llama.cpp b/src/llama.cpp index c11472112f8fb..3e1f8755ffb85 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2801,6 +2801,10 @@ struct llama_kv_cache { // computed before each graph build uint32_t n = 0; + // first zero-ed state + // NOTE: only used by recurrent models + int32_t rs_z = -1; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; @@ -3381,8 +3385,6 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] @@ -3813,6 +3815,15 @@ static bool llama_kv_cache_find_slot( } } + // Find first to-be-cleared cell + cache.rs_z = -1; + for (int i = min; i <= max; ++i) { + if (cache.cells[i].src == -1) { + cache.rs_z = i; + break; + } + } + // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; @@ -9569,36 +9580,42 @@ static struct ggml_tensor * llm_build_kv( return cur; } -static struct ggml_tensor * llm_build_copy_mask_state( +static struct ggml_tensor * llm_build_rs( struct ggml_context * ctx, struct ggml_cgraph * graph, struct ggml_tensor * s, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t n_state, int32_t kv_size, int32_t kv_head, int32_t n_kv, - int32_t n_seqs) { + int32_t n_seqs, + bool avoid_copies = false) { struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx, states, state_copy); - - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx, states, state_mask); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + struct ggml_tensor * state_zero = ggml_view_1d(ctx, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(graph, ggml_scale_inplace(ctx, state_zero, 0)); // copy states which won't be changed further (between n_seqs and n_kv) + struct ggml_tensor * states_extra = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), + states_extra, ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_seqs, 0)); + // the part of the states that will be used and modified + states = ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + } + + return states; } // TODO: split @@ -9609,7 +9626,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9639,14 +9656,14 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9711,10 +9728,11 @@ static struct ggml_tensor * llm_build_mamba( x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -9746,7 +9764,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9772,14 +9790,14 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9835,9 +9853,12 @@ static struct ggml_tensor * llm_build_mamba2( // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); + // Use the same shape semantics for A as Mamba-1 + struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10069,6 +10090,7 @@ struct llm_build_context { const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_zero; // the first zero-ed recurrent state const int32_t n_ctx_orig; const bool flash_attn; @@ -10119,6 +10141,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + rs_zero (kv_self.rs_z), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -10147,8 +10170,6 @@ struct llm_build_context { lctx.inp_mean = nullptr; lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -10332,13 +10353,6 @@ struct llm_build_context { return lctx.inp_s_copy; } - struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - cb(lctx.inp_s_mask, "inp_s_mask", -1); - ggml_set_input(lctx.inp_s_mask); - return lctx.inp_s_mask; - } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; @@ -13901,7 +13915,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -13912,15 +13925,13 @@ struct llm_build_context { switch (version) { case 2: - cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; case 1: default: - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; } @@ -15946,7 +15957,6 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); @@ -15955,11 +15965,11 @@ struct llm_build_context { const llama_layer * layer = &model.layers[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, - gf, kv_self.k_l[il], state_copy, state_mask, + struct ggml_tensor * token_shift = llm_build_rs(ctx0, + gf, kv_self.k_l[il], state_copy, rs_zero, hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); - struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, - gf, kv_self.v_l[il], state_copy, state_mask, + struct ggml_tensor * wkv_states = llm_build_rs(ctx0, + gf, kv_self.v_l[il], state_copy, rs_zero, hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); @@ -16329,18 +16339,6 @@ static void llama_set_k_shift(llama_context & lctx) { } } -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; - } -} - static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; @@ -16656,24 +16654,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; - if (lctx.inp_s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } - } - } - if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); int32_t * data = (int32_t *) lctx.inp_s_copy->data; @@ -16683,8 +16663,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const uint32_t cell_id = i + kv_self.head; llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + if (kv_cell.src < 0) { + GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source + kv_cell.src = kv_self.rs_z; + } + if ((uint32_t) kv_cell.src >= kv_self.size) { + // ignore out-of-bound sources kv_cell.src = cell_id; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aa7896defdad0..092639eed42e1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case { const int64_t d_state; const int64_t d_inner; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_MUL_MAT @@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 #if 1 for (ggml_type type_a : base_types) { From 2c77d799f9387f5971289139aaca23b4ce37c435 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:36:22 -0400 Subject: [PATCH 038/117] metal : attempt to adapt SSM_SCAN for Mamba-2 --- ggml/src/ggml-metal.m | 107 ++++++++++++++++++++-------- ggml/src/ggml-metal.metal | 146 ++++++++++++++++++++++++++++++++------ 2 files changed, 202 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 9da08fe2e9771..5d5b98307d264 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -95,6 +95,7 @@ GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -591,6 +592,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); @@ -1629,47 +1631,74 @@ static void ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; + struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); + GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; + size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; + id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); const uint64_t nb30 = src3->nb[0]; const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); const uint64_t nb40 = src4->nb[0]; const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); const uint64_t nb50 = src5->nb[0]; const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; + + const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); + + const uint64_t nb70 = src7->nb[0]; const int64_t d_state = ne00; const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_seqs = ne13; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + if (ne30 == 1) { + // Mamba-2 + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1678,33 +1707,49 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; + [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; + [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; + [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_MUL_MAT: { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b200032394b1..c75fa25c34e7d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -795,7 +795,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part // TODO: optimize kernel void kernel_ssm_scan_f32( device const void * src0, @@ -804,14 +804,19 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, + device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, constant int64_t & n_seq_tokens, constant int64_t & n_seqs, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -824,47 +829,148 @@ kernel void kernel_ssm_scan_f32( constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, + constant uint64_t & nb43, constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq const int64_t nc = d_state; const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; const int64_t n_s = n_seqs; + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); - - if (i2 > 0) { - s0 = s; + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; } - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device const void * src7, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb43, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = expf(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } - y[0] = sumf; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; } } From 87b97d08f43652c7a2e73929e34432ae5f9e8713 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:41:10 -0400 Subject: [PATCH 039/117] metal : fix SSM_SCAN pipeline scope --- ggml/src/ggml-metal.m | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5d5b98307d264..477f720a0e32f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1693,11 +1693,13 @@ static void ggml_metal_encode_node( const int64_t n_seq_tokens = ne11; const int64_t n_seqs = ne13; + id pipeline = nil; + if (ne30 == 1) { // Mamba-2 - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; } else { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; } [encoder setComputePipelineState:pipeline]; From 03d0e6eabe6172a56a7d470bfd844012f2c2b291 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:58:41 -0400 Subject: [PATCH 040/117] metal : use log and exp instead of log1pf and expf in SSM_SCAN --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c75fa25c34e7d..cee9980a75619 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { const int64_t i = i0 + i1*nc; - const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } @@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - const float dA = expf(dt_soft_plus * A[0]); + const float dA = exp(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { From 7a351abc28e36aeb73d1fd8ce172db56fbb3ebcb Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:28:16 -0400 Subject: [PATCH 041/117] metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. --- ggml/src/ggml-metal.m | 53 ++++++++++++++++----------------------- ggml/src/ggml-metal.metal | 34 +++++++++---------------- 2 files changed, 34 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 477f720a0e32f..5127b34f8edaa 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1655,7 +1655,7 @@ static void ggml_metal_encode_node( const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); @@ -1663,7 +1663,7 @@ static void ggml_metal_encode_node( const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; const uint64_t nb43 = src4->nb[3]; @@ -1673,18 +1673,18 @@ static void ggml_metal_encode_node( const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; const uint64_t nb53 = src5->nb[3]; const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); - const uint64_t nb60 = src6->nb[0]; + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - const uint64_t nb70 = src7->nb[0]; + const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); const int64_t d_state = ne00; const int64_t d_inner = ne01; @@ -1718,32 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; - [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; - [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; - [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + // NOTE: max index is 31 if (ne30 == 1) { // Mamba-2 diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index cee9980a75619..3745f2f225512 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,30 +812,21 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -843,12 +834,16 @@ kernel void kernel_ssm_scan_f32( const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -864,7 +859,7 @@ kernel void kernel_ssm_scan_f32( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; @@ -901,30 +896,21 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -932,12 +918,16 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -953,7 +943,7 @@ kernel void kernel_ssm_scan_f32_group( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; From 8b15bc6fa0fbb7a0d831b90955430c0a9e281ac2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:47:56 -0400 Subject: [PATCH 042/117] metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. --- ggml/src/ggml-metal.m | 33 +++++++++++++++++---------------- ggml/src/ggml-metal.metal | 2 ++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5127b34f8edaa..3f7183060d83d 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1718,22 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3745f2f225512..c36eedb010de1 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,6 +812,7 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, @@ -896,6 +897,7 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, From 5b8ec2b978b84dfdb05e6fca4def928f72b1090c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 12:11:45 -0400 Subject: [PATCH 043/117] metal : fix SSM_SCAN state head offset --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c36eedb010de1..9e1d14ff5d8b5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} @@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} From 62b09b343c6c4e35486368f1a7b653c9ae58574a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 21:35:50 -0400 Subject: [PATCH 044/117] metal : fix wrong number of tokens per sequence in SSM_SCAN --- ggml/src/ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 3f7183060d83d..a39770bd4ed1b 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1690,7 +1690,7 @@ static void ggml_metal_encode_node( const int64_t d_inner = ne01; const int64_t n_head = ne02; const int64_t n_group = ne41; - const int64_t n_seq_tokens = ne11; + const int64_t n_seq_tokens = ne12; const int64_t n_seqs = ne13; id pipeline = nil; From 805512a73b9876853f0e7d0cd612259806fa5d93 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Oct 2024 16:20:26 -0400 Subject: [PATCH 045/117] ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. --- ggml/src/ggml.c | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e8a5e3d153548..8fd335270dd5a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10173,37 +10173,7 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (ne00 > 1 && ne10 == 1) { - // fast broadcast path - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - const float scale = src1_ptr[0]; - - if (scale == 0.0f) { - // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, - // but it is useful when resetting the state of recurrent models. - memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); - } else { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); - } - if (scale != 1.0f) { - ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); - } - } - } - } else if (nb10 == sizeof(float)) { + if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); From 3bc7103d2ef1c41cd380a1ad8d918cf9c26694d8 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 11:36:37 -0500 Subject: [PATCH 046/117] ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks --- convert_hf_to_gguf.py | 26 ++++++++++++++++- ggml/include/ggml.h | 1 - ggml/src/ggml-metal.m | 57 ++++++++++++++++---------------------- ggml/src/ggml-metal.metal | 14 +++------- ggml/src/ggml.c | 20 ++++--------- src/llama.cpp | 54 +++++++++++++++++++----------------- tests/test-backend-ops.cpp | 25 ++++++++--------- 7 files changed, 100 insertions(+), 97 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f307b1ac69202..f0a63d921d65f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -264,6 +264,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + # TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + del new_name, bid # unused + + return data_torch.squeeze() + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -295,7 +301,7 @@ def prepare_tensors(self): break for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): - data = data_torch.squeeze().numpy() + data = self.reshape_tensors(data_torch, new_name, bid).numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore if len(data.shape) == 0: @@ -3063,6 +3069,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + return data_torch.reshape((*data_torch.shape, 1)) + + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + return data_torch.reshape((n_group, d_inner // n_group)) + + return data_torch.squeeze() + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d2e5cb011a3b..735f56b005a28 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1828,7 +1828,6 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 73e2fedc36544..902728d8e6b55 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1649,25 +1649,21 @@ static void ggml_metal_encode_node( struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; struct ggml_tensor * src6 = node->src[6]; - struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); GGML_ASSERT(src6); - GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; size_t offs_src6 = 0; - size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); @@ -1699,10 +1695,6 @@ static void ggml_metal_encode_node( const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); - const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - - const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); - const int64_t d_state = ne00; const int64_t d_inner = ne01; const int64_t n_head = ne02; @@ -1727,31 +1719,30 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; - [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; - [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:8]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:9]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:10]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:11]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:12]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:13]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2f5a4d12eeec3..05d04e8f3fdbf 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -805,7 +805,6 @@ kernel void kernel_ssm_scan_f32( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -838,7 +837,6 @@ kernel void kernel_ssm_scan_f32( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -848,7 +846,7 @@ kernel void kernel_ssm_scan_f32( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -859,7 +857,6 @@ kernel void kernel_ssm_scan_f32( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -873,7 +870,7 @@ kernel void kernel_ssm_scan_f32( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; @@ -890,7 +887,6 @@ kernel void kernel_ssm_scan_f32_group( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -923,7 +919,6 @@ kernel void kernel_ssm_scan_f32_group( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -933,7 +928,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -944,7 +939,6 @@ kernel void kernel_ssm_scan_f32_group( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -959,7 +953,7 @@ kernel void kernel_ssm_scan_f32_group( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 91b256a4c25f0..9036fc0be9858 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7181,7 +7181,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -7205,7 +7204,6 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); @@ -7235,8 +7233,6 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); - GGML_ASSERT(D->ne[0] == n_head); - GGML_ASSERT(ggml_is_vector(D)); GGML_ASSERT(ids->ne[0] == n_seqs); GGML_ASSERT(ggml_is_vector(ids)); GGML_ASSERT(A->ne[1] == n_head); @@ -7258,8 +7254,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = D; - result->src[7] = ids; + result->src[6] = ids; return result; } @@ -16217,8 +16212,7 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} - const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} - const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16240,8 +16234,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(float)); - GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16252,7 +16245,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); - const int32_t * ids = (const int32_t *) src7->data; + const int32_t * ids = (const int32_t *) src6->data; for (int i3 = 0; i3 < ns; ++i3) { const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} @@ -16264,7 +16257,6 @@ static void ggml_compute_forward_ssm_scan_f32( const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} - const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} if (src3->ne[0] == 1) { @@ -16325,7 +16317,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } else { @@ -16353,7 +16345,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } diff --git a/src/llama.cpp b/src/llama.cpp index e84510ce8ffd1..52052caf250b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7120,6 +7120,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -7227,23 +7228,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV: { @@ -8572,10 +8577,10 @@ static bool llm_load_tensors( layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {n_head}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); @@ -9994,7 +9999,7 @@ static struct ggml_tensor * llm_build_rs( return states; } -// TODO: split +// TODO: split conv and ssm static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, struct llama_context & lctx, @@ -10102,13 +10107,14 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + cur = x; x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10120,6 +10126,7 @@ static struct ggml_tensor * llm_build_mamba( // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, cur, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -10184,7 +10191,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); // split the above in three - struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * z = ggml_view_4d(ctx, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); @@ -10230,11 +10237,9 @@ static struct ggml_tensor * llm_build_mamba2( dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); - // Use the same shape semantics for A as Mamba-1 - struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10242,17 +10247,16 @@ static struct ggml_tensor * llm_build_mamba2( ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + struct ggml_tensor * y = ggml_view_4d(ctx, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = llm_build_norm(ctx, y, hparams, - ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, - LLM_NORM_RMS, cb, il); + y = llm_build_norm(ctx, y, hparams, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ca254a45f23f..95f8abbd80968 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1589,35 +1589,34 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; const int64_t n_head; const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, int64_t d_state = 32, - int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t head_dim = 1, // non-zero for Mamba-2 int64_t n_head = 32, int64_t n_group = 1, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); - ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); - ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; } From b4e9c5998dea2d657cfd22bc2e6fa0630fba2fa9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 15:26:15 -0500 Subject: [PATCH 047/117] convert : fix flake8 lint --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f0efe5d5b0c7c..019e7b7ef93b6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3088,7 +3088,6 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.squeeze() - @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R From 8006f3b3c83d63995acfaff19cd2f9c3ffc52949 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 24 Nov 2024 20:35:30 -0500 Subject: [PATCH 048/117] llama : remove implicit recurrent state rollbacks --- common/common.cpp | 2 +- examples/batched-bench/batched-bench.cpp | 4 +- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 2 +- .../cvector-generator/cvector-generator.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/gritlm/gritlm.cpp | 4 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 4 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 8 +- .../llama.cpp.swift/LibLlama.swift | 8 +- examples/lookahead/lookahead.cpp | 12 +- examples/lookup/lookup.cpp | 2 +- examples/main/main.cpp | 17 +- examples/parallel/parallel.cpp | 10 +- examples/passkey/passkey.cpp | 28 +- examples/perplexity/perplexity.cpp | 12 +- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 4 +- examples/speculative/speculative.cpp | 22 +- ggml/src/ggml.c | 1 + include/llama.h | 59 +- src/llama.cpp | 1291 ++++------------- 25 files changed, 399 insertions(+), 1107 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d47b12acba04e..451307b554b6b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -966,7 +966,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_decoder(model)) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); } - llama_past_clear(lctx); + llama_kv_cache_clear(lctx); llama_synchronize(lctx); llama_perf_context_reset(lctx); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index ecaa793bafd8b..81c3220ada0b0 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -133,7 +133,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_ERR("%s: llama_decode() failed\n", __func__); @@ -142,7 +142,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_past_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index e1552750cf76f..d3d156932e60c 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -111,7 +111,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_past_seq_cp(context, 0, Int32(i), -1, -1) + llama_kv_cache_seq_cp(context, 0, Int32(i), -1, -1) } if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 83312ad969d18..3b554033e7ee4 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -138,7 +138,7 @@ int main(int argc, char ** argv) { //// assign the system KV cache to all parallel sequences //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them //for (int32_t i = 1; i < n_parallel; ++i) { - // llama_past_seq_cp(ctx, 0, i, -1, -1); + // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); //} if (n_parallel > 1) { diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 846905aacbed5..69e141ecb94e4 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { } static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3f38667ae74a3..3f18fc6a70878 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu const struct llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 36668b0e86638..6e42fa0734ecb 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,7 +44,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); @@ -99,7 +99,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std const llama_model * model = llama_get_model(ctx); llama_token eos_token = llama_token_eos(model); - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index d882d02de41f0..d1ff3e8bc4c7f 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -494,7 +494,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 5d092f000208e..f82c614f5706f 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -375,8 +375,8 @@ int main(int argc, char ** argv) { LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 6fbb97f85de5c..c22bdedcfa231 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1566,7 +1566,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // cool off before the test if (params.delay) { @@ -1606,7 +1606,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); uint64_t t_start = get_time_ns(); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 8e4ffd851e640..f5ffd063f8e2d 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_past_clear(context); + llama_kv_cache_clear(context); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_past_clear(context); + llama_kv_cache_clear(context); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_past_clear(context); + llama_kv_cache_clear(context); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -446,5 +446,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_past_clear(reinterpret_cast(context)); + llama_kv_cache_clear(reinterpret_cast(context)); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 570b4081c9942..dcd9803a2adc2 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -209,7 +209,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_past_clear(context) + llama_kv_cache_clear(context) let t_pp_start = ggml_time_us() @@ -222,7 +222,7 @@ actor LlamaContext { // bench text generation - llama_past_clear(context) + llama_kv_cache_clear(context) let t_tg_start = ggml_time_us() @@ -241,7 +241,7 @@ actor LlamaContext { let t_tg_end = ggml_time_us() - llama_past_clear(context) + llama_kv_cache_clear(context) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -291,7 +291,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_past_clear(context) + llama_kv_cache_clear(context) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 607c755fce0cc..03cd63f3fe95f 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -93,7 +93,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); for (int s = 1; s < W + G + 1; ++s) { - llama_past_seq_cp(ctx, 0, s, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -436,17 +436,17 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation // FIXME: recurrent and hybrid models - llama_past_seq_rm(ctx, -1, n_past, -1); + llama_kv_cache_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_past_seq_keep(ctx, seq_id_best); - llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_past_seq_rm (ctx, seq_id_best, -1, -1); + llama_kv_cache_seq_keep(ctx, seq_id_best); + llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); + llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_past_seq_cp(ctx, 0, s, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 700d519717e11..e2c8c3828f5d7 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -191,7 +191,7 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted // FIXME: recurrent and hybrid models - llama_past_seq_rm(ctx, 0, n_past, -1); + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); common_batch_clear(batch_tgt); common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4632da83442a9..fb10c20c5e36d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -332,10 +332,6 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } - - // remove any "future" tokens that we might have inherited from the previous session - n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1); - if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { LOG_INF("%s: using full prompt from session file\n", __func__); } else if (n_matching_session_tokens >= embd_inp.size()) { @@ -347,6 +343,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n", __func__, n_matching_session_tokens, embd_inp.size()); } + + // remove any "future" tokens that we might have inherited from the previous session + llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", @@ -358,8 +357,6 @@ int main(int argc, char ** argv) { LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); - } else { - session_tokens.resize(n_matching_session_tokens); } // number of tokens to keep when resetting context @@ -609,9 +606,9 @@ int main(int argc, char ** argv) { LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; @@ -625,8 +622,6 @@ int main(int argc, char ** argv) { if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { - // TODO: are the session tokens guaranteed to all be matching here? - // Should n_matching_session_tokens be re-used instead? if (embd[i] != session_tokens[n_session_consumed]) { session_tokens.resize(n_session_consumed); break; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index f3b54c2ee7640..20274c1479a47 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -199,7 +199,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_past_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_INF("\n"); @@ -231,9 +231,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_past_seq_rm(ctx, i, -1, -1); + llama_kv_cache_seq_rm(ctx, i, -1, -1); // but keep the system prompt - llama_past_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_INF("%s: clearing the KV cache\n", __func__); @@ -370,8 +370,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_past_seq_rm(ctx, client.id + 1, -1, -1); - llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 87df5f2421049..09bba708f6f91 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -130,11 +130,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update(ctx); + llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update (ctx); - n_past = llama_past_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } common_batch_clear(batch); @@ -164,12 +164,12 @@ int main(int argc, char ** argv) { LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag(ctx); - llama_kv_cache_update(ctx); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past = llama_past_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; common_batch_clear(batch); @@ -195,12 +195,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag(ctx); - llama_kv_cache_update(ctx); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past = llama_past_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 707e98b0b66d3..efb41b80a3df3 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -406,7 +406,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -580,7 +580,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -955,7 +955,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { return; } - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1232,7 +1232,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) return; } - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1602,7 +1602,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par return; } - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1789,7 +1789,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 3741f63d61704..1768aae510067 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index b936a2ec58cbf..3866cfa27e13e 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -199,7 +199,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_past_clear(ctx3); + llama_kv_cache_clear(ctx3); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5f054dea40d81..f809c46d5a308 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1029,7 +1029,7 @@ struct server_context { SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); clean_kv_cache = false; } @@ -1760,7 +1760,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_past_seq_rm(ctx, slot->id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); slot->cache_tokens.clear(); server_task_result result; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8dbcdba3050ba..33b469e8f5e68 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -410,15 +410,15 @@ int main(int argc, char ** argv) { { LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_past_seq_keep(ctx_dft, s_keep); - llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_past_seq_keep(ctx_dft, 0); + llama_kv_cache_seq_keep(ctx_dft, s_keep); + llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); + llama_kv_cache_seq_keep(ctx_dft, 0); // FIXME: recurrent and hybrid models - llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_past_seq_keep(ctx_tgt, s_keep); - llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_past_seq_keep(ctx_tgt, 0); + llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_kv_cache_seq_keep(ctx_tgt, s_keep); + llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); + llama_kv_cache_seq_keep(ctx_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -495,8 +495,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) { LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); + llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -577,9 +577,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_past_seq_keep(ctx_tgt, 0); + llama_kv_cache_seq_keep(ctx_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_past_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); } // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index abe8d04ff8f22..3f01092d9f59a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -19825,6 +19825,7 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/include/llama.h b/include/llama.h index ebd00e771f526..510e862caafa4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -41,7 +41,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 9 +#define LLAMA_SESSION_VERSION 10 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 3 @@ -613,58 +613,35 @@ extern "C" { LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed - LLAMA_API void llama_past_clear( + LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); - LLAMA_API DEPRECATED(void llama_kv_cache_clear( - struct llama_context * ctx), - "use llama_past_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past (one more than the largest remaining pos in the seq_id) - // which is only meaningful to handle for partial removals. - LLAMA_API llama_pos llama_past_seq_rm( + LLAMA_API bool llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); - LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1), - "use llama_past_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past (one more than the largest remaining pos in the destination seq_id) - // which is only meaningful to handle when partially copying. - LLAMA_API llama_pos llama_past_seq_cp( + LLAMA_API void llama_kv_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1), - "use llama_past_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_past_seq_keep( + LLAMA_API void llama_kv_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_past_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -672,19 +649,12 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_past_seq_add( + LLAMA_API void llama_kv_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_add( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta), - "use llama_past_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -692,28 +662,17 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_past_seq_div( + LLAMA_API void llama_kv_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_div( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d), - "use llama_past_seq_div instead"); // Returns the largest position present in the KV and/or RS cache for the specified sequence - LLAMA_API llama_pos llama_past_seq_pos_max( + LLAMA_API llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: diff --git a/src/llama.cpp b/src/llama.cpp index ee22ec394d876..0b3b181f70fe9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2847,69 +2847,24 @@ struct llama_kv_self_cache { } }; -// for recurrent models, use a tree of sequences to simplify -// quickly finding the tail cell of each sequence -// TODO: drop the _rs_ infix -struct llama_rs_seq_node { - llama_seq_id seq_id = -1; - int32_t next_cell = -1; - - // needed for automatic typecasting from a llama_seq_id - llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} - - // needed for more convenient std::find - bool operator==(const llama_rs_seq_node & other) const { - return seq_id == other.seq_id; - } - - bool is_tail() const { - return next_cell < 0; - } -}; - struct llama_rs_cell { llama_pos pos = -1; int32_t src = -1; // copy source id (cleared next when -1) - // Link to previous cell in this sequence. - // Sequences can only diverge, never converge, - // so this works when there are multiple seq_ids per cell too. - int32_t prev = -1; - - // ref count of tails (should match the number of next_cell == -1 in seq_nodes) - uint32_t tail_rc = 0; - - // seq_ids by insertion order, to simplify updating n_cells compared to a set - std::vector seq_nodes; - - void insert_node(const llama_rs_seq_node & node) { - auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node); - if (node_dest == seq_nodes.end()) { - seq_nodes.push_back(node); - } else { - // overwrite the pre-existing node with the same seq_id if it exists - *node_dest = node; - } - } + std::set seq_id; bool has_seq_id(const llama_seq_id & id) const { - return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end(); + return seq_id.find(id) != seq_id.end(); } bool is_empty() const { - return seq_nodes.empty(); + return seq_id.empty(); } }; struct llama_rs_seq_meta { // cell id of the latest state of this seq_id int32_t tail = -1; - // number of cells for which this seq_id is the first - // (useful to know if cells in this sequence should be pruned) - int32_t n_cells = 0; - // the last pos of this sequence if it is in the current ubatch, - // only set and used when finding a slot. - llama_pos ubatch_end_pos = -1; }; // ring-buffered tree of cached recurrent state data @@ -2922,32 +2877,17 @@ struct llama_rs_self_cache { // computed when finding a slot uint32_t n = 0; // range of states used for the last slot - // only counts cells which are tails of all of their sequences. - // useful to know the minimum reserved cell count per seq_id. - uint32_t n_seqs = 0; - // cells part of multiple sequences, - // but which are only the tail of some of them. - // useful to dismiss sequences used as a shared prompt - uint32_t n_shared_tail_cells = 0; - // with state models, a cell can hold the state for more than one past token - // TODO: it's probably not possible to always use contiguous cells std::vector cells; // find tail cells faster std::vector seq_tails; // map seq_ids to cell ids - // freeable cell ids, computed when finding a slot - // useful to find the smallest range to defrag - std::vector freeable; - // per layer // NOTE: the naming of r and s is arbitrary std::vector r_l; // rolling/shift states std::vector s_l; // ssm (recurrent) states - // TODO: maybe use a simpler data structure than a tree - // Inefficient, but thorough verification and rebuilding of the rs cache // from only the cells list with `pos` and seq_ids. // Should not be called in a hot loop except when desperate and/or debugging. @@ -2977,7 +2917,7 @@ struct llama_rs_self_cache { uint32_t used_verif = 0; for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { llama_rs_cell & cell = cells[cell_id]; - if (cell.seq_nodes.empty()) { + if (cell.is_empty()) { if (cell.pos >= 0) { if (debug) { LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", @@ -2986,6 +2926,8 @@ struct llama_rs_self_cache { cell.pos = -1; was_valid = false; } + } else { + used_verif += 1; } if (cell.pos < 0) { if (cell.pos != -1) { @@ -2996,30 +2938,19 @@ struct llama_rs_self_cache { cell.pos = -1; was_valid = false; } - if (!cell.seq_nodes.empty()) { + if (!cell.is_empty()) { if (debug) { LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n", - __func__, cell_id, cell.seq_nodes.size()); + __func__, cell_id, cell.seq_id.size()); } - cell.seq_nodes.clear(); + cell.seq_id.clear(); was_valid = false; } cell.src = -1; - if (cell.prev != -1) { - if (debug) { - LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n", - __func__, cell_id, cell.prev); - } - cell.prev = -1; - was_valid = false; - } } else if (!debug) { // Assuming the cache should be actually rebuilt when not debugging cell.src = cell_id; } - if (!cell.seq_nodes.empty()) { - used_verif += 1; - } } if (used != used_verif) { if (debug) { @@ -3051,480 +2982,10 @@ struct llama_rs_self_cache { seq.tail = tail; was_valid = false; } - int32_t prev = -1; - for (size_t i = 0; i < seq_cells.size(); ++i) { - uint32_t cell_id = seq_cells[i].second; - llama_rs_cell & cell = cells[cell_id]; - if (cell.prev != prev) { - // TODO: relax the error when multiple cells have the same pos - if (debug) { - LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n", - __func__, cell_id, cell.prev, prev); - } - cell.prev = prev; - was_valid = false; - } - prev = cell_id; - } - int32_t n_cells = 0; - int32_t next = -1; - for (size_t i = seq_cells.size(); i-- > 0;) { - uint32_t cell_id = seq_cells[i].second; - llama_rs_cell & cell = cells[cell_id]; - // assuming it's always found, because how else would it end up in the list of cells for this seq_id? - auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id); - if (seq_node == cell.seq_nodes.begin()) { - n_cells += 1; - } - if (seq_node->next_cell != next) { - // TODO: relax the error when multiple cells have the same pos - if (debug) { - LLAMA_LOG_ERROR("%s: invalid next cell for seq_id %d in cells[%u] (%d instead of %d)\n", - __func__, seq_id, cell_id, seq_node->next_cell, next); - } - seq_node->next_cell = next; - was_valid = false; - } - next = cell_id; - } - if (seq.n_cells != n_cells) { - if (debug) { - LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n", - __func__, seq_id, seq.n_cells, n_cells); - } - seq.n_cells = n_cells; - } - } - // tail_rc - for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { - llama_rs_cell & cell = cells[cell_id]; - uint32_t tail_rc = 0; - for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { - auto & seq = seq_tails[seq_id]; - if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) { - tail_rc += 1; - } - } - if (cell.tail_rc != tail_rc) { - if (debug) { - LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n", - __func__, cell_id, cell.tail_rc, tail_rc); - } - cell.tail_rc = tail_rc; - was_valid = false; - } - } - // n_seqs - uint32_t n_seqs_verif = 0; - uint32_t n_shared_tail_cells_verif = 0; - for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { - llama_rs_cell & rs_cell = cells[cell_id]; - if (!rs_cell.seq_nodes.empty()) { - if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) { - n_seqs_verif += 1; - } else if (rs_cell.tail_rc > 0) { - n_shared_tail_cells_verif += 1; - } - } - } - if (n_seqs != n_seqs_verif) { - if (debug) { - LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n", - __func__, n_seqs, n_seqs_verif); - } - n_seqs = n_seqs_verif; - was_valid = false; - } - if (n_shared_tail_cells != n_shared_tail_cells_verif) { - if (debug) { - LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n", - __func__, n_shared_tail_cells, n_shared_tail_cells_verif); - } - n_shared_tail_cells = n_shared_tail_cells_verif; - was_valid = false; } return was_valid; } - // each seq_id should have access to at least this many cells - // (to use when pruning (to avoid over-pruning)) - uint32_t min_cells_per_seq(const llama_ubatch & batch) const { - uint32_t seqs = n_seqs; - for (uint32_t i = 0; i < batch.n_seqs; ++i) { - llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_rs_seq_meta & new_seq = seq_tails[seq_id]; - if (new_seq.tail < 0 || new_seq.n_cells == 0) { - seqs += 1; - } - } - return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); - } - - void freeable_for_batch(const llama_ubatch & batch, llama_pos checkpoint_interval) { - GGML_ASSERT(batch.equal_seqs); - int32_t min_cells = min_cells_per_seq(batch); - - // TODO: minimize work required to find freeable cells - // currently, this finds freeable cells by excluding non-freeable cells, - // because some conditions are more easily expressed this way. - - freeable.assign(size, 1); - - for (llama_rs_seq_meta & seq : seq_tails) { - seq.ubatch_end_pos = -1; - } - - for (uint32_t i = 0; i < batch.n_seqs; ++i) { - int32_t n_seq_id = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_id; j++) { - llama_seq_id seq_id = batch.seq_id[i][j]; - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_tails.size()); - llama_rs_seq_meta & seq = seq_tails[seq_id]; - seq.ubatch_end_pos = batch.pos[i * batch.n_seq_tokens + batch.n_seq_tokens - 1]; - } - } - - for (llama_rs_seq_meta & seq : seq_tails) { - if (seq.tail >= 0 && freeable[seq.tail] != 0) { - llama_pos end_pos = seq.ubatch_end_pos; - // When is a tail cell not freeable? - if (end_pos < 0) { - // when any of its tails are not in the batch - freeable[seq.tail] = 0; - } else if (min_cells > 1) { - // TODO: fallback to this less often - llama_rs_cell & tail = cells[seq.tail]; - GGML_ASSERT(tail.pos < end_pos); - if (tail.prev < 0 || tail.pos + checkpoint_interval <= end_pos) { - // make a checkpoint before prompt processing - // TODO: should it always be done after instead? - freeable[seq.tail] = 0; - } else { - llama_rs_cell & prev = cells[tail.prev]; - if (prev.pos + checkpoint_interval <= end_pos) { - // make a checkpoint during text generation - freeable[seq.tail] = 0; - } - } - } - } - } - - for (uint32_t i = 0; i < size; ++i) { - llama_rs_cell & cell = cells[i]; - if (!cell.is_empty() && cell.tail_rc == 0) { - // TODO: reduce indirection here - llama_rs_seq_node & seq_node = cell.seq_nodes[0]; - llama_rs_seq_meta & seq = seq_tails[seq_node.seq_id]; - bool keep_tail = freeable[seq.tail] == 0; - // kept tails use an additional cell, so make them allow freeing a checkpoint - int32_t really_min_cells = keep_tail ? min_cells - 1 : min_cells; - // A checkpoint is kept if there's enough alloted space for this sequence - // or if it's the state right before the tail - if (seq.n_cells <= really_min_cells || (really_min_cells > 1 && seq_node.next_cell == seq.tail)) { - freeable[i] = 0; - } - } - } - } - - // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. - // Why an iterator? Because it allows using std::vector::erase. - std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { - GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - // The iterator needs to point inside the correct vector - GGML_ASSERT(&(*node_iter) >= rs_cell.seq_nodes.data() && &(*node_iter) < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); - if (node_iter != rs_cell.seq_nodes.end()) { - // update the tree - llama_rs_seq_node node = *node_iter; - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: because of this, partially removing seq_ids from cells should only be done from the tail - cells[node.next_cell].prev = rs_cell.prev; - } - if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { - llama_rs_cell & prev_cell = cells[rs_cell.prev]; - auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); - // assuming the previous node is always found - GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); - prev_node->next_cell = node.next_cell; - if (node.is_tail()) { - // move the tail back to the previous cell - prev_cell.tail_rc += 1; - if (prev_cell.seq_nodes.size() > 1) { - if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { - if (prev_cell.tail_rc == 1) { - n_shared_tail_cells += 1; - } - - if (rs_cell.tail_rc == 1) { - if (prev_cell.tail_rc != prev_cell.seq_nodes.size()) { - // o oo oo - // |/ -> o/ - // | | - // e.g. when removing the leaf of a split tree - n_seqs -= 1; - } else { - // o - // o -> oo - // | | - // e.g. when merging back with a previous tail - n_shared_tail_cells -= 1; - } - } - } - } - } - } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (rs_cell.tail_rc == 1) { - if (seq.tail < 0) { - // no more tail, no more sequence - if (rs_cell.seq_nodes.size() > 1) { - n_shared_tail_cells -= 1; - } else { - n_seqs -= 1; - } - } - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; - } else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) { - // will fully become a tail cell - if (rs_cell.tail_rc > 0) { - n_seqs += 1; - n_shared_tail_cells -= 1; - } - } - if (node_iter == rs_cell.seq_nodes.begin()) { - // this seq_id was the first in the list - seq.n_cells -= 1; - - auto next_node = std::next(node_iter); - if (next_node != rs_cell.seq_nodes.end()) { - // the next node is the new first one, so update its n_cells - if ((uint32_t) next_node->seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node->seq_id]; - next_seq.n_cells += 1; - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } else { - // this was the last seq_id of the cell - used -= 1; - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - // the other fields *should* have already been updated elsewhere - } - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - return rs_cell.seq_nodes.erase(node_iter); - } - return node_iter; - } - - void clear_cell(llama_rs_cell & rs_cell) { - GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { - node_iter = remove_seq_node_from_cell(rs_cell, node_iter); - } - } - - // returns whether or not the seq_id was removed - bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { - if (i_cell < size && (size_t) id < size) { - llama_rs_cell & rs_cell = cells[i_cell]; - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once - return node_iter != remove_seq_node_from_cell(rs_cell, node_iter); - } - return false; - } - - bool swap_cells(uint32_t i_src, uint32_t i_dst) { - if (i_src < size && i_dst < size && i_src != i_dst) { - llama_rs_cell & src = cells[i_src]; - llama_rs_cell & dst = cells[i_dst]; - - for (llama_rs_seq_node & seq_node : src.seq_nodes) { - if (seq_node.next_cell >= 0) { - llama_rs_cell & next = cells[seq_node.next_cell]; - next.prev = i_dst; - if ((uint32_t) seq_node.next_cell == i_dst) { - seq_node.next_cell = i_src; - } - } else { - // this is a tail - seq_tails[seq_node.seq_id].tail = i_dst; - } - } - for (llama_rs_seq_node & seq_node : dst.seq_nodes) { - if (seq_node.next_cell >= 0) { - llama_rs_cell & next = cells[seq_node.next_cell]; - next.prev = i_src; - if ((uint32_t) seq_node.next_cell == i_src) { - seq_node.next_cell = i_dst; - } - } else { - // this is a tail - seq_tails[seq_node.seq_id].tail = i_src; - } - } - - if (src.prev == dst.prev) { - // avoid swapping them twice - if (src.prev >= 0) { - llama_rs_cell & prev = cells[src.prev]; - for (llama_rs_seq_node & seq_node : prev.seq_nodes) { - if ((uint32_t) seq_node.next_cell == i_src) { - seq_node.next_cell = i_dst; - } else if ((uint32_t) seq_node.next_cell == i_dst) { - seq_node.next_cell = i_src; - } - } - } - } else { - if (src.prev >= 0) { - llama_rs_cell & prev = cells[src.prev]; - for (llama_rs_seq_node & seq_node : prev.seq_nodes) { - if ((uint32_t) seq_node.next_cell == i_src) { - seq_node.next_cell = i_dst; - } - } - } - if (dst.prev >= 0) { - llama_rs_cell & prev = cells[dst.prev]; - for (llama_rs_seq_node & seq_node : prev.seq_nodes) { - if ((uint32_t) seq_node.next_cell == i_dst) { - seq_node.next_cell = i_src; - } - } - } - } - - std::swap(src.pos, dst.pos); - std::swap(src.src, dst.src); - std::swap(src.prev, dst.prev); - std::swap(src.tail_rc, dst.tail_rc); - std::swap(src.seq_nodes, dst.seq_nodes); - - return true; - } - return false; - } - - bool insert_seq_tail_to_cell_id(uint32_t i_cell, llama_seq_id id, llama_pos end_pos = -1) { - if (i_cell < size && (size_t) id < seq_tails.size()) { - llama_rs_cell & rs_cell = cells[i_cell]; - auto & seq = seq_tails[id]; - int32_t prev = rs_cell.prev; - if (end_pos >= 0) { - if (end_pos <= rs_cell.pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, end_pos, rs_cell.pos, id); - } - rs_cell.pos = end_pos; - } else { - // if no pos was specified, then the target cell should already have a valid one. - GGML_ASSERT(!rs_cell.is_empty()); - } - if ((uint32_t) seq.tail == i_cell) { - // the cell is already the tail of this seq_id - if (rs_cell.tail_rc != rs_cell.seq_nodes.size()) { - GGML_ASSERT(end_pos >= 0); // make sure this is the first re-added seq_id - // remove non-tail seq_ids (branch off them) - for (size_t i = rs_cell.seq_nodes.size(); i-- > 0;) { - if (!rs_cell.seq_nodes[i].is_tail()) { - remove_seq_node_from_cell(rs_cell, rs_cell.seq_nodes.begin() + i); - } - } - } - return true; - } - if (rs_cell.is_empty()) { - prev = seq.tail; - } - // ensure the new tail won't mess up the tree - GGML_ASSERT(seq.tail == -1 || seq.tail == prev); - if (prev >= 0 && (uint32_t) prev < size) { - // the targeted cell has a previous cell - llama_rs_cell & prev_cell = cells[prev]; - auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); - GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing - GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken - if (rs_cell.is_empty()) { - rs_cell.src = prev_cell.src; - } - prev_node->next_cell = i_cell; - rs_cell.prev = prev; - if (seq.tail == prev) { - // What to do when the tail moves... - // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _) - // O -> oO (n_seqs--, n_shared_tail_cells++) - // O -> O (seq.n_cells++) - // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2) - // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+)) - // _ -> oO (n_shared_tail_cells++) - // _ -> O (seq.n_cells++, n_seqs++) - // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--) - // Oo -> OO+ (n_shared_tail_cell--) - // OOo -> O (seq.n_cells++, n_seqs++) - if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) { - // from fully tail - if (prev_cell.tail_rc > 1) { - // the previous tail becomes shared with a non-tail - n_shared_tail_cells += 1; - } - if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) { - // the new tail cell was previously a fully non-tail cell - n_shared_tail_cells += 1; - n_seqs -= 1; - } - } else { - if (rs_cell.is_empty()) { - // from shared to unique - n_seqs += 1; - } - if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) { - // from last shared to fully tail - n_shared_tail_cells -= 1; - } - } - } - prev_cell.tail_rc -= 1; - } - if (rs_cell.is_empty()) { - // to unique - seq.n_cells += 1; - if (seq.tail < 0) { - // from empty to unique - n_seqs += 1; - // make sure it's cleared - rs_cell.src = -1; - } - used += 1; - } else if (rs_cell.tail_rc == 0) { - // to shared - if (seq.tail < 0) { - // from empty to shared - n_shared_tail_cells += 1; - } - } - // the target cell was not already a tail of this seq_id - rs_cell.insert_node(id); // next_cell == -1 by default - rs_cell.tail_rc += 1; - seq.tail = i_cell; - return true; - } - return false; - } - size_t total_size() const { size_t size = 0; for (struct ggml_tensor * r : r_l) { @@ -4341,7 +3802,6 @@ static bool llama_kv_cache_init( cache.rs.cells.resize(rs_size); cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(rs_size); - cache.rs.freeable.reserve(rs_size); // count used buffer types std::map buft_layer_count; @@ -4429,86 +3889,78 @@ static bool llama_kv_cache_init( static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_ubatch & batch) { - const uint32_t kv_size = cache.kv.size; - const uint32_t rs_size = cache.rs.size; + struct llama_kv_self_cache & kv_self = cache.kv; + struct llama_rs_self_cache & rs_self = cache.rs; + const uint32_t kv_size = kv_self.size; + const uint32_t rs_size = rs_self.size; const uint32_t n_tokens = batch.n_tokens; const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seq_tokens = batch.n_seq_tokens; - // only check first, to allow failing gracefully - if (rs_size > 0) { - // everything should fit if all seq_ids are smaller than the max - for (uint32_t i = 0; i < n_seqs; ++i) { - int32_t n_seq_id = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id = batch.seq_id[i][j]; - - if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { - // too big seq_id - // TODO: would it be possible to resize the rs cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); - return false; - } + // check only at first, to allow failing gracefully + { + if (rs_size > 0) { + if (!batch.equal_seqs) { + LLAMA_LOG_ERROR("%s: can't process batch with unequal new tokens per sequence for recurrent models\n", __func__); + return false; } - } - // TODO: configurable checkpoint interval - cache.rs.freeable_for_batch(batch, 8); - { - uint32_t freeable_rs_cell_count = 0; - for (uint32_t is_freeable : cache.rs.freeable) { - freeable_rs_cell_count += (uint32_t) (is_freeable != 0); - if (freeable_rs_cell_count >= n_seqs) { - // there's enough, no need to count them all - break; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t i = 0; i < n_seqs; ++i) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { + // too big seq_id + // TODO: would it be possible to resize the rs cache size instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, rs_size); + return false; + } } } - if (n_seqs > freeable_rs_cell_count) { - // This should not happen - LLAMA_LOG_ERROR("%s: n_seqs=%d > freeable_rs_cell_count=%d\n", __func__, n_seqs, freeable_rs_cell_count); + } + + if (kv_size > 0) { + // one KV cell per token + if (n_tokens > kv_size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > kv_size=%d\n", __func__, n_tokens, kv_size); return false; } - } - } - if (kv_size > 0) { - // one KV cell per token - if (n_tokens > kv_size) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > kv_size=%d\n", __func__, n_tokens, kv_size); - return false; - } + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (cache.kv.head > cache.kv.used + 2*n_tokens) { - cache.kv.head = 0; - } + uint32_t n_tested = 0; - uint32_t n_tested = 0; + while (true) { + if (kv_self.head + n_tokens > kv_size) { + n_tested += kv_size - kv_self.head; + kv_self.head = 0; + continue; + } - while (true) { - if (cache.kv.head + n_tokens > kv_size) { - n_tested += kv_size - cache.kv.head; - cache.kv.head = 0; - continue; - } + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (kv_self.cells[kv_self.head + i].pos >= 0) { + found = false; + kv_self.head += i + 1; + n_tested += i + 1; + break; + } + } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.kv.cells[cache.kv.head + i].pos >= 0) { - found = false; - cache.kv.head += i + 1; - n_tested += i + 1; + if (found) { break; } - } - - if (found) { - break; - } - if (n_tested >= kv_size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; + if (n_tested >= kv_size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } } } } @@ -4520,154 +3972,142 @@ static bool llama_kv_cache_find_slot( // each cache cell can store the state for a whole sequence. // A slot should be always be contiguous. - uint32_t min_head = 0; - uint32_t min_n = cache.rs.size; - uint32_t min_free = 0; + int32_t min = rs_size - 1; + int32_t max = 0; - // compact the freeable cell list - // e.g. 0,1,0,0,1,1,0,1,0,1 -> 1,4,5,7,9 - // while also finding the smallest cell range for the slot - { - uint32_t next_free = 0; - for (size_t i = 0; i < cache.rs.freeable.size(); ++i) { - if (cache.rs.freeable[i]) { - cache.rs.freeable[next_free] = i; - next_free += 1; - - if (next_free >= n_seqs) { - uint32_t head = cache.rs.freeable[next_free - n_seqs]; - // i is the last seen freeable cell id - uint32_t n = i - head + 1; - // keep the first smallest big enough slot - if (n < min_n) { - min_free = next_free - n_seqs; - min_head = head; - min_n = n; - if (n == n_seqs) { - // it's the smallest it can be - break; - } - } + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = batch.n_seq_id[s]; + for (uint32_t j = 1; j < n_seq_id; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; + + llama_rs_seq_meta & seq = rs_self.seq_tails[seq_id]; + if (seq.tail >= 0) { + llama_rs_cell & cell = rs_self.cells[seq.tail]; + // Clear previous tail cells from seq_ids that become shared. + // Only happens on batches with multiple seq_ids per token, + // but the seq_ids each had their own tail cell. + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + rs_self.used -= 1; } } } } - // sanity check - GGML_ASSERT(min_head + min_n <= cache.rs.size); - - // keep only the necessary range - cache.rs.freeable.resize(min_free + n_seqs); - cache.rs.freeable.erase(cache.rs.freeable.begin(), cache.rs.freeable.begin() + min_free); - GGML_ASSERT(cache.rs.freeable.size() == n_seqs); - GGML_ASSERT(min_n >= n_seqs); - cache.rs.freeable.resize(min_n); + // find next empty cell + uint32_t next_empty_cell = rs_self.head; - // expand the free list - // e.g. 2,4,5,8 -> 1,0,1,1,0,0,1 - for (uint32_t i = n_seqs; i-- > 0;) { - uint32_t dst = cache.rs.freeable[i] - min_head; - if (dst != i) { - cache.rs.freeable[i] = 0; - } - GGML_ASSERT(dst >= i); - cache.rs.freeable[dst] = 1; + for (uint32_t i = 0; i < rs_size; ++i) { + if (next_empty_cell >= rs_size) { next_empty_cell -= rs_size; } + llama_rs_cell & cell = rs_self.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; } - // coalesce the free cells together - // e.g. 1,0,1,1,0,0,1 -> 1,1,1,1,0,0,0 - // or 1,0,1,1,1,1 -> 1,1,1,1,1,0 - { - uint32_t top_free = min_n - 1; - for (uint32_t i = min_n; i-- > 1;) { - uint32_t is_free = cache.rs.freeable[i]; - if (!is_free) { - GGML_ASSERT(top_free > i); - cache.rs.swap_cells(min_head + i, min_head + top_free); - std::swap(cache.rs.freeable[i], cache.rs.freeable[top_free]); - // the previous one has to be free, - // otherwise it would already have been swapped. - top_free -= 1; + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + llama_rs_seq_meta & seq_meta = rs_self.seq_tails[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_rs_cell & cell = rs_self.cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_rs_cell & empty_cell = rs_self.cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_rs_cell & orig_cell = rs_self.cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < rs_size; ++i) { + if (next_empty_cell >= rs_size) { next_empty_cell -= rs_size; } + llama_rs_cell & cell = rs_self.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } } - // stop early if all freeable cells have already been put at the beginning - if (top_free < n_seqs) { break; } } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } } - // order the re-used cells identically to their batch order - // (and clear the non-reused cells) - { - for (uint32_t i = 0; i < n_seqs; ++i) { - // ignore the already-swapped cells - if (cache.rs.freeable[i]) { - llama_rs_cell & cell = cache.rs.cells[min_head + i]; - if (!cell.is_empty()) { - if (cell.tail_rc == 0) { - cache.rs.clear_cell(cell); - } else { - // Find the seq_id of the first tail of this cell - llama_seq_id seq_id = -1; - for (llama_rs_seq_node & seq_node : cell.seq_nodes) { - if (seq_node.is_tail()) { - seq_id = seq_node.seq_id; - break; - } - } - GGML_ASSERT(seq_id != -1); - - // Which seq_id of the batch is it? - int32_t nth_seq_id = -1; - for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { - if (seq_id == batch.seq_id[s][0]) { - nth_seq_id = s; - break; - } - } - GGML_ASSERT(nth_seq_id != -1); + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = rs_self.seq_tails[batch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_rs_cell & dst_cell = rs_self.cells[dst_id]; + llama_rs_cell & src_cell = rs_self.cells[src_id]; - cache.rs.swap_cells(min_head + i, min_head + nth_seq_id); - cache.rs.freeable[i] = 0; - std::swap(cache.rs.freeable[i], cache.rs.freeable[nth_seq_id]); - i -= 1; // check this cell again, now that it was swapped - } - } + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + rs_self.seq_tails[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + rs_self.seq_tails[seq_id].tail = dst_id; } } } - // reserve - { - for (uint32_t i = 0; i < n_seqs; ++i) { - uint32_t i_cell = min_head + i; - int32_t n_seq_id = batch.n_seq_id[i]; - llama_pos end_pos = batch.pos[(i * n_seq_tokens) + n_seq_tokens - 1]; - // set the pos with the first seq_id - cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][0], end_pos); - // insert the rest of the seq_ids by re-using the cell's pos - for (int j = 1; j < n_seq_id; ++j) { - cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][j]); - } + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_rs_cell & cell = rs_self.cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + rs_self.seq_tails[seq_id].tail = cell_id; } } // allow getting the range of used cells, from head to head + n - cache.rs.head = min_head; - cache.rs.n = min_n; + rs_self.head = min; + rs_self.n = max - min + 1; } if (kv_size > 0) { for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; - cache.kv.cells[cache.kv.head + k].pos = batch.pos[k]; + kv_self.cells[kv_self.head + k].pos = batch.pos[k]; for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { - cache.kv.cells[cache.kv.head + k].seq_id.insert(batch.seq_id[s][j]); + kv_self.cells[kv_self.head + k].seq_id.insert(batch.seq_id[s][j]); } } } - cache.kv.used += n_tokens; + kv_self.used += n_tokens; } return true; @@ -4686,20 +4126,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_self_cache & cache return 0; } -// find how many recurrent state cells are currently in use -static uint32_t llama_rs_cache_cell_max(const struct llama_rs_self_cache & cache) { - for (uint32_t i = cache.size; i > 0; --i) { - const llama_rs_cell & cell = cache.cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - -static void llama_past_clear(struct llama_kv_cache & cache) { +static void llama_kv_cache_clear(struct llama_kv_cache & cache) { if (cache.kv.size > 0) { for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; @@ -4717,14 +4144,10 @@ static void llama_past_clear(struct llama_kv_cache & cache) { llama_rs_cell & rs_cell = cache.rs.cells[i]; rs_cell.pos = -1; rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); - } - cache.rs.head = 0; - cache.rs.used = 0; - cache.rs.n_seqs = 0; - cache.rs.n_shared_tail_cells = 0; + rs_cell.seq_id.clear(); + } + cache.rs.head = 0; + cache.rs.used = 0; cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(cache.rs.size); } @@ -4733,63 +4156,65 @@ static void llama_past_clear(struct llama_kv_cache & cache) { } } -static llama_pos llama_past_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { +static bool llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - + // models like Mamba or RWKV can't have a state partially erased + // TODO: refactor the recurrent state cache to allow partial rollbacks if (cache.rs.size > 0) { - if (seq_id >= (int64_t) cache.rs.size) { + uint32_t new_head = cache.rs.size; + + if (seq_id >= (int64_t) cache.rs.seq_tails.size()) { // could be fatal - return n_past; + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cache.rs.seq_tails[seq_id].tail; + if (tail_id >= 0) { + const llama_rs_cell & cell = cache.rs.cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } } - uint32_t new_head = cache.rs.size; - // adjust p0 and p1 according to the states found - llama_pos new_p0 = 0; - llama_pos new_p1 = std::numeric_limits::max(); - // partial seq_id removal has to happen from the tail - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - int32_t cell_id = seq.tail; + // Assume there's only one cell per seq_id + for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { + if (seq_id < 0 || i == (uint32_t) seq_id) { + int32_t tail_id = cache.rs.seq_tails[i].tail; + if (tail_id >= 0) { + llama_rs_cell rs_cell = cache.rs.cells[tail_id]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.seq_id.erase(i); + if (rs_cell.is_empty()) { + // keep count of the number of used cells + if (cache.rs.cells[i].pos >= 0) { cache.rs.used--; } - while (cell_id >= 0) { - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - // copy before the cell is potentially changed - int32_t prev_id = rs_cell.prev; - if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) { - // non-tail removal for shared cells can only be done when clearing a cell - // (i.e. when the next cell's link to the previous cell can be safely changed) - p1 = rs_cell.pos + 1; - } - if (rs_cell.pos >= p0 && rs_cell.pos < p1) { - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); - // if the node isn't found, the sequence tree is malformed - GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); - cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); - // get the smallest removed cell id - if (new_head > (uint32_t) cell_id) { new_head = cell_id; } - } else { - // one more than the biggest non-removed cell of this sequence - if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; } - - if (rs_cell.pos < p0) { - // new_p0 should be right after the max pos in the states before p0 - if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; } - } else { // (rs_cell.pos >= p1) - // new_p1 should be the min pos in the states after p1 - if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; } + cache.rs.cells[i].pos = -1; + cache.rs.cells[i].src = -1; + if (new_head == cache.rs.size) { new_head = i; } + } + } + cache.rs.seq_tails[i].tail = -1; } } - cell_id = prev_id; } - p0 = new_p0; - p1 = new_p1; // If we freed up a slot, set head to it so searching can start there. if (new_head != cache.rs.size && new_head < cache.rs.head) { @@ -4801,24 +4226,20 @@ static llama_pos llama_past_seq_rm( uint32_t new_head = cache.kv.size; for (uint32_t i = 0; i < cache.kv.size; ++i) { - llama_kv_cell & kv_cell = cache.kv.cells[i]; - - if (seq_id < 0 || kv_cell.has_seq_id(seq_id)) { - if (kv_cell.pos >= p0 && kv_cell.pos < p1) { - if (seq_id < 0) { - kv_cell.seq_id.clear(); - } else { // (kv_cell.has_seq_id(seq_id)) - kv_cell.seq_id.erase(seq_id); - } - if (kv_cell.is_empty()) { - // keep count of the number of used cells - if (kv_cell.pos >= 0) { cache.kv.used--; } + if (cache.kv.cells[i].pos >= p0 && cache.kv.cells[i].pos < p1) { + if (seq_id < 0) { + cache.kv.cells[i].seq_id.clear(); + } else if (cache.kv.cells[i].has_seq_id(seq_id)) { + cache.kv.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.kv.cells[i].is_empty()) { + // keep count of the number of used cells + if (cache.kv.cells[i].pos >= 0) { cache.kv.used--; } - kv_cell.pos = -1; - if (new_head == cache.kv.size) { new_head = i; } - } - } else if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; + cache.kv.cells[i].pos = -1; + if (new_head == cache.kv.size) { new_head = i; } } } } @@ -4829,59 +4250,29 @@ static llama_pos llama_past_seq_rm( } } - return n_past; + return true; } -static llama_pos llama_past_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { +static void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - // TODO: in practice this seems to be only used on whole sequences; - // should partial sequence copy support be removed? - // TODO: What if the destination sequence is not empty? - - llama_pos n_past = 0; - if (cache.rs.size > 0) { - // have to start from the beginning for recurrent models - p0 = 0; - if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { - int32_t src_head = -1; - int32_t head_pos = p1; - int32_t src_next = -1; - // find the start of the sequence - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (!rs_cell.is_empty() && rs_cell.prev < 0) { - auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); - if (seq_node != rs_cell.seq_nodes.end()) { - src_head = i; - head_pos = rs_cell.pos; - src_next = seq_node->next_cell; - break; - } - } - } - while (src_head >= 0 && head_pos < p1) { - cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst); - src_head = src_next; - if (head_pos >= n_past) { n_past = head_pos + 1; } - if (src_next >= 0) { - llama_rs_cell & rs_cell = cache.rs.cells[src_next]; - auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); - head_pos = rs_cell.pos; - // it should always be found if the seq tree is valid - GGML_ASSERT(seq_node != rs_cell.seq_nodes.end()); - src_next = seq_node->next_cell; - } + llama_rs_seq_meta & seq_meta = cache.rs.seq_tails[seq_id_src]; + if (seq_meta.tail >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[seq_meta.tail]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.seq_id.insert(seq_id_dst); + // TODO: What if the destination sequence is not empty? + GGML_ASSERT(cache.rs.seq_tails[seq_id_dst].tail < 0); + cache.rs.seq_tails[seq_id_dst].tail = seq_meta.tail; } } - p1 = n_past; } if (cache.kv.size > 0) { @@ -4889,32 +4280,30 @@ static llama_pos llama_past_seq_cp( llama_kv_cell & kv_cell = cache.kv.cells[i]; if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { kv_cell.seq_id.insert(seq_id_dst); - if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; } } } } - - return n_past; } -static void llama_past_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; - // partial seq_id removal has to happen from the tail(s) + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (!rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos >= 0) { cache.rs.used--; } + rs_cell.pos = -1; + rs_cell.seq_id.clear(); + if (new_head == cache.rs.size) { new_head = i; } + } else { + rs_cell.seq_id.clear(); + rs_cell.seq_id.insert(seq_id); + } + } for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { - if (i == (uint32_t) seq_id) { continue; } - llama_rs_seq_meta & seq = cache.rs.seq_tails[i]; - int32_t cell_id = seq.tail; - while (cell_id >= 0) { - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i); - GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); - cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); - cell_id = rs_cell.prev; - if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) { - new_head = cell_id; - } + if (i != (uint32_t) seq_id) { + cache.rs.seq_tails[i].tail = -1; } } @@ -4947,41 +4336,29 @@ static void llama_past_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_ } } -static void llama_past_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { +static void llama_kv_cache_seq_add( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } if (cache.rs.size > 0) { - // for Mamba-like or RKWV models, only the pos needs to be shifted - auto & seq = cache.rs.seq_tails[seq_id]; - // follow the sequence from its tail - int32_t cell_id = seq.tail; - uint32_t new_head = cache.rs.size; - while (cell_id >= 0) { - GGML_ASSERT((uint32_t) cell_id < cache.rs.size); - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - cell_id = rs_cell.prev; - if (rs_cell.pos >= p0 && rs_cell.pos < p1) { - rs_cell.pos += delta; - if (rs_cell.pos < 0) { - // NOTE: this affects the other sequences which share the cell - cache.rs.clear_cell(rs_cell); - if (new_head > (uint32_t) cell_id) { - new_head = cell_id; - } + // for recurrent states, the pos shift is faked + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos += delta; + // TODO: handle deletion + // (but this should not happen anyway when only the last states are stored) + GGML_ASSERT(rs_cell.pos >= 0); } } } - - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.rs.head = new_head != cache.rs.size ? new_head : 0; } if (cache.kv.size > 0) { @@ -5015,26 +4392,24 @@ static void llama_past_seq_add( } } -static void llama_past_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { +static void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } if (cache.rs.size > 0) { - // for Mamba-like or RWKV models, only the pos needs to be changed - auto & seq = cache.rs.seq_tails[seq_id]; - int32_t cell_id = seq.tail; - while (cell_id >= 0) { - GGML_ASSERT((uint32_t) cell_id < cache.rs.size); - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - if (rs_cell.pos >= p0 && rs_cell.pos < p1) { - rs_cell.pos /= d; + // for recurrent states, the pos shift is faked + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos /= d; + } } - cell_id = rs_cell.prev; } } @@ -5056,7 +4431,7 @@ static void llama_past_seq_div( } } -static llama_pos llama_past_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { llama_pos result = -1; if (cache.rs.size > 0) { @@ -21341,86 +20716,49 @@ int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { return ctx->cache.rs.used; } -void llama_past_clear(struct llama_context * ctx) { - llama_past_clear(ctx->cache); -} - -// deprecated void llama_kv_cache_clear(struct llama_context * ctx) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx->cache); } -llama_pos llama_past_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - return llama_past_seq_rm(ctx->cache, seq_id, p0, p1); -} - -// deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_pos n_past = llama_past_seq_rm(ctx, seq_id, p0, p1); - return n_past >= p0; + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_kv_cache_seq_rm(ctx->cache, seq_id, p0, p1); } - -llama_pos llama_past_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { uint32_t n_seq_max = llama_n_seq_max(ctx); if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { - return 0; + // TODO: error? + return; } if (seq_id_src == seq_id_dst) { - return llama_past_seq_pos_max(ctx->cache, seq_id_dst) + 1; + return; } - return llama_past_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); + return llama_kv_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } -// deprecated -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - llama_past_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); -} - -void llama_past_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_past_seq_keep(ctx->cache, seq_id); -} - -// deprecated void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_past_seq_keep(ctx, seq_id); -} - -void llama_past_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - if (delta == 0) { return; } - - llama_past_seq_add(ctx->cache, seq_id, p0, p1, delta); + llama_kv_cache_seq_keep(ctx->cache, seq_id); } -// deprecated void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - llama_past_seq_add(ctx, seq_id, p0, p1, delta); -} - -void llama_past_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - if (d == 1) { return; } + if (delta == 0) { return; } - llama_past_seq_div(ctx->cache, seq_id, p0, p1, d); + llama_kv_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } -// deprecated void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - llama_past_seq_div(ctx, seq_id, p0, p1, d); -} + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (d == 1) { return; } -llama_pos llama_past_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } - return llama_past_seq_pos_max(ctx->cache, seq_id); + llama_kv_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } -// deprecated llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - llama_pos max_pos = llama_past_seq_pos_max(ctx, seq_id); - return max_pos < 0 ? 0 : max_pos; + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } + return llama_kv_cache_seq_pos_max(ctx->cache, seq_id); } void llama_kv_cache_defrag(struct llama_context * ctx) { @@ -21562,14 +20900,14 @@ struct llama_data_write { for (uint32_t i = range.first; i < range.second; ++i) { const auto & cell = rs_self.cells[i]; const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; write(&pos, sizeof(pos)); write(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id) { - for (auto seq_node : cell.seq_nodes) { - write(&seq_node.seq_id, sizeof(seq_node.seq_id)); + for (auto seq_id : cell.seq_id) { + write(&seq_id, sizeof(seq_id)); } } } @@ -21968,8 +21306,7 @@ struct llama_data_read { return false; } - cell.insert_node(seq_id); - + cell.seq_id.insert(seq_id); } } @@ -22233,10 +21570,10 @@ struct llama_data_read { bool res = true; if (seq_id == -1) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count); } else { - llama_past_seq_rm(ctx, seq_id, -1, -1); + llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); // Only a single recurrent cell at most, // because otherwise the cells can be shuffled when a slot is allocated if (rs_cell_count > 1) { @@ -22250,9 +21587,9 @@ struct llama_data_read { if (!res) { if (seq_id == -1) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); } else { - llama_past_seq_rm(ctx, seq_id, -1, -1); + llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); } throw std::runtime_error("failed to restore kv cache"); } From e3fe61203c2230b8af32ef1f8153114c175e696b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 25 Nov 2024 11:31:46 -0500 Subject: [PATCH 049/117] llama : partially apply clang-format style --- src/llama.cpp | 55 +++++++++++++++++++++++---------------------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 510b7fe8935f7..b076681134bce 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2583,7 +2583,7 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings + uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings // TODO: support using an SSM in place of the MLP of a Transformer if (n_head_kv(il) != 0) { return 0; } // corresponds to Mamba's conv_states size or RWKV's token_shift states size @@ -2597,7 +2597,7 @@ struct llama_hparams { } } - uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings + uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings // TODO: support using an SSM in place of the MLP of a Transformer if (n_head_kv(il) != 0) { return 0; } @@ -2875,17 +2875,13 @@ struct llama_kv_self_cache { struct llama_rs_cell { llama_pos pos = -1; - int32_t src = -1; // copy source id (cleared next when -1) + int32_t src = -1; // copy source id (cleared next when -1) std::set seq_id; - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } + bool has_seq_id(const llama_seq_id & id) const { return seq_id.find(id) != seq_id.end(); } - bool is_empty() const { - return seq_id.empty(); - } + bool is_empty() const { return seq_id.empty(); } }; struct llama_rs_seq_meta { @@ -2895,24 +2891,23 @@ struct llama_rs_seq_meta { // ring-buffered tree of cached recurrent state data struct llama_rs_self_cache { - - uint32_t head = 0; // first state used for the last slot + uint32_t head = 0; // first state used for the last slot uint32_t size = 0; uint32_t used = 0; // computed when finding a slot - uint32_t n = 0; // range of states used for the last slot + uint32_t n = 0; // range of states used for the last slot // with state models, a cell can hold the state for more than one past token std::vector cells; // find tail cells faster - std::vector seq_tails; // map seq_ids to cell ids + std::vector seq_tails; // map seq_ids to cell ids // per layer // NOTE: the naming of r and s is arbitrary - std::vector r_l; // rolling/shift states - std::vector s_l; // ssm (recurrent) states + std::vector r_l; // rolling/shift states + std::vector s_l; // ssm (recurrent) states // Inefficient, but thorough verification and rebuilding of the rs cache // from only the cells list with `pos` and seq_ids. @@ -2920,21 +2915,21 @@ struct llama_rs_self_cache { bool rebuild(bool debug) { bool was_valid = true; // skip for non-recurrent models - if (size == 0) { return true; } + if (size == 0) { + return true; + } // the source of truth is the cells list // buffer sizes if (size != cells.size()) { if (debug) { - LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", - __func__, cells.size(), size); + LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", __func__, cells.size(), size); } cells.resize(size); was_valid = false; } if (size != seq_tails.size()) { if (debug) { - LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", - __func__, seq_tails.size(), size); + LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", __func__, seq_tails.size(), size); } seq_tails.resize(size); was_valid = false; @@ -2994,7 +2989,7 @@ struct llama_rs_self_cache { for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { llama_rs_cell & cell = cells[cell_id]; if (cell.has_seq_id(seq_id)) { - seq_cells.push_back({cell.pos, cell_id}); + seq_cells.push_back({ cell.pos, cell_id }); } } // sort by pos and then by cell_id @@ -3718,16 +3713,16 @@ static bool llama_kv_cache_init( } if (has_kv) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i) * kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i) * kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.kv.k_l.push_back(k); cache.kv.v_l.push_back(v); } if (has_rs) { - ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size); - ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size); + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i) * rs_size); + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i) * rs_size); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); cache.rs.r_l.push_back(r); @@ -4370,8 +4365,8 @@ struct llama_kv_slot_restorer { bool do_restore = false; explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { - old_state.head = cache.kv.head; - old_state.n = cache.kv.n; + old_state.head = cache.kv.head; + old_state.n = cache.kv.n; } // saves a slot information for future restoration @@ -4388,10 +4383,10 @@ struct llama_kv_slot_restorer { // and rollback changes from all llama_kv_cache_find_slot calls void restore(struct llama_kv_cache & cache) { if (do_restore) { - cache.kv.head = old_state.head; - cache.kv.n = old_state.n; + cache.kv.head = old_state.head; + cache.kv.n = old_state.n; - if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased + if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased llama_kv_cache_seq_rm(cache, -1, -1, -1); } else { for (auto & slot : slot_boundaries) { From cf4f0a4123d94b3f09e9f9343b76f48bd6043756 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 18:55:34 -0400 Subject: [PATCH 050/117] metal : fix confusion between ; and , --- ggml/src/ggml-metal/ggml-metal.m | 42 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 370d0ad7744fa..7d6377c790903 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2652,27 +2652,27 @@ static bool ggml_metal_encode_node( } ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state; - /*.d_inner =*/ d_inner; - /*.n_head =*/ n_head; - /*.n_group =*/ n_group; - /*.n_seq_tokens =*/ n_seq_tokens; - /*.n_seqs =*/ n_seqs; - /*.nb01 =*/ nb01; - /*.nb02 =*/ nb02; - /*.nb03 =*/ nb03; - /*.nb11 =*/ nb11; - /*.nb12 =*/ nb12; - /*.nb13 =*/ nb13; - /*.nb21 =*/ nb21; - /*.nb22 =*/ nb22; - /*.nb31 =*/ nb31; - /*.nb41 =*/ nb41; - /*.nb42 =*/ nb42; - /*.nb43 =*/ nb43; - /*.nb51 =*/ nb51; - /*.nb52 =*/ nb52; - /*.nb53 =*/ nb53; + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.nb53 =*/ nb53, }; [encoder setComputePipelineState:pipeline]; From 6def5cd729fdde64b2addeaa5cce016c72485e06 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 19:10:20 -0400 Subject: [PATCH 051/117] metal : add missing args for nb references in ssm_scan_f32_group --- ggml/src/ggml-metal/ggml-metal.metal | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 71ab693721298..4b5e4f8457210 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1350,16 +1350,16 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; From 791998b42d6cd6edb31e4d5824e29c100cecd40b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 21:27:12 -0400 Subject: [PATCH 052/117] metal : single-user mamba2 inference works --- ggml/src/ggml-metal/ggml-metal.metal | 14 +++++++------- src/llama-model.cpp | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4b5e4f8457210..4e50efdee41ca 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1284,7 +1284,7 @@ kernel void kernel_ssm_scan_f32( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * nt * args.n_seqs * sizeof(float); + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); device const int32_t * ids = (device const int32_t *) src6; @@ -1292,12 +1292,12 @@ kernel void kernel_ssm_scan_f32( device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; @@ -1354,12 +1354,12 @@ kernel void kernel_ssm_scan_f32_group( device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f295c684099e7..cffdbc6845363 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9009,7 +9009,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); // {n_head, n_seq_tokens, n_seqs} - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); // TODO: use semistructured matrices to implement state-space duality From 94c3d5304352eef27c33b08a858facdffbb28438 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 22:18:57 -0400 Subject: [PATCH 053/117] kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. --- src/llama-graph.cpp | 22 ++++++++-------------- src/llama-kv-cache.cpp | 7 +++++-- src/llama-kv-cache.h | 1 + 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0f77f98b24f64..8d2fceb17def5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { for (uint32_t i = 0; i < n_kv; ++i) { const uint32_t cell_id = i + kv_self->head; - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; + const llama_kv_cell & kv_cell = kv_self->cells[cell_id]; + + int32_t src = kv_cell.src0; // prevent out-of-bound sources - if (kv_cell.src < 0) { + if (src < 0) { GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source - kv_cell.src = kv_self->rs_z; + src = kv_self->rs_z; } - if ((uint32_t) kv_cell.src >= kv_self->size) { + if ((uint32_t) src >= kv_self->size) { // ignore out-of-bound sources - kv_cell.src = cell_id; + src = cell_id; } - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } + data[i] = src; } } } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 108c07731b1ab..743b30badcf67 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot( // Find first to-be-cleared cell rs_z = -1; for (int i = min; i <= max; ++i) { - if (cells[i].src == -1) { + if (rs_z < 0 && cells[i].src == -1) { rs_z = i; - break; } + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + cells[i].src = i; } // allow getting the range of used cells, from head to head + n diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 7939bc6b8dd2d..6b115e8f7d134 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -47,6 +47,7 @@ struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; int32_t src = -1; // used by recurrent state models to copy states + int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once) int32_t tail = -1; std::set seq_id; From d55b0d06210cdc10b6cf872b9009d82bb6372b01 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 2 May 2025 18:24:55 -0400 Subject: [PATCH 054/117] convert : avoid AutoConfig for Mamba and Mamba2 hparams --- convert_hf_to_gguf.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 532cc879de324..2debb6e63fef9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4127,6 +4127,14 @@ def set_gguf_parameters(self): class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 8 @@ -4205,6 +4213,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class Mamba2Model(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA2 + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 16 @@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) - arch = hparams["architectures"][0] + arch = None + if (arches := hparams.get("architectures")) is not None and len(arches) > 0: + arch = arches[0] + elif "ssm_cfg" in hparams: + # For non-hf Mamba and Mamba2 models + arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" + # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] elif model_type == ModelType.VISION and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] + if arch is None: + raise ValueError("Failed to detect model architecture") return arch From e94f3932f2dbcb2120580a9f42878e058a18cf5b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 2 May 2025 19:29:23 -0400 Subject: [PATCH 055/117] kv-cache : allow context shift for recurrent models --- src/llama-kv-cache.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 87ce7ce03d503..99dd20b68fd73 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1938,7 +1938,8 @@ llama_pos llama_kv_cache_recurrent::get_pos_max() const { } bool llama_kv_cache_recurrent::get_can_shift() const { - return false; + // shifting is trivial, the recurrent states don't care about the absolute position + return true; } uint32_t llama_kv_cache_recurrent::cell_max() const { From 2fa5f2ceb8b49bbd2835878ad5429ea74383566c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 10 Jun 2025 20:00:41 -0400 Subject: [PATCH 056/117] graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. --- src/llama-graph.cpp | 19 +++++++------------ src/llama-graph.h | 3 ++- src/llama-model.cpp | 44 +++++++++++++++++++++++++++++--------------- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e74c9ff53b05a..1abe3b8febb4a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1429,7 +1429,8 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { + const std::function & get_state_rows) const { + const auto * kv_state = static_cast(mstate); const auto n_kv = kv_state->get_n_kv(); @@ -1445,17 +1446,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_tensor * output_states; - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_kv) ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); diff --git a/src/llama-graph.h b/src/llama-graph.h index 88fb77f1ddc9a..1fcf1cde45a41 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -599,7 +599,8 @@ struct llm_graph_context { ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, - bool avoid_copies = false) const; + const std::function + & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2c0f7d4084344..2999483ad71ed 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9024,11 +9024,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il); - // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); - ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size()); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9094,11 +9091,21 @@ struct llm_build_mamba : public llm_graph_context { cur = x; x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9151,11 +9158,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il); - // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); - ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size()); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9211,10 +9215,20 @@ struct llm_build_mamba : public llm_graph_context { // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, From 757aa6239de5cc41afdea32561ab227b7b447424 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 11 Jun 2025 12:33:05 -0400 Subject: [PATCH 057/117] ggml : fix mamba2 ssm scan when compiled with SVE --- ggml/src/ggml-cpu/ops.cpp | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2a6be25852e4e..11d4819c868f3 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7664,6 +7664,37 @@ static void ggml_compute_forward_ssm_scan_f32( const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; #if defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int ggml_f32_epr = svcntw(); + const int ggml_f32_step = 1 * ggml_f32_epr; + + const int np = (nc & ~(ggml_f32_step - 1)); + + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + for (int i = 0; i < np; i += ggml_f32_step) { + // TODO: maybe unroll more? + for (int j = 0; j < 1; j++) { + GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + + t0 = GGML_F32_VEC_MUL(t0, adA); + t1 = GGML_F32_VEC_MUL(t1, axdt); + + t0 = GGML_F32_VEC_ADD(t0, t1); + + sum = GGML_F32_VEC_FMA(sum, t0, t2); + + GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0); + } + } + + sumf = GGML_F32xt_REDUCE_ONE(sum); + #else const int np = (nc & ~(GGML_F32_STEP - 1)); GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; @@ -7694,6 +7725,7 @@ static void ggml_compute_forward_ssm_scan_f32( // reduce sum0..sum3 to sum0 GGML_F32_VEC_REDUCE(sumf, sum); + #endif #else const int np = 0; #endif @@ -7722,7 +7754,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i1 = 0; i1 < nr; ++i1) { const int ii = i1 + h*nr; const float x_dt = x[ii] * dt_soft_plus; -#ifdef __ARM_FEATURE_SVE +#if defined(__ARM_FEATURE_SVE) svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); svfloat32_t r1_vector = GGML_F32_VEC_ZERO; From 0b6f6becb4e916a24fcaf2966647381a21d1f084 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 11 Jun 2025 15:29:58 -0400 Subject: [PATCH 058/117] ggml-cpu : reorder SVE FMA for consistency with other SIMD arches --- ggml/src/ggml-cpu/ops.cpp | 2 +- ggml/src/ggml-cpu/simd-mappings.h | 2 +- ggml/src/ggml-cpu/vec.cpp | 18 +++++++++--------- ggml/src/ggml-cpu/vec.h | 18 +++++++++--------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 11d4819c868f3..711d8abcc5fdd 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7771,7 +7771,7 @@ static void ggml_compute_forward_ssm_scan_f32( t1 = exp_ps_sve(svptrue_b32(), t1); svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); - vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); + vs0 = GGML_F32_VEC_FMA(t2, vs0, t1); r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); GGML_F32_VEC_STORE(&s[ii*nc + k], vs0); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 2e3669c0186c9..91bb867bf57b8 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -32,7 +32,7 @@ #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) -#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index f7614568ea388..7e61b5bf965a3 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = 0; i < np; i += ggml_f32_step) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); + sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); + sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8); } // leftovers // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop @@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); } // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only if (np2 < n) { diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 09dbade2179fb..a144259800477 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx); GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx); GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx); GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx); GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx); GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx); GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx); GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } @@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); } From f8c7caeeb7dd610d9689a8965d1205b23a58d9ae Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 15 May 2025 18:09:53 -0400 Subject: [PATCH 059/117] cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 --- ggml/src/ggml-cuda/ggml-cuda.cu | 13 +- ggml/src/ggml-cuda/ssm-scan.cu | 232 +++++++++++++++++++++++++------- src/llama-model.cpp | 2 +- tests/test-backend-ops.cpp | 2 +- 4 files changed, 195 insertions(+), 54 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 898b24341471d..dc67360464450 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3191,7 +3191,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_LOG: - case GGML_OP_SSM_SCAN: + return true; + case GGML_OP_SSM_SCAN: { + if (op->src[3]->ne[0] == 1) { + // Mamba2 + // (kernel only supports d_state == 128 && d_head % 16 == 0) + return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0; + } else { + // Mamba + // (kernel only supports d_state == 16, n_group == 1, d_head == 1) + return op->src[0]->ne[0] == 16 && op->src[4]->ne[1] == 1 && op->src[0]->ne[1] == 1; + } + } case GGML_OP_SSM_CONV: return true; case GGML_OP_CONT: diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 2d34b836054f8..61f35f859b7be 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -4,16 +4,15 @@ template __global__ void __launch_bounds__(splitD, 2) ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, - const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, - const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * __restrict__ dst, const int64_t L) { - GGML_UNUSED(src1_nb0); - GGML_UNUSED(src2_nb0); + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t d_inner, const int64_t L) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int bidx = blockIdx.x; // split along B - const int bidy = blockIdx.y; // split along D + const int bidx = blockIdx.x; // split along B (sequences) + const int bidy = blockIdx.y; // split along D (d_inner) const int tid = threadIdx.x; const int wid = tid / 32; const int wtid = tid % 32; @@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2) float * smem_A = smem; float * smem_s0 = smem_A + splitD * stride_sA; - const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); - const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); + const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2); + const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float)); const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); - const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); - float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); - float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); + const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3)); + const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3)); + float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float)); + float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2); - const int stride_s0 = src0_nb1 / sizeof(float); - const int stride_x = src1_nb1 / sizeof(float); + const int stride_s0 = src0_nb2 / sizeof(float); + const int stride_x = src1_nb2 / sizeof(float); const int stride_dt = src2_nb1 / sizeof(float); const int stride_A = src3_nb1 / sizeof(float); - const int stride_B = src4_nb1 / sizeof(float); - const int stride_C = src5_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); const int stride_s = stride_s0; - const int stride_y = stride_x; + const int stride_y = d_inner; // can N not be 16? for example 32? if (N == 16) { @@ -84,24 +83,157 @@ __global__ void __launch_bounds__(splitD, 2) } } +// assumes as many threads as d_state +template +__global__ void __launch_bounds__(d_state, 1) + ssm_scan_f32_group( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + + const int head_idx = (blockIdx.x * splitH) / d_head; + const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; + + const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float); + + const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); + const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; + float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + + // strides across n_seq_tokens + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = n_head * d_head; + + float state[splitH]; + // for the parallel accumulation + __shared__ float stateC[splitH * d_state]; + +#pragma unroll + for (int j = 0; j < splitH; j++) { + state[j] = s0_block[j * d_state + threadIdx.x]; + } + + for (int64_t i = 0; i < n_tok; i++) { + // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements + // TODO: only calculate B and C once per head group + // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. + float dt_soft_plus = dt_block[i * stride_dt]; + if (dt_soft_plus <= 20.0f) { + dt_soft_plus = log1pf(expf(dt_soft_plus)); + } + const float dA = expf(dt_soft_plus * A_block[0]); + const float B = B_block[i * stride_B + threadIdx.x]; + const float C = C_block[i * stride_C + threadIdx.x]; + + // across d_head +#pragma unroll + for (int j = 0; j < splitH; j++) { + const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; + + state[j] = (state[j] * dA) + (B * x_dt); + + stateC[j * d_state + threadIdx.x] = state[j] * C; + } + + __syncthreads(); + + // parallel accumulation for stateC + // TODO: simplify + { + static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); + static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); + + // reduce until w matches the warp size + // TODO: does this work even when the physical warp size is 64? +#pragma unroll + for (int w = d_state; w > WARP_SIZE; w >>= 1) { + // (assuming there are d_state threads) +#pragma unroll + for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { + // TODO: check for bank conflicts + const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); + stateC[k] += stateC[k + (w >> 1)]; + + } + __syncthreads(); + } + + static_assert(splitH >= d_state / WARP_SIZE); + +#pragma unroll + for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { + float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; + y = warp_reduce_sum(y); + + // store the above accumulations + if (threadIdx.x % WARP_SIZE == 0) { + const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); + y_block[i * stride_y + k] = y; + } + } + } + } + + // write back the state +#pragma unroll + for (int j = 0; j < splitH; j++) { + s_block[j * d_state + threadIdx.x] = state[j]; + } +} + static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, - const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, - const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, - const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, - const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, + const float * src4, const float * src5, const int32_t * src6, float * dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, + const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, + const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, + const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { const int threads = 128; - // todo: consider D cannot be divided,does this situation exist? - GGML_ASSERT(D % threads == 0); - const dim3 blocks(B, (D + threads - 1) / threads, 1); - const int smem_size = (threads * (N + 1) * 2) * sizeof(float); - if (N == 16) { - ssm_scan_f32<128, 16><<>>( - src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, - src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); + // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! + if (src3_nb1 == sizeof(float)) { + // Mamba2 + if (d_state == 128) { + GGML_ASSERT(d_state % threads == 0); + // NOTE: can be any power of two between 4 and 64 + const int splitH = 16; + GGML_ASSERT(head_dim % splitH == 0); + const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); + ssm_scan_f32_group<16, 128><<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); + } else { + GGML_ABORT("doesn't support d_state!=128."); + } } else { - GGML_ABORT("doesn't support N!=16."); + // Mamba1 + // todo: consider n_head cannot be divided, does this situation exist? + GGML_ASSERT(n_head % threads == 0); + GGML_ASSERT(head_dim == 1); + GGML_ASSERT(n_group == 1); + const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); + const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); + if (d_state == 16) { + ssm_scan_f32<128, 16><<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + } else { + GGML_ABORT("doesn't support d_state!=16."); + } } } @@ -112,30 +244,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - - // const int64_t d_state = src0->ne[0]; - // const int64_t d_inner = src0->ne[1]; - // const int64_t l = src1->ne[1]; - // const int64_t b = src0->ne[2]; + const struct ggml_tensor * src6 = dst->src[6]; // ids const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nr = src0->ne[1]; // head_dim or 1 + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; // n_group + const int64_t n_t = src1->ne[2]; // number of tokens per sequence + const int64_t n_s = src1->ne[3]; // number of sequences in the batch + + const int64_t s_off = ggml_nelements(src1) * sizeof(float); - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; @@ -143,13 +270,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float * src3_d = (const float *) src3->data; const float * src4_d = (const float *) src4->data; const float * src5_d = (const float *) src5->data; + const int32_t * src6_d = (const int32_t *) src6->data; float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], - src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], - src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); + ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d, + src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], + src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], + s_off, nc, nr, nh, ng, n_t, n_s, stream); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 74ce7ac78cff4..3e5093185ba25 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -215,7 +215,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; const int64_t n_head = w->ne[1]; const int64_t head_dim = hparams.ssm_d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group; + const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; const int64_t n_seq_tokens = 512; const int64_t n_seqs = 3; ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 653ad0d9bd91e..3957325ec2485 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4225,7 +4225,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1)); From 28881af112e079141e6b92ed969e1bc6e301f48f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 10:48:54 -0600 Subject: [PATCH 060/117] feat: Add conversion for Bamba models This is borrowed and adapted from the original implementation https://github.com/ggml-org/llama.cpp/pull/10810 Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 109 +++++++++++++++++++++++++++++++-- gguf-py/gguf/constants.py | 30 +++++++++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 17 ++++- 4 files changed, 152 insertions(+), 7 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 32456d5c743af..218a46372ee38 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4717,6 +4717,9 @@ def __init__(self, dir_model: Path, *args, **kwargs): with open(dir_model / "config.json", "r", encoding="utf-8") as f: hparams = json.load(f) super().__init__(dir_model, *args, hparams=hparams, **kwargs) + self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model + self.n_group = self.hparams.get("n_groups", 1) def set_vocab(self): vocab_size = self.hparams["vocab_size"] @@ -4787,10 +4790,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # (D is also unsqueezed, but for more straightforward broadcast internally) data_torch = data_torch.reshape((*data_torch.shape, 1)) elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): - d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model - n_group = self.hparams.get("n_groups", 1) - data_torch = data_torch.reshape((n_group, d_inner // n_group)) + data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group)) if name.endswith(".A_log"): logger.debug("A_log --> A ==> " + new_name) @@ -4799,6 +4799,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) +@ModelBase.register("BambaForCausalLM") +class BambaModel(Mamba2Model): + """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers""" + model_arch = gguf.MODEL_ARCH.BAMBA + undo_permute = True + + def __init__(self, *args, **kwargs): + + # Hybrid mamba models use a prefix for the mamba-specific params. + # TODO: Extend this if the prefix(es) need to be configurable + self.hparam_prefixes = ["mamba"] + + super().__init__(*args, **kwargs) + + # Use Llama conversion for attention + self._transformer_model_class: type[TextModel] = LlamaModel + + # Lists of which layers use ssm vs attention + self._attn_layers = self.hparams.get("attn_layer_indices", []) + if not self._attn_layers: + attn_period = self.hparams.get("attn_layer_period") + assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" + attn_offset = self.hparams.get("attn_layer_offset") + assert attn_offset is not None, "No attention layer offset set with attn_layer_period" + self._attn_layers = [ + i for i in range(self.block_count) + if i % attn_period == attn_offset + ] + self._ssm_layers = [ + i for i in range(self.block_count) + if i not in self._attn_layers + ] + + # n_group and d_inner are used during reshape_tensors for mamaba2 + self.d_model = self.find_hparam(["hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups"]) + self.d_inner = self.find_hparam(["expand"]) * self.d_model + + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: + prefixed = [] + for pfx in self.hparam_prefixes: + prefixed.extend( + "_".join([pfx, k]) + for k in keys + ) + keys = list(keys) + prefixed + return super().find_hparam(keys, *args, **kwargs) + + def set_gguf_parameters(self): + + ## General Params ## + self.gguf_writer.add_embedding_length(self.d_model) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + + ## Mamba mixer params ## + self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) + self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) + self.gguf_writer.add_ssm_group_count(self.n_group) + self.gguf_writer.add_ssm_inner_size(self.d_inner) + # NOTE: The mamba_dt_rank is _not_ the right field for how this is used + # in llama.cpp + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) + + ## Attention params ## + self.gguf_writer.add_attn_layer_indices(self._attn_layers) + self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) + + ## Feed Forward Params ## + self.gguf_writer.add_layer_norm_rms_eps( + self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + ) + + ## Validation ## + d_head = self.find_hparam(["d_head"], optional=True) or 64 + assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" + assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + + # Determine whether this is a mamaba layer or an attention layer + if bid in self._ssm_layers: + for mamba_new_name, data_torch in super().modify_tensors( + data_torch, name, bid + ): + yield mamba_new_name, data_torch + elif bid in self._attn_layers: + for llama_new_name, data_torch in self._transformer_model_class.modify_tensors( + self, data_torch, name, bid + ): + yield llama_new_name, data_torch + else: + yield self.map_tensor_name(name), data_torch + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cecda83a13ef5..3b2c62cbbc29b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -167,6 +167,9 @@ class SSM: GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + class HybridAttention: + ATTN_LAYER_INDICES = "{arch}.attention.layer_indices" + class WKV: HEAD_SIZE = "{arch}.wkv.head_size" @@ -322,6 +325,7 @@ class MODEL_ARCH(IntEnum): ARWKV7 = auto() MAMBA = auto() MAMBA2 = auto() + BAMBA = auto() XVERSE = auto() COMMAND_R = auto() COHERE2 = auto() @@ -607,6 +611,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA2: "mamba2", + MODEL_ARCH.BAMBA: "bamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COHERE2: "cohere2", @@ -1655,6 +1660,31 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.BAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 5a57bc20b5b76..8fd7d5ef2a7bd 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -849,6 +849,9 @@ def add_ssm_group_count(self, value: int) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_attn_layer_indices(self, values: list[int]) -> None: + self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 23ef00c3defa3..1255d223622a7 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -118,7 +118,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -281,6 +281,7 @@ class TensorNameMap: "transformer.layers.{bid}.ffn_norm", # openelm "model.layers.{bid}.post_attention_layernorm", # llama4 "transformer_encoder.{bid}.ffn_norm", # neobert + "model.layers.{bid}.pre_ff_layernorm", # bamba ), # Post feed-forward norm @@ -346,6 +347,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_fc_1", # exaone "model.layers.{bid}.feed_forward.up_proj", # llama4 "transformer_encoder.{bid}.ffn.w12", # neobert + "model.layers.{bid}.feed_forward.up_proj", # bamba ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -382,7 +384,8 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone - "model.layers.{bid}.feed_forward.gate_proj", # llama4 + "language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4 + "model.layers.{bid}.feed_forward.gate_proj", # bamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -429,6 +432,7 @@ class TensorNameMap: "model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.{bid}.feed_forward.down_proj", # llama4 "transformer_encoder.{bid}.ffn.w3", # neobert + "model.layers.{bid}.feed_forward.down_proj", # bamba ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -483,11 +487,13 @@ class TensorNameMap: MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.mamba.in_proj", # bamba ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.mamba.conv1d", # bamba ), MODEL_TENSOR.SSM_X: ( @@ -498,25 +504,30 @@ class TensorNameMap: MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.mamba.dt_proj", # bamba ), MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.mamba.A_log", # bamba ), MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.mamba.D", # bamba ), MODEL_TENSOR.SSM_NORM: ( "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.mamba.norm", # bamba ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.mamba.out_proj", # bamba ), MODEL_TENSOR.TIME_MIX_W0: ( From c43259bdeed32a5e8f004d753f91c787594dbae6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 14:50:23 -0600 Subject: [PATCH 061/117] feat: Add Granite 4 conversion This is a manual copy from my draft branch https://github.com/gabe-l-hart/llama.cpp/blob/GraniteFourDraft/convert_hf_to_gguf.py#L5076 Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 99 ++++++++++-- gguf-py/gguf/constants.py | 319 +++++++++++++++++++++----------------- 2 files changed, 261 insertions(+), 157 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 218a46372ee38..f6c9a115581be 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4817,16 +4817,7 @@ def __init__(self, *args, **kwargs): self._transformer_model_class: type[TextModel] = LlamaModel # Lists of which layers use ssm vs attention - self._attn_layers = self.hparams.get("attn_layer_indices", []) - if not self._attn_layers: - attn_period = self.hparams.get("attn_layer_period") - assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" - attn_offset = self.hparams.get("attn_layer_offset") - assert attn_offset is not None, "No attention layer offset set with attn_layer_period" - self._attn_layers = [ - i for i in range(self.block_count) - if i % attn_period == attn_offset - ] + self._attn_layers = self.get_attn_layres() self._ssm_layers = [ i for i in range(self.block_count) if i not in self._attn_layers @@ -4837,6 +4828,19 @@ def __init__(self, *args, **kwargs): self.n_group = self.find_hparam(["n_groups"]) self.d_inner = self.find_hparam(["expand"]) * self.d_model + def get_attn_layres(self) -> list[int]: + attn_layers = self.hparams.get("attn_layer_indices", []) + if not attn_layers: + attn_period = self.hparams.get("attn_layer_period") + assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" + attn_offset = self.hparams.get("attn_layer_offset") + assert attn_offset is not None, "No attention layer offset set with attn_layer_period" + attn_layers = [ + i for i in range(self.block_count) + if i % attn_period == attn_offset + ] + return attn_layers + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: prefixed = [] for pfx in self.hparam_prefixes: @@ -4867,7 +4871,8 @@ def set_gguf_parameters(self): ## Attention params ## self.gguf_writer.add_attn_layer_indices(self._attn_layers) - self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"]) + if rope_dim := self.hparams.get("attn_rotary_emb"): + self.gguf_writer.add_rope_dimension_count(rope_dim) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) @@ -6273,6 +6278,78 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("GraniteMoeHybridForCausalLM") +class GraniteMoeHybridModel(BambaModel, GraniteMoeModel): + """GraniteMoeHybrid is a hybrid SSM + MoE Attention model that uses Mamba2 + SSM layers""" + model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._transformer_model_class = GraniteMoeModel + + def get_attn_layres(self): + if layer_types := self.hparams.get("layer_types"): + return [ + i for i, typ in enumerate(layer_types) + if typ == "attention" + ] + return super().get_attn_layres() + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + + # In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert + if name.endswith("block_sparse_moe.input_linear.weight"): + ffn_dim = self.hparams["intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" + gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), + ] + if name.endswith("shared_mlp.input_linear.weight"): + ffn_dim = self.hparams["shared_intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" + gate, up = data_torch.split(ffn_dim, dim=-2) + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + ] + + return super().modify_tensors(data_torch, name, bid) + + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if attention_scale := self.hparams.get("attention_multiplier"): + self.gguf_writer.add_attention_scale(attention_scale) + logger.info("gguf: (granite) attention_scale = %s", attention_scale) + if embedding_scale := self.hparams.get("embedding_multiplier"): + self.gguf_writer.add_embedding_scale(embedding_scale) + logger.info("gguf: (granite) embedding_scale = %s", embedding_scale) + if residual_scale := self.hparams.get("residual_multiplier"): + self.gguf_writer.add_residual_scale(residual_scale) + logger.info("gguf: (granite) residual_scale = %s", residual_scale) + if logits_scale := self.hparams.get("logits_scaling"): + self.gguf_writer.add_logit_scale(logits_scale) + logger.info("gguf: (granite) logits_scale = %s", logits_scale) + if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): + self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) + logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + logger.info(f"gguf: expert count = {n_experts}") + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + logger.info(f"gguf: experts used count = {n_experts_used}") + + def set_vocab(self): + self.hparams["pad_vocab_size_multiple"] = 8 + super().set_vocab() + + @ModelBase.register("BailingMoeForCausalLM") class BailingMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3b2c62cbbc29b..358b409d6d957 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -280,79 +280,80 @@ class GGUFType: class MODEL_ARCH(IntEnum): - MMPROJ = auto() # dummy arch for clip.cpp - LLAMA = auto() - LLAMA4 = auto() - DECI = auto() - FALCON = auto() - BAICHUAN = auto() - GROK = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() - STARCODER = auto() - REFACT = auto() - BERT = auto() - NOMIC_BERT = auto() - NOMIC_BERT_MOE = auto() - NEO_BERT = auto() - JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - QWEN2VL = auto() - QWEN3 = auto() - QWEN3MOE = auto() - PHI2 = auto() - PHI3 = auto() - PHIMOE = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - MINICPM3 = auto() - GEMMA = auto() - GEMMA2 = auto() - GEMMA3 = auto() - STARCODER2 = auto() - RWKV6 = auto() - RWKV6QWEN2 = auto() - RWKV7 = auto() - ARWKV7 = auto() - MAMBA = auto() - MAMBA2 = auto() - BAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - COHERE2 = auto() - DBRX = auto() - OLMO = auto() - OLMO2 = auto() - OLMOE = auto() - OPENELM = auto() - ARCTIC = auto() - DEEPSEEK = auto() - DEEPSEEK2 = auto() - CHATGLM = auto() - GLM4 = auto() - BITNET = auto() - T5 = auto() - T5ENCODER = auto() - JAIS = auto() - NEMOTRON = auto() - EXAONE = auto() - GRANITE = auto() - GRANITE_MOE = auto() - CHAMELEON = auto() - WAVTOKENIZER_DEC = auto() - PLM = auto() - BAILINGMOE = auto() - DOTS1 = auto() - ARCEE = auto() + MMPROJ = auto() # dummy arch for clip.cpp + LLAMA = auto() + LLAMA4 = auto() + DECI = auto() + FALCON = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() + NOMIC_BERT_MOE = auto() + NEO_BERT = auto() + JINA_BERT_V2 = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + QWEN2VL = auto() + QWEN3 = auto() + QWEN3MOE = auto() + PHI2 = auto() + PHI3 = auto() + PHIMOE = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + MINICPM3 = auto() + GEMMA = auto() + GEMMA2 = auto() + GEMMA3 = auto() + STARCODER2 = auto() + RWKV6 = auto() + RWKV6QWEN2 = auto() + RWKV7 = auto() + ARWKV7 = auto() + MAMBA = auto() + MAMBA2 = auto() + BAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + COHERE2 = auto() + DBRX = auto() + OLMO = auto() + OLMO2 = auto() + OLMOE = auto() + OPENELM = auto() + ARCTIC = auto() + DEEPSEEK = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() + GLM4 = auto() + BITNET = auto() + T5 = auto() + T5ENCODER = auto() + JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() + GRANITE = auto() + GRANITE_MOE = auto() + GRANITE_MOE_HYBRID = auto() + CHAMELEON = auto() + WAVTOKENIZER_DEC = auto() + PLM = auto() + BAILINGMOE = auto() + DOTS1 = auto() + ARCEE = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -566,79 +567,80 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp - MODEL_ARCH.LLAMA: "llama", - MODEL_ARCH.LLAMA4: "llama4", - MODEL_ARCH.DECI: "deci", - MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.BAICHUAN: "baichuan", - MODEL_ARCH.GROK: "grok", - MODEL_ARCH.GPT2: "gpt2", - MODEL_ARCH.GPTJ: "gptj", - MODEL_ARCH.GPTNEOX: "gptneox", - MODEL_ARCH.MPT: "mpt", - MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.REFACT: "refact", - MODEL_ARCH.BERT: "bert", - MODEL_ARCH.NOMIC_BERT: "nomic-bert", - MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", - MODEL_ARCH.NEO_BERT: "neo-bert", - MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", - MODEL_ARCH.BLOOM: "bloom", - MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.QWEN: "qwen", - MODEL_ARCH.QWEN2: "qwen2", - MODEL_ARCH.QWEN2MOE: "qwen2moe", - MODEL_ARCH.QWEN2VL: "qwen2vl", - MODEL_ARCH.QWEN3: "qwen3", - MODEL_ARCH.QWEN3MOE: "qwen3moe", - MODEL_ARCH.PHI2: "phi2", - MODEL_ARCH.PHI3: "phi3", - MODEL_ARCH.PHIMOE: "phimoe", - MODEL_ARCH.PLAMO: "plamo", - MODEL_ARCH.CODESHELL: "codeshell", - MODEL_ARCH.ORION: "orion", - MODEL_ARCH.INTERNLM2: "internlm2", - MODEL_ARCH.MINICPM: "minicpm", - MODEL_ARCH.MINICPM3: "minicpm3", - MODEL_ARCH.GEMMA: "gemma", - MODEL_ARCH.GEMMA2: "gemma2", - MODEL_ARCH.GEMMA3: "gemma3", - MODEL_ARCH.STARCODER2: "starcoder2", - MODEL_ARCH.RWKV6: "rwkv6", - MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", - MODEL_ARCH.RWKV7: "rwkv7", - MODEL_ARCH.ARWKV7: "arwkv7", - MODEL_ARCH.MAMBA: "mamba", - MODEL_ARCH.MAMBA2: "mamba2", - MODEL_ARCH.BAMBA: "bamba", - MODEL_ARCH.XVERSE: "xverse", - MODEL_ARCH.COMMAND_R: "command-r", - MODEL_ARCH.COHERE2: "cohere2", - MODEL_ARCH.DBRX: "dbrx", - MODEL_ARCH.OLMO: "olmo", - MODEL_ARCH.OLMO2: "olmo2", - MODEL_ARCH.OLMOE: "olmoe", - MODEL_ARCH.OPENELM: "openelm", - MODEL_ARCH.ARCTIC: "arctic", - MODEL_ARCH.DEEPSEEK: "deepseek", - MODEL_ARCH.DEEPSEEK2: "deepseek2", - MODEL_ARCH.CHATGLM: "chatglm", - MODEL_ARCH.GLM4: "glm4", - MODEL_ARCH.BITNET: "bitnet", - MODEL_ARCH.T5: "t5", - MODEL_ARCH.T5ENCODER: "t5encoder", - MODEL_ARCH.JAIS: "jais", - MODEL_ARCH.NEMOTRON: "nemotron", - MODEL_ARCH.EXAONE: "exaone", - MODEL_ARCH.GRANITE: "granite", - MODEL_ARCH.GRANITE_MOE: "granitemoe", - MODEL_ARCH.CHAMELEON: "chameleon", - MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", - MODEL_ARCH.PLM: "plm", - MODEL_ARCH.BAILINGMOE: "bailingmoe", - MODEL_ARCH.DOTS1: "dots1", - MODEL_ARCH.ARCEE: "arcee", + MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.LLAMA4: "llama4", + MODEL_ARCH.DECI: "deci", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", + MODEL_ARCH.NEO_BERT: "neo-bert", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", + MODEL_ARCH.MINICPM3: "minicpm3", + MODEL_ARCH.GEMMA: "gemma", + MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", + MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", + MODEL_ARCH.RWKV7: "rwkv7", + MODEL_ARCH.ARWKV7: "arwkv7", + MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", + MODEL_ARCH.BAMBA: "bamba", + MODEL_ARCH.XVERSE: "xverse", + MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", + MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OLMO2: "olmo2", + MODEL_ARCH.OLMOE: "olmoe", + MODEL_ARCH.OPENELM: "openelm", + MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK: "deepseek", + MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.GRANITE: "granite", + MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.GRANITE_MOE_HYBRID: "granitemoehybrid", + MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.PLM: "plm", + MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.DOTS1: "dots1", + MODEL_ARCH.ARCEE: "arcee", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2054,6 +2056,31 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, ], + MODEL_ARCH.GRANITE_MOE_HYBRID: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + ], MODEL_ARCH.CHAMELEON: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, From 26816fd6c6d9b431f399918eb368b5d609c0fe44 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 14:59:29 -0600 Subject: [PATCH 062/117] feat: Plumb bamba through llama-arch Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 33 +++++++++++++++++++++++++++++++++ src/llama-arch.h | 1 + 2 files changed, 34 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fb3f0c72e87fb..712eaa31577a1 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -45,6 +45,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, + { LLM_ARCH_BAMBA, "bamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COHERE2, "cohere2" }, @@ -984,6 +985,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_BAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // non-moe FFN + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // moe FFN + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_XVERSE, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 0083dc7b28ad9..31426b901957e 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -49,6 +49,7 @@ enum llm_arch { LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, + LLM_ARCH_BAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, From b901947a3589d7919919b849997c0091f0e28c31 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 15:46:24 -0600 Subject: [PATCH 063/117] feat: Add bamba to llama_arch_is_hybrid_recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 712eaa31577a1..d5e1e8473410d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1889,6 +1889,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { // TODO: There are currently no hybrid models! Once there are, this will be // the place to identify them switch (arch) { + case LLM_ARCH_BAMBA: + return true; default: return false; } From fc56325a8f8f01e7bfea707ec0cceb2077e2de50 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 16:49:23 -0600 Subject: [PATCH 064/117] feat: Add optional mamba ssm_in bias tensor Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama-model.h b/src/llama-model.h index 3896d3314722d..0f8487d0820a2 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -254,6 +254,7 @@ struct llama_layer { // mamba bias struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + struct ggml_tensor * ssm_in_b = nullptr; // rwkv struct ggml_tensor * time_mix_w1 = nullptr; From b3453dc935a7f535989cd03e628803e7298b70b5 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 16:52:17 -0600 Subject: [PATCH 065/117] feat: Add template specialization for get_arr to load a vector for layer index arr in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model-loader.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index bd9e6da8832b7..0bd1e5d006950 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -464,6 +464,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_arr(enum llm_kv kid, std::vector & result, bool required); llama_model_loader::llama_model_loader( const std::string & fname, From 13e8d3df72afad69638834978a1ccfef7b663706 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 17:21:00 -0600 Subject: [PATCH 066/117] feat: Use an explicit bool to determine mamaba vs mamba2 This allows other architectures like bamba and granitemoehybrid to use mamab2 without a growing architecture `if` statement inside the mamba implementation. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3c3b3981e1e84..ef6231e450be0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9200,7 +9200,12 @@ struct llm_build_starcoder2 : public llm_graph_context { struct llm_build_mamba : public llm_graph_context { const llama_model & model; - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + llm_build_mamba( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_mamba2 + ) : llm_graph_context(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -9218,7 +9223,7 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (model.arch == LLM_ARCH_MAMBA2) { + if (use_mamba2) { cur = build_mamba2_layer(rs_inp, gf, cur, ubatch, il); } else { cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); @@ -14206,9 +14211,12 @@ llm_graph_result_ptr llama_model::build_graph( llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_MAMBA: + { + llm = std::make_unique(*this, params, gf, false); + } break; case LLM_ARCH_MAMBA2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params, gf, true); } break; case LLM_ARCH_XVERSE: { From b435dce241dfcf0f2ff7f28569f30c435cde135f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 18 Jun 2025 11:18:42 -0600 Subject: [PATCH 067/117] feat: Isolate mamba(2) and granite attention layer building in static methods This will allow these layer-builder methods to be used from other build structs without complex inheritance. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 204 +++++++++++++++++++++++++------------------- 1 file changed, 114 insertions(+), 90 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ef6231e450be0..21da81de91b0e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9224,9 +9224,9 @@ struct llm_build_mamba : public llm_graph_context { cb(cur, "attn_norm", il); if (use_mamba2) { - cur = build_mamba2_layer(rs_inp, gf, cur, ubatch, il); + cur = build_mamba2_layer(this, rs_inp, gf, cur, model, ubatch, il); } else { - cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); + cur = build_mamba_layer(this, rs_inp, gf, cur, model, ubatch, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9262,26 +9262,29 @@ struct llm_build_mamba : public llm_graph_context { } ggml_tensor * build_mamba_layer( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_ubatch & ubatch, - int il) const { - const auto * mctx_cur = static_cast(mctx); + const llm_graph_context * self, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) const { + const auto * mctx_cur = static_cast(self->mctx); const auto kv_head = mctx_cur->get_head(); + auto * ctx0 = self->ctx0; - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t d_conv = self->hparams.ssm_d_conv; + const int64_t d_inner = self->hparams.ssm_d_inner; + const int64_t d_state = self->hparams.ssm_d_state; + const int64_t dt_rank = self->hparams.ssm_dt_rank; const int64_t n_head = d_inner; const int64_t head_dim = 1; const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) - const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; + const bool ssm_dt_b_c_rms = self->hparams.ssm_dt_b_c_rms; // Use the same RMS norm as the final layer norm - const float norm_rms_eps = hparams.f_norm_rms_eps; + const float norm_rms_eps = self->hparams.f_norm_rms_eps; const int64_t n_seq_tokens = ubatch.n_seq_tokens; @@ -9292,14 +9295,14 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, self->hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * xz = self->build_lora_mm(model.layers[il].ssm_in, cur); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -9339,7 +9342,7 @@ struct llm_build_mamba : public llm_graph_context { // ssm { // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} - ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); + ggml_tensor * x_db = self->build_lora_mm(model.layers[il].ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); @@ -9353,7 +9356,7 @@ struct llm_build_mamba : public llm_graph_context { } // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} - dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = self->build_lora_mm(model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); cur = x; @@ -9373,7 +9376,7 @@ struct llm_build_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, self->hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9389,7 +9392,7 @@ struct llm_build_mamba : public llm_graph_context { y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = self->build_lora_mm(model.layers[il].ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -9400,21 +9403,24 @@ struct llm_build_mamba : public llm_graph_context { } ggml_tensor * build_mamba2_layer( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_ubatch & ubatch, - int il) const { - const auto * mctx_cur = static_cast(mctx); + const llm_graph_context * self, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) const { + const auto * mctx_cur = static_cast(self->mctx); const auto kv_head = mctx_cur->get_head(); + auto * ctx0 = self->ctx0; - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_head = hparams.ssm_dt_rank; + const int64_t d_conv = self->hparams.ssm_d_conv; + const int64_t d_inner = self->hparams.ssm_d_inner; + const int64_t d_state = self->hparams.ssm_d_state; + const int64_t n_head = self->hparams.ssm_dt_rank; const int64_t head_dim = d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group; + const int64_t n_group = self->hparams.ssm_n_group; const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; @@ -9426,7 +9432,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, self->hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -9435,7 +9441,7 @@ struct llm_build_mamba : public llm_graph_context { // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * zxBCdt = self->build_lora_mm(model.layers[il].ssm_in, cur); // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); @@ -9496,7 +9502,7 @@ struct llm_build_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, self->hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9513,11 +9519,11 @@ struct llm_build_mamba : public llm_graph_context { // grouped RMS norm y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + y = self->build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = self->build_lora_mm(model.layers[il].ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -12855,8 +12861,8 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; - struct llm_build_granite : public llm_graph_context { + llm_build_granite( const llama_model & model, const llm_graph_params & params, @@ -12882,8 +12888,6 @@ struct llm_build_granite : public llm_graph_context { auto * inp_attn = build_attn_inp_kv_unified(); - const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { @@ -12896,57 +12900,9 @@ struct llm_build_granite : public llm_graph_context { cb(cur, "attn_norm", il); // self-attention - { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - if (use_rope) { - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, gf, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - } + cur = build_attention_layer( + this, gf, cur, inp_pos, inp_attn, + model, n_embd_head, use_rope, il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -13042,6 +12998,74 @@ struct llm_build_granite : public llm_graph_context { ggml_build_forward_expand(gf, cur); } + + // static layer build function that enables other models to borrow this + // layer logic + static ggml_tensor * build_attention_layer( + const llm_graph_context * self, + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv_unified * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const bool use_rope, + const int il) { + + auto * ctx0 = self->ctx0; + + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = self->build_lora_mm(model.layers[il].wq, cur); + self->cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + self->cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = self->build_lora_mm(model.layers[il].wk, cur); + self->cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + self->cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = self->build_lora_mm(model.layers[il].wv, cur); + self->cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + self->cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, self->n_head, self->n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, self->n_head_kv, self->n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, self->n_head_kv, self->n_tokens); + + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(self->cparams, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale, + self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale, + self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow + ); + } + + self->cb(Qcur, "Qcur", il); + self->cb(Kcur, "Kcur", il); + self->cb(Vcur, "Vcur", il); + + const float kq_scale = self->hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : self->hparams.f_attention_scale; + cur = self->build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + self->cb(cur, "attn_out", il); + return cur; + } }; // ref: https://github.com/facebookresearch/chameleon From 3d4c36b520002494b40201e9eddfeeb0ae5df708 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 09:26:42 -0600 Subject: [PATCH 068/117] fix: Use per-layer sizes in granite build_attention_layer Also no need to pass in kv cache since it's already in the inp_attn Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 21da81de91b0e..81a865f3aa44c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13036,9 +13036,9 @@ struct llm_build_granite : public llm_graph_context { self->cb(Vcur, "Vcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, self->n_head, self->n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, self->n_head_kv, self->n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, self->n_head_kv, self->n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, self->hparams.n_head(il), self->n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, self->hparams.n_head_kv(il), self->n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, self->hparams.n_head_kv(il), self->n_tokens); if (use_rope) { ggml_tensor * rope_factors = model.get_rope_factors(self->cparams, il); From 0d28bf61d4f29c1f88485b9e5338f9b24081531a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 09:28:38 -0600 Subject: [PATCH 069/117] feat: First (broken) pass at end-to-end Bamba implementation It generates (garbage) tokens! Still lots of debugging to do. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 282 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 81a865f3aa44c..ccada815e3e05 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1472,6 +1472,49 @@ void llama_model::load_hparams(llama_model_loader & ml) { // For Granite MoE Shared ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; + case LLM_ARCH_BAMBA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false); + + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Zero-out n_head_arr and n_head_kv_arr since SSM layers don't + // have attention heads. We'll set them correctly below once we + // know which layers are attention layers + // NOTE: It's important that this happens after n_embd_head_[kv] + // are set above! + const auto n_head_attn = hparams.n_head(); + const auto n_head_kv_attn = hparams.n_head_kv(); + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + + // Attention params + std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + std::vector attn_layer_indices; + ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices); + for (const auto attn_idx : attn_layer_indices) { + GGML_ASSERT(attn_idx < hparams.n_layer); + hparams.recurrent_layer_arr[attn_idx] = false; + // Correctly set n_head and n_head_kv for attention layers + hparams.n_head_arr[attn_idx] = n_head_attn; + hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + // TODO: Add llm type label (not sure this is useful) + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_CHAMELEON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3123,6 +3166,83 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } } break; + case LLM_ARCH_BAMBA: + { + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.recurrent_layer(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + + // feed forward (w/ optional biases) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_XVERSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -13068,6 +13188,160 @@ struct llm_build_granite : public llm_graph_context { } }; +struct llm_build_hybrid_mamba : public llm_graph_context { + + llm_build_hybrid_mamba( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_mamba2 = true, + const bool use_rope = true) + : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // Build the inputs in the recurrent cache + ggml_tensor * state_copy = build_inp_s_copy(); + + // Build the inputs in the attention cache + auto * inp_attn = build_attn_inp_kv_unified(); + + // Positional embeddings populated if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + // Extract the recurrent cache from the hybrid parent + const auto * kv_recurrent = static_cast(memory)->get_child_cache(); + GGML_ASSERT(kv_recurrent); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (hparams.recurrent_layer(il)) { + // ssm layer // + if (use_mamba2) { + cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il); + } else { + cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il); + } + } else { + // attention layer // + cur = llm_build_granite::build_attention_layer( + this, gf, cur, inp_pos, inp_attn, + model, n_embd_head, use_rope, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -14355,6 +14629,13 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_BAMBA: + { + llm = std::make_unique( + *this, params, gf, + /* use_mamba2 */ true, + /* use_rope */ true); + } break; case LLM_ARCH_CHAMELEON: { llm = std::make_unique(*this, params, gf); @@ -14527,6 +14808,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_BAMBA: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: case LLM_ARCH_NEO_BERT: From ed6216a785e4b19d2de0a0e4bebaa226eeb91118 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 10:13:59 -0600 Subject: [PATCH 070/117] fix: Only do Granite multipliers if set Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ccada815e3e05..9439adbded626 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13252,7 +13252,9 @@ struct llm_build_hybrid_mamba : public llm_graph_context { } // For Granite architectures - scale residual - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -13310,7 +13312,9 @@ struct llm_build_hybrid_mamba : public llm_graph_context { } // For Granite architectures - scale residual - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -13334,7 +13338,9 @@ struct llm_build_hybrid_mamba : public llm_graph_context { cur = build_lora_mm(model.output, cur); // For Granite architectures - scale logits - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + if (hparams.f_logit_scale) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } cb(cur, "result_output", -1); res->t_logits = cur; From a6f9f90d3ad4630b4e3b2e4f05985f06cc275192 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 10:26:14 -0600 Subject: [PATCH 071/117] refactor: Pull granite ffn portion into a static function and reuse in hybrid Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 222 ++++++++++++++++++-------------------------- 1 file changed, 88 insertions(+), 134 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9439adbded626..b81332099e5ae 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13029,71 +13029,8 @@ struct llm_build_granite : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architectures - scale residual - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network (non-MoE) - if (model.layers[il].ffn_gate_inp == nullptr) { - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - - } else { - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - ggml_tensor * moe_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(moe_out, "ffn_moe_out", il); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "ffn_shexp", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); - cb(cur, "ffn_out", il); - } else { - cur = moe_out; - } - } - - // For Granite architectures - scale residual - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + // ffn + cur = build_layer_ffn(this, cur, inpSA, model, il); // input for next layer inpL = cur; @@ -13186,6 +13123,90 @@ struct llm_build_granite : public llm_graph_context { self->cb(cur, "attn_out", il); return cur; } + + // static ffn layer builder for reuse in hybrid architectures + static ggml_tensor * build_layer_ffn( + const llm_graph_context * self, + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { + + auto * ctx0 = self->ctx0; + const auto& hparams = self->hparams; + + // For Granite architectures - scale residual + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + self->cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = self->build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + self->cb(cur, "ffn_norm", il); + + cur = self->build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + self->cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = self->build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + self->cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = self->build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + self->n_expert, self->n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + self->cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = self->build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + self->cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + self->cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + cur = ggml_add(ctx0, cur, ffn_inp); + self->cb(cur, "ffn_out", il); + + cur = self->build_cvec(cur, il); + self->cb(cur, "l_out", il); + + return cur; + } }; struct llm_build_hybrid_mamba : public llm_graph_context { @@ -13251,75 +13272,8 @@ struct llm_build_hybrid_mamba : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architectures - scale residual - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network (non-MoE) - if (model.layers[il].ffn_gate_inp == nullptr) { - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - - } else { - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - ggml_tensor * moe_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(moe_out, "ffn_moe_out", il); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "ffn_shexp", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); - cb(cur, "ffn_out", il); - } else { - cur = moe_out; - } - } - - // For Granite architectures - scale residual - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + // ffn + cur = llm_build_granite::build_layer_ffn(this, cur, inpSA, model, il); // input for next layer inpL = cur; From de4d87010beb14ce2a2dcec622e472ad04a8817b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 12:06:59 -0600 Subject: [PATCH 072/117] feat(py): Allow gguf duplicate keys if they match by value and type This is helpful for hybrid models that want to do gguf param setting by calling multiple parent classes without needing to make those parent classes try/except on every attempt to set a gguf value. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- gguf-py/gguf/gguf_writer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8fd7d5ef2a7bd..998e84abc16c7 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -270,7 +270,14 @@ def write_ti_data_to_file(self) -> None: self.state = WriterState.TI_DATA def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None: - if any(key in kv_data for kv_data in self.kv_data): + # Disallow duplicate keys if they differ by value or type + if any( + ( + key in kv_data and + (kv_data[key].value != val or kv_data[key].type != vtype) + ) + for kv_data in self.kv_data + ): logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}') self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type) From 7c2b0b80ceaacd53d7a7a4fea5af53359634d983 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 12:07:26 -0600 Subject: [PATCH 073/117] refactor(py): Simplify granitemoehybrid conversion to use parents better Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 53 ++++++------------------------------------- 1 file changed, 7 insertions(+), 46 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f6c9a115581be..217814aba2090 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6284,10 +6284,6 @@ class GraniteMoeHybridModel(BambaModel, GraniteMoeModel): SSM layers""" model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._transformer_model_class = GraniteMoeModel - def get_attn_layres(self): if layer_types := self.hparams.get("layer_types"): return [ @@ -6299,51 +6295,16 @@ def get_attn_layres(self): def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None ) -> Iterable[tuple[str, Tensor]]: - - # In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert - if name.endswith("block_sparse_moe.input_linear.weight"): - ffn_dim = self.hparams["intermediate_size"] - assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" - gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), - ] - if name.endswith("shared_mlp.input_linear.weight"): - ffn_dim = self.hparams["shared_intermediate_size"] - assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" - gate, up = data_torch.split(ffn_dim, dim=-2) - return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), - ] - + if ( + name.endswith("block_sparse_moe.input_linear.weight") or + name.endswith("shared_mlp.input_linear.weight") + ): + return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) return super().modify_tensors(data_torch, name, bid) - def set_gguf_parameters(self): - super().set_gguf_parameters() - if attention_scale := self.hparams.get("attention_multiplier"): - self.gguf_writer.add_attention_scale(attention_scale) - logger.info("gguf: (granite) attention_scale = %s", attention_scale) - if embedding_scale := self.hparams.get("embedding_multiplier"): - self.gguf_writer.add_embedding_scale(embedding_scale) - logger.info("gguf: (granite) embedding_scale = %s", embedding_scale) - if residual_scale := self.hparams.get("residual_multiplier"): - self.gguf_writer.add_residual_scale(residual_scale) - logger.info("gguf: (granite) residual_scale = %s", residual_scale) - if logits_scale := self.hparams.get("logits_scaling"): - self.gguf_writer.add_logit_scale(logits_scale) - logger.info("gguf: (granite) logits_scale = %s", logits_scale) - if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): - self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) - logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) - if (n_experts := self.hparams.get("num_local_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) - logger.info(f"gguf: expert count = {n_experts}") - if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: - self.gguf_writer.add_expert_used_count(n_experts_used) - logger.info(f"gguf: experts used count = {n_experts_used}") + GraniteMoeModel.set_gguf_parameters(self) + BambaModel.set_gguf_parameters(self) def set_vocab(self): self.hparams["pad_vocab_size_multiple"] = 8 From 915f1e3f00ce5176f5ce50fe81124576e5cc87d8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 13:09:49 -0600 Subject: [PATCH 074/117] feat: Add GRANITE_MOE_HYBRID through llama-arch Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 180 +++++++++++++++++++++++++++------------------ src/llama-arch.h | 1 + 2 files changed, 108 insertions(+), 73 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d5e1e8473410d..71e156c2c323c 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -5,79 +5,80 @@ #include static const std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_LLAMA4, "llama4" }, - { LLM_ARCH_DECI, "deci" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GROK, "grok" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BERT, "bert" }, - { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, - { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, - { LLM_ARCH_NEO_BERT, "neo-bert" }, - { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, - { LLM_ARCH_BLOOM, "bloom" }, - { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_QWEN2, "qwen2" }, - { LLM_ARCH_QWEN2MOE, "qwen2moe" }, - { LLM_ARCH_QWEN2VL, "qwen2vl" }, - { LLM_ARCH_QWEN3, "qwen3" }, - { LLM_ARCH_QWEN3MOE, "qwen3moe" }, - { LLM_ARCH_PHI2, "phi2" }, - { LLM_ARCH_PHI3, "phi3" }, - { LLM_ARCH_PHIMOE, "phimoe" }, - { LLM_ARCH_PLAMO, "plamo" }, - { LLM_ARCH_CODESHELL, "codeshell" }, - { LLM_ARCH_ORION, "orion" }, - { LLM_ARCH_INTERNLM2, "internlm2" }, - { LLM_ARCH_MINICPM, "minicpm" }, - { LLM_ARCH_MINICPM3, "minicpm3" }, - { LLM_ARCH_GEMMA, "gemma" }, - { LLM_ARCH_GEMMA2, "gemma2" }, - { LLM_ARCH_GEMMA3, "gemma3" }, - { LLM_ARCH_STARCODER2, "starcoder2" }, - { LLM_ARCH_MAMBA, "mamba" }, - { LLM_ARCH_MAMBA2, "mamba2" }, - { LLM_ARCH_BAMBA, "bamba" }, - { LLM_ARCH_XVERSE, "xverse" }, - { LLM_ARCH_COMMAND_R, "command-r" }, - { LLM_ARCH_COHERE2, "cohere2" }, - { LLM_ARCH_DBRX, "dbrx" }, - { LLM_ARCH_OLMO, "olmo" }, - { LLM_ARCH_OLMO2, "olmo2" }, - { LLM_ARCH_OLMOE, "olmoe" }, - { LLM_ARCH_OPENELM, "openelm" }, - { LLM_ARCH_ARCTIC, "arctic" }, - { LLM_ARCH_DEEPSEEK, "deepseek" }, - { LLM_ARCH_DEEPSEEK2, "deepseek2" }, - { LLM_ARCH_CHATGLM, "chatglm" }, - { LLM_ARCH_GLM4, "glm4" }, - { LLM_ARCH_BITNET, "bitnet" }, - { LLM_ARCH_T5, "t5" }, - { LLM_ARCH_T5ENCODER, "t5encoder" }, - { LLM_ARCH_JAIS, "jais" }, - { LLM_ARCH_NEMOTRON, "nemotron" }, - { LLM_ARCH_EXAONE, "exaone" }, - { LLM_ARCH_RWKV6, "rwkv6" }, - { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, - { LLM_ARCH_RWKV7, "rwkv7" }, - { LLM_ARCH_ARWKV7, "arwkv7" }, - { LLM_ARCH_GRANITE, "granite" }, - { LLM_ARCH_GRANITE_MOE, "granitemoe" }, - { LLM_ARCH_CHAMELEON, "chameleon" }, - { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, - { LLM_ARCH_PLM, "plm" }, - { LLM_ARCH_BAILINGMOE, "bailingmoe" }, - { LLM_ARCH_DOTS1, "dots1" }, - { LLM_ARCH_ARCEE, "arcee" }, - { LLM_ARCH_UNKNOWN, "(unknown)" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, + { LLM_ARCH_NEO_BERT, "neo-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, + { LLM_ARCH_BAMBA, "bamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, + { LLM_ARCH_RWKV7, "rwkv7" }, + { LLM_ARCH_ARWKV7, "arwkv7" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_GRANITE_MOE_HYBRID, "granitemoehybrid" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, + { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_DOTS1, "dots1" }, + { LLM_ARCH_ARCEE, "arcee" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; static const std::map LLM_KV_NAMES = { @@ -1577,6 +1578,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_GRANITE_MOE_HYBRID, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // moe FFN + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + // shared expert + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_CHAMELEON, { @@ -1890,6 +1923,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { // the place to identify them switch (arch) { case LLM_ARCH_BAMBA: + case LLM_ARCH_GRANITE_MOE_HYBRID: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 31426b901957e..981410e75f39f 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -75,6 +75,7 @@ enum llm_arch { LLM_ARCH_ARWKV7, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, + LLM_ARCH_GRANITE_MOE_HYBRID, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, From d0d3723a497d106a986d61ec6957eb3ca3ac5bca Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 13:11:28 -0600 Subject: [PATCH 075/117] feat: Support GRANITE_MOE_HYBRID in llama-model This re-uses the Bamba code paths heavily and simply adds the missing parts for loading MoE and the shared expert. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 49 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b81332099e5ae..059a95f399466 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1473,6 +1473,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_BAMBA: + case LLM_ARCH_GRANITE_MOE_HYBRID: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); @@ -1514,6 +1515,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // TODO: Add llm type label (not sure this is useful) default: type = LLM_TYPE_UNKNOWN; } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_CHAMELEON: { @@ -3167,6 +3171,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_BAMBA: + case LLM_ARCH_GRANITE_MOE_HYBRID: { // mamba2 Mixer SSM params // NOTE: int64_t for tensor dimensions @@ -3233,14 +3238,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // feed forward (w/ optional biases) - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (n_expert > 0) { + // MoE FFN + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } else { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_XVERSE: @@ -4781,7 +4803,9 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE) { + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_MOE_HYBRID || + arch == LLM_ARCH_BAMBA) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -14589,6 +14613,12 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GRANITE_MOE_HYBRID: + { + llm = std::make_unique(*this, params, gf, + /* use_mamba2 */ true, + /* use_rope */ false); + } break; case LLM_ARCH_BAMBA: { llm = std::make_unique( @@ -14768,6 +14798,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_GRANITE_MOE_HYBRID: case LLM_ARCH_BAMBA: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: From 2ca34162f9461bb480e60180cc81dd52f6bb3e4d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 14:18:19 -0600 Subject: [PATCH 076/117] style: Fix flake8 errors Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 4 ++-- gguf-py/gguf/gguf_writer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 217814aba2090..26ecf6f2acc44 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6296,8 +6296,8 @@ def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None ) -> Iterable[tuple[str, Tensor]]: if ( - name.endswith("block_sparse_moe.input_linear.weight") or - name.endswith("shared_mlp.input_linear.weight") + name.endswith("block_sparse_moe.input_linear.weight") + or name.endswith("shared_mlp.input_linear.weight") ): return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 998e84abc16c7..121a1bcc785bd 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -273,8 +273,8 @@ def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUF # Disallow duplicate keys if they differ by value or type if any( ( - key in kv_data and - (kv_data[key].value != val or kv_data[key].type != vtype) + key in kv_data + and (kv_data[key].value != val or kv_data[key].type != vtype) ) for kv_data in self.kv_data ): From 3c22e1def2c79000d49da1920dde50c8dc1e336c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 28 May 2025 13:28:09 -0600 Subject: [PATCH 077/117] fix: Fix recurrent cache get after rebase Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 059a95f399466..39a105fc20036 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13263,7 +13263,7 @@ struct llm_build_hybrid_mamba : public llm_graph_context { } // Extract the recurrent cache from the hybrid parent - const auto * kv_recurrent = static_cast(memory)->get_child_cache(); + const auto * kv_recurrent = static_cast(memory)->get_kv_recurrent(); GGML_ASSERT(kv_recurrent); for (int il = 0; il < n_layer; ++il) { From 08493bfffd63a39d8251ca0be0aaf30fbf483718 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 29 May 2025 16:07:07 -0600 Subject: [PATCH 078/117] fix: Fix hybrid granite implementation for signature changes in build_mamba*_layer Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 39a105fc20036..27d373ca2e788 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13262,10 +13262,6 @@ struct llm_build_hybrid_mamba : public llm_graph_context { inp_pos = build_inp_pos(); } - // Extract the recurrent cache from the hybrid parent - const auto * kv_recurrent = static_cast(memory)->get_kv_recurrent(); - GGML_ASSERT(kv_recurrent); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -13278,9 +13274,9 @@ struct llm_build_hybrid_mamba : public llm_graph_context { if (hparams.recurrent_layer(il)) { // ssm layer // if (use_mamba2) { - cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il); + cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il); } else { - cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il); + cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, model, ubatch, il); } } else { // attention layer // From ed150125d9778ec92c25bd64e258397923d04892 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 26 Jun 2025 09:53:48 -0600 Subject: [PATCH 079/117] refactor: Refactor relationship between non-hybrid classes and hybrid impl to use mixins The challenge here is to give both the non-hybrid classes (llm_build_mamba and llm_build_granite) AND the hybrid class (llm_build_hybrid_mamba) access to the same intermediate "base class" functionality (build_mamba*_layer, build_granite_attention_layer) without running into trouble with diamond inheritance of llm_graph_context. Due to the non-trivial initialization that happens in llm_graph_context, diamond inheritance results in multiple initializations of the common base which cause problems around the unique ptrs. I wanted to get away from `self->` everywhere, but this is still a bit cleaner than making those methods static I think. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 377 +++++++++++++++++++++++--------------------- 1 file changed, 199 insertions(+), 178 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 27d373ca2e788..362ac34625250 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3204,7 +3204,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.recurrent_layer(i)) { + if (hparams.is_recurrent(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED); @@ -9341,94 +9341,39 @@ struct llm_build_starcoder2 : public llm_graph_context { } }; -struct llm_build_mamba : public llm_graph_context { - const llama_model & model; - - llm_build_mamba( - const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_mamba2 - ) : llm_graph_context(params), model(model) { - ggml_tensor * cur; - ggml_tensor * inpL; - - // {n_embd, n_tokens} - inpL = build_inp_embd(model.tok_embd); - - auto * rs_inp = build_rs_inp(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - if (use_mamba2) { - cur = build_mamba2_layer(this, rs_inp, gf, cur, model, ubatch, il); - } else { - cur = build_mamba_layer(this, rs_inp, gf, cur, model, ubatch, il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } +// Mixin class to allow graph builders to use mamba layer construction +struct llm_build_mamba_mixin { - // residual - cur = ggml_add(ctx0, cur, inpL); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - // final rmsnorm - cur = build_norm(inpL, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; + llm_graph_context * self; - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } + llm_build_mamba_mixin(llm_graph_context * self) : self(self) {} + // static layer build function that enables other models to borrow this + // layer logic ggml_tensor * build_mamba_layer( - const llm_graph_context * self, llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, - int il) const { - const auto * mctx_cur = static_cast(self->mctx); + int il) { + const auto * mctx_cur = inp->mctx; const auto kv_head = mctx_cur->get_head(); auto * ctx0 = self->ctx0; + const auto & hparams = self->hparams; - const int64_t d_conv = self->hparams.ssm_d_conv; - const int64_t d_inner = self->hparams.ssm_d_inner; - const int64_t d_state = self->hparams.ssm_d_state; - const int64_t dt_rank = self->hparams.ssm_dt_rank; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; const int64_t n_head = d_inner; const int64_t head_dim = 1; const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) - const bool ssm_dt_b_c_rms = self->hparams.ssm_dt_b_c_rms; + const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; // Use the same RMS norm as the final layer norm - const float norm_rms_eps = self->hparams.f_norm_rms_eps; + const float norm_rms_eps = hparams.f_norm_rms_eps; const int64_t n_seq_tokens = ubatch.n_seq_tokens; @@ -9439,7 +9384,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, self->hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -9520,7 +9465,7 @@ struct llm_build_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, self->hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9546,25 +9491,27 @@ struct llm_build_mamba : public llm_graph_context { return cur; } + // static layer build function that enables other models to borrow this + // layer logic ggml_tensor * build_mamba2_layer( - const llm_graph_context * self, llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const { - const auto * mctx_cur = static_cast(self->mctx); + const auto * mctx_cur = inp->mctx; const auto kv_head = mctx_cur->get_head(); auto * ctx0 = self->ctx0; + const auto & hparams = self->hparams; - const int64_t d_conv = self->hparams.ssm_d_conv; - const int64_t d_inner = self->hparams.ssm_d_inner; - const int64_t d_state = self->hparams.ssm_d_state; - const int64_t n_head = self->hparams.ssm_dt_rank; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; const int64_t head_dim = d_inner / n_head; - const int64_t n_group = self->hparams.ssm_n_group; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; @@ -9576,7 +9523,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, self->hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -9646,7 +9593,7 @@ struct llm_build_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, self->hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9678,6 +9625,72 @@ struct llm_build_mamba : public llm_graph_context { } }; +struct llm_build_mamba : public llm_graph_context, public llm_build_mamba_mixin { + const llama_model & model; + + llm_build_mamba( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_mamba2 + ) : llm_graph_context(params), llm_build_mamba_mixin(this), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + auto * rs_inp = build_rs_inp(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (use_mamba2) { + cur = build_mamba2_layer(this, rs_inp, gf, cur, model, ubatch, il); + } else { + cur = build_mamba_layer(this, rs_inp, gf, cur, model, ubatch, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +}; + struct llm_build_command_r : public llm_graph_context { llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -13005,85 +13018,16 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; -struct llm_build_granite : public llm_graph_context { - - llm_build_granite( - const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_rope = true) - : llm_graph_context(params) { - - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - built only if rope enabled - ggml_tensor * inp_pos = nullptr; - if (use_rope) { - inp_pos = build_inp_pos(); - } - - auto * inp_attn = build_attn_inp_kv_unified(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - cur = build_attention_layer( - this, gf, cur, inp_pos, inp_attn, - model, n_embd_head, use_rope, il); - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - // ffn - cur = build_layer_ffn(this, cur, inpSA, model, il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; +// Mixin class to give builders access to a common granite layer builder +struct llm_build_granite_mixin { - // lm_head - cur = build_lora_mm(model.output, cur); + llm_graph_context * self; - // For Granite architectures - scale logits - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } + llm_build_granite_mixin(llm_graph_context * self) : self(self) {} // static layer build function that enables other models to borrow this // layer logic - static ggml_tensor * build_attention_layer( - const llm_graph_context * self, + ggml_tensor * build_granite_attention_layer( ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -13094,6 +13038,8 @@ struct llm_build_granite : public llm_graph_context { const int il) { auto * ctx0 = self->ctx0; + const auto & hparams = self->hparams; + const auto & cparams = self->cparams; // compute Q and K and (optionally) RoPE them ggml_tensor * Qcur = self->build_lora_mm(model.layers[il].wq, cur); @@ -13117,12 +13063,12 @@ struct llm_build_granite : public llm_graph_context { self->cb(Vcur, "Vcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, self->hparams.n_head(il), self->n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, self->hparams.n_head_kv(il), self->n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, self->hparams.n_head_kv(il), self->n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), self->n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), self->n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), self->n_tokens); if (use_rope) { - ggml_tensor * rope_factors = model.get_rope_factors(self->cparams, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale, @@ -13140,7 +13086,7 @@ struct llm_build_granite : public llm_graph_context { self->cb(Kcur, "Kcur", il); self->cb(Vcur, "Vcur", il); - const float kq_scale = self->hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : self->hparams.f_attention_scale; + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = self->build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); @@ -13149,15 +13095,14 @@ struct llm_build_granite : public llm_graph_context { } // static ffn layer builder for reuse in hybrid architectures - static ggml_tensor * build_layer_ffn( - const llm_graph_context * self, + ggml_tensor * build_layer_ffn( ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il) { auto * ctx0 = self->ctx0; - const auto& hparams = self->hparams; + const auto & hparams = self->hparams; // For Granite architectures - scale residual if (hparams.f_residual_scale) { @@ -13233,29 +13178,102 @@ struct llm_build_granite : public llm_graph_context { } }; -struct llm_build_hybrid_mamba : public llm_graph_context { +struct llm_build_granite : public llm_graph_context, public llm_build_granite_mixin { - llm_build_hybrid_mamba( + llm_build_granite( const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf, - const bool use_mamba2 = true, const bool use_rope = true) - : llm_graph_context(params) { + : llm_graph_context(params), llm_build_granite_mixin(this) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - // Build the inputs in the recurrent cache - ggml_tensor * state_copy = build_inp_s_copy(); + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } - // Build the inputs in the attention cache auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + cur = build_granite_attention_layer( + gf, cur, inp_pos, inp_attn, + model, n_embd_head, use_rope, il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // ffn + cur = build_layer_ffn(cur, inpSA, model, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_hybrid_mamba : public llm_graph_context, public llm_build_mamba_mixin, public llm_build_granite_mixin { + + llm_build_hybrid_mamba( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_mamba2 = true, + const bool use_rope = true) + : llm_graph_context(params), llm_build_mamba_mixin(this), llm_build_granite_mixin(this) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // Build the inputs in the recurrent and attention caches + auto * inp = build_inp_mem_hybrid(); + // Positional embeddings populated if rope enabled ggml_tensor * inp_pos = nullptr; if (use_rope) { @@ -13271,18 +13289,21 @@ struct llm_build_hybrid_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.recurrent_layer(il)) { + // NOTE: Broken during rebase!!! The structure of inp changed in the + // last round, but the next commit will undo the need for this, so + // not worth fixing correctly. + if (hparams.is_recurrent(il)) { // ssm layer // if (use_mamba2) { - cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il); + cur = build_mamba2_layer(inp, gf, cur, model, ubatch, il); } else { - cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, model, ubatch, il); + cur = build_mamba_layer(inp, gf, cur, model, ubatch, il); } } else { // attention layer // - cur = llm_build_granite::build_attention_layer( - this, gf, cur, inp_pos, inp_attn, - model, n_embd_head, use_rope, il); + cur = build_granite_attention_layer( + gf, cur, inp_pos, inp, model, + n_embd_head, use_rope, il); } if (il == n_layer - 1) { @@ -13293,7 +13314,7 @@ struct llm_build_hybrid_mamba : public llm_graph_context { } // ffn - cur = llm_build_granite::build_layer_ffn(this, cur, inpSA, model, il); + cur = build_layer_ffn(cur, inpSA, model, il); // input for next layer inpL = cur; From 40e23469deaabf151a4880d89881865adfcfd7b1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 26 Jun 2025 10:00:09 -0600 Subject: [PATCH 080/117] refactor: Implement the full copy-paste version to duplicate the layer builders This follows the pattern where the type of input is pinned to the type of memory and that is used to dispatch to the correct version of `build_rs` / `build_attn`. There's a lot of code duplication that can hopefully be pulled into common functions in the graph later. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 708 +++++++++++++++++++++++++++++--------------- 1 file changed, 466 insertions(+), 242 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 362ac34625250..e3b1a10ff9b9c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9341,27 +9341,74 @@ struct llm_build_starcoder2 : public llm_graph_context { } }; -// Mixin class to allow graph builders to use mamba layer construction -struct llm_build_mamba_mixin { +struct llm_build_mamba : public llm_graph_context { + const llama_model & model; + + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; - llm_graph_context * self; + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); - llm_build_mamba_mixin(llm_graph_context * self) : self(self) {} + auto * inp = build_rs_inp(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (model.arch == LLM_ARCH_MAMBA2) { + cur = build_mamba2_layer(inp, gf, cur, ubatch, il); + } else { + cur = build_mamba_layer(inp, gf, cur, ubatch, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } - // static layer build function that enables other models to borrow this - // layer logic ggml_tensor * build_mamba_layer( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_model & model, - const llama_ubatch & ubatch, - int il) { - const auto * mctx_cur = inp->mctx; + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { + const auto * mctx_cur = static_cast(mctx); const auto kv_head = mctx_cur->get_head(); - auto * ctx0 = self->ctx0; - const auto & hparams = self->hparams; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -9384,14 +9431,14 @@ struct llm_build_mamba_mixin { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = self->build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -9431,7 +9478,7 @@ struct llm_build_mamba_mixin { // ssm { // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} - ggml_tensor * x_db = self->build_lora_mm(model.layers[il].ssm_x, x); + ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); @@ -9445,7 +9492,7 @@ struct llm_build_mamba_mixin { } // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} - dt = self->build_lora_mm(model.layers[il].ssm_dt, dt); + dt = build_lora_mm(model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); cur = x; @@ -9465,7 +9512,7 @@ struct llm_build_mamba_mixin { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9481,7 +9528,7 @@ struct llm_build_mamba_mixin { y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = self->build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(model.layers[il].ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -9491,20 +9538,15 @@ struct llm_build_mamba_mixin { return cur; } - // static layer build function that enables other models to borrow this - // layer logic ggml_tensor * build_mamba2_layer( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_model & model, - const llama_ubatch & ubatch, - int il) const { - const auto * mctx_cur = inp->mctx; + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { + const auto * mctx_cur = static_cast(mctx); const auto kv_head = mctx_cur->get_head(); - auto * ctx0 = self->ctx0; - const auto & hparams = self->hparams; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -9523,7 +9565,7 @@ struct llm_build_mamba_mixin { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = self->build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -9532,7 +9574,7 @@ struct llm_build_mamba_mixin { // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = self->build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); @@ -9593,7 +9635,7 @@ struct llm_build_mamba_mixin { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = self->build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9610,11 +9652,11 @@ struct llm_build_mamba_mixin { // grouped RMS norm y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = self->build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = self->build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(model.layers[il].ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -9625,72 +9667,6 @@ struct llm_build_mamba_mixin { } }; -struct llm_build_mamba : public llm_graph_context, public llm_build_mamba_mixin { - const llama_model & model; - - llm_build_mamba( - const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_mamba2 - ) : llm_graph_context(params), llm_build_mamba_mixin(this), model(model) { - ggml_tensor * cur; - ggml_tensor * inpL; - - // {n_embd, n_tokens} - inpL = build_inp_embd(model.tok_embd); - - auto * rs_inp = build_rs_inp(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - if (use_mamba2) { - cur = build_mamba2_layer(this, rs_inp, gf, cur, model, ubatch, il); - } else { - cur = build_mamba_layer(this, rs_inp, gf, cur, model, ubatch, il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - - // residual - cur = ggml_add(ctx0, cur, inpL); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - // final rmsnorm - cur = build_norm(inpL, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } - -}; - struct llm_build_command_r : public llm_graph_context { llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -13018,12 +12994,80 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; -// Mixin class to give builders access to a common granite layer builder -struct llm_build_granite_mixin { +struct llm_build_granite : public llm_graph_context { + + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; - llm_graph_context * self; + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - llm_build_granite_mixin(llm_graph_context * self) : self(self) {} + // self-attention + cur = build_granite_attention_layer( + gf, cur, inp_pos, inp_attn, + model, n_embd_head, use_rope, il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // ffn + cur = build_layer_ffn(cur, inpSA, model, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } // static layer build function that enables other models to borrow this // layer logic @@ -13037,60 +13081,56 @@ struct llm_build_granite_mixin { const bool use_rope, const int il) { - auto * ctx0 = self->ctx0; - const auto & hparams = self->hparams; - const auto & cparams = self->cparams; - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = self->build_lora_mm(model.layers[il].wq, cur); - self->cb(Qcur, "Qcur", il); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - self->cb(Qcur, "Qcur", il); + cb(Qcur, "Qcur", il); } - ggml_tensor * Kcur = self->build_lora_mm(model.layers[il].wk, cur); - self->cb(Kcur, "Kcur", il); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - self->cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il); } - ggml_tensor * Vcur = self->build_lora_mm(model.layers[il].wv, cur); - self->cb(Vcur, "Vcur", il); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - self->cb(Vcur, "Vcur", il); + cb(Vcur, "Vcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), self->n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), self->n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), self->n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); if (use_rope) { ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, - self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale, - self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, - self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale, - self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow ); } - self->cb(Qcur, "Qcur", il); - self->cb(Kcur, "Kcur", il); - self->cb(Vcur, "Vcur", il); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - cur = self->build_attn(inp_attn, gf, + cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); - self->cb(cur, "attn_out", il); + cb(cur, "attn_out", il); return cur; } @@ -13101,64 +13141,61 @@ struct llm_build_granite_mixin { const llama_model & model, const int il) { - auto * ctx0 = self->ctx0; - const auto & hparams = self->hparams; - // For Granite architectures - scale residual if (hparams.f_residual_scale) { cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - self->cb(ffn_inp, "ffn_inp", il); + cb(ffn_inp, "ffn_inp", il); // feed-forward network (non-MoE) if (model.layers[il].ffn_gate_inp == nullptr) { - cur = self->build_norm(ffn_inp, + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - self->cb(cur, "ffn_norm", il); + cb(cur, "ffn_norm", il); - cur = self->build_ffn(cur, + cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - self->cb(cur, "ffn_out", il); + cb(cur, "ffn_out", il); } else { // MoE branch - cur = self->build_norm(ffn_inp, + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - self->cb(cur, "ffn_norm", il); + cb(cur, "ffn_norm", il); - ggml_tensor * moe_out = self->build_moe_ffn(cur, + ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, - self->n_expert, self->n_expert_used, + n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - self->cb(moe_out, "ffn_moe_out", il); + cb(moe_out, "ffn_moe_out", il); // For Granite MoE Shared if (hparams.n_ff_shexp > 0) { - ggml_tensor * ffn_shexp = self->build_ffn(cur, + ggml_tensor * ffn_shexp = build_ffn(cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - self->cb(ffn_shexp, "ffn_shexp", il); + cb(ffn_shexp, "ffn_shexp", il); cur = ggml_add(ctx0, moe_out, ffn_shexp); - self->cb(cur, "ffn_out", il); + cb(cur, "ffn_out", il); } else { cur = moe_out; } @@ -13169,46 +13206,43 @@ struct llm_build_granite_mixin { cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); } cur = ggml_add(ctx0, cur, ffn_inp); - self->cb(cur, "ffn_out", il); + cb(cur, "ffn_out", il); - cur = self->build_cvec(cur, il); - self->cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); return cur; } }; -struct llm_build_granite : public llm_graph_context, public llm_build_granite_mixin { +struct llm_build_hybrid_mamba : public llm_graph_context { - llm_build_granite( + const llama_model & model; + + llm_build_hybrid_mamba( const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf, const bool use_rope = true) - : llm_graph_context(params), llm_build_granite_mixin(this) { - + : llm_graph_context(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - // inp_pos - built only if rope enabled + auto * inp = build_inp_mem_hybrid(); + + // Positional embeddings populated if rope enabled ggml_tensor * inp_pos = nullptr; if (use_rope) { inp_pos = build_inp_pos(); } - auto * inp_attn = build_attn_inp_kv_unified(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; + struct ggml_tensor * inpSA = inpL; // norm cur = build_norm(inpL, @@ -13216,12 +13250,19 @@ struct llm_build_granite : public llm_graph_context, public llm_build_granite_mi LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - // self-attention - cur = build_granite_attention_layer( - gf, cur, inp_pos, inp_attn, - model, n_embd_head, use_rope, il); + if (hparams.is_recurrent(il)) { + // ssm layer // + cur = build_mamba2_layer(inp, gf, cur, ubatch, il); + } else { + // attention layer // + cur = build_granite_attention_layer( + gf, cur, inp_pos, inp, model, + n_embd_head, use_rope, il); + } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -13246,100 +13287,286 @@ struct llm_build_granite : public llm_graph_context, public llm_build_granite_mi cur = build_lora_mm(model.output, cur); // For Granite architectures - scale logits - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + if (hparams.f_logit_scale) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } -}; -struct llm_build_hybrid_mamba : public llm_graph_context, public llm_build_mamba_mixin, public llm_build_granite_mixin { + ggml_tensor * build_mamba2_layer( + llm_graph_input_mem_hybrid * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { + const auto * mctx_cur = static_cast(mctx)->get_recr(); - llm_build_hybrid_mamba( - const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_mamba2 = true, - const bool use_rope = true) - : llm_graph_context(params), llm_build_mamba_mixin(this), llm_build_granite_mixin(this) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const auto kv_head = mctx_cur->get_head(); - ggml_tensor * cur; - ggml_tensor * inpL; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; - inpL = build_inp_embd(model.tok_embd); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; - // Build the inputs in the recurrent and attention caches - auto * inp = build_inp_mem_hybrid(); + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - // Positional embeddings populated if rope enabled - ggml_tensor * inp_pos = nullptr; - if (use_rope) { - inp_pos = build_inp_pos(); + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + // split the above in three + ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx0, xBC); } - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; + // ssm + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - // NOTE: Broken during rebase!!! The structure of inp changed in the - // last round, but the next commit will undo the need for this, so - // not worth fixing correctly. - if (hparams.is_recurrent(il)) { - // ssm layer // - if (use_mamba2) { - cur = build_mamba2_layer(inp, gf, cur, model, ubatch, il); - } else { - cur = build_mamba_layer(inp, gf, cur, model, ubatch, il); - } - } else { - // attention layer // - cur = build_granite_attention_layer( - gf, cur, inp_pos, inp, model, - n_embd_head, use_rope, il); - } + ggml_tensor * A = model.layers[il].ssm_a; - if (il == n_layer - 1) { - // skip computing output for unused tokens - ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } + // use the states and the indices provided by build_rs + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - // ffn - cur = build_layer_ffn(cur, inpSA, model, il); + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; - // input for next layer - inpL = cur; + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // grouped RMS norm + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); } - cur = inpL; + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + // cb(cur, "mamba_out", il); - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); + return cur; + } - cb(cur, "result_norm", -1); - res->t_embd = cur; + // static layer build function that enables other models to borrow this + // layer logic + ggml_tensor * build_granite_attention_layer( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_mem_hybrid * inp, + const llama_model & model, + const int64_t n_embd_head, + const bool use_rope, + const int il) { - // lm_head - cur = build_lora_mm(model.output, cur); + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } - // For Granite architectures - scale logits - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); } - cb(cur, "result_output", -1); - res->t_logits = cur; - ggml_build_forward_expand(gf, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + cur = build_attn(inp, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + return cur; + } + + // static ffn layer builder for reuse in hybrid architectures + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { + + // For Granite architectures - scale residual + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; } }; @@ -14511,11 +14738,11 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_MAMBA: { - llm = std::make_unique(*this, params, gf, false); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_MAMBA2: { - llm = std::make_unique(*this, params, gf, true); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_XVERSE: { @@ -14633,14 +14860,11 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_GRANITE_MOE_HYBRID: { llm = std::make_unique(*this, params, gf, - /* use_mamba2 */ true, /* use_rope */ false); } break; case LLM_ARCH_BAMBA: { - llm = std::make_unique( - *this, params, gf, - /* use_mamba2 */ true, + llm = std::make_unique(*this, params, gf, /* use_rope */ true); } break; case LLM_ARCH_CHAMELEON: From a9dcc8452b8215df38b9bee56be2e364dd7eb7a6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 26 Jun 2025 10:00:47 -0600 Subject: [PATCH 081/117] refactor: Rename llm_build_hybrid_mamba -> llm_build_granite_hybrid I've got back-and-forth a lot about how/if to try to implement reuse of the "child model" layer types for hybrid models. At the end of the day, I think hybrid models are their own beast and even if their layers are inspired by other models, they should maintain control of their own layer building (in other words, the copy-paste method). Given that, the name should reflect that this is not a generic hybrid model builder, but rather a granite- specific hybrid model builder that can do MoE (granite 4) or dense (bamba). As part if this, I also cleaned up dangling comments from previous attempts at using static methods for reusability. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3b1a10ff9b9c..13f1ec5bdf9ac 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13069,8 +13069,6 @@ struct llm_build_granite : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - // static layer build function that enables other models to borrow this - // layer logic ggml_tensor * build_granite_attention_layer( ggml_cgraph * gf, ggml_tensor * cur, @@ -13134,7 +13132,6 @@ struct llm_build_granite : public llm_graph_context { return cur; } - // static ffn layer builder for reuse in hybrid architectures ggml_tensor * build_layer_ffn( ggml_tensor * cur, ggml_tensor * inpSA, @@ -13215,16 +13212,17 @@ struct llm_build_granite : public llm_graph_context { } }; -struct llm_build_hybrid_mamba : public llm_graph_context { +struct llm_build_granite_hybrid : public llm_graph_context { const llama_model & model; - llm_build_hybrid_mamba( - const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_rope = true) - : llm_graph_context(params), model(model) { + llm_build_granite_hybrid( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) : + llm_graph_context(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13424,8 +13422,6 @@ struct llm_build_hybrid_mamba : public llm_graph_context { return cur; } - // static layer build function that enables other models to borrow this - // layer logic ggml_tensor * build_granite_attention_layer( ggml_cgraph * gf, ggml_tensor * cur, @@ -13489,7 +13485,6 @@ struct llm_build_hybrid_mamba : public llm_graph_context { return cur; } - // static ffn layer builder for reuse in hybrid architectures ggml_tensor * build_layer_ffn( ggml_tensor * cur, ggml_tensor * inpSA, @@ -14859,12 +14854,12 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_GRANITE_MOE_HYBRID: { - llm = std::make_unique(*this, params, gf, + llm = std::make_unique(*this, params, gf, /* use_rope */ false); } break; case LLM_ARCH_BAMBA: { - llm = std::make_unique(*this, params, gf, + llm = std::make_unique(*this, params, gf, /* use_rope */ true); } break; case LLM_ARCH_CHAMELEON: From dc1d109da863e5bb4ca4c58c83f185fc82b59ce6 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 26 Jun 2025 17:52:28 -0400 Subject: [PATCH 082/117] mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON --- src/llama-model.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3c3b3981e1e84..354778bc45c66 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9198,9 +9198,7 @@ struct llm_build_starcoder2 : public llm_graph_context { }; struct llm_build_mamba : public llm_graph_context { - const llama_model & model; - - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -9219,9 +9217,9 @@ struct llm_build_mamba : public llm_graph_context { cb(cur, "attn_norm", il); if (model.arch == LLM_ARCH_MAMBA2) { - cur = build_mamba2_layer(rs_inp, gf, cur, ubatch, il); + cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); } else { - cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); + cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9260,6 +9258,7 @@ struct llm_build_mamba : public llm_graph_context { llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, + const llama_model & model, const llama_ubatch & ubatch, int il) const { const auto * mctx_cur = static_cast(mctx); @@ -9398,6 +9397,7 @@ struct llm_build_mamba : public llm_graph_context { llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, + const llama_model & model, const llama_ubatch & ubatch, int il) const { const auto * mctx_cur = static_cast(mctx); From 66a7a432f9650c01ae931c330b11a949af2e54c9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Jun 2025 10:09:52 +0300 Subject: [PATCH 083/117] memory : correctly handle failure in apply() ggml-ci --- src/llama-kv-cache-unified-iswa.cpp | 2 +- src/llama-kv-cache-unified.cpp | 2 +- src/llama-memory-hybrid.cpp | 2 +- src/llama-memory-recurrent.cpp | 10 +++++++++- src/llama-memory.cpp | 17 +++++++++++++++++ src/llama-memory.h | 3 +++ 6 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index b9169299c0760..d1f839b63aaf5 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -246,7 +246,7 @@ bool llama_kv_cache_unified_iswa_context::next() { } bool llama_kv_cache_unified_iswa_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); bool res = true; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 8517b722a9f80..7f7b162ffd7ce 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1776,7 +1776,7 @@ bool llama_kv_cache_unified_context::next() { } bool llama_kv_cache_unified_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); // no ubatches -> this is a KV cache update if (ubatches.empty()) { diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 15cde98d138a8..67cbf95548235 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -218,7 +218,7 @@ bool llama_memory_hybrid_context::next() { } bool llama_memory_hybrid_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); bool res = true; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index e52156bf308b6..6ed84057ccfe2 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1071,7 +1071,15 @@ bool llama_memory_recurrent_context::next() { } bool llama_memory_recurrent_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); + + // no ubatches -> this is an update + if (ubatches.empty()) { + // recurrent cache never performs updates + assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE); + + return true; + } mem->find_slot(ubatches[i_next]); diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp index f1107672c6476..ca6844c32a767 100644 --- a/src/llama-memory.cpp +++ b/src/llama-memory.cpp @@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me // if either status has an update, then the combined status has an update return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; } + +bool llama_memory_status_is_fail(llama_memory_status status) { + switch (status) { + case LLAMA_MEMORY_STATUS_SUCCESS: + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + return false; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return true; + } + } + + return false; +} diff --git a/src/llama-memory.h b/src/llama-memory.h index 16b7e5ee2484a..e8ba336e8525d 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -31,6 +31,9 @@ enum llama_memory_status { // useful for implementing hybrid memory types (e.g. iSWA) llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); +// helper function for checking if a memory status indicates a failure +bool llama_memory_status_is_fail(llama_memory_status status); + // the interface for managing the memory context during batch processing // this interface is implemented per memory type. see: // - llama_kv_cache_unified_context From 8f9b5130105698d2f70b9a1e29ccfa51b5599941 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Jul 2025 11:55:35 -0600 Subject: [PATCH 084/117] style: Remove TODO for adding first hybrid models to the switch Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index bb267876e4437..d895f4cb0e4c6 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1991,8 +1991,6 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { } bool llm_arch_is_hybrid(const llm_arch & arch) { - // TODO: There are currently no hybrid models! Once there are, this will be - // the place to identify them switch (arch) { case LLM_ARCH_BAMBA: case LLM_ARCH_GRANITE_MOE_HYBRID: From eaec9c68207547d54f4ce3e18f5f4112c116cba4 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Jul 2025 11:58:41 -0600 Subject: [PATCH 085/117] fix: Fix bad merge in tensor_mapping.py w/ SSM_NORM Branch: GraniteFour Signed-off-by: Gabe Goodhart --- gguf-py/gguf/tensor_mapping.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 42e2def47782f..be024408108f5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -588,10 +588,6 @@ class TensorNameMap: "model.layers.{bid}.mamba.norm", # bamba ), - MODEL_TENSOR.SSM_NORM: ( - "backbone.layers.{bid}.mixer.norm", # mamba2 - ), - MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", From 1085cf9cff96448b06bba2189780e519aea78dc1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Jul 2025 12:07:19 -0600 Subject: [PATCH 086/117] fix: Fix bad merge resolution with variable renames/moves in llm_build_mamba Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 29548403333ae..9688372610b78 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9904,7 +9904,9 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - auto * inp = build_rs_inp(); + auto * rs_inp = build_rs_inp(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { // norm @@ -9914,14 +9916,12 @@ struct llm_build_mamba : public llm_graph_context { cb(cur, "attn_norm", il); if (model.arch == LLM_ARCH_MAMBA2) { - cur = build_mamba2_layer(inp, gf, cur, model, ubatch, il); + cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); } else { - cur = build_mamba_layer(inp, gf, cur, model, ubatch, il); + cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); } if (il == n_layer - 1) { - // skip computing output for unused tokens - ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } @@ -13550,7 +13550,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { }; struct llm_build_granite : public llm_graph_context { - llm_build_granite( const llama_model & model, const llm_graph_params & params, From b6d772f97f50078b1cc9bac6e123c4858ddde4d9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Jul 2025 12:44:25 -0600 Subject: [PATCH 087/117] docs: Fix comment about duplicate key check Branch: GraniteFour Signed-off-by: Gabe Goodhart --- gguf-py/gguf/gguf_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 0395366e0b8a8..48586a2a0ccbb 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -270,7 +270,7 @@ def write_ti_data_to_file(self) -> None: self.state = WriterState.TI_DATA def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None: - # Disallow duplicate keys if they differ by value or type + # Warn about duplicate keys if they differ by value or type if any( ( key in kv_data From bb590f2e166a20abe5beefe20649337914edcaa3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Jul 2025 12:49:29 -0600 Subject: [PATCH 088/117] fix: Conform to standard way of initializing inp_out_ids Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9688372610b78..2372c374fec9c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9921,7 +9921,7 @@ struct llm_build_mamba : public llm_graph_context { cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); } - if (il == n_layer - 1) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } @@ -13785,6 +13785,8 @@ struct llm_build_granite_hybrid : public llm_graph_context { auto * inp = build_inp_mem_hybrid(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + // Positional embeddings populated if rope enabled ggml_tensor * inp_pos = nullptr; if (use_rope) { @@ -13810,9 +13812,7 @@ struct llm_build_granite_hybrid : public llm_graph_context { n_embd_head, use_rope, il); } - if (il == n_layer - 1) { - // skip computing output for unused tokens - ggml_tensor * inp_out_ids = build_inp_out_ids(); + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } From 908e6559d696c9965e4f91e59a92722e442757b0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Jul 2025 23:49:12 -0400 Subject: [PATCH 089/117] convert : fix jamba conv1d shape squeezing --- convert_hf_to_gguf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0474fd9652e5a..441d6c6d6ffa5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5055,6 +5055,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = self.map_tensor_name(name) + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + if name.endswith(".A_log"): logger.debug("A_log --> A ==> " + new_name) data_torch = -torch.exp(data_torch) From 4b5f67357b7f2d73684200dee320f7e5430cffb3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 3 Jul 2025 11:57:52 -0600 Subject: [PATCH 090/117] fix: Fix input initialization in granite_hybrid after removal of hybrid inputs Branch: GraniteFourWithJamba Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f3cec23d65b86..666159784b377 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14028,7 +14028,11 @@ struct llm_build_granite_hybrid : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp = build_inp_mem_hybrid(); + const auto * mctx_hyb = static_cast(mctx); + + auto * inp_rs = build_rs_inp(mctx_hyb->get_recr()); + + auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn()); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -14049,11 +14053,11 @@ struct llm_build_granite_hybrid : public llm_graph_context { if (hparams.is_recurrent(il)) { // ssm layer // - cur = build_mamba2_layer(inp, gf, cur, model, ubatch, il); + cur = build_mamba2_layer(inp_rs, gf, cur, model, ubatch, il); } else { // attention layer // cur = build_granite_attention_layer( - gf, cur, inp_pos, inp, model, + gf, cur, inp_pos, inp_attn, model, n_embd_head, use_rope, il); } @@ -14092,12 +14096,12 @@ struct llm_build_granite_hybrid : public llm_graph_context { } ggml_tensor * build_mamba2_layer( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_model & model, - const llama_ubatch & ubatch, - int il) const { + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) const { const auto * mctx_cur = static_cast(mctx)->get_recr(); const auto kv_head = mctx_cur->get_head(); @@ -14221,14 +14225,14 @@ struct llm_build_granite_hybrid : public llm_graph_context { } ggml_tensor * build_granite_attention_layer( - ggml_cgraph * gf, - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_mem_hybrid * inp, - const llama_model & model, - const int64_t n_embd_head, - const bool use_rope, - const int il) { + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv_unified * inp, + const llama_model & model, + const int64_t n_embd_head, + const bool use_rope, + const int il) { // compute Q and K and (optionally) RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); From 0796726b9d711fe360984fd73b31ff4b41a495f6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 3 Jul 2025 12:00:47 -0600 Subject: [PATCH 091/117] fix: Use llm_graph_context_mamba in llm_build_granite_hybrid Branch: GraniteFourWithJamba Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 133 +------------------------------------------- 1 file changed, 2 insertions(+), 131 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 666159784b377..37ba99764fa4b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14011,14 +14011,14 @@ struct llm_build_granite : public llm_graph_context { } }; -struct llm_build_granite_hybrid : public llm_graph_context { +struct llm_build_granite_hybrid : public llm_graph_context_mamba { llm_build_granite_hybrid( const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf, const bool use_rope = true) : - llm_graph_context(params) { + llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14095,135 +14095,6 @@ struct llm_build_granite_hybrid : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_mamba2_layer( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_model & model, - const llama_ubatch & ubatch, - int il) const { - const auto * mctx_cur = static_cast(mctx)->get_recr(); - - const auto kv_head = mctx_cur->get_head(); - - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group; - const int64_t n_seqs = ubatch.n_seqs; - - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - - GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs); - GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - - ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); - ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); - conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - - // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - - // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads - - // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); - - // split the above in three - ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); - ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); - ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); - - // conv - { - // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} - ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); - - // copy last (d_conv - 1) columns back into the state cache - ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); - - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, last_conv, - ggml_view_1d(ctx0, conv_states_all, - (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); - - // 1D convolution - // The equivalent is to make a self-overlapping view of conv_x - // over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weight, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // For simultaneous sequences, all sequences need to have the same length. - xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); - - // bias - xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); - - xBC = ggml_silu(ctx0, xBC); - } - - // ssm - { - // These correspond to V K Q in SSM/attention duality - ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); - ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); - ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); - - // {n_head, n_seq_tokens, n_seqs} - dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - - ggml_tensor * A = model.layers[il].ssm_a; - - // use the states and the indices provided by build_rs - // (this is necessary in order to properly use the states before they are overwritten, - // while avoiding to make unnecessary copies of the states) - auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { - ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); - - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); - }; - - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); - - // store last states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), - ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - - ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); - - // TODO: skip computing output earlier for unused tokens - - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); - - // grouped RMS norm - y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); - y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); - - // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); - } - - // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - // cb(cur, "mamba_out", il); - - return cur; - } - ggml_tensor * build_granite_attention_layer( ggml_cgraph * gf, ggml_tensor * cur, From f7fa1b15000d73793615165ae5e66fe1a2a3e66d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 3 Jul 2025 12:21:43 -0600 Subject: [PATCH 092/117] refactor: Refactor mamba2/granite/jamba/granite_hybrid relationships as mixins The key is for the mixin classes (llm_graph_context_mamba, llm_graph_context_granite) to use virtual inheritance from llm_graph_context. This allows the common members to exist only once in the class hierarchy. The downside is that llm_graph_context will be re-initialized once for each parent (ie 2x for single mixin, 3x for two mixins, etc...). Branch: GraniteFourWithJamba Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 310 +++++++++++++------------------------------- 1 file changed, 88 insertions(+), 222 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 37ba99764fa4b..2651035fbb58f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10024,7 +10024,7 @@ struct llm_build_starcoder2 : public llm_graph_context { } }; -struct llm_graph_context_mamba : public llm_graph_context { +struct llm_graph_context_mamba : public virtual llm_graph_context { llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} ggml_tensor * build_mamba_layer( @@ -10298,7 +10298,8 @@ struct llm_graph_context_mamba : public llm_graph_context { }; struct llm_build_mamba : public llm_graph_context_mamba { - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) + : llm_graph_context(params), llm_graph_context_mamba(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -10355,7 +10356,8 @@ struct llm_build_mamba : public llm_graph_context_mamba { }; struct llm_build_jamba : public llm_graph_context_mamba { - llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) + : llm_graph_context(params), llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -13794,81 +13796,10 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; -struct llm_build_granite : public llm_graph_context { - llm_build_granite( - const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_rope = true) - : llm_graph_context(params) { - - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - built only if rope enabled - ggml_tensor * inp_pos = nullptr; - if (use_rope) { - inp_pos = build_inp_pos(); - } - - auto * inp_attn = build_attn_inp_kv_unified(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; +struct llm_graph_context_granite : public virtual llm_graph_context { + llm_graph_context_granite(const llm_graph_params & params) : llm_graph_context(params) {} - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - cur = build_granite_attention_layer( - gf, cur, inp_pos, inp_attn, - model, n_embd_head, use_rope, il); - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - // ffn - cur = build_layer_ffn(cur, inpSA, model, il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - // For Granite architectures - scale logits - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } - - ggml_tensor * build_granite_attention_layer( + ggml_tensor * build_attention_layer( ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -14011,14 +13942,91 @@ struct llm_build_granite : public llm_graph_context { } }; -struct llm_build_granite_hybrid : public llm_graph_context_mamba { +struct llm_build_granite : public llm_graph_context_granite { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params), llm_graph_context_granite(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + cur = build_attention_layer( + gf, cur, inp_pos, inp_attn, + model, n_embd_head, use_rope, il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // ffn + cur = build_layer_ffn(cur, inpSA, model, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_graph_context_granite { llm_build_granite_hybrid( const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf, const bool use_rope = true) : - llm_graph_context_mamba(params) { + llm_graph_context(params), + llm_graph_context_mamba(params), + llm_graph_context_granite(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14056,7 +14064,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { cur = build_mamba2_layer(inp_rs, gf, cur, model, ubatch, il); } else { // attention layer // - cur = build_granite_attention_layer( + cur = build_attention_layer( gf, cur, inp_pos, inp_attn, model, n_embd_head, use_rope, il); } @@ -14094,148 +14102,6 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { ggml_build_forward_expand(gf, cur); } - - ggml_tensor * build_granite_attention_layer( - ggml_cgraph * gf, - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv_unified * inp, - const llama_model & model, - const int64_t n_embd_head, - const bool use_rope, - const int il) { - - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - - if (use_rope) { - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - cur = build_attn(inp, gf, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - return cur; - } - - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { - - // For Granite architectures - scale residual - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network (non-MoE) - if (model.layers[il].ffn_gate_inp == nullptr) { - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - - } else { - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - ggml_tensor * moe_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(moe_out, "ffn_moe_out", il); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "ffn_shexp", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); - cb(cur, "ffn_out", il); - } else { - cur = moe_out; - } - } - - // For Granite architectures - scale residual - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - return cur; - } }; // ref: https://github.com/facebookresearch/chameleon From 20f8e43e63033a1bf5ba936b468e70aec36f6e53 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 3 Jul 2025 17:07:46 -0400 Subject: [PATCH 093/117] graph : add back hybrid memory graph input But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually). --- src/llama-graph.cpp | 59 ++++++++++++++++++++++++++++++++++++--------- src/llama-graph.h | 33 ++++++++++++++++++++++--- src/llama-model.cpp | 10 +++----- 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 750175856c1ea..7c2e880066e3d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { + inp_attn->set_input(ubatch); + inp_rs->set_input(ubatch); +} + void llm_graph_input_one::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); GGML_ASSERT(one && ggml_nelements(one) == 1); @@ -1147,10 +1152,12 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur) const { - if (!mctx_cur) { - mctx_cur = static_cast(mctx); - } +static std::unique_ptr build_attn_inp_kv_unified_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_context * mctx_cur) { auto inp = std::make_unique(hparams, cparams, mctx_cur); @@ -1158,6 +1165,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); @@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + return inp; +} + +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); } @@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur) const { - if (!mctx_cur) { - mctx_cur = static_cast(mctx); - } +// TODO: maybe separate the inner implementation into a separate function +// like with the non-sliding window equivalent +// once sliding-window hybrid caches are a thing. +llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { + const auto * mctx_cur = static_cast(mctx); auto inp = std::make_unique(hparams, cparams, mctx_cur); @@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs( return output_states; } -llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurrent_context * mctx_cur) const { - if (!mctx_cur) { - mctx_cur = static_cast(mctx); - } +static std::unique_ptr build_rs_inp_impl( + ggml_context * ctx0, + const llama_memory_recurrent_context * mctx_cur) { auto inp = std::make_unique(mctx_cur); @@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); + return inp; +} + +llm_graph_input_rs * llm_graph_context::build_rs_inp() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_rs_inp_impl(ctx0, mctx_cur); + return (llm_graph_input_rs *) res->add_input(std::move(inp)); } @@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ); } +llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_pooling( ggml_cgraph * gf, ggml_tensor * cls, diff --git a/src/llama-graph.h b/src/llama-graph.h index b3542e337d4c8..d8dc1e3307db2 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -319,6 +319,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { const llama_cross * cross = nullptr; }; +class llm_graph_input_mem_hybrid : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid( + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid() = default; + + void set_input(const llama_ubatch * ubatch) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_memory_hybrid_context * mctx; +}; + // TODO: remove this when ggml_scale_add is implemented class llm_graph_input_one : public llm_graph_input_i { public: @@ -575,7 +597,7 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur = nullptr) const; + llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const; ggml_tensor * build_attn( llm_graph_input_attn_kv_unified * inp, @@ -590,7 +612,7 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur = nullptr) const; + llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( @@ -643,7 +665,7 @@ struct llm_graph_context { int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; - llm_graph_input_rs * build_rs_inp(const llama_memory_recurrent_context * mctx_cur = nullptr) const; + llm_graph_input_rs * build_rs_inp() const; ggml_tensor * build_rs( llm_graph_input_rs * inp, @@ -663,6 +685,11 @@ struct llm_graph_context { ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const; + // + // hybrid + // + + llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; // // pooling diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e965715f2d060..3121d9c4b7b63 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10220,11 +10220,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - const auto * mctx_hyb = static_cast(mctx); - - auto * inp_rs = build_rs_inp(mctx_hyb->get_recr()); - - auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn()); + auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10235,7 +10231,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(cur, "attn_norm", il); if (n_head_kv == 0) { - cur = build_mamba_layer(inp_rs, gf, cur, model, ubatch, il); + cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); } else { // Attention @@ -10256,7 +10252,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); // No RoPE :) - cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { From 07c252f0382f2431156bb922a4f0a46cc2450e74 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 3 Jul 2025 17:10:18 -0400 Subject: [PATCH 094/117] model : add Jamba to Mamba-specific hparams printing --- src/llama-model.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3121d9c4b7b63..9f0fca5ede3dd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4842,16 +4842,6 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); - } - - if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - if (!classifier_labels.empty()) { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); @@ -4862,6 +4852,15 @@ void llama_model::print_info() const { } } + if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); From 5c32e80d34734fe3314ab268c33c7e061518f7d0 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 7 Jul 2025 12:26:28 -0600 Subject: [PATCH 095/117] fix: Fix input setup after upstream merge Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 76d6bd94dbd25..4bbb19712b2a9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14031,11 +14031,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_gra inpL = build_inp_embd(model.tok_embd); - const auto * mctx_hyb = static_cast(mctx); - - auto * inp_rs = build_rs_inp(mctx_hyb->get_recr()); - - auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn()); + auto * inp = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -14056,11 +14052,11 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_gra if (hparams.is_recurrent(il)) { // ssm layer // - cur = build_mamba2_layer(inp_rs, gf, cur, model, ubatch, il); + cur = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); } else { // attention layer // cur = build_attention_layer( - gf, cur, inp_pos, inp_attn, model, + gf, cur, inp_pos, inp->get_attn(), model, n_embd_head, use_rope, il); } From db5ff0cc6b1bb9c68e7b7098f60c14df003e98fc Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Jul 2025 15:15:49 -0400 Subject: [PATCH 096/117] jamba : remove redundant nullptr initializations --- src/llama-model.cpp | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6aaaa3a5d1d06..1aeb47a786986 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3305,12 +3305,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - - layer.wq = nullptr; - layer.wk = nullptr; - layer.wv = nullptr; - layer.wo = nullptr; - } else { // Attention layers @@ -3318,19 +3312,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ssm_in = nullptr; - layer.ssm_conv1d = nullptr; - layer.ssm_conv1d_b = nullptr; - layer.ssm_x = nullptr; - layer.ssm_dt_norm = nullptr; - layer.ssm_dt = nullptr; - layer.ssm_dt_b = nullptr; - layer.ssm_b_norm = nullptr; - layer.ssm_c_norm = nullptr; - layer.ssm_a = nullptr; - layer.ssm_d = nullptr; - layer.ssm_out = nullptr; } layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3342,19 +3323,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - layer.ffn_gate = nullptr; - layer.ffn_down = nullptr; - layer.ffn_up = nullptr; } else { // FFN (no MoE) layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - layer.ffn_gate_exps = nullptr; - layer.ffn_down_exps = nullptr; - layer.ffn_up_exps = nullptr; } } } break; From 2f39cd7bb7d468965c758f50c6dc04aed4df7601 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Jul 2025 15:37:49 -0400 Subject: [PATCH 097/117] model : remove unnecessary prefix for tensor loading constants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1aeb47a786986..501dbbb92ed11 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3220,10 +3220,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } } @@ -3266,10 +3266,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } } @@ -3316,7 +3316,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); if (layer.ffn_gate_inp) { // MoE From f7c7a926f063c912d367adc594aada9a933a6d6a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Jul 2025 15:45:20 -0400 Subject: [PATCH 098/117] model : use ggml_swiglu_split for Mamba MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 501dbbb92ed11..09e46381a9105 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10057,7 +10057,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = build_lora_mm(layer.ssm_out, y); @@ -10181,7 +10181,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); From 8a1ea3ef5c42e1fe315b354be1e2e12566629bcc Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 8 Jul 2025 14:50:33 -0600 Subject: [PATCH 099/117] feat: Add support for dense FFN in GraniteMoeHybrid This was already partially supported via reusing the granite ffn builder, and there may be models that leverage this architecture going forward. The naming is a bit odd, but in the transformers version, it reuses the same model class and simply has zero regular experts and a single shared expert (which is the same as a single dense FFN). Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 18 +++++++++++++++--- gguf-py/gguf/constants.py | 5 +++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e04ccac8168c2..f9f0df3e2a20d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6538,13 +6538,25 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), ] + has_experts = bool(self.hparams.get('num_local_experts')) + if name.endswith("shared_mlp.input_linear.weight"): ffn_dim = self.hparams["shared_intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" gate, up = data_torch.split(ffn_dim, dim=-2) + if has_experts: + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + ] + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), up), + ] + + if not has_experts and name.endswith("shared_mlp.output_linear.weight"): return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), - (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), data_torch) ] return super().modify_tensors(data_torch, name, bid) @@ -6569,7 +6581,7 @@ def modify_tensors( ) -> Iterable[tuple[str, Tensor]]: if ( name.endswith("block_sparse_moe.input_linear.weight") - or name.endswith("shared_mlp.input_linear.weight") + or "shared_mlp" in name ): return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f387c7d23ec93..8c5ca83cd562a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2190,6 +2190,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.FFN_NORM, + # MoE MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, @@ -2197,6 +2198,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, + # Dense + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, ], MODEL_ARCH.CHAMELEON: [ MODEL_TENSOR.TOKEN_EMBD, From 12c50f135e233ab21463091cc182b591080c9444 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 8 Jul 2025 14:50:59 -0600 Subject: [PATCH 100/117] feat: Add support for dense FFN tensor names on c++ side Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-arch.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 23044e9b8f00b..fd863f9211fde 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1670,6 +1670,11 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, // moe FFN { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, From f8b81c0ed301a9f7ee18b876f48b4a16a9d58e42 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Jul 2025 07:57:23 -0600 Subject: [PATCH 101/117] fix: Use child inputs for Falcon H1 after merge resolution Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3422d5c429c0..dfbc75e7eeaf7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -15239,7 +15239,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { cb(Kcur, "Kcur-post-rope", il); cb(Vcur, "Vcur-post-rope", il); - ggml_tensor * attn_out = build_attn(inp, gf, + ggml_tensor * attn_out = build_attn(inp->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); @@ -15334,7 +15334,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_r_l(il); ggml_tensor * ssm_states_all = kv_state->get_s_l(il); - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp->get_recr(), gf, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -15407,7 +15407,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = build_rs(inp->get_recr(), gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, From 0583d9524d90209fd970b11cc3c0a58da9992289 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 8 Jul 2025 14:56:16 -0600 Subject: [PATCH 102/117] fix: Remove unnecessary prefix on tensor constants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Gabe Goodhart Co-authored-by: Sigbjørn Skjæret --- src/llama-model.cpp | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dfbc75e7eeaf7..64ef17141817e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3430,10 +3430,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } } @@ -3446,10 +3446,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (hparams.is_recurrent(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, TENSOR_NOT_REQUIRED); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); @@ -3470,17 +3470,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } // feed forward (w/ optional biases) if (n_expert > 0) { // MoE FFN layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); @@ -3494,13 +3494,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } else { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } } } break; From 7f3955a06833afbbc10c01f08f43d1bfb23c58c1 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 9 Jul 2025 09:44:37 -0400 Subject: [PATCH 103/117] model : make falcon-h1 use shared mamba2 layer builder --- src/llama-model.cpp | 155 ++++---------------------------------------- 1 file changed, 13 insertions(+), 142 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 661a242f055d3..c21cc28806c75 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5021,7 +5021,10 @@ void llama_model::print_info() const { } } - if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA) { + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -10292,8 +10295,11 @@ struct llm_graph_context_mamba : public llm_graph_context { y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm - y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + if (model.layers[il].ssm_norm) { + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + } + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -14919,10 +14925,8 @@ struct llm_build_ernie4_5 : public llm_graph_context { } }; -struct llm_build_falcon_h1 : public llm_graph_context { - const llama_model & model; - - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { +struct llm_build_falcon_h1 : public llm_graph_context_mamba { + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -14978,7 +14982,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { cb(Kcur, "Kcur-post-rope", il); cb(Vcur, "Vcur-post-rope", il); - ggml_tensor * attn_out = build_attn(inp, gf, + ggml_tensor * attn_out = build_attn(inp->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); @@ -14989,7 +14993,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { // Mamba2 layer cb(cur, "ssm_in", il); - ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il); + ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); cb(ssm_out, "ssm_out", il); // // Aggregation @@ -15045,139 +15049,6 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - - ggml_tensor * build_mamba2_layer( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_ubatch & ubatch, - int il) const { - const auto * kv_state = static_cast(mctx)->get_recr(); - - const auto kv_head = kv_state->get_head(); - - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group; - const int64_t n_seqs = ubatch.n_seqs; - - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - - GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs); - GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - - ggml_tensor * conv_states_all = kv_state->get_r_l(il); - ggml_tensor * ssm_states_all = kv_state->get_s_l(il); - - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); - conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - - // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - - // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads - - // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); - cb(zxBCdt, "zxBCdt", il); - - // split the above in three - ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); - ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); - ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); - - // conv - { - // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} - ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); - - // copy last (d_conv - 1) columns back into the state cache - ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); - - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, last_conv, - ggml_view_1d(ctx0, conv_states_all, - (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); - - // 1D convolution - // The equivalent is to make a self-overlapping view of conv_x - // over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weight, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // For simultaneous sequences, all sequences need to have the same length. - xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); - - // bias - xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); - - xBC = ggml_silu(ctx0, xBC); - } - - // ssm - { - // These correspond to V K Q in SSM/attention duality - ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); - - ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); - - ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); - - // {n_head, n_seq_tokens, n_seqs} - dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - - ggml_tensor * A = model.layers[il].ssm_a; - - // use the states and the indices provided by build_rs - // (this is necessary in order to properly use the states before they are overwritten, - // while avoiding to make unnecessary copies of the states) - auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { - ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); - - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); - }; - - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); - - // store last states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), - ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - - ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); - - // TODO: skip computing output earlier for unused tokens - - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); - - // grouped RMS norm - if (model.layers[il].ssm_norm) { - y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); - } - - y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); - - // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); - } - - // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - cb(cur, "mamba_out", il); - return cur; - } }; struct llm_build_arcee : public llm_graph_context { From 452207f318acf70b5c494806620b70d23f70e2bf Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 9 Jul 2025 10:05:35 -0400 Subject: [PATCH 104/117] memory : avoid referring to KV in recurrent cache logs --- src/llama-memory-recurrent.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index a1b5b1a272cc0..2c1ae67098ca4 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent( uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; - LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n", - __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer); - head = 0; size = mem_size; used = 0; @@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent( ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); + throw std::runtime_error("failed to create ggml context for rs cache"); } ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); @@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent( ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); + throw std::runtime_error("failed to allocate buffer for rs cache"); } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } @@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } From 44cda757c525699a3598a46cd2fa6940cdc956a6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Jul 2025 08:50:32 -0600 Subject: [PATCH 105/117] fix: Revert order changes for Falcon H1 to stay consistent with upstream Branch: GraniteFour Signed-off-by: Gabe Goodhart --- gguf-py/gguf/constants.py | 2 +- src/llama-arch.cpp | 52 +++++++++++++++++++-------------------- src/llama-arch.h | 2 +- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e8f3b0d3f1688..44ca443f3fdd1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -603,7 +603,6 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.LLAMA4: "llama4", MODEL_ARCH.DECI: "deci", MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.BAICHUAN: "baichuan", MODEL_ARCH.GROK: "grok", MODEL_ARCH.GPT2: "gpt2", @@ -676,6 +675,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", MODEL_ARCH.ERNIE4_5: "ernie4_5", + MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.SMOLLM3: "smollm3", } diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 05c23dbe2c527..3740767d98e8f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -9,7 +9,6 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_FALCON_H1, "falcon-h1" }, { LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GPT2, "gpt2" }, { LLM_ARCH_GPTJ, "gptj" }, @@ -48,6 +47,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_JAMBA, "jamba" }, + { LLM_ARCH_FALCON_H1, "falcon-h1" }, { LLM_ARCH_BAMBA, "bamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, @@ -364,30 +364,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, - { - LLM_ARCH_FALCON_H1, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, { LLM_ARCH_GROK, { @@ -1083,6 +1059,30 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_FALCON_H1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_BAMBA, { @@ -2099,8 +2099,8 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { bool llm_arch_is_hybrid(const llm_arch & arch) { switch (arch) { - case LLM_ARCH_FALCON_H1: case LLM_ARCH_JAMBA: + case LLM_ARCH_FALCON_H1: case LLM_ARCH_BAMBA: case LLM_ARCH_GRANITE_MOE_HYBRID: return true; diff --git a/src/llama-arch.h b/src/llama-arch.h index 832a28721bce6..dca92be16458d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -13,7 +13,6 @@ enum llm_arch { LLM_ARCH_LLAMA4, LLM_ARCH_DECI, LLM_ARCH_FALCON, - LLM_ARCH_FALCON_H1, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, LLM_ARCH_GPT2, @@ -52,6 +51,7 @@ enum llm_arch { LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, LLM_ARCH_JAMBA, + LLM_ARCH_FALCON_H1, LLM_ARCH_BAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, From 4d6a179c68f5db8b194cee2935e8254f43f18583 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 9 Jul 2025 11:58:35 -0400 Subject: [PATCH 106/117] gguf-py : avoid adding duplicate tensor mappings for Jamba Some of the tensor names are common with Llama4 --- gguf-py/gguf/tensor_mapping.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index aebb6145e23c1..215eb297ebcc1 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -304,9 +304,8 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate", # qwen2moe olmoe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx - "model.layers.{bid}.feed_forward.router", # jamba "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe - "model.layers.{bid}.feed_forward.router", # llama4 + "model.layers.{bid}.feed_forward.router", # llama4 jamba "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe "model.layers.{bid}.mlp.gate.wg", # hunyuan ), @@ -348,10 +347,9 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU) "encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU) "model.layers.{bid}.residual_mlp.w3", # arctic - "model.layers.{bid}.feed_forward.up_proj", # jamba "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone - "model.layers.{bid}.feed_forward.up_proj", # llama4 + "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba "transformer_encoder.{bid}.ffn.w12", # neobert ), @@ -390,9 +388,8 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used) "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic - "model.layers.{bid}.feed_forward.gate_proj", # jamba "transformer.h.{bid}.mlp.c_fc_0", # exaone - "model.layers.{bid}.feed_forward.gate_proj", # llama4 + "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -436,10 +433,9 @@ class TensorNameMap: "transformer.layers.{bid}.ffn.proj_2", # openelm "model.layers.{bid}.residual_mlp.w2", # arctic "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 - "model.layers.{bid}.feed_forward.down_proj", # jamba "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone - "model.layers.{bid}.feed_forward.down_proj", # llama4 + "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba "transformer_encoder.{bid}.ffn.w3", # neobert ), From 68756970a28e6c5db23a946d867c608edda6fee3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Jul 2025 15:00:55 -0600 Subject: [PATCH 107/117] refactor: Collapse Bamba and GraniteMoeHybrid into GraniteHybrid The only key difference is the use of rope which is now set via rope_finetuned in the hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 211 +++++++++----------- gguf-py/gguf/constants.py | 345 +++++++++++++++------------------ gguf-py/gguf/tensor_mapping.py | 26 +-- src/llama-arch.cpp | 196 ++++++++----------- src/llama-arch.h | 3 +- src/llama-model.cpp | 48 +++-- 6 files changed, 374 insertions(+), 455 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2bae86197ec24..acd2e89e5ae1d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4971,112 +4971,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) -@ModelBase.register("BambaForCausalLM") -class BambaModel(Mamba2Model): - """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers""" - model_arch = gguf.MODEL_ARCH.BAMBA - undo_permute = True - - def __init__(self, *args, **kwargs): - - # Hybrid mamba models use a prefix for the mamba-specific params. - # TODO: Extend this if the prefix(es) need to be configurable - self.hparam_prefixes = ["mamba"] - - super().__init__(*args, **kwargs) - - # Use Llama conversion for attention - self._transformer_model_class: type[TextModel] = LlamaModel - - # Lists of which layers use ssm vs attention - self._attn_layers = self.get_attn_layres() - self._ssm_layers = [ - i for i in range(self.block_count) - if i not in self._attn_layers - ] - - # n_group and d_inner are used during reshape_tensors for mamaba2 - self.d_model = self.find_hparam(["hidden_size", "d_model"]) - self.n_group = self.find_hparam(["n_groups"]) - self.d_inner = self.find_hparam(["expand"]) * self.d_model - - def get_attn_layres(self) -> list[int]: - attn_layers = self.hparams.get("attn_layer_indices", []) - if not attn_layers: - attn_period = self.hparams.get("attn_layer_period") - assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" - attn_offset = self.hparams.get("attn_layer_offset") - assert attn_offset is not None, "No attention layer offset set with attn_layer_period" - attn_layers = [ - i for i in range(self.block_count) - if i % attn_period == attn_offset - ] - return attn_layers - - def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: - prefixed = [] - for pfx in self.hparam_prefixes: - prefixed.extend( - "_".join([pfx, k]) - for k in keys - ) - keys = list(keys) + prefixed - return super().find_hparam(keys, *args, **kwargs) - - def set_gguf_parameters(self): - - ## General Params ## - self.gguf_writer.add_embedding_length(self.d_model) - self.gguf_writer.add_block_count(self.block_count) - self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) - self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) - self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - - ## Mamba mixer params ## - self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) - self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) - self.gguf_writer.add_ssm_group_count(self.n_group) - self.gguf_writer.add_ssm_inner_size(self.d_inner) - # NOTE: The mamba_dt_rank is _not_ the right field for how this is used - # in llama.cpp - self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) - - ## Attention params ## - self.gguf_writer.add_attn_layer_indices(self._attn_layers) - if rope_dim := self.hparams.get("attn_rotary_emb"): - self.gguf_writer.add_rope_dimension_count(rope_dim) - self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) - - ## Feed Forward Params ## - self.gguf_writer.add_layer_norm_rms_eps( - self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 - ) - - ## Validation ## - d_head = self.find_hparam(["d_head"], optional=True) or 64 - assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" - assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" - - def modify_tensors( - self, data_torch: Tensor, name: str, bid: int | None - ) -> Iterable[tuple[str, Tensor]]: - - # Determine whether this is a mamaba layer or an attention layer - if bid in self._ssm_layers: - for mamba_new_name, data_torch in super().modify_tensors( - data_torch, name, bid - ): - yield mamba_new_name, data_torch - elif bid in self._attn_layers: - for llama_new_name, data_torch in self._transformer_model_class.modify_tensors( - self, data_torch, name, bid - ): - yield llama_new_name, data_torch - else: - yield self.map_tensor_name(name), data_torch - - @ModelBase.register("JambaForCausalLM") class JambaModel(TextModel): model_arch = gguf.MODEL_ARCH.JAMBA @@ -6579,19 +6473,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return super().modify_tensors(data_torch, name, bid) -@ModelBase.register("GraniteMoeHybridForCausalLM") -class GraniteMoeHybridModel(BambaModel, GraniteMoeModel): - """GraniteMoeHybrid is a hybrid SSM + MoE Attention model that uses Mamba2 - SSM layers""" - model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID +@ModelBase.register("GraniteMoeHybridForCausalLM", "BambaForCausalLM") +class GraniteHybridModel(Mamba2Model, GraniteMoeModel): + """GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM + layers and optionally uses MoE w/ a shared expert""" + model_arch = gguf.MODEL_ARCH.GRANITE_HYBRID + undo_permute = True + + def __init__(self, *args, **kwargs): + + # Hybrid mamba models use a prefix for the mamba-specific params. + # TODO: Extend this if the prefix(es) need to be configurable + self.hparam_prefixes = ["mamba"] + + super().__init__(*args, **kwargs) + + # Use Granite conversion for attention + self._transformer_model_class: type[TextModel] = GraniteModel + + # Lists of which layers use ssm vs attention + self._attn_layers = self.get_attn_layres() + self._ssm_layers = [ + i for i in range(self.block_count) + if i not in self._attn_layers + ] + + # n_group and d_inner are used during reshape_tensors for mamaba2 + self.d_model = self.find_hparam(["hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups"]) + self.d_inner = self.find_hparam(["expand"]) * self.d_model def get_attn_layres(self): + # Explicit list of layer type names if layer_types := self.hparams.get("layer_types"): return [ i for i, typ in enumerate(layer_types) if typ == "attention" ] - return super().get_attn_layres() + + # Layer types indicated by index or period + attn_layers = self.hparams.get("attn_layer_indices", []) + if not attn_layers: + attn_period = self.hparams.get("attn_layer_period") + assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" + attn_offset = self.hparams.get("attn_layer_offset") + assert attn_offset is not None, "No attention layer offset set with attn_layer_period" + attn_layers = [ + i for i in range(self.block_count) + if i % attn_period == attn_offset + ] + return attn_layers + + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: + prefixed = [] + for pfx in self.hparam_prefixes: + prefixed.extend( + "_".join([pfx, k]) + for k in keys + ) + keys = list(keys) + prefixed + return super().find_hparam(keys, *args, **kwargs) def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None @@ -6601,11 +6542,53 @@ def modify_tensors( or "shared_mlp" in name ): return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) - return super().modify_tensors(data_torch, name, bid) + + # Determine whether this is a mamaba layer or an attention layer + if bid in self._ssm_layers: + return super().modify_tensors(data_torch, name, bid) + elif bid in self._attn_layers: + return self._transformer_model_class.modify_tensors(self, data_torch, name, bid) + return [(self.map_tensor_name(name), data_torch)] def set_gguf_parameters(self): GraniteMoeModel.set_gguf_parameters(self) - BambaModel.set_gguf_parameters(self) + + ## General Params ## + self.gguf_writer.add_embedding_length(self.d_model) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + + ## Mamba mixer params ## + self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) + self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) + self.gguf_writer.add_ssm_group_count(self.n_group) + self.gguf_writer.add_ssm_inner_size(self.d_inner) + # NOTE: The mamba_dt_rank is _not_ the right field for how this is used + # in llama.cpp + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) + + ## Attention params ## + self.gguf_writer.add_attn_layer_indices(self._attn_layers) + if rope_dim := self.hparams.get("attn_rotary_emb"): + self.gguf_writer.add_rope_dimension_count(rope_dim) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) + + ## Feed Forward Params ## + self.gguf_writer.add_layer_norm_rms_eps( + self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + ) + + ## If Bamba, use rope, otherwise don't + use_rope = "BambaForCausalLM" in self.hparams["architectures"] + self.gguf_writer.add_rope_scaling_finetuned(use_rope) + + ## Validation ## + d_head = self.find_hparam(["d_head"], optional=True) or 64 + assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" + assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" def set_vocab(self): self.hparams["pad_vocab_size_multiple"] = 8 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 44ca443f3fdd1..a672e574ccbd7 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -286,86 +286,85 @@ class GGUFType: class MODEL_ARCH(IntEnum): - MMPROJ = auto() # dummy arch for clip.cpp - LLAMA = auto() - LLAMA4 = auto() - DECI = auto() - FALCON = auto() - FALCON_H1 = auto() - BAICHUAN = auto() - GROK = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() - STARCODER = auto() - REFACT = auto() - BERT = auto() - NOMIC_BERT = auto() - NOMIC_BERT_MOE = auto() - NEO_BERT = auto() - JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - QWEN2VL = auto() - QWEN3 = auto() - QWEN3MOE = auto() - PHI2 = auto() - PHI3 = auto() - PHIMOE = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - MINICPM3 = auto() - GEMMA = auto() - GEMMA2 = auto() - GEMMA3 = auto() - GEMMA3N = auto() - STARCODER2 = auto() - RWKV6 = auto() - RWKV6QWEN2 = auto() - RWKV7 = auto() - ARWKV7 = auto() - MAMBA = auto() - MAMBA2 = auto() - JAMBA = auto() - BAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - COHERE2 = auto() - DBRX = auto() - OLMO = auto() - OLMO2 = auto() - OLMOE = auto() - OPENELM = auto() - ARCTIC = auto() - DEEPSEEK = auto() - DEEPSEEK2 = auto() - CHATGLM = auto() - GLM4 = auto() - BITNET = auto() - T5 = auto() - T5ENCODER = auto() - JAIS = auto() - NEMOTRON = auto() - EXAONE = auto() - GRANITE = auto() - GRANITE_MOE = auto() - GRANITE_MOE_HYBRID = auto() - CHAMELEON = auto() - WAVTOKENIZER_DEC = auto() - PLM = auto() - BAILINGMOE = auto() - DOTS1 = auto() - ARCEE = auto() - ERNIE4_5 = auto() - HUNYUAN_MOE = auto() - SMOLLM3 = auto() + MMPROJ = auto() # dummy arch for clip.cpp + LLAMA = auto() + LLAMA4 = auto() + DECI = auto() + FALCON = auto() + FALCON_H1 = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() + NOMIC_BERT_MOE = auto() + NEO_BERT = auto() + JINA_BERT_V2 = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + QWEN2VL = auto() + QWEN3 = auto() + QWEN3MOE = auto() + PHI2 = auto() + PHI3 = auto() + PHIMOE = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + MINICPM3 = auto() + GEMMA = auto() + GEMMA2 = auto() + GEMMA3 = auto() + GEMMA3N = auto() + STARCODER2 = auto() + RWKV6 = auto() + RWKV6QWEN2 = auto() + RWKV7 = auto() + ARWKV7 = auto() + MAMBA = auto() + MAMBA2 = auto() + JAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + COHERE2 = auto() + DBRX = auto() + OLMO = auto() + OLMO2 = auto() + OLMOE = auto() + OPENELM = auto() + ARCTIC = auto() + DEEPSEEK = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() + GLM4 = auto() + BITNET = auto() + T5 = auto() + T5ENCODER = auto() + JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() + GRANITE = auto() + GRANITE_MOE = auto() + GRANITE_HYBRID = auto() + CHAMELEON = auto() + WAVTOKENIZER_DEC = auto() + PLM = auto() + BAILINGMOE = auto() + DOTS1 = auto() + ARCEE = auto() + ERNIE4_5 = auto() + HUNYUAN_MOE = auto() + SMOLLM3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -598,86 +597,85 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp - MODEL_ARCH.LLAMA: "llama", - MODEL_ARCH.LLAMA4: "llama4", - MODEL_ARCH.DECI: "deci", - MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.BAICHUAN: "baichuan", - MODEL_ARCH.GROK: "grok", - MODEL_ARCH.GPT2: "gpt2", - MODEL_ARCH.GPTJ: "gptj", - MODEL_ARCH.GPTNEOX: "gptneox", - MODEL_ARCH.MPT: "mpt", - MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.REFACT: "refact", - MODEL_ARCH.BERT: "bert", - MODEL_ARCH.NOMIC_BERT: "nomic-bert", - MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", - MODEL_ARCH.NEO_BERT: "neo-bert", - MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", - MODEL_ARCH.BLOOM: "bloom", - MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.QWEN: "qwen", - MODEL_ARCH.QWEN2: "qwen2", - MODEL_ARCH.QWEN2MOE: "qwen2moe", - MODEL_ARCH.QWEN2VL: "qwen2vl", - MODEL_ARCH.QWEN3: "qwen3", - MODEL_ARCH.QWEN3MOE: "qwen3moe", - MODEL_ARCH.PHI2: "phi2", - MODEL_ARCH.PHI3: "phi3", - MODEL_ARCH.PHIMOE: "phimoe", - MODEL_ARCH.PLAMO: "plamo", - MODEL_ARCH.CODESHELL: "codeshell", - MODEL_ARCH.ORION: "orion", - MODEL_ARCH.INTERNLM2: "internlm2", - MODEL_ARCH.MINICPM: "minicpm", - MODEL_ARCH.MINICPM3: "minicpm3", - MODEL_ARCH.GEMMA: "gemma", - MODEL_ARCH.GEMMA2: "gemma2", - MODEL_ARCH.GEMMA3: "gemma3", - MODEL_ARCH.GEMMA3N: "gemma3n", - MODEL_ARCH.STARCODER2: "starcoder2", - MODEL_ARCH.RWKV6: "rwkv6", - MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", - MODEL_ARCH.RWKV7: "rwkv7", - MODEL_ARCH.ARWKV7: "arwkv7", - MODEL_ARCH.MAMBA: "mamba", - MODEL_ARCH.MAMBA2: "mamba2", - MODEL_ARCH.JAMBA: "jamba", - MODEL_ARCH.BAMBA: "bamba", - MODEL_ARCH.XVERSE: "xverse", - MODEL_ARCH.COMMAND_R: "command-r", - MODEL_ARCH.COHERE2: "cohere2", - MODEL_ARCH.DBRX: "dbrx", - MODEL_ARCH.OLMO: "olmo", - MODEL_ARCH.OLMO2: "olmo2", - MODEL_ARCH.OLMOE: "olmoe", - MODEL_ARCH.OPENELM: "openelm", - MODEL_ARCH.ARCTIC: "arctic", - MODEL_ARCH.DEEPSEEK: "deepseek", - MODEL_ARCH.DEEPSEEK2: "deepseek2", - MODEL_ARCH.CHATGLM: "chatglm", - MODEL_ARCH.GLM4: "glm4", - MODEL_ARCH.BITNET: "bitnet", - MODEL_ARCH.T5: "t5", - MODEL_ARCH.T5ENCODER: "t5encoder", - MODEL_ARCH.JAIS: "jais", - MODEL_ARCH.NEMOTRON: "nemotron", - MODEL_ARCH.EXAONE: "exaone", - MODEL_ARCH.GRANITE: "granite", - MODEL_ARCH.GRANITE_MOE: "granitemoe", - MODEL_ARCH.GRANITE_MOE_HYBRID: "granitemoehybrid", - MODEL_ARCH.CHAMELEON: "chameleon", - MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", - MODEL_ARCH.PLM: "plm", - MODEL_ARCH.BAILINGMOE: "bailingmoe", - MODEL_ARCH.DOTS1: "dots1", - MODEL_ARCH.ARCEE: "arcee", - MODEL_ARCH.ERNIE4_5: "ernie4_5", - MODEL_ARCH.FALCON_H1: "falcon-h1", - MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", - MODEL_ARCH.SMOLLM3: "smollm3", + MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.LLAMA4: "llama4", + MODEL_ARCH.DECI: "deci", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", + MODEL_ARCH.NEO_BERT: "neo-bert", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", + MODEL_ARCH.MINICPM3: "minicpm3", + MODEL_ARCH.GEMMA: "gemma", + MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.GEMMA3N: "gemma3n", + MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", + MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", + MODEL_ARCH.RWKV7: "rwkv7", + MODEL_ARCH.ARWKV7: "arwkv7", + MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", + MODEL_ARCH.JAMBA: "jamba", + MODEL_ARCH.XVERSE: "xverse", + MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", + MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OLMO2: "olmo2", + MODEL_ARCH.OLMOE: "olmoe", + MODEL_ARCH.OPENELM: "openelm", + MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK: "deepseek", + MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.GRANITE: "granite", + MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.GRANITE_HYBRID: "granitehybrid", + MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.PLM: "plm", + MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.DOTS1: "dots1", + MODEL_ARCH.ARCEE: "arcee", + MODEL_ARCH.ERNIE4_5: "ernie4_5", + MODEL_ARCH.FALCON_H1: "falcon-h1", + MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", + MODEL_ARCH.SMOLLM3: "smollm3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1781,31 +1779,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], - MODEL_ARCH.BAMBA: [ - MODEL_TENSOR.TOKEN_EMBD, - MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.OUTPUT, - MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.SSM_IN, - MODEL_TENSOR.SSM_CONV1D, - MODEL_TENSOR.SSM_DT, - MODEL_TENSOR.SSM_A, - MODEL_TENSOR.SSM_D, - MODEL_TENSOR.SSM_NORM, - MODEL_TENSOR.SSM_OUT, - MODEL_TENSOR.ATTN_Q, - MODEL_TENSOR.ATTN_K, - MODEL_TENSOR.ATTN_V, - MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_NORM, - MODEL_TENSOR.FFN_GATE, - MODEL_TENSOR.FFN_DOWN, - MODEL_TENSOR.FFN_UP, - MODEL_TENSOR.FFN_GATE_INP, - MODEL_TENSOR.FFN_GATE_EXP, - MODEL_TENSOR.FFN_DOWN_EXP, - MODEL_TENSOR.FFN_UP_EXP, - ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2175,7 +2148,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, ], - MODEL_ARCH.GRANITE_MOE_HYBRID: [ + MODEL_ARCH.GRANITE_HYBRID: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 713747ef39613..7a4f275ceec28 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 granite-hybrid "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -118,7 +118,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe granite-hybrid "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -279,7 +279,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm - "model.layers.{bid}.pre_ff_layernorm", # jamba bamba + "model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid "model.layers.{bid}.pre_moe_layernorm", # mini-jamba "model.layers.{bid}.post_attention_layernorm", # llama4 "transformer_encoder.{bid}.ffn_norm", # neobert @@ -349,7 +349,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone - "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba bamba + "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w12", # neobert ), @@ -389,7 +389,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone - "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba bamba + "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -435,7 +435,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone - "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba bamba + "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w3", # neobert ), @@ -558,13 +558,13 @@ class TensorNameMap: MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", # mamba-hf "backbone.layers.{bid}.mixer.in_proj", # mamba - "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 bamba + "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", # mamba-hf "backbone.layers.{bid}.mixer.conv1d", # mamba - "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 bamba + "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid ), MODEL_TENSOR.SSM_X: ( @@ -576,7 +576,7 @@ class TensorNameMap: MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", # mamba-hf "backbone.layers.{bid}.mixer.dt_proj", # mamba - "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 bamba + "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -586,7 +586,7 @@ class TensorNameMap: MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", # mamba-hf "backbone.layers.{bid}.mixer.A_log", # mamba - "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 bamba + "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid ), MODEL_TENSOR.SSM_B_NORM: ( @@ -602,18 +602,18 @@ class TensorNameMap: MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", # mamba-hf "backbone.layers.{bid}.mixer.D", # mamba - "model.layers.{bid}.mamba.D", # jamba falcon-h1 bamba + "model.layers.{bid}.mamba.D", # jamba falcon-h1 granite-hybrid ), MODEL_TENSOR.SSM_NORM: ( - "model.layers.{bid}.mamba.norm", # falcon-h1 bamba + "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid "backbone.layers.{bid}.mixer.norm", # mamba2 ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", # mamba-hf "backbone.layers.{bid}.mixer.out_proj", # mamba - "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 bamba + "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid ), MODEL_TENSOR.TIME_MIX_W0: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 3740767d98e8f..e60c408601611 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -5,86 +5,85 @@ #include static const std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_LLAMA4, "llama4" }, - { LLM_ARCH_DECI, "deci" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GROK, "grok" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BERT, "bert" }, - { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, - { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, - { LLM_ARCH_NEO_BERT, "neo-bert" }, - { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, - { LLM_ARCH_BLOOM, "bloom" }, - { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_QWEN2, "qwen2" }, - { LLM_ARCH_QWEN2MOE, "qwen2moe" }, - { LLM_ARCH_QWEN2VL, "qwen2vl" }, - { LLM_ARCH_QWEN3, "qwen3" }, - { LLM_ARCH_QWEN3MOE, "qwen3moe" }, - { LLM_ARCH_PHI2, "phi2" }, - { LLM_ARCH_PHI3, "phi3" }, - { LLM_ARCH_PHIMOE, "phimoe" }, - { LLM_ARCH_PLAMO, "plamo" }, - { LLM_ARCH_CODESHELL, "codeshell" }, - { LLM_ARCH_ORION, "orion" }, - { LLM_ARCH_INTERNLM2, "internlm2" }, - { LLM_ARCH_MINICPM, "minicpm" }, - { LLM_ARCH_MINICPM3, "minicpm3" }, - { LLM_ARCH_GEMMA, "gemma" }, - { LLM_ARCH_GEMMA2, "gemma2" }, - { LLM_ARCH_GEMMA3, "gemma3" }, - { LLM_ARCH_GEMMA3N, "gemma3n" }, - { LLM_ARCH_STARCODER2, "starcoder2" }, - { LLM_ARCH_MAMBA, "mamba" }, - { LLM_ARCH_MAMBA2, "mamba2" }, - { LLM_ARCH_JAMBA, "jamba" }, - { LLM_ARCH_FALCON_H1, "falcon-h1" }, - { LLM_ARCH_BAMBA, "bamba" }, - { LLM_ARCH_XVERSE, "xverse" }, - { LLM_ARCH_COMMAND_R, "command-r" }, - { LLM_ARCH_COHERE2, "cohere2" }, - { LLM_ARCH_DBRX, "dbrx" }, - { LLM_ARCH_OLMO, "olmo" }, - { LLM_ARCH_OLMO2, "olmo2" }, - { LLM_ARCH_OLMOE, "olmoe" }, - { LLM_ARCH_OPENELM, "openelm" }, - { LLM_ARCH_ARCTIC, "arctic" }, - { LLM_ARCH_DEEPSEEK, "deepseek" }, - { LLM_ARCH_DEEPSEEK2, "deepseek2" }, - { LLM_ARCH_CHATGLM, "chatglm" }, - { LLM_ARCH_GLM4, "glm4" }, - { LLM_ARCH_BITNET, "bitnet" }, - { LLM_ARCH_T5, "t5" }, - { LLM_ARCH_T5ENCODER, "t5encoder" }, - { LLM_ARCH_JAIS, "jais" }, - { LLM_ARCH_NEMOTRON, "nemotron" }, - { LLM_ARCH_EXAONE, "exaone" }, - { LLM_ARCH_RWKV6, "rwkv6" }, - { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, - { LLM_ARCH_RWKV7, "rwkv7" }, - { LLM_ARCH_ARWKV7, "arwkv7" }, - { LLM_ARCH_GRANITE, "granite" }, - { LLM_ARCH_GRANITE_MOE, "granitemoe" }, - { LLM_ARCH_GRANITE_MOE_HYBRID, "granitemoehybrid" }, - { LLM_ARCH_CHAMELEON, "chameleon" }, - { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, - { LLM_ARCH_PLM, "plm" }, - { LLM_ARCH_BAILINGMOE, "bailingmoe" }, - { LLM_ARCH_DOTS1, "dots1" }, - { LLM_ARCH_ARCEE, "arcee" }, - { LLM_ARCH_ERNIE4_5, "ernie4_5" }, - { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, - { LLM_ARCH_SMOLLM3, "smollm3" }, - { LLM_ARCH_UNKNOWN, "(unknown)" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, + { LLM_ARCH_NEO_BERT, "neo-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, + { LLM_ARCH_JAMBA, "jamba" }, + { LLM_ARCH_FALCON_H1, "falcon-h1" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, + { LLM_ARCH_RWKV7, "rwkv7" }, + { LLM_ARCH_ARWKV7, "arwkv7" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_GRANITE_HYBRID, "granitehybrid" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_PLM, "plm" }, + { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_DOTS1, "dots1" }, + { LLM_ARCH_ARCEE, "arcee" }, + { LLM_ARCH_ERNIE4_5, "ernie4_5" }, + { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, + { LLM_ARCH_SMOLLM3, "smollm3" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; static const std::map LLM_KV_NAMES = { @@ -1083,38 +1082,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, - { - LLM_ARCH_BAMBA, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - // mamba(2) ssm layers - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - // attention layers - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - // non-moe FFN - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - // moe FFN - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, { LLM_ARCH_XVERSE, { @@ -1676,7 +1643,7 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_GRANITE_MOE_HYBRID, + LLM_ARCH_GRANITE_HYBRID, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, @@ -2101,8 +2068,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { switch (arch) { case LLM_ARCH_JAMBA: case LLM_ARCH_FALCON_H1: - case LLM_ARCH_BAMBA: - case LLM_ARCH_GRANITE_MOE_HYBRID: + case LLM_ARCH_GRANITE_HYBRID: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index dca92be16458d..a82af4032bff2 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -52,7 +52,6 @@ enum llm_arch { LLM_ARCH_MAMBA2, LLM_ARCH_JAMBA, LLM_ARCH_FALCON_H1, - LLM_ARCH_BAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -78,7 +77,7 @@ enum llm_arch { LLM_ARCH_ARWKV7, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, - LLM_ARCH_GRANITE_MOE_HYBRID, + LLM_ARCH_GRANITE_HYBRID, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3ffa4fc22dfac..051fee30cd985 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1504,6 +1504,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + switch (hparams.n_layer) { case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_3B; break; @@ -1514,8 +1519,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // For Granite MoE Shared ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; - case LLM_ARCH_BAMBA: - case LLM_ARCH_GRANITE_MOE_HYBRID: + case LLM_ARCH_GRANITE_HYBRID: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); @@ -1529,6 +1533,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + // Zero-out n_head_arr and n_head_kv_arr since SSM layers don't // have attention heads. We'll set them correctly below once we // know which layers are attention layers @@ -3409,8 +3418,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; - case LLM_ARCH_BAMBA: - case LLM_ARCH_GRANITE_MOE_HYBRID: + case LLM_ARCH_GRANITE_HYBRID: { // mamba2 Mixer SSM params // NOTE: int64_t for tensor dimensions @@ -5222,8 +5230,7 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_MOE_HYBRID || - arch == LLM_ARCH_BAMBA) { + arch == LLM_ARCH_GRANITE_HYBRID) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -13961,7 +13968,6 @@ struct llm_graph_context_granite : public virtual llm_graph_context { llm_graph_input_attn_kv_unified * inp_attn, const llama_model & model, const int64_t n_embd_head, - const bool use_rope, const int il) { // compute Q and K and (optionally) RoPE them @@ -13990,6 +13996,7 @@ struct llm_graph_context_granite : public virtual llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + const bool use_rope = hparams.rope_finetuned; if (use_rope) { ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); Qcur = ggml_rope_ext( @@ -14101,8 +14108,7 @@ struct llm_build_granite : public llm_graph_context_granite { llm_build_granite( const llama_model & model, const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_rope = true) + ggml_cgraph * gf) : llm_graph_context(params), llm_graph_context_granite(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14117,7 +14123,7 @@ struct llm_build_granite : public llm_graph_context_granite { // inp_pos - built only if rope enabled ggml_tensor * inp_pos = nullptr; - if (use_rope) { + if (hparams.rope_finetuned) { inp_pos = build_inp_pos(); } @@ -14137,7 +14143,7 @@ struct llm_build_granite : public llm_graph_context_granite { // self-attention cur = build_attention_layer( gf, cur, inp_pos, inp_attn, - model, n_embd_head, use_rope, il); + model, n_embd_head, il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -14177,8 +14183,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_gra llm_build_granite_hybrid( const llama_model & model, const llm_graph_params & params, - ggml_cgraph * gf, - const bool use_rope = true) : + ggml_cgraph * gf) : llm_graph_context(params), llm_graph_context_mamba(params), llm_graph_context_granite(params) { @@ -14197,7 +14202,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_gra // Positional embeddings populated if rope enabled ggml_tensor * inp_pos = nullptr; - if (use_rope) { + if (hparams.rope_finetuned) { inp_pos = build_inp_pos(); } @@ -14217,7 +14222,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_gra // attention layer // cur = build_attention_layer( gf, cur, inp_pos, inp->get_attn(), model, - n_embd_head, use_rope, il); + n_embd_head, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16104,15 +16109,9 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; - case LLM_ARCH_GRANITE_MOE_HYBRID: - { - llm = std::make_unique(*this, params, gf, - /* use_rope */ false); - } break; - case LLM_ARCH_BAMBA: + case LLM_ARCH_GRANITE_HYBRID: { - llm = std::make_unique(*this, params, gf, - /* use_rope */ true); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_CHAMELEON: { @@ -16303,8 +16302,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_GRANITE_MOE_HYBRID: - case LLM_ARCH_BAMBA: + case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: case LLM_ARCH_NEO_BERT: From 8dd7f977629dd0d08ee2ffeee243f2a59e7f909a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Jul 2025 15:07:04 -0600 Subject: [PATCH 108/117] refactor: Remove use of diamond inheritance Per PR discussion, it's simpler to keep this with basic inheritance and not introduce the complexity of virtual inheritance and multiple inheritance https://github.com/ggml-org/llama.cpp/pull/13550#issuecomment-3053787556 Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 307 +++++++++++++++++++++++++++++++------------- 1 file changed, 220 insertions(+), 87 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 051fee30cd985..92bf8c83a865f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10189,7 +10189,7 @@ struct llm_build_starcoder2 : public llm_graph_context { } }; -struct llm_graph_context_mamba : public virtual llm_graph_context { +struct llm_graph_context_mamba : public llm_graph_context { llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} ggml_tensor * build_mamba_layer( @@ -10466,8 +10466,7 @@ struct llm_graph_context_mamba : public virtual llm_graph_context { }; struct llm_build_mamba : public llm_graph_context_mamba { - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) - : llm_graph_context(params), llm_graph_context_mamba(params) { + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -10524,8 +10523,7 @@ struct llm_build_mamba : public llm_graph_context_mamba { }; struct llm_build_jamba : public llm_graph_context_mamba { - llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) - : llm_graph_context(params), llm_graph_context_mamba(params) { + llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -13958,8 +13956,78 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; -struct llm_graph_context_granite : public virtual llm_graph_context { - llm_graph_context_granite(const llm_graph_params & params) : llm_graph_context(params) {} +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (hparams.rope_finetuned) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + cur = build_attention_layer( + gf, cur, inp_pos, inp_attn, + model, n_embd_head, il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // ffn + cur = build_layer_ffn(cur, inpSA, model, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } ggml_tensor * build_attention_layer( ggml_cgraph * gf, @@ -14104,89 +14172,13 @@ struct llm_graph_context_granite : public virtual llm_graph_context { } }; -struct llm_build_granite : public llm_graph_context_granite { - llm_build_granite( - const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf) - : llm_graph_context(params), llm_graph_context_granite(params) { - - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - built only if rope enabled - ggml_tensor * inp_pos = nullptr; - if (hparams.rope_finetuned) { - inp_pos = build_inp_pos(); - } - - auto * inp_attn = build_attn_inp_kv_unified(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - cur = build_attention_layer( - gf, cur, inp_pos, inp_attn, - model, n_embd_head, il); - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - // ffn - cur = build_layer_ffn(cur, inpSA, model, il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - // For Granite architectures - scale logits - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } -}; - -struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_graph_context_granite { +struct llm_build_granite_hybrid : public llm_graph_context_mamba { llm_build_granite_hybrid( const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : - llm_graph_context(params), - llm_graph_context_mamba(params), - llm_graph_context_granite(params) { + llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14258,6 +14250,148 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_gra ggml_build_forward_expand(gf, cur); } + + ggml_tensor * build_attention_layer( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv_unified * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { + + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + + const bool use_rope = hparams.rope_finetuned; + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + return cur; + } + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { + + // For Granite architectures - scale residual + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; + } }; // ref: https://github.com/facebookresearch/chameleon @@ -15192,8 +15326,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { }; struct llm_build_falcon_h1 : public llm_graph_context_mamba { - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) - : llm_graph_context(params), llm_graph_context_mamba(params) { + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; From dcf51e085c04620b57f73c1b7cab89908d30b62d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Jul 2025 21:06:59 -0600 Subject: [PATCH 109/117] feat: Log mamba params for Granite Hybrid Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e27263faf700b..8179d488ed0c9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5174,7 +5174,8 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1) { + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_GRANITE_HYBRID) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); From 5b44f4e792c2c6953e8e34b2807cb34a788c4f9a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Jul 2025 21:07:18 -0600 Subject: [PATCH 110/117] fix: Remove unused ssm_in_b Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 1 - src/llama-model.h | 1 - 2 files changed, 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8179d488ed0c9..ac8da0684873a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3454,7 +3454,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (hparams.is_recurrent(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, TENSOR_NOT_REQUIRED); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); diff --git a/src/llama-model.h b/src/llama-model.h index 0cafb1b12dc5c..453f5af62fbc7 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -261,7 +261,6 @@ struct llama_layer { // mamba bias struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; - struct ggml_tensor * ssm_in_b = nullptr; // rwkv struct ggml_tensor * time_mix_w1 = nullptr; From 4e9fef1a96dd1971f6cc8b0ba0c3b6f5de3e6549 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Jul 2025 21:30:02 -0600 Subject: [PATCH 111/117] refactor: Remove ATTENTION_LAYER_INDICES hparam in favor of n_head_kv This matches how recurrent vs attention heads are identified for Jamba Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 7 +++++-- gguf-py/gguf/constants.py | 3 --- gguf-py/gguf/gguf_writer.py | 3 --- src/llama-arch.cpp | 1 - src/llama-arch.h | 1 - src/llama-model.cpp | 23 +++-------------------- 6 files changed, 8 insertions(+), 30 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index acd2e89e5ae1d..d4947a6a000fd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6570,11 +6570,14 @@ def set_gguf_parameters(self): self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) ## Attention params ## - self.gguf_writer.add_attn_layer_indices(self._attn_layers) + head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) + head_count_kv_vec = [ + head_count_kv if i in self._attn_layers else 0 for i in range(self.block_count) + ] if rope_dim := self.hparams.get("attn_rotary_emb"): self.gguf_writer.add_rope_dimension_count(rope_dim) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"])) + self.gguf_writer.add_head_count_kv(head_count_kv_vec) ## Feed Forward Params ## self.gguf_writer.add_layer_norm_rms_eps( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a672e574ccbd7..37df17fa1cdd0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -173,9 +173,6 @@ class SSM: GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" - class HybridAttention: - ATTN_LAYER_INDICES = "{arch}.attention.layer_indices" - class WKV: HEAD_SIZE = "{arch}.wkv.head_size" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6aca08b5fc378..6695c00c44860 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -874,9 +874,6 @@ def add_ssm_group_count(self, value: int) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) - def add_attn_layer_indices(self, values: list[int]) -> None: - self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values) - def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e60c408601611..0d673a847cb86 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -155,7 +155,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index a82af4032bff2..a9dd188a8f27d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -159,7 +159,6 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, - LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ac8da0684873a..26de3c66ceb2b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1538,26 +1538,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - // Zero-out n_head_arr and n_head_kv_arr since SSM layers don't - // have attention heads. We'll set them correctly below once we - // know which layers are attention layers - // NOTE: It's important that this happens after n_embd_head_[kv] - // are set above! - const auto n_head_attn = hparams.n_head(); - const auto n_head_kv_attn = hparams.n_head_kv(); - std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); - std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); - - // Attention params - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); - std::vector attn_layer_indices; - ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices); - for (const auto attn_idx : attn_layer_indices) { - GGML_ASSERT(attn_idx < hparams.n_layer); - hparams.recurrent_layer_arr[attn_idx] = false; - // Correctly set n_head and n_head_kv for attention layers - hparams.n_head_arr[attn_idx] = n_head_attn; - hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn; + // A layer is recurrent IFF the n_head_kv value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); From d02d3ddb542e294ba20ffedc98a797525d03a0d2 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 10 Jul 2025 07:26:19 -0600 Subject: [PATCH 112/117] fix: Remove unused template expansion for get_arr Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model-loader.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 0bd1e5d006950..bd9e6da8832b7 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -464,7 +464,6 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); - template bool llama_model_loader::get_arr(enum llm_kv kid, std::vector & result, bool required); llama_model_loader::llama_model_loader( const std::string & fname, From f43a8dc5b78707ec40112466e89f6bf4b81a96fe Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 10 Jul 2025 14:25:26 -0600 Subject: [PATCH 113/117] fix: Review cleanup in convert_hf_to_gguf The gist is to be explicit about which base class is being used with the multiple inheritance setup Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d4947a6a000fd..2df43ba112b67 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6488,22 +6488,19 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Use Granite conversion for attention - self._transformer_model_class: type[TextModel] = GraniteModel - # Lists of which layers use ssm vs attention - self._attn_layers = self.get_attn_layres() + self._attn_layers = self.get_attn_layers() self._ssm_layers = [ i for i in range(self.block_count) if i not in self._attn_layers ] - # n_group and d_inner are used during reshape_tensors for mamaba2 + # n_group and d_inner are used during reshape_tensors for mamba2 self.d_model = self.find_hparam(["hidden_size", "d_model"]) self.n_group = self.find_hparam(["n_groups"]) self.d_inner = self.find_hparam(["expand"]) * self.d_model - def get_attn_layres(self): + def get_attn_layers(self): # Explicit list of layer type names if layer_types := self.hparams.get("layer_types"): return [ @@ -6532,7 +6529,7 @@ def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: for k in keys ) keys = list(keys) + prefixed - return super().find_hparam(keys, *args, **kwargs) + return Mamba2Model.find_hparam(self, keys, *args, **kwargs) def modify_tensors( self, data_torch: Tensor, name: str, bid: int | None @@ -6543,11 +6540,11 @@ def modify_tensors( ): return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) - # Determine whether this is a mamaba layer or an attention layer + # Determine whether this is a mamba layer or an attention layer if bid in self._ssm_layers: - return super().modify_tensors(data_torch, name, bid) + return Mamba2Model.modify_tensors(self, data_torch, name, bid) elif bid in self._attn_layers: - return self._transformer_model_class.modify_tensors(self, data_torch, name, bid) + return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) return [(self.map_tensor_name(name), data_torch)] def set_gguf_parameters(self): @@ -6595,7 +6592,7 @@ def set_gguf_parameters(self): def set_vocab(self): self.hparams["pad_vocab_size_multiple"] = 8 - super().set_vocab() + Mamba2Model.set_vocab(self) @ModelBase.register("BailingMoeForCausalLM") @@ -6821,7 +6818,7 @@ def __init__(self, *args, **kwargs): # Use Llama conversion for attention self._transformer_model_class = LlamaModel - # n_group and d_inner are used during reshape_tensors for mamaba2 + # n_group and d_inner are used during reshape_tensors for mamba2 self.n_group = self.find_hparam(["n_groups"]) self.d_inner = self.find_hparam(["mamba_d_ssm"]) self.d_head = self.find_hparam(["d_head"]) From 63f1ed8399709b111c243b202069de2f65409dc6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 10 Jul 2025 15:01:41 -0600 Subject: [PATCH 114/117] fix: Undo hidden warnings about duplicate identical keys in add_key_value After further discussion, this encourages sloppy overwriting in the model converters Branch: GraniteFour Signed-off-by: Gabe Goodhart --- gguf-py/gguf/gguf_writer.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6695c00c44860..a7ecf3d31209f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -270,14 +270,7 @@ def write_ti_data_to_file(self) -> None: self.state = WriterState.TI_DATA def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None: - # Warn about duplicate keys if they differ by value or type - if any( - ( - key in kv_data - and (kv_data[key].value != val or kv_data[key].type != vtype) - ) - for kv_data in self.kv_data - ): + if any(key in kv_data for kv_data in self.kv_data): logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}') self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type) From f1485d2ab7c3dfc329272eaac53e4085f75c7f51 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 10 Jul 2025 15:44:07 -0600 Subject: [PATCH 115/117] fix: If not using ROPE, context is "infinite" Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2df43ba112b67..569d925697677 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6553,7 +6553,6 @@ def set_gguf_parameters(self): ## General Params ## self.gguf_writer.add_embedding_length(self.d_model) self.gguf_writer.add_block_count(self.block_count) - self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) @@ -6584,6 +6583,8 @@ def set_gguf_parameters(self): ## If Bamba, use rope, otherwise don't use_rope = "BambaForCausalLM" in self.hparams["architectures"] self.gguf_writer.add_rope_scaling_finetuned(use_rope) + if not use_rope: + self.gguf_writer.add_context_length(2**20) ## Validation ## d_head = self.find_hparam(["d_head"], optional=True) or 64 From 04883fc7d236e5b6153f9519fe61ed683fe340be Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 10 Jul 2025 15:47:53 -0600 Subject: [PATCH 116/117] doc: Add a comment outlining expected duplicate key warnings Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 569d925697677..5cf4658d6270f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6548,6 +6548,19 @@ def modify_tensors( return [(self.map_tensor_name(name), data_torch)] def set_gguf_parameters(self): + """This method merges params from both parents and some that are + specific to this model. The result is some duplication of how the params + get set. The following warnings are expected during conversion: + + WARNING:Duplicated key name 'granitehybrid.embedding_length' + WARNING:Duplicated key name 'granitehybrid.block_count' + WARNING:Duplicated key name 'granitehybrid.vocab_size' + WARNING:Duplicated key name 'granitehybrid.feed_forward_length' + WARNING:Duplicated key name 'granitehybrid.attention.head_count' + WARNING:Duplicated key name 'granitehybrid.attention.head_count_kv' + WARNING:Duplicated key name 'granitehybrid.attention.layer_norm_rms_epsilon' + WARNING:Duplicated key name 'granitehybrid.context_length' + """ GraniteMoeModel.set_gguf_parameters(self) ## General Params ## From e53632b664dba7b39d96bc2768e84e85cea22ac5 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 10 Jul 2025 16:08:18 -0600 Subject: [PATCH 117/117] fix: Remove unnecessary duplicate keys in converter Co-authored-by: Francis Couture-Harpin (thanks for the sharp eyes and patience!) Branch: GraniteFour Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5cf4658d6270f..52aa87d6a9952 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6552,23 +6552,11 @@ def set_gguf_parameters(self): specific to this model. The result is some duplication of how the params get set. The following warnings are expected during conversion: - WARNING:Duplicated key name 'granitehybrid.embedding_length' - WARNING:Duplicated key name 'granitehybrid.block_count' - WARNING:Duplicated key name 'granitehybrid.vocab_size' - WARNING:Duplicated key name 'granitehybrid.feed_forward_length' - WARNING:Duplicated key name 'granitehybrid.attention.head_count' WARNING:Duplicated key name 'granitehybrid.attention.head_count_kv' - WARNING:Duplicated key name 'granitehybrid.attention.layer_norm_rms_epsilon' WARNING:Duplicated key name 'granitehybrid.context_length' """ GraniteMoeModel.set_gguf_parameters(self) - ## General Params ## - self.gguf_writer.add_embedding_length(self.d_model) - self.gguf_writer.add_block_count(self.block_count) - self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) - self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - ## Mamba mixer params ## self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) @@ -6585,14 +6573,8 @@ def set_gguf_parameters(self): ] if rope_dim := self.hparams.get("attn_rotary_emb"): self.gguf_writer.add_rope_dimension_count(rope_dim) - self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(head_count_kv_vec) - ## Feed Forward Params ## - self.gguf_writer.add_layer_norm_rms_eps( - self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 - ) - ## If Bamba, use rope, otherwise don't use_rope = "BambaForCausalLM" in self.hparams["architectures"] self.gguf_writer.add_rope_scaling_finetuned(use_rope)