Skip to content

Turbo dflash#103

Open
aminya wants to merge 20 commits intoTheTom:feature/turboquant-kv-cachefrom
aminya:turbo-dflash
Open

Turbo dflash#103
aminya wants to merge 20 commits intoTheTom:feature/turboquant-kv-cachefrom
aminya:turbo-dflash

Conversation

@aminya
Copy link
Copy Markdown

@aminya aminya commented Apr 23, 2026

Overview

This cherry-picks and fixes the conflicts of Dflash PR.
ggml-org#22105

This PR is built on top of my previous PR ggml-org#18039 (EAGLE3) and currently includes its commits. The reason is that Eagle3 and DFlash have many similarities. Please focus on the DFlash-specific commit(s), the EAGLE3 commits will disappear from the diff once ggml-org#18039 merges.

This PR adds DFlash speculative decoding to llama.cpp, achieving up to 8x speedup (Qwen3) with full numerical equivalence to the reference original implementation.

Compared to EAGLE3 - which uses an autoregressive draft and generates one token per draft step, DFlash produces an entire block of candidates in a single draft forward pass, resulting in higher per-iteration draft throughput. However, DFlash relies on multiple transformer layers for its draft model, whereas EAGLE3 uses only a single transformer layer.

There is still quite meaningful headroom for further performance improvements with current implementation, summarized in the Future Performance Work section below.

Performance Evaluation (NVIDIA L40S 48GB)

Numbers below were collected with --draft-max 16, --temp 0 --top-k 1 --seed 42, n=256. Baseline is llama-cli running the target model alone with the same sampling parameters.

"Thinking on/off" toggles reasoning via the LLAMA_SPEC_NO_THINK env var. Turning off thinking generally yields a higher acceptance rate with DFlash, which may be due to the nature of its training data.

Qwen3-8B

Draft: z-lab/Qwen3-8B-DFlash (bf16), Target: Qwen/Qwen3-8B (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept Rate DFlash w/o thinking (t/s) Speedup Accept Rate
Write a quicksort algorithm in Python. Write code only. 16 51.9 92.0 1.77x 12.0% 419.3 8.08x 93.3%
Explain the Pythagorean theorem 16 51.6 95.7 1.85x 13.7% 133.8 2.59x 20.9%
Plan a 1 day trip to DC 16 51.6 56.5 1.09x 4.9% 76.7 1.49x 8.9%

Qwen3-4B

Draft: z-lab/Qwen3-4B-DFlash (bf16), Target: Qwen/Qwen3-4B (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept Rate DFlash w/o thinking (t/s) Speedup Accept Rate
Write a quicksort algorithm in Python. Write code only. 16 91.0 138.9 1.53x 11.3% 537.9 5.91x 93.3%
Explain the Pythagorean theorem 16 91.1 130.5 1.43x 11.3% 187.3 2.06x 18.0%
Plan a 1 day trip to DC 16 91.2 102.3 1.12x 6.9% 123.7 1.36x 9.0%

GPT-OSS-20B

Draft: z-lab/gpt-oss-20b-DFlash (bf16), Target: openai/gpt-oss-20b (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept DFlash w/o thinking (t/s) Speedup Accept
Write a quicksort algorithm in Python. Write code only. 8 171.0 167.9 0.98x 38.0% 216.8 1.27x 55.6%
Explain the Pythagorean theorem 8 172.0 147.0 0.85x 31.0% 178.0 1.03x 42.0%
Plan a 1 day trip to DC 8 171.9 105.2 0.61x 16.6% 120.5 0.70x 21.9%

For MoE targets (gpt-oss-20b), DFlash speedup is generally smaller than for dense attention targets because more experts get activated during the parallel verification step than during single-token autoregressive decoding (same observation as in ggml-org#18039 for gpt-oss EAGLE3).

Qwen3.5-4B (With Performance Issue)

Draft: z-lab/Qwen3.5-4B-DFlash (bf16), Target: Qwen/Qwen3.5-4B (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept DFlash w/o thinking (t/s) Speedup Accept
Write a quicksort algorithm in Python. Write code only. 16 82.4 109.7 1.33x 29.0% 276.6 3.36x 84.8%
Explain the Pythagorean theorem 16 81.9 92.3 1.13x 22.9% 97.6 1.19x 25.6%
Plan a 1 day trip to DC 16 81.3 69.6 0.86x 15.5% 51.2 0.63x 9.3%

Speedup is intrinsically limited on hybrid target models:

  • For Hybrid targets (Qwen3.5, Jamba, ...), when target verify draft tokens, llama.cpp writes KV / recurrent state for the full [id_last + draft block] before acceptance is known.
  • Pure-attention target models can drop rejected suffixes with seq_rm; hybrid targets cannot, because recurrent state is not decomposable by token position.
  • Current workaround in examples/speculative-simple/speculative-simple.cpp:
    • snapshot target state before verify
    • on rejection, restore + replay(rerun target model forward) only the accepted prefix to recover recurrent state
  • Cost: each rejected step requires one extra target forward, which is the main reason hybrid speedup lags pure-attention.

How to run DFlash in llama.cpp

Step 1: Convert models to GGUF

TARGET_MODEL_HF="${MODELS_DIR}/Qwen3-8B"
TARGET_MODEL_GGUF="${MODELS_DIR}/Qwen3-8B.gguf"
DFLASH_MODEL_HF="${MODELS_DIR}/Qwen3-8B-DFlash-b16"
DFLASH_MODEL_GGUF="${MODELS_DIR}/Qwen3-8B-DFlash-b16.gguf"

python convert_hf_to_gguf.py \
    "${TARGET_MODEL_HF}" \
    --outtype bf16 \
    --outfile "${TARGET_MODEL_GGUF}"

python convert_hf_to_gguf.py \
    "${DFLASH_MODEL_HF}" \
    --outtype bf16 \
    --target-model-dir "${TARGET_MODEL_HF}" \
    --outfile "${DFLASH_MODEL_GGUF}"

Step 2: Build llama.cpp

cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release -j

Step 3: Run DFlash speculative decoding

# thinking off: set LLAMA_SPEC_NO_THINK=1
# Omit it to test thinking-mode behavior
export LLAMA_SPEC_NO_THINK=1

for prompt in \
    "Write a quicksort algorithm in Python. Write code only." \
    "Explain the Pythagorean theorem" \
    "Plan a 1 day trip to DC"; do
  echo "=== Prompt: $prompt ==="
  ./build/bin/llama-speculative-simple \
    -m  "${TARGET_MODEL_GGUF}" \
    -md "${DFLASH_MODEL_GGUF}" \
    --dflash -p "$prompt" -n 256 \
    --draft-max 16 \
    -cd 512 -c 1024 \
    --temp 0 --top-k 1 --seed 42 \
    -ngl 99 -ngld 99
done

Future Performance Work

KV cache / graph reuse for the DFlash decoder

The DFlash decoder currently rebuilds its graph every iteration (graphs reused = 0). The main cause is that cross.n_enc (the length of accumulated_target_ctx) grows monotonically, which changes the shape of target_ctx and invalidates all downstream tensor shapes.

Possible improvements:

  • add a draft-side KV cache to the DFlash decoder.
    This would make the implementation closer to the original reference: committed target-context K/V would be materialized once and reused across iterations, instead of recomputing K/V from the full accumulated context every step. This reduces draft-side compute and also makes graph shapes much more stable, which should improve graph reuse. Since the DFlash decoder attention includes both cross-attention and self-attention, the current llama.cpp implementation does not support this pattern well.

  • keep the current no-cache design, but fix the target_ctx input shape.
    Instead of letting target_ctx grow every iteration, reserve a fixed-size buffer, track the active length separately, and mask out the padded region in attention. This preserves the current semantics while allowing the decoder graph to be reused. This method is not ideal compared to using a KV cache.

Hybrid target model performance improvement (For all speculative decoding methods)

Hybrid targets (e.g. Qwen3.5) are slower because the problem is no longer just draft-side graph reuse. During target verify, llama.cpp writes KV / recurrent state for the full draft block before acceptance is known. Pure-attention target models can discard rejected suffixes with seq_rm, but hybrid targets cannot, because their recurrent state is not decomposable by token position.

The current workaround is:

  • snapshot the target state before verify
  • on rejection, restore the snapshot
  • replay only the accepted prefix

This is correct, but each rejected step may require one extra target forward, which is the main reason hybrid speedup lags pure-attention.
A more fundamental future improvement would be target-side deferred commit (SGLang Implementation): verify would compute temporary recurrent states, and only the accepted-prefix state would be committed. That would remove replay from the hybrid path, but it requires deeper changes to llama.cpp’s recurrent-state update flow.
Note this applies to all hybrid models used as target models in speculative decoding methods, not just DFlash.

More (Low Priority)

  • Draft-side sampling fast path: For greedy / no-grammar mode, batch argmax over the entire drafted block instead of invoking the sampler one token at a time.
  • CUDA graph for both draft model and target model
  • ....

Requirements

ruixiang63 and others added 20 commits December 14, 2025 18:12
EAGLE3 is an encoder-decoder based speculative decoding method:
- Extracts features from target model at specific layers
- Uses feature fusion layer to compress target features
- Generates draft tokens with single-layer decoder
- Maps draft vocabulary to target vocabulary via d2t tensor

Key changes:
- Add LLM_ARCH_EAGLE3 architecture
- Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp)
- Add feature extraction from target model layers
- Add g_embeddings handling for decoder input
- Add GGML_TENSOR_FLAG_SYNC for GPU synchronization
- Add --eagle3 flag for speculative-simple example
- Add EAGLE3 model conversion in convert_hf_to_gguf.py
@nybblr
Copy link
Copy Markdown

nybblr commented Apr 23, 2026

Trying to take this for a spin, looks like a bunch of attention biases were recently renamed upstream (and merged to the turboquant fork):
4f02d47

Got it compiling here:
https://github.com/nybblr/llama.cpp/commits/nybblr/turboquant-dflash/

Now the question is how to test it on llama-server (dflash flag doesn't seem allowed?)

@nybblr
Copy link
Copy Markdown

nybblr commented Apr 23, 2026

I naively resolved merge conflicts and enabled --dflash as a valid llama-server argument. The model and draft successfully load, but on attempting to inference, the server crashes with:

GGML_ASSERT(!dflash.target_features.empty() && "DFlash target features not extracted") failed

@aminya Any pointers on what it should take to make this work with llama-server?

@aminya
Copy link
Copy Markdown
Author

aminya commented Apr 23, 2026

Yes. The server isn't set up but the CLI works. I couldn't get the speed ups I was expecting when using Qwopus. I will try with more models

@taniguchi-taku-softm
Copy link
Copy Markdown

@nybblr

Currently only llama_set_eagle3 is called. You should similarly add a call to llama_set_dflash.

llama-cpp-turboquant/tools/server/server-context.cpp:801

            if (params_base.speculative.eagle3) {
                // EAGLE3 current limitation: extracted target features are per-context; multiple slots would overwrite each other
                if (params_base.n_parallel > 1) {
                    SRV_ERR("%s", "EAGLE3 speculative decoding is not supported with n_parallel > 1\n");
                    return false;
                }
                llama_set_eagle3(ctx, model_dft.get());
                SRV_INF("%s", "EAGLE3 feature extraction enabled on target model\n");
            }
            // TODO: params_base.speculative.dflash

@nybblr
Copy link
Copy Markdown

nybblr commented Apr 24, 2026

Thanks @taniguchi-taku-softm. Got it running on llama-server now, but performance instantly tanked to 20 t/s from 54 t/s on Qwen3.6. I only have a 12GB nvidia card, so I rely heavily on CPU mapped memory. Throwing a draft model at it seems to cause a lot of crashing (from miscalculations I'd guess), and trying to set --ngl on the main model to make space for the draft model on the GPU, seems to instantly kill performance.

@aminya Do you think it's moot trying to get this technique working for a VRAM starved setup like mine? I was thinking a tiny draft model like DFlash could help, even at the expense of consuming some VRAM from the main model.

@yansheng1003
Copy link
Copy Markdown

Vulkan is support?

@aminya
Copy link
Copy Markdown
Author

aminya commented Apr 27, 2026

@nybblr I've tried many different patches and configurations over the weekend for my single 3090 setup. There's no benefit in Dflash I can see. I cannot reproduce any of the claimed speed ups in real workflows.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants