diff --git a/common/arg.cpp b/common/arg.cpp index 2e0f46db519..9982ec5a3b6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1280,6 +1280,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.swa_full = true; } ).set_env("LLAMA_ARG_SWA_FULL")); + add_opt(common_arg( + {"-mtp", "--multi-token-prediction"}, + "enable multi-token prediction (load MTP weights and auto-enable MTP speculative)", + [](common_params & params) { + params.mtp = true; + if (params.speculative.type == COMMON_SPECULATIVE_TYPE_NONE) { + params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; + } + } + ).set_env("LLAMA_ARG_MTP")); add_opt(common_arg( {"-ctxcp", "--ctx-checkpoints", "--swa-checkpoints"}, "N", string_format("max number of context checkpoints to create per slot (default: %d)" @@ -3497,7 +3507,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( - {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]", + {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod|mtp]", string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", common_speculative_type_to_str(params.speculative.type).c_str()), [](common_params & params, const std::string & value) { @@ -3513,6 +3523,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; } else if (value == "ngram-mod") { params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD; + } else if (value == "mtp") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP; + params.mtp = true; } else { throw std::invalid_argument("unknown speculative decoding type without draft model"); } diff --git a/common/common.cpp b/common/common.cpp index 16f78debd02..341a055ef33 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1443,12 +1443,14 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.progress_callback = params.load_progress_callback; mparams.progress_callback_user_data = params.load_progress_callback_user_data; mparams.no_alloc = params.no_alloc; + mparams.mtp = params.mtp; return mparams; } struct llama_context_params common_context_params_to_llama(const common_params & params) { auto cparams = llama_context_default_params(); + const bool mtp_needs_hidden_states = params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP; cparams.n_ctx = params.n_ctx; cparams.n_seq_max = params.n_parallel; @@ -1457,7 +1459,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? params.cpuparams.n_threads : params.cpuparams_batch.n_threads; - cparams.embeddings = params.embedding; + cparams.embeddings = params.embedding || mtp_needs_hidden_states; cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_scale = params.rope_freq_scale; diff --git a/common/common.h b/common/common.h index 020b6a721ff..db83c2438f4 100644 --- a/common/common.h +++ b/common/common.h @@ -177,6 +177,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values COMMON_SPECULATIVE_TYPE_NGRAM_MOD, COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache + COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction (uses same model, dedicated draft context) COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type }; @@ -336,7 +337,7 @@ struct common_params_speculative { llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts - llama_context_params cparams_dft; // these are the parameters for the draft llama_context + llama_context_params cparams_dft = llama_context_default_params(); // these are the parameters for the draft llama_context int32_t n_ctx = 0; // draft context size int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) @@ -355,6 +356,10 @@ struct common_params_speculative { bool has_dft() const { return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty(); } + + bool requires_dft() const { + return type == COMMON_SPECULATIVE_TYPE_DRAFT || type == COMMON_SPECULATIVE_TYPE_EAGLE3; + } }; struct common_params_vocoder { @@ -543,6 +548,7 @@ struct common_params { bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) bool no_host = false; // bypass host buffer allowing extra buffers to be used + bool mtp = false; // enable multi-token prediction bool single_turn = false; // single turn chat conversation diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e68c38e49c..bafaa87d741 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1,5 +1,6 @@ #include "speculative.h" +#include "../src/llama-kv-cache-iswa.h" #include "common.h" #include "ggml.h" #include "llama.h" @@ -25,7 +26,8 @@ const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, COMMON_SPECULATIVE_TYPE_NGRAM_MOD, - COMMON_SPECULATIVE_TYPE_NGRAM_CACHE + COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, + COMMON_SPECULATIVE_TYPE_MTP }; const std::map common_speculative_type_from_name_map = { @@ -36,7 +38,8 @@ const std::map common_speculative_typ {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, {"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, - {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} + {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}, + {"mtp", COMMON_SPECULATIVE_TYPE_MTP} }; struct common_speculative_config { @@ -130,10 +133,13 @@ struct common_speculative_state { int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. common_speculative_state(enum common_speculative_type type) : type(type) {} - virtual ~common_speculative_state() = default; virtual void begin(const llama_tokens & prompt) = 0; + virtual void begin(const llama_tokens & prompt, llama_pos retained_prefix_len) { + GGML_UNUSED(retained_prefix_len); + begin(prompt); + } virtual void draft( const common_params_speculative & params, @@ -141,7 +147,28 @@ struct common_speculative_state { llama_token id_last, llama_tokens & result) = 0; - virtual void accept(uint16_t n_accepted) = 0; + virtual void accept(uint16_t n_accepted, const std::vector & batch_idxs) = 0; + + virtual llama_pos get_committed_prefix_len() const { + return 0; + } + + virtual void invalidate_retained_state() { + } + + virtual void set_first_pass_source( + const llama_tokens & source_tokens, + const float * hidden_states, + int32_t n_tokens, + int32_t n_embd, + llama_pos start_pos) { + GGML_UNUSED(source_tokens); + GGML_UNUSED(hidden_states); + GGML_UNUSED(n_tokens); + GGML_UNUSED(n_embd); + GGML_UNUSED(start_pos); + } + }; struct common_speculative_state_draft : public common_speculative_state { @@ -403,8 +430,9 @@ struct common_speculative_state_draft : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { + void accept(uint16_t n_accepted, const std::vector & batch_idxs) override { // noop + GGML_UNUSED(batch_idxs); GGML_UNUSED(n_accepted); } @@ -456,8 +484,9 @@ struct common_speculative_state_eagle3 : public common_speculative_state { GGML_UNUSED(draft_tokens); } - void accept(uint16_t n_accepted) override { + void accept(uint16_t n_accepted, const std::vector & batch_idxs) override { // noop + GGML_UNUSED(batch_idxs); GGML_UNUSED(n_accepted); } }; @@ -485,8 +514,9 @@ struct common_speculative_state_ngram_simple : public common_speculative_state { GGML_UNUSED(params); } - void accept(uint16_t n_accepted) override { + void accept(uint16_t n_accepted, const std::vector & batch_idxs) override { // noop + GGML_UNUSED(batch_idxs); GGML_UNUSED(n_accepted); } }; @@ -513,7 +543,8 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state { GGML_UNUSED(params); } - void accept(uint16_t n_accepted) override { + void accept(uint16_t n_accepted, const std::vector & batch_idxs) override { + GGML_UNUSED(batch_idxs); common_ngram_map_accept(map, n_accepted); } }; @@ -621,7 +652,8 @@ struct common_speculative_state_ngram_mod : public common_speculative_state { n_draft_last = result.size(); } - void accept(uint16_t n_accepted) override { + void accept(uint16_t n_accepted, const std::vector & batch_idxs) override { + GGML_UNUSED(batch_idxs); if (verbose) { LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last); } @@ -731,12 +763,547 @@ struct common_speculative_state_ngram_cache : public common_speculative_state { } } - void accept(uint16_t n_accepted) override { + void accept(uint16_t n_accepted, const std::vector & batch_idxs) override { // TODO: noop + GGML_UNUSED(batch_idxs); GGML_UNUSED(n_accepted); } }; +static bool copy_hidden_state( + llama_context * ctx, + int32_t idx, + std::vector & dst, + int32_t dst_row, + int32_t n_embd) { + float * hidden = llama_get_embeddings_ith(ctx, idx); + if (hidden == nullptr) { + return false; + } + + std::memcpy(dst.data() + (int64_t) dst_row*n_embd, hidden, n_embd*sizeof(float)); + return true; +} + +struct common_speculative_state_mtp : public common_speculative_state { + enum mtp_draft_step_status : int { + MTP_DRAFT_STEP_GUARD_STOP = -2, + MTP_DRAFT_STEP_DECODE_FAIL = -1, + MTP_DRAFT_STEP_STOP = 0, + MTP_DRAFT_STEP_CONTINUE = 1, + }; + + struct mtp_round_state { + llama_token frontier_token = 0; + llama_pos prompt_size = 0; + llama_tokens draft_tokens; + std::vector recurrence_hidden; + }; + + llama_context * ctx_tgt; + llama_context * ctx_dft; + + common_sampler * smpl = nullptr; + llama_batch batch; + + const int32_t mtp_layer_idx; + const int32_t n_embd; + + // retained state + llama_pos verified_pos_end = 0; + llama_pos committed_prefix_len = 0; + llama_kv_cache_iswa * draft_kv_iswa = nullptr; + bool retained_state_valid = true; + + // staged sources + llama_tokens initial_source_tokens; + std::vector initial_source_hidden_states; + llama_pos initial_source_start_pos = 0; + + llama_tokens pending_target_tokens; + std::vector pending_hidden_states; + llama_pos pending_start_pos = 0; + + mtp_round_state round; + + size_t n_guard_stop_rounds = 0; // rounds that conservatively stopped because the guard found no safe reuse position + size_t n_guard_zero_draft_rounds = 0; // subset of guard stops that degraded to 0-draft immediately + size_t n_guard_partial_rounds = 0; // subset of guard stops that drafted some tokens but stopped before reaching n_max + size_t n_guard_lost_tokens = 0; // total draft tokens lost to guard degradation, computed as n_max - drafted + size_t n_draft_decode_fail_rounds = 0; // regular draft decode failures unrelated to the guard + + common_speculative_state_mtp( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + , batch(llama_batch_init(llama_n_batch(ctx_dft), 0, 1)) + , mtp_layer_idx(llama_model_n_layer(llama_get_model(ctx_dft)) - llama_model_n_nextn_predict_layers(llama_get_model(ctx_dft))) + , n_embd(llama_model_n_embd(llama_get_model(ctx_dft))) { + common_params_sampling sparams; + sparams.no_perf = false; + sparams.top_k = 10; + sparams.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + auto * mem = llama_get_memory(ctx_dft); + draft_kv_iswa = dynamic_cast(mem); + + smpl = common_sampler_init(llama_get_model(ctx_dft), sparams); + round.recurrence_hidden.resize(n_embd); + + llama_set_embeddings(ctx_dft, true); + } + + ~common_speculative_state_mtp() override { + llama_perf_context_print(ctx_dft); + llama_free(ctx_dft); + common_sampler_free(smpl); + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + begin(prompt, 0); + } + + llama_pos apply_retained_prefix(llama_pos retained_prefix_len) { + retained_prefix_len = std::max(retained_prefix_len, 0); + + llama_pos retained_prefix_applied = 0; + if (retained_state_valid) { + retained_prefix_applied = std::min(retained_prefix_len, committed_prefix_len); + if (retained_prefix_applied != retained_prefix_len) { + LOG_WRN("%s: clamping retained prefix from %d to committed prefix %d\n", + __func__, (int) retained_prefix_len, (int) committed_prefix_len); + } + } + + verified_pos_end = retained_prefix_applied; + committed_prefix_len = retained_prefix_applied; + + if (auto * mem = llama_get_memory(ctx_dft)) { + if (retained_prefix_applied == 0) { + llama_memory_clear(mem, true); + } else if (!llama_memory_seq_rm(mem, 0, retained_prefix_applied, -1)) { + LOG_WRN("%s: failed to truncate retained draft state at %d - clearing memory instead\n", + __func__, (int) retained_prefix_applied); + llama_memory_clear(mem, true); + retained_prefix_applied = 0; + verified_pos_end = 0; + committed_prefix_len = 0; + } + } + + return retained_prefix_applied; + } + + void begin(const llama_tokens & prompt, llama_pos retained_prefix_len) override { + GGML_UNUSED(prompt); + + // 1. Clear staged first-pass sources and round-local state. + initial_source_tokens.clear(); + initial_source_hidden_states.clear(); + pending_start_pos = 0; + initial_source_start_pos = 0; + + pending_target_tokens.clear(); + pending_hidden_states.clear(); + + round.frontier_token = 0; + round.prompt_size = 0; + round.draft_tokens.clear(); + + // 2. Re-apply the retained prefix boundary. + apply_retained_prefix(retained_prefix_len); + + // 3. Reset draft-side sampler and inputs for the new round. + retained_state_valid = true; + common_sampler_reset(smpl); + llama_set_mtp_op_type(ctx_dft, LLAMA_MTP_OP_NONE); + llama_set_draft_input_hidden_state(ctx_dft, nullptr); + llama_synchronize(ctx_dft); + } + + llama_pos get_committed_prefix_len() const override { + return retained_state_valid ? committed_prefix_len : 0; + } + + void invalidate_retained_state() override { + retained_state_valid = false; + committed_prefix_len = 0; + verified_pos_end = 0; + pending_start_pos = 0; + initial_source_start_pos = 0; + initial_source_tokens.clear(); + initial_source_hidden_states.clear(); + pending_target_tokens.clear(); + pending_hidden_states.clear(); + round.draft_tokens.clear(); + + if (auto * mem = llama_get_memory(ctx_dft)) { + llama_memory_clear(mem, true); + } + + common_sampler_reset(smpl); + llama_set_mtp_op_type(ctx_dft, LLAMA_MTP_OP_NONE); + llama_set_draft_input_hidden_state(ctx_dft, nullptr); + llama_synchronize(ctx_dft); + } + + void set_first_pass_source( + const llama_tokens & source_tokens, + const float * hidden_states, + int32_t n_tokens, + int32_t n_embd_in, + llama_pos start_pos) override { + initial_source_tokens.clear(); + initial_source_hidden_states.clear(); + initial_source_start_pos = 0; + + if (n_tokens <= 0 || source_tokens.empty()) { + return; + } + + if ((int32_t) source_tokens.size() != n_tokens) { + LOG_WRN("%s: ignoring first-pass source with mismatched token count (%zu != %d)\n", + __func__, source_tokens.size(), n_tokens); + return; + } + + if (hidden_states == nullptr) { + LOG_WRN("%s: ignoring first-pass source without hidden states\n", __func__); + return; + } + + if (n_embd_in != n_embd) { + LOG_WRN("%s: ignoring first-pass source with mismatched n_embd (%d != %d)\n", + __func__, n_embd_in, n_embd); + return; + } + + initial_source_tokens = source_tokens; + initial_source_hidden_states.assign(hidden_states, hidden_states + (int64_t) n_tokens*n_embd); + initial_source_start_pos = start_pos; + } + + bool select_first_pass_source( + const llama_tokens *& source_tokens, + const std::vector *& source_hidden_states, + llama_pos & source_start_pos) const { + if (!pending_target_tokens.empty()) { + source_tokens = &pending_target_tokens; + source_hidden_states = &pending_hidden_states; + source_start_pos = pending_start_pos; + return true; + } + + if (!initial_source_tokens.empty()) { + source_tokens = &initial_source_tokens; + source_hidden_states = &initial_source_hidden_states; + source_start_pos = initial_source_start_pos; + return true; + } + + return false; + } + + bool supports_swa_guard() const { + return draft_kv_iswa != nullptr; + } + + void set_swa_guard(llama_pos query_pos) { + GGML_ASSERT(draft_kv_iswa != nullptr); + draft_kv_iswa->set_swa_reuse_guard(query_pos); + } + + void clear_swa_guard() { + if (draft_kv_iswa != nullptr) { + draft_kv_iswa->clear_swa_reuse_guard(); + } + } + + int classify_decode_failure() { + const bool blocked = draft_kv_iswa != nullptr + ? draft_kv_iswa->consume_swa_reuse_guard_block_prepare() + : false; + + return blocked + ? MTP_DRAFT_STEP_GUARD_STOP + : MTP_DRAFT_STEP_DECODE_FAIL; + } + + void record_guard_stop(const common_params_speculative & params, size_t drafted_tokens) { + ++n_guard_stop_rounds; + n_guard_lost_tokens += params.n_max > (int32_t) drafted_tokens ? params.n_max - drafted_tokens : 0; + + const bool zero_draft = drafted_tokens == 0; + if (zero_draft) { + ++n_guard_zero_draft_rounds; + } else { + ++n_guard_partial_rounds; + } + + LOG_DBG("%s: MTP draft stopped by SWA guard: round.prompt_size = %d, drafted = %zu/%d, zero_draft = %d\n", + __func__, + (int) round.prompt_size, + drafted_tokens, + params.n_max, + zero_draft); + } + + void record_decode_fail() { + ++n_draft_decode_fail_rounds; + } + + void record_step_failure(int step_status, const common_params_speculative & params, size_t drafted_tokens) { + GGML_ASSERT(step_status < 0); + + if (step_status == MTP_DRAFT_STEP_GUARD_STOP) { + record_guard_stop(params, drafted_tokens); + } else { + record_decode_fail(); + } + } + + int finalize_step_from_logits( + int32_t output_idx, + const common_params_speculative & params, + llama_tokens & result) { + common_sampler_sample(smpl, ctx_dft, output_idx, true); + const auto * cur_p = common_sampler_get_candidates(smpl, true); + if (!cur_p || cur_p->size == 0) { + return MTP_DRAFT_STEP_DECODE_FAIL; + } + + const llama_token id = cur_p->data[0].id; + const float p = cur_p->data[0].p; + + common_sampler_accept(smpl, id, true); + result.push_back(id); + + float * next_hidden = llama_get_embeddings_ith(ctx_dft, output_idx); + if (next_hidden == nullptr) { + return MTP_DRAFT_STEP_DECODE_FAIL; + } + + std::memcpy(round.recurrence_hidden.data(), next_hidden, n_embd*sizeof(float)); + return p >= params.p_min ? MTP_DRAFT_STEP_CONTINUE : MTP_DRAFT_STEP_STOP; + } + + int run_first_pass( + const llama_tokens & source_tokens, + const std::vector & source_hidden_states, + llama_pos start_pos, + llama_token frontier_token, + const common_params_speculative & params, + llama_tokens & result) { + const int32_t n_tokens = (int32_t) source_tokens.size(); + if (n_tokens <= 0) { + return -1; + } + + if ((int64_t) source_hidden_states.size() != (int64_t) n_tokens*n_embd) { + LOG_WRN("%s: first-pass hidden-state size mismatch (%zu vs %d)\n", + __func__, source_hidden_states.size(), n_tokens*n_embd); + return -1; + } + + common_batch_clear(batch); + for (int32_t i = 0; i < n_tokens; ++i) { + const llama_token token = i + 1 < n_tokens ? source_tokens[i + 1] : frontier_token; + common_batch_add(batch, token, start_pos + i, { 0 }, true); + } + + llama_set_mtp_op_type(ctx_dft, LLAMA_MTP_OP_DRAFT_GEN); + llama_set_mtp_layer_idx(ctx_dft, mtp_layer_idx); + llama_set_draft_input_hidden_state(ctx_dft, source_hidden_states.data()); + const auto clear_draft_input = [&]() { + llama_set_mtp_op_type(ctx_dft, LLAMA_MTP_OP_NONE); + llama_set_draft_input_hidden_state(ctx_dft, nullptr); + }; + + if (llama_decode(ctx_dft, batch) != 0) { + const int status = classify_decode_failure(); + clear_draft_input(); + return status; + } + + const int step_status = finalize_step_from_logits(n_tokens - 1, params, result); + if (step_status >= 0) { + verified_pos_end = start_pos + n_tokens; + committed_prefix_len = verified_pos_end; + retained_state_valid = true; + } + + clear_draft_input(); + return step_status; + } + + int run_single_token_step( + llama_token frontier_token, + llama_pos pos, + const common_params_speculative & params, + llama_tokens & result) { + common_batch_clear(batch); + common_batch_add(batch, frontier_token, pos, { 0 }, true); + + llama_set_mtp_op_type(ctx_dft, LLAMA_MTP_OP_DRAFT_GEN); + llama_set_mtp_layer_idx(ctx_dft, mtp_layer_idx); + llama_set_draft_input_hidden_state(ctx_dft, round.recurrence_hidden.data()); + const auto clear_draft_input = [&]() { + llama_set_mtp_op_type(ctx_dft, LLAMA_MTP_OP_NONE); + llama_set_draft_input_hidden_state(ctx_dft, nullptr); + }; + + if (llama_decode(ctx_dft, batch) != 0) { + const int status = classify_decode_failure(); + clear_draft_input(); + return status; + } + + const int step_status = finalize_step_from_logits(0, params, result); + clear_draft_input(); + return step_status; + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + // 1. round setup + result.clear(); + result.reserve(params.n_max); + + round.frontier_token = id_last; + round.prompt_size = (llama_pos) prompt_tgt.size(); + round.draft_tokens.clear(); + + // 2. draft memory / guard + auto * mem = llama_get_memory(ctx_dft); + if (mem == nullptr || !supports_swa_guard()) { + LOG_WRN("%s: MTP draft requires llama_kv_cache or llama_kv_cache_iswa memory\n", __func__); + record_decode_fail(); + pending_target_tokens.clear(); + pending_hidden_states.clear(); + return; + } + set_swa_guard(round.prompt_size); + + do { + llama_memory_seq_rm(mem, 0, verified_pos_end, -1); + + // 3. source selection + first pass + const llama_tokens * source_tokens = nullptr; + const std::vector * source_hidden_states = nullptr; + llama_pos source_start_pos = 0; + + if (!select_first_pass_source(source_tokens, source_hidden_states, source_start_pos)) { + break; + } + + const bool consumed_pending_source = source_tokens == &pending_target_tokens; + common_sampler_reset(smpl); + + const int first_pass_status = run_first_pass(*source_tokens, *source_hidden_states, source_start_pos, id_last, params, result); + + // initial_source_* is a one-shot first-pass source, regardless of success or failure. + initial_source_tokens.clear(); + initial_source_hidden_states.clear(); + initial_source_start_pos = 0; + + if (first_pass_status < 0) { + record_step_failure(first_pass_status, params, result.size()); + + // A failed first pass must conservatively drop any staged pending source. + pending_target_tokens.clear(); + pending_hidden_states.clear(); + break; + } + + if (consumed_pending_source) { + pending_target_tokens.clear(); + pending_hidden_states.clear(); + } + + if (!result.empty()) { + round.draft_tokens.push_back(result.back()); + } + + if (first_pass_status != MTP_DRAFT_STEP_CONTINUE) { + break; + } + + // 4. recurrence + llama_pos next_pos = verified_pos_end; + while ((int) result.size() < params.n_max) { + const size_t result_size_prev = result.size(); + const int step_status = run_single_token_step(result.back(), next_pos, params, result); + if (result.size() == result_size_prev) { + if (step_status < 0) { + record_step_failure(step_status, params, result.size()); + } + break; + } + + round.draft_tokens.push_back(result.back()); + next_pos++; + if (step_status == MTP_DRAFT_STEP_STOP) { + break; + } + if (step_status < 0) { + record_step_failure(step_status, params, result.size()); + break; + } + } + } while (false); + + // 5. cleanup + clear_swa_guard(); + } + + void accept(uint16_t n_accepted, const std::vector & batch_idxs) override { + // 1. clamp n_accepted to the drafted prefix that actually exists. + const int32_t n_accepted_clamped = std::min(n_accepted, (int32_t) round.draft_tokens.size()); + + // 2. build the pending source tokens for the next draft round. + pending_target_tokens.clear(); + pending_hidden_states.clear(); + pending_start_pos = 0; + + pending_target_tokens.reserve(n_accepted_clamped + 1); + pending_target_tokens.push_back(round.frontier_token); + for (int32_t i = 0; i < n_accepted_clamped; ++i) { + pending_target_tokens.push_back(round.draft_tokens[i]); + } + + pending_hidden_states.resize((int64_t) pending_target_tokens.size()*n_embd); + pending_start_pos = round.prompt_size; + + // 3. copy verifier hidden states into pending_hidden_states. + for (int32_t i = 0; i < (int32_t) pending_target_tokens.size(); ++i) { + if (!batch_idxs.empty() && (size_t) i >= batch_idxs.size()) { + LOG_WRN("%s: batch_idxs missing verifier index for pending token %d\n", __func__, i); + pending_target_tokens.clear(); + pending_hidden_states.clear(); + pending_start_pos = 0; + break; + } + const int32_t hidden_idx = batch_idxs.empty() ? i : batch_idxs[i]; + if (!copy_hidden_state(ctx_tgt, hidden_idx, pending_hidden_states, i, n_embd)) { + LOG_WRN("%s: failed to copy target hidden state %d for pending first pass\n", __func__, hidden_idx); + pending_target_tokens.clear(); + pending_hidden_states.clear(); + pending_start_pos = 0; + break; + } + } + } +}; + struct common_speculative { std::vector> impls; // list of implementations to use and their states common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) @@ -786,6 +1353,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) { case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod"; case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + case COMMON_SPECULATIVE_TYPE_MTP: return "mtp"; default: return "unknown"; } } @@ -851,7 +1419,7 @@ common_speculative * common_speculative_init( // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { - bool has_draft = !params.mparams_dft.path.empty(); + bool has_draft = params.has_dft(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); @@ -859,6 +1427,7 @@ common_speculative * common_speculative_init( bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD); + bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP); // In a more complex implementation we could use the same implementation but with different parameters. // This was initially used in PR-18471 but removed to simplify the code. @@ -892,6 +1461,9 @@ common_speculative * common_speculative_init( if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } + if (has_mtp) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params)); + } if (has_draft) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } @@ -955,6 +1527,63 @@ common_speculative * common_speculative_init( impls.push_back(std::make_unique(state)); break; } + case COMMON_SPECULATIVE_TYPE_MTP: { + const llama_model * model_tgt = llama_get_model(ctx_tgt); + + if (llama_model_is_recurrent(model_tgt)) { + LOG_WRN("%s: MTP speculative does not support recurrent memory models yet\n", __func__); + break; + } + + if (llama_model_is_hybrid(model_tgt)) { + LOG_WRN("%s: MTP speculative does not support hybrid memory models yet\n", __func__); + break; + } + + if (llama_model_n_nextn_predict_layers(model_tgt) <= 0) { + LOG_WRN("%s: target model has no nextn_predict_layers\n", __func__); + break; + } + + llama_context_params cparams = config.params.cparams_dft; + + if (cparams.n_ctx == 0) { + cparams.n_ctx = llama_n_ctx_seq(ctx_tgt); + } + if (cparams.n_batch == 0) { + cparams.n_batch = llama_n_ctx_seq(ctx_tgt); + } + if (cparams.n_ubatch == 0) { + cparams.n_ubatch = llama_n_ubatch(ctx_tgt); + } + if (cparams.n_threads <= 0) { + cparams.n_threads = llama_n_threads(ctx_tgt); + } + if (cparams.n_threads_batch <= 0) { + cparams.n_threads_batch = llama_n_threads_batch(ctx_tgt); + } + + llama_set_embeddings(ctx_tgt, true); + cparams.embeddings = true; + + llama_context * ctx_mtp = llama_init_from_model(const_cast(llama_get_model(ctx_tgt)), cparams); + if (ctx_mtp == nullptr) { + LOG_WRN("%s", "failed to initialize dedicated MTP draft context\n"); + break; + } + + auto * mem_mtp = llama_get_memory(ctx_mtp); + if (dynamic_cast(mem_mtp) == nullptr) { + LOG_WRN("%s: MTP draft context requires llama_kv_cache_iswa memory for current SWA/iSWA rollback guard\n", + __func__); + llama_free(ctx_mtp); + break; + } + + llama_set_embeddings(ctx_mtp, true); + impls.push_back(std::make_unique(config.type, ctx_tgt, ctx_mtp)); + break; + } default: break; } @@ -981,17 +1610,65 @@ void common_speculative_free(common_speculative * spec) { } void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { + common_speculative_begin(spec, prompt, 0); +} + +void common_speculative_begin( + common_speculative * spec, + const llama_tokens & prompt, + llama_pos retained_prefix_len) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); - impl->begin(prompt); + impl->begin(prompt, retained_prefix_len); impl->n_call_begin++; } } +llama_pos common_speculative_get_committed_prefix_len( + const common_speculative * spec) { + if (spec == nullptr) { + return 0; + } + + llama_pos result = 0; + for (const auto & impl : spec->impls) { + result = std::max(result, impl->get_committed_prefix_len()); + } + + return result; +} + +void common_speculative_invalidate_retained_state( + common_speculative * spec) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + impl->invalidate_retained_state(); + } +} + +void common_speculative_set_first_pass_source( + common_speculative * spec, + const llama_tokens & source_tokens, + const float * hidden_states, + int32_t n_tokens, + int32_t n_embd, + llama_pos start_pos) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + impl->set_first_pass_source(source_tokens, hidden_states, n_tokens, n_embd, start_pos); + } +} + llama_tokens common_speculative_draft( common_speculative * spec, const common_params_speculative & params, @@ -1024,11 +1701,7 @@ llama_tokens common_speculative_draft( return result; } -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { - if (n_accepted == 0) { - return; - } - +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted, const std::vector & batch_idxs) { common_speculative_state * impl = spec->curr_impl; GGML_ASSERT(impl); @@ -1040,7 +1713,7 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { impl->n_acc_tokens += n_accepted; } - impl->accept(n_accepted); + impl->accept(n_accepted, batch_idxs); impl->n_call_accept++; } } @@ -1070,5 +1743,15 @@ void common_speculative_print_stats(const common_speculative * spec) { impl->n_gen_tokens, impl->n_acc_tokens, str_perf.c_str()); + + if (auto * mtp = dynamic_cast(impl.get())) { + LOG_INF("statistics %s: guard stops = %zu, zero-draft guard stops = %zu, partial guard stops = %zu, guard lost tokens = %zu, draft decode fail rounds = %zu\n", + common_speculative_type_to_str(impl->type).c_str(), + mtp->n_guard_stop_rounds, + mtp->n_guard_zero_draft_rounds, + mtp->n_guard_partial_rounds, + mtp->n_guard_lost_tokens, + mtp->n_draft_decode_fail_rounds); + } } } diff --git a/common/speculative.h b/common/speculative.h index 876cde3d180..b0021bfadb2 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -26,6 +26,28 @@ void common_speculative_free(common_speculative * spec); // optionally call once at the beginning of a new generation void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); +// starts a new generation while preserving at most the retained common prefix that is +// still valid in both the target and draft contexts +void common_speculative_begin( + common_speculative * spec, + const llama_tokens & prompt, + llama_pos retained_prefix_len); + +llama_pos common_speculative_get_committed_prefix_len( + const common_speculative * spec); + +void common_speculative_invalidate_retained_state( + common_speculative * spec); + +// supplies the token/hidden-state source used by the next MTP first pass; start_pos +// is the target-context position of source_tokens[0] +void common_speculative_set_first_pass_source( + common_speculative * spec, + const llama_tokens & source_tokens, + const float * hidden_states, + int32_t n_tokens, + int32_t n_embd, + llama_pos start_pos); // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_draft( @@ -34,8 +56,9 @@ llama_tokens common_speculative_draft( const llama_tokens & prompt, llama_token id_last); -// informs the speculative decoder that n_accepted tokens were accepted by the target model -void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +// informs the speculative decoder that n_accepted tokens were accepted by the target model; +// batch_idxs maps the frontier token and accepted draft tokens back to verifier output rows +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted, const std::vector & batch_idxs); // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 09f6e7ae29c..7b7c900d0ca 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9145,6 +9145,16 @@ def prepare_tensors(self): class Step35Model(TextModel): model_arch = gguf.MODEL_ARCH.STEP35 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + source_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) + # Step3.5 runtime currently uses only the first MTP layer. Keep the + # GGUF export aligned with that runtime until multi-layer MTP lands. + nextn = 1 if source_nextn > 0 else 0 + if nextn > 0: + self.block_count = self.hparams["num_hidden_layers"] + nextn + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + def set_gguf_parameters(self): rope_theta = self.hparams.get("rope_theta") if isinstance(rope_theta, list): @@ -9172,6 +9182,11 @@ def set_gguf_parameters(self): kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] swa_pat = [lt == "sliding_attention" for lt in layer_types] + source_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) + nextn = 1 if source_nextn > 0 else 0 + if nextn > 0: + self.gguf_writer.add_nextn_predict_layers(nextn) + self.gguf_writer.add_head_count(head_arr) self.gguf_writer.add_head_count_kv(kv_arr) @@ -9212,12 +9227,24 @@ def set_gguf_parameters(self): self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - # remove mtp layers + n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) + source_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) + export_nextn = 1 if source_nextn > 0 else 0 + last_export_layer = n_main + export_nextn + if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None: il = int(m.group(1)) - n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) - if il >= n_main: + if il >= last_export_layer: return + if il >= n_main: + name = name.replace(f"model.layers.{il}.transformer.", f"model.layers.{il}.") + + if "shared_head.output" in name: + name = name.replace("shared_head.output", "shared_head.head") + elif "embed_tokens" in name: + if il > n_main: + return + if name.endswith("norm.weight"): data_torch += 1.0 # Map router bias (expert selection bias) to a GGUF bias tensor diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a03dbce887f..df9bf0ff729 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -27,7 +27,7 @@ int main(int argc, char ** argv) { return 1; } - if (params.speculative.mparams_dft.path.empty()) { + if (params.speculative.requires_dft() && !params.speculative.has_dft()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -48,39 +48,61 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model_tgt); - // load the draft model + // load the draft model or configure the dedicated MTP draft context llama_model_ptr model_dft; // TODO: simplify this logic { const auto & params_spec = params.speculative; - auto params_dft = params; + if (params_spec.has_dft()) { + auto params_dft = params; - params_dft.n_parallel = 1; - params_dft.n_ctx = params_spec.n_ctx; - params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); - params_dft.devices = params_spec.devices; - params_dft.model = params_spec.mparams_dft; - params_dft.n_gpu_layers = params_spec.n_gpu_layers; + params_dft.n_parallel = 1; + params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx; + params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); + params_dft.cache_type_k = params_spec.cache_type_k; + params_dft.cache_type_v = params_spec.cache_type_v; + params_dft.devices = params_spec.devices; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; - if (params_spec.cpuparams.n_threads > 0) { - params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads; - params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; - } + if (params_spec.cpuparams.n_threads > 0) { + params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + } - params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + params_dft.model = params_spec.mparams_dft; - auto mparams_dft = common_model_params_to_llama(params_dft); + auto mparams_dft = common_model_params_to_llama(params_dft); - model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); - if (model_dft == nullptr) { - LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); - return 1; - } + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); + if (model_dft == nullptr) { + LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); + return 1; + } - params.speculative.model_dft = model_dft.get(); - params.speculative.cparams_dft = common_context_params_to_llama(params_dft); + params.speculative.model_dft = model_dft.get(); + params.speculative.cparams_dft = common_context_params_to_llama(params_dft); + } else if (params_spec.type == COMMON_SPECULATIVE_TYPE_MTP) { + auto params_dft = params; + + params_dft.n_parallel = 1; + params_dft.n_ctx = params_spec.n_ctx == 0 ? (int32_t) llama_n_ctx_seq(ctx_tgt) : params_spec.n_ctx; + params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); + params_dft.cache_type_k = params_spec.cache_type_k; + params_dft.cache_type_v = params_spec.cache_type_v; + params_dft.devices = params_spec.devices; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; + + if (params_spec.cpuparams.n_threads > 0) { + params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + } + + params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + params.speculative.cparams_dft = common_context_params_to_llama(params_dft); + } } // Tokenize the prompt @@ -122,7 +144,16 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + if (params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + llama_batch prompt_batch = llama_batch_init(inp.size() - 1, 0, 1); + for (size_t i = 0; i + 1 < inp.size(); ++i) { + common_batch_add(prompt_batch, inp[i], i, { 0 }, true); + } + llama_decode(ctx_tgt, prompt_batch); + llama_batch_free(prompt_batch); + } else { + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + } // note: keep the last token separate! llama_token id_last = inp.back(); @@ -140,6 +171,23 @@ int main(int argc, char ** argv) { common_speculative_begin(spec, prompt_tgt); + if (params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP && !prompt_tgt.empty()) { + std::vector prompt_hidden((int64_t) prompt_tgt.size()*llama_model_n_embd(model_tgt)); + for (size_t i = 0; i < prompt_tgt.size(); ++i) { + float * hidden = llama_get_embeddings_ith(ctx_tgt, i); + GGML_ASSERT(hidden != nullptr); + std::memcpy(prompt_hidden.data() + i*llama_model_n_embd(model_tgt), hidden, + llama_model_n_embd(model_tgt)*sizeof(float)); + } + common_speculative_set_first_pass_source( + spec, + prompt_tgt, + prompt_hidden.data(), + prompt_tgt.size(), + llama_model_n_embd(model_tgt), + 0); + } + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); const auto t_enc_end = ggml_time_us(); @@ -155,6 +203,7 @@ int main(int argc, char ** argv) { // from a cache or lookup tables. // llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + const bool had_draft = !draft.empty(); //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); @@ -196,6 +245,10 @@ int main(int argc, char ** argv) { n_accept += ids.size() - 1; n_predict += ids.size(); + if (had_draft) { + common_speculative_accept(spec, ids.size() - 1, {}); + } + // process the accepted tokens and update contexts // // this is the standard token post-processing that we normally do diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0cdd1c471da..011752d8a62 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3790,6 +3790,12 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.LLAMA_EMBED: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/include/llama.h b/include/llama.h index bf2bff8dac6..fe510506bfd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -197,6 +197,11 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported }; + enum llama_mtp_op_type { + LLAMA_MTP_OP_NONE = 0, + LLAMA_MTP_OP_DRAFT_GEN = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -318,6 +323,7 @@ extern "C" { bool use_extra_bufts; // use extra buffer types (used for weight repacking) bool no_host; // bypass host buffer allowing extra buffers to be used bool no_alloc; // only load metadata and simulate memory allocations + bool mtp; // enable multi-token prediction (load MTP weights) }; struct llama_sampler_seq_config { @@ -574,6 +580,9 @@ extern "C" { LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + // number of MTP (multi-token prediction) layers; 0 if model has none + LLAMA_API int32_t llama_model_n_nextn_predict_layers(const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); @@ -983,6 +992,11 @@ extern "C" { // If true, all model tensors are activated during llama_decode() to load and cache their weights. LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // MTP (Multi-Token Prediction) API + LLAMA_API void llama_set_mtp_op_type(struct llama_context * ctx, enum llama_mtp_op_type op); + LLAMA_API void llama_set_mtp_layer_idx(struct llama_context * ctx, int32_t layer_idx); + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e210dcdae21..1d4a0a3b065 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2485,6 +2485,12 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: @@ -2768,14 +2774,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors — per-layer repeating tensors (blk.%d.nextn.*) + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a808e3e4542..f613b9228b4 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -9,6 +9,7 @@ #include "llama-model.h" #include "llama-ext.h" +#include #include #include #include @@ -1061,6 +1062,18 @@ void llama_context::set_warmup(bool value) { //sched_need_reserve = true; } +void llama_context::set_mtp_op_type(llama_mtp_op_type op) { + mtp_op_type = op; +} + +void llama_context::set_mtp_layer_idx(int32_t layer_idx) { + mtp_layer_idx = layer_idx; +} + +void llama_context::set_draft_input_hidden_state(const float * hidden_state) { + draft_input_hidden_state = hidden_state; +} + bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { if (!sampler && sampling.samplers.count(seq_id) == 0) { return true; @@ -1616,6 +1629,7 @@ int llama_context::decode(const llama_batch & batch_inp) { while (true) { mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + if (!mctx) { return -2; } @@ -2146,22 +2160,33 @@ llm_graph_params llama_context::graph_params( const llama_ubatch & ubatch, const llama_memory_context_i * mctx, llm_graph_type gtype) const { + llm_mtp_op_type mtp_op = LLM_MTP_OP_NONE; + switch (mtp_op_type) { + case LLAMA_MTP_OP_NONE: mtp_op = LLM_MTP_OP_NONE; break; + case LLAMA_MTP_OP_DRAFT_GEN: mtp_op = LLM_MTP_OP_DRAFT_GEN; break; + } + + const float * mtp_hidden = mtp_op_type == LLAMA_MTP_OP_DRAFT_GEN ? draft_input_hidden_state : nullptr; + return { - /*.arch =*/ model.arch, - /*.hparams =*/ model.hparams, - /*.cparams =*/ cparams, - /*.ubatch =*/ ubatch, - /*.gtype =*/ gtype, - /*.sched =*/ sched.get(), - /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ cvec.get(), - /*.loras =*/ loras.get(), - /*.mctx =*/ mctx, - /*.cross =*/ &cross, - /*.samplers =*/ sampling.samplers, - /*.n_outputs =*/ n_outputs, - /*.cb =*/ graph_get_cb(), - /*.res =*/ res, + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.gtype =*/ gtype, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), + /*.mctx =*/ mctx, + /*.cross =*/ &cross, + /*.mtp_op_type =*/ mtp_op, + /*.mtp_layer_idx =*/ mtp_layer_idx, + /*.mtp_hidden_state =*/ mtp_hidden, + /*.samplers =*/ sampling.samplers, + /*.n_outputs =*/ n_outputs, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, }; } @@ -3064,6 +3089,18 @@ void llama_set_warmup(llama_context * ctx, bool warmup) { ctx->set_warmup(warmup); } +void llama_set_mtp_op_type(llama_context * ctx, llama_mtp_op_type op) { + ctx->set_mtp_op_type(op); +} + +void llama_set_mtp_layer_idx(llama_context * ctx, int32_t layer_idx) { + ctx->set_mtp_layer_idx(layer_idx); +} + +void llama_set_draft_input_hidden_state(llama_context * ctx, const float * hidden_state) { + ctx->set_draft_input_hidden_state(hidden_state); +} + void llama_synchronize(llama_context * ctx) { ctx->synchronize(); } diff --git a/src/llama-context.h b/src/llama-context.h index e0d0085c1c3..2f7b0ac80d2 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -105,6 +105,10 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); + void set_mtp_op_type(llama_mtp_op_type op); + void set_mtp_layer_idx(int32_t layer_idx); + void set_draft_input_hidden_state(const float * hidden_state); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); @@ -335,6 +339,11 @@ struct llama_context { llm_graph_result_ptr gf_res_prev; llm_graph_result_ptr gf_res_reserve; + // MTP state + llama_mtp_op_type mtp_op_type = LLAMA_MTP_OP_NONE; + int32_t mtp_layer_idx = -1; + const float * draft_input_hidden_state = nullptr; + // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0e7d96ca10d..eaf8bb559ce 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -332,6 +332,26 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_mtp_hidden_state::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (hidden_state && data) { + ggml_backend_tensor_set(hidden_state, data, 0, ggml_nbytes(hidden_state)); + } +} + +bool llm_graph_input_mtp_hidden_state::can_reuse(const llm_graph_params & params) { + data = params.mtp_hidden_state; + + bool res = true; + + res &= hidden_state != nullptr; + res &= data != nullptr; + res &= hidden_state->ne[1] == params.ubatch.n_tokens; + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -494,15 +514,20 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { } void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + // In single-layer ISWA graphs, one branch can be pruned and never get a backend buffer. + if (self_k_idxs && self_k_idxs->buffer) { + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -929,6 +954,9 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + mtp_op_type (params.mtp_op_type), + mtp_layer_idx (params.mtp_layer_idx), + mtp_hidden_state (params.mtp_hidden_state), samplers (params.samplers), cb_func (params.cb), res (params.res), @@ -1761,6 +1789,19 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_mtp_hidden_state() const { + auto inp = std::make_unique(mtp_hidden_state); + + auto & cur = inp->hidden_state; + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); diff --git a/src/llama-graph.h b/src/llama-graph.h index bb0ad75198f..a32b461a8f7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -33,6 +33,11 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DECODER, }; +enum llm_mtp_op_type { + LLM_MTP_OP_NONE = 0, + LLM_MTP_OP_DRAFT_GEN = 1, +}; + enum llm_ffn_op_type { LLM_FFN_SILU, LLM_FFN_GELU, @@ -258,6 +263,20 @@ class llm_graph_input_cross_embd : public llm_graph_input_i { const llama_cross * cross; }; +class llm_graph_input_mtp_hidden_state : public llm_graph_input_i { +public: + llm_graph_input_mtp_hidden_state(const float * data) : data(data) {} + virtual ~llm_graph_input_mtp_hidden_state() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * hidden_state = nullptr; // F32 [n_embd, n_tokens] + + const float * data = nullptr; +}; + class llm_graph_input_attn_no_cache : public llm_graph_input_i { public: llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) : @@ -542,6 +561,10 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; + llm_mtp_op_type mtp_op_type = LLM_MTP_OP_NONE; + int mtp_layer_idx = -1; + const float * mtp_hidden_state = nullptr; + std::map samplers; static bool samplers_equal( @@ -627,7 +650,9 @@ struct llm_graph_params { gtype == other.gtype && cvec == other.cvec && loras == other.loras && - cross == other.cross; + cross == other.cross && + mtp_op_type == other.mtp_op_type && + mtp_layer_idx == other.mtp_layer_idx; } }; @@ -750,6 +775,10 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; + const llm_mtp_op_type mtp_op_type; + const int mtp_layer_idx; + const float * mtp_hidden_state; + std::map samplers; const llm_graph_cb & cb_func; @@ -864,6 +893,7 @@ struct llm_graph_context { ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; + ggml_tensor * build_inp_mtp_hidden_state() const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; ggml_tensor * build_inp_pos_bucket_dec() const; diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 26e2cb4270b..a2fd68b308b 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -247,6 +247,20 @@ llama_kv_cache * llama_kv_cache_iswa::get_swa() const { return kv_swa.get(); } +void llama_kv_cache_iswa::set_swa_reuse_guard(llama_pos query_pos) { + kv_base->clear_swa_reuse_guard(); + kv_swa->set_swa_reuse_guard(query_pos); +} + +void llama_kv_cache_iswa::clear_swa_reuse_guard() { + kv_base->clear_swa_reuse_guard(); + kv_swa->clear_swa_reuse_guard(); +} + +bool llama_kv_cache_iswa::consume_swa_reuse_guard_block_prepare() { + return kv_swa->consume_swa_reuse_guard_block_prepare(); +} + // // llama_kv_cache_iswa_context // diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h index 70ab22f0d60..158c03200e8 100644 --- a/src/llama-kv-cache-iswa.h +++ b/src/llama-kv-cache-iswa.h @@ -70,6 +70,11 @@ class llama_kv_cache_iswa : public llama_memory_i { llama_kv_cache * get_base() const; llama_kv_cache * get_swa () const; + void set_swa_reuse_guard(llama_pos query_pos); + void clear_swa_reuse_guard(); + + bool consume_swa_reuse_guard_block_prepare(); + private: const llama_hparams & hparams; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3e0fd3107f3..1a2e753b22e 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -76,6 +76,21 @@ static ggml_tensor * ggml_mul_mat_aux( // llama_kv_cache // +static bool llama_kv_cache_is_swa_cell_reusable_for_query( + uint32_t n_swa, + llama_swa_type swa_type, + llama_pos pos_cell, + llama_pos query_pos) { + if (n_swa == 0 || swa_type == LLAMA_SWA_TYPE_NONE) { + return false; + } + + GGML_ASSERT(pos_cell >= 0); + GGML_ASSERT(query_pos >= 0); + + return llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, query_pos); +} + llama_kv_cache::llama_kv_cache( const llama_model & model, ggml_type type_k, @@ -674,6 +689,8 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector states; + swa_reuse_guard_blocked_prepare = false; + bool success = true; for (const auto & ubatch : ubatches) { @@ -919,6 +936,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } uint32_t n_tested = 0; + bool guard_blocked = false; // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head // for non-continuous slots, we test the tokens one by one @@ -959,9 +977,28 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, if (!can_use) { const llama_seq_id seq_id_cell = cells.seq_get(idx); + const llama_pos query_pos_default = cells.seq_pos_max(seq_id_cell) + 1; // SWA mask - if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (swa_reuse_guard.active) { + const bool can_use_default = llama_kv_cache_is_swa_cell_reusable_for_query( + n_swa, + swa_type, + pos_cell, + query_pos_default); + const bool can_use_guarded = llama_kv_cache_is_swa_cell_reusable_for_query( + n_swa, + swa_type, + pos_cell, + swa_reuse_guard.query_pos); + + can_use = can_use_guarded; + guard_blocked = guard_blocked || (can_use_default && !can_use_guarded); + } else if (llama_kv_cache_is_swa_cell_reusable_for_query( + n_swa, + swa_type, + pos_cell, + query_pos_default)) { can_use = true; } } @@ -985,6 +1022,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } if (n_tested >= cells.size()) { + swa_reuse_guard_blocked_prepare = swa_reuse_guard_blocked_prepare || guard_blocked; //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); return { }; } @@ -1113,6 +1151,25 @@ ggml_type llama_kv_cache::type_v() const { return layers[0].v->type; } +void llama_kv_cache::set_swa_reuse_guard(llama_pos query_pos) { + GGML_ASSERT(query_pos >= 0); + + swa_reuse_guard.active = true; + swa_reuse_guard.query_pos = query_pos; + swa_reuse_guard_blocked_prepare = false; +} + +void llama_kv_cache::clear_swa_reuse_guard() { + swa_reuse_guard.active = false; + swa_reuse_guard.query_pos = 0; +} + +bool llama_kv_cache::consume_swa_reuse_guard_block_prepare() { + const bool blocked = swa_reuse_guard_blocked_prepare; + swa_reuse_guard_blocked_prepare = false; + return blocked; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d4569a06f71..986df3007ce 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -13,6 +13,11 @@ struct llama_hparams; struct llama_model; struct llama_context; +struct llama_kv_swa_guard_state { + bool active = false; + llama_pos query_pos = 0; +}; + // // llama_kv_cache // @@ -154,6 +159,10 @@ class llama_kv_cache : public llama_memory_i { ggml_type type_k() const; ggml_type type_v() const; + void set_swa_reuse_guard(llama_pos query_pos); + void clear_swa_reuse_guard(); + + bool consume_swa_reuse_guard_block_prepare(); // // graph_build API @@ -263,6 +272,9 @@ class llama_kv_cache : public llama_memory_i { // pending stream copies that will be applied during the next update stream_copy_info sc_info; + llama_kv_swa_guard_state swa_reuse_guard; + mutable bool swa_reuse_guard_blocked_prepare = false; + std::vector layers; // model layer id -> KV cache layer id @@ -379,6 +391,7 @@ class llama_kv_cache_context : public llama_memory_context_i { void set_input_k_rot(ggml_tensor * dst) const; void set_input_v_rot(ggml_tensor * dst) const; + const slot_info_vec_t & get_sinfos() const { return sinfos; } private: llama_memory_status status; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 79d08ff41e3..82ef1506cde 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2584,7 +2584,19 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); - switch (hparams.n_layer) { + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + if (params.mtp && hparams.nextn_predict_layers > 1) { + LLAMA_LOG_WARN("%s: Step35 MTP uses only the first nextn layer\n", __func__); + } + + if (params.mtp && hparams.nextn_predict_layers > 0) { + hparams.n_layer_kv_from_start = hparams.n_layer; + } else { + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + } + + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 45: type = LLM_TYPE_196B_A11B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -7530,14 +7542,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_STEP35: { + const uint32_t n_nextn = hparams.nextn_predict_layers; + const bool mtp_en = params.mtp && n_nextn > 0; + const uint32_t n_main = n_layer - n_nextn; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor - // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. uint32_t n_rot_max = 0; for (int i = 0; i < n_layer; ++i) { n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); @@ -7547,51 +7561,66 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { + const bool is_mtp_layer = (n_nextn > 0 && static_cast(i) >= n_main); + int flags = 0; + if (is_mtp_layer && !mtp_en) { + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags | TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags | TENSOR_NOT_REQUIRED); - // optional rope factors (llama3) / longrope tensors - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + if (!(flags & TENSOR_SKIP)) { + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, flags | TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, flags | TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, flags | TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags); - // head-wise attention gate (Step35 self_attn.g_proj) - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, flags | TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - // dense MLP (leading dense blocks) - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags | TENSOR_NOT_REQUIRED); - // MoE routed experts + selection bias (router_bias) const int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - // shared expert MLP - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, flags | TENSOR_NOT_REQUIRED); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, flags | TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, flags | TENSOR_NOT_REQUIRED); + + if (is_mtp_layer) { + const int first_mtp = static_cast(n_main); + const int dup_flag = (i > first_mtp) ? TENSOR_DUPLICATED : 0; + + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED | dup_flag); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_MAINCODER: @@ -8954,6 +8983,7 @@ llama_model_params llama_model_default_params() { /*.use_extra_bufts =*/ true, /*.no_host =*/ false, /*.no_alloc =*/ false, + /*.mtp =*/ false, }; return result; @@ -9003,6 +9033,10 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } +int32_t llama_model_n_nextn_predict_layers(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index f91d795b3e9..6905209124f 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -99,11 +99,17 @@ static std::string remap_imatrix(const std::string & orig_name, const std::mapquantize_output_tensor || name != "output.weight"; + quantize &= params->quantize_output_tensor || !tensor_name_match_output_weight(name.c_str()); // do not quantize expert gating tensors // NOTE: can't use LLM_TN here because the layer number is not known @@ -1195,6 +1201,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } } + if (tensor_name_should_ignore_imatrix(tensor->name)) { + imatrix = nullptr; + } if (!imatrix && tm.requires_imatrix) { LLAMA_LOG_ERROR("\n\n============================================================\n"); LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); diff --git a/src/models/models.h b/src/models/models.h index 8e6b9c238fd..49f365d50d6 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -696,6 +696,16 @@ struct llm_build_starcoder : public llm_graph_context { struct llm_build_step35_iswa : public llm_graph_context { llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params); + +private: + const llama_model & model; + + ggml_tensor * build_layer( + ggml_tensor * inpL, + int il, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv_iswa * inp_attn, + ggml_tensor * inp_out_ids); }; struct llm_build_t5_dec : public llm_graph_context { diff --git a/src/models/step35-iswa.cpp b/src/models/step35-iswa.cpp index c80cb26c5af..31f0341a5e7 100644 --- a/src/models/step35-iswa.cpp +++ b/src/models/step35-iswa.cpp @@ -1,156 +1,220 @@ #include "models.h" -llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +ggml_tensor * llm_build_step35_iswa::build_layer( + ggml_tensor * inpL, + int il, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv_iswa * inp_attn, + ggml_tensor * inp_out_ids) { + ggml_tensor * cur; - ggml_tensor * inpL; + ggml_tensor * inpSA = inpL; - inpL = build_inp_embd(model.tok_embd); - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_iswa(); - ggml_tensor * inp_out_ids = build_inp_out_ids(); + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + cb(cur, "attn_norm_in", il); + + // self-attention + { + cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - const uint32_t n_head_l = hparams.n_head(il); - const uint32_t n_head_kv_l = hparams.n_head_kv(il); - - const float freq_base_l = model.get_rope_freq_base(cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - cur = inpL; - - // dump pre-attn RMSNorm input to pinpoint layer boundary issues - cb(cur, "attn_norm_in", il); - - // self-attention - { - cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); - - // Q/K per-head RMSNorm (Step35 q_norm / k_norm) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - } - - // RoPE (partial rotary factors per layer) - const bool is_swa = hparams.is_swa(il); - ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); - const int64_t n_rot_l = hparams.n_rot(il); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Qcur, "Qcur_pos", il); - cb(Kcur, "Kcur_pos", il); - - const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); - ggml_tensor * attn_out = build_attn(inp_attn, - nullptr, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(attn_out, "attn_out", il); - // head-wise attention gate: sigmoid(g_proj(x)) in torch - if (model.layers[il].wqkv_gate) { - ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens] - cb(gate, "attn_gate", il); - - gate = ggml_sigmoid(ctx0, gate); - cb(gate, "attn_gate_sigmoid", il); - - // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens] - ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); - ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); - cb(gate_3d, "attn_gate_3d", il); - - attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); - cb(attn_3d, "attn_gated_3d", il); - - attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); - cb(attn_out, "attn_gated", il); - } - - // output projection - cur = build_lora_mm(model.layers[il].wo, attn_out); - cb(cur, "attn_proj", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "attn_out", il); + + if (model.layers[il].wqkv_gate) { + ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); + cb(gate, "attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "attn_gate_sigmoid", il); + + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "attn_gated", il); } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // feed-forward - if (model.layers[il].ffn_gate_inp == nullptr) { - // dense MLP - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr, - nullptr, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + cur = build_lora_mm(model.layers[il].wo, attn_out); + cb(cur, "attn_proj", il); + } + + if (inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + const bool norm_w = hparams.expert_weights_norm; + const float w_scale = hparams.expert_weights_scale; + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, + norm_w, w_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * sh_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, nullptr, nullptr, + model.layers[il].ffn_gate_shexp, nullptr, nullptr, + model.layers[il].ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; +} + +llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + const int n_layer_main = n_layer - (int) hparams.nextn_predict_layers; + + if (mtp_op_type != LLM_MTP_OP_NONE) { + // === MTP graph branch === + GGML_ASSERT(mtp_hidden_state != nullptr); + const int il = mtp_layer_idx; + GGML_ASSERT(il >= n_layer_main && il < n_layer); + const auto & layer = model.layers[il]; + + // 1. token embedding + ggml_tensor * inp_tokens = build_inp_embd( + layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + + // 2. MTP hidden state input + ggml_tensor * prev_hidden = build_inp_mtp_hidden_state(); + + // 3. enorm / hnorm (Gemma-style, weights already contain +1 offset) + ggml_tensor * inp_normed = build_norm(inp_tokens, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(inp_normed, "mtp_enorm", il); + + ggml_tensor * hid_normed = build_norm(prev_hidden, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(hid_normed, "mtp_hnorm", il); + + // 4. concat + projection: [2*n_embd, n_tokens] -> [n_embd, n_tokens] + ggml_tensor * concat = ggml_concat(ctx0, inp_normed, hid_normed, 0); + cb(concat, "mtp_concat", il); + + cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + // 5. full decoder layer (attention + MoE FFN with KV cache R/W) + cur = build_layer(cur, il, inp_pos, inp_attn, nullptr); + + // 6. save hidden state for next MTP step + res->t_embd = cur; + + // 7. shared head -> logits + if (layer.nextn.shared_head_norm) { + cur = build_norm(cur, layer.nextn.shared_head_norm, nullptr, LLM_NORM_RMS, -1); + } else { + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + } + cb(cur, "mtp_head_norm", il); + + if (layer.nextn.shared_head_head) { + cur = build_lora_mm(layer.nextn.shared_head_head, cur); } else { - // MoE routed experts - ggml_tensor * moe_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(moe_out, "ffn_moe_out", il); - - // shared expert MLP (always added on MoE layers in Step35) - ggml_tensor * sh_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, nullptr, nullptr, - model.layers[il].ffn_gate_shexp, nullptr, nullptr, - model.layers[il].ffn_down_shexp, nullptr, nullptr, - nullptr, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(sh_out, "ffn_shared_out", il); - - cur = ggml_add(ctx0, moe_out, sh_out); - cb(cur, "ffn_out", il); + cur = build_lora_mm(model.output, cur); } - cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_logits", il); + res->t_logits = cur; - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + ggml_build_forward_expand(gf, cur); + return; + } + + // === Main model graph === + inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); - // input for next layer - inpL = cur; + for (int il = 0; il < n_layer_main; ++il) { + ggml_tensor * out_ids = (il == n_layer_main - 1) ? inp_out_ids : nullptr; + inpL = build_layer(inpL, il, inp_pos, inp_attn, out_ids); } cur = inpL; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5523f23b540..52cb4777e8c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -92,6 +92,8 @@ struct server_slot { bool has_next_token = true; bool has_new_line = false; bool truncated = false; + bool mtp_cache_reused = false; + llama_pos mtp_common_prefix_len = 0; stop_type stop; @@ -148,6 +150,7 @@ struct server_slot { llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; + std::vector mtp_prompt_hidden; // stats size_t n_sent_text = 0; // number of sent text character @@ -179,6 +182,9 @@ struct server_slot { drafted.clear(); i_batch_dft.clear(); + mtp_prompt_hidden.clear(); + mtp_cache_reused = false; + mtp_common_prefix_len = 0; generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -226,8 +232,9 @@ struct server_slot { GGML_ASSERT(task); return - !task->need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + !need_embd() || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); } bool can_batch_with(server_slot & other_slot) const { @@ -259,7 +266,17 @@ struct server_slot { } bool can_speculate() const { - return !!spec; + return !!spec && !(task && task->params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP && mtp_cache_reused); + } + + bool need_embd() const { + if (!task) { + return false; + } + if (task->need_embd()) { + return true; + } + return can_speculate() && task->params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP; } void add_token(const completion_token_output & token) { @@ -451,10 +468,54 @@ struct server_slot { other.n_prompt_tokens_processed = n_prompt_tokens_processed; other.prompt = prompt.clone(); + other.mtp_prompt_hidden = mtp_prompt_hidden; + other.mtp_common_prefix_len = 0; other.init_sampler(); } }; +static void server_slot_append_mtp_prompt_hidden( + server_slot & slot, + llama_context * ctx, + const llama_batch & batch_view) { + if (slot.spec == nullptr || + slot.task == nullptr || + !slot.can_speculate() || + slot.task->params.speculative.type != COMMON_SPECULATIVE_TYPE_MTP || + (slot.state != SLOT_STATE_PROCESSING_PROMPT && slot.state != SLOT_STATE_DONE_PROMPT)) { + return; + } + + const int32_t n_embd = llama_model_n_embd(llama_get_model(ctx)); + + int32_t output_idx = 0; + for (int32_t i = 0; i < batch_view.n_tokens; ++i) { + const bool is_output = batch_view.logits ? batch_view.logits[i] != 0 : i == batch_view.n_tokens - 1; + if (!is_output) { + continue; + } + + bool matches_slot = false; + if (batch_view.seq_id != nullptr && batch_view.n_seq_id != nullptr) { + for (int32_t s = 0; s < batch_view.n_seq_id[i]; ++s) { + if (batch_view.seq_id[i][s] == slot.id) { + matches_slot = true; + break; + } + } + } + + if (matches_slot) { + float * hidden = llama_get_embeddings_ith(ctx, output_idx); + if (hidden != nullptr) { + slot.mtp_prompt_hidden.insert(slot.mtp_prompt_hidden.end(), hidden, hidden + n_embd); + } + } + + output_idx++; + } +} + // @@ -656,8 +717,6 @@ struct server_context_impl { add_bos_token = llama_vocab_get_add_bos(vocab); if (params_base.speculative.has_dft()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str()); - const auto & params_spec = params_base.speculative; auto params_dft = params_base; @@ -678,6 +737,8 @@ struct server_context_impl { params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides; + SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str()); + auto mparams_dft = common_model_params_to_llama(params_dft); model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); @@ -688,6 +749,26 @@ struct server_context_impl { params_base.speculative.model_dft = model_dft.get(); params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); + } else if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + const auto & params_spec = params_base.speculative; + + auto params_dft = params_base; + + params_dft.n_parallel = 1; + params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx; + params_dft.n_batch = llama_n_ctx_seq(ctx); + params_dft.devices = params_spec.devices; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; + params_dft.cache_type_k = params_spec.cache_type_k; + params_dft.cache_type_v = params_spec.cache_type_v; + + if (params_spec.cpuparams.n_threads > 0) { + params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads; + } + + params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides; + params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); } std::string & mmproj_path = params_base.mmproj.path; @@ -1046,6 +1127,12 @@ struct server_context_impl { if (!ret->prompt_load(*prompt_cache, task.tokens)) { ret->prompt_clear(false); + } else if (task.params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + common_speculative_invalidate_retained_state(ret->spec); + ret->mtp_prompt_hidden.clear(); + ret->mtp_cache_reused = true; + ret->mtp_common_prefix_len = 0; + SLT_WRN(*ret, "%s", "disabling MTP speculative on prompt-cache reuse until hidden-state reconstruction is implemented\n"); } prompt_cache->update(); @@ -2076,6 +2163,7 @@ struct server_context_impl { // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; + bool batch_need_embd = false; auto accept_special_token = [&](server_slot & slot, llama_token token) { return params_base.special || @@ -2094,6 +2182,7 @@ struct server_context_impl { } else if (!slot_batched->can_batch_with(slot)) { continue; } + batch_need_embd = batch_need_embd || slot.need_embd(); // generate draft tokens in speculative decoding mode // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] @@ -2123,10 +2212,16 @@ struct server_context_impl { if (slot.task->params.speculative.n_min > (int) draft.size()) { SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - // fallback to normal decoding - slot.i_batch = slot.i_batch_dft[0]; slot.drafted.clear(); - slot.i_batch_dft.clear(); + if (slot.task->params.speculative.type != COMMON_SPECULATIVE_TYPE_MTP) { + // Non-MTP speculation can safely fall back to plain decoding. + slot.i_batch = slot.i_batch_dft[0]; + slot.i_batch_dft.clear(); + } else { + // MTP still needs a 0-accept speculative round so accept() can stage + // the frontier hidden state for the next shifted first pass. + slot.i_batch = -1; + } } else { // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); @@ -2188,6 +2283,8 @@ struct server_context_impl { if (slot.state == SLOT_STATE_STARTED) { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; + slot.mtp_prompt_hidden.clear(); + slot.mtp_common_prefix_len = 0; slot.state = SLOT_STATE_PROCESSING_PROMPT; @@ -2271,7 +2368,15 @@ struct server_context_impl { n_past = std::min(n_past, slot.alora_invocation_start - 1); } - const auto n_cache_reuse = slot.task->params.n_cache_reuse; + auto n_cache_reuse = slot.task->params.n_cache_reuse; + const bool is_mtp_request = + slot.spec != nullptr && + slot.task->params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP; + + if (is_mtp_request && n_cache_reuse > 0) { + SLT_WRN(slot, "MTP only supports continuous prefix reuse for now - ignoring n_cache_reuse = %d\n", n_cache_reuse); + n_cache_reuse = 0; + } const bool can_cache_reuse = llama_memory_can_shift(llama_get_memory(ctx)) && @@ -2456,7 +2561,15 @@ struct server_context_impl { SLT_WRN(slot, "n_past was set to %d\n", n_past); } - slot.n_prompt_tokens_cache = n_past; + if (slot.spec != nullptr && + slot.can_speculate() && + slot.task->params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + const llama_pos target_prefix_len = std::max(n_past, 0); + const llama_pos draft_prefix_len = common_speculative_get_committed_prefix_len(slot.spec); + slot.mtp_common_prefix_len = std::min(target_prefix_len, draft_prefix_len); + } + + slot.n_prompt_tokens_cache = n_past; slot.n_prompt_tokens_processed = 0; slot.prompt.tokens.keep_first(n_past); @@ -2566,7 +2679,7 @@ struct server_context_impl { cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.task->need_embd()); + slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -2686,6 +2799,7 @@ struct server_context_impl { if (!slot_batched) { slot_batched = &slot; } + batch_need_embd = batch_need_embd || slot.need_embd(); if (batch.n_tokens >= n_batch) { break; @@ -2706,7 +2820,7 @@ struct server_context_impl { slot_batched->lora[alora_disabled_id].scale = alora_scale; } - llama_set_embeddings(ctx, slot_batched->task->need_embd()); + llama_set_embeddings(ctx, batch_need_embd); } if (batch.n_tokens == 0) { @@ -2794,6 +2908,10 @@ struct server_context_impl { // on successful decode, restore the original batch size n_batch = llama_n_batch(ctx); + for (auto & slot : slots) { + server_slot_append_mtp_prompt_hidden(slot, ctx, batch_view); + } + // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { @@ -2851,7 +2969,41 @@ struct server_context_impl { slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { - common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); + const auto & prompt_tokens = slot.prompt.tokens.get_text_tokens(); + if (slot.task->params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { + const int32_t n_embd = llama_model_n_embd(model); + const llama_pos reuse_len = std::min( + slot.mtp_common_prefix_len, + (llama_pos) prompt_tokens.size()); + llama_tokens prompt_tail_tokens( + prompt_tokens.begin() + reuse_len, + prompt_tokens.end()); + + common_speculative_begin(slot.spec, prompt_tokens, reuse_len); + + const int64_t expected_hidden = (int64_t) prompt_tail_tokens.size()*n_embd; + if ((int64_t) slot.mtp_prompt_hidden.size() != expected_hidden) { + SLT_WRN(slot, "MTP prompt tail hidden size mismatch (%zu vs %lld) - clearing initial source\n", + slot.mtp_prompt_hidden.size(), (long long) expected_hidden); + common_speculative_set_first_pass_source( + slot.spec, + llama_tokens(), + nullptr, + 0, + n_embd, + reuse_len); + } else { + common_speculative_set_first_pass_source( + slot.spec, + prompt_tail_tokens, + expected_hidden > 0 ? slot.mtp_prompt_hidden.data() : nullptr, + prompt_tail_tokens.size(), + n_embd, + reuse_len); + } + } else { + common_speculative_begin(slot.spec, prompt_tokens); + } } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -2912,6 +3064,7 @@ struct server_context_impl { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); + const std::vector batch_idxs = slot.i_batch_dft; slot.i_batch_dft.clear(); slot.drafted.clear(); @@ -2925,7 +3078,7 @@ struct server_context_impl { slot.n_draft_accepted += ids.size() - 1; // inform the speculative decoding about the number of accepted tokens - common_speculative_accept(slot.spec, ids.size() - 1); + common_speculative_accept(slot.spec, ids.size() - 1, batch_idxs); // rollback to the state before sampling the draft tokens slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);