Skip to content

[Speculative decoding] feat: add DFlash support#22105

Draft
ruixiang63 wants to merge 22 commits intoggml-org:masterfrom
ruixiang63:dflash
Draft

[Speculative decoding] feat: add DFlash support#22105
ruixiang63 wants to merge 22 commits intoggml-org:masterfrom
ruixiang63:dflash

Conversation

@ruixiang63
Copy link
Copy Markdown

@ruixiang63 ruixiang63 commented Apr 19, 2026

Overview

This PR is built on top of my previous PR #18039 (EAGLE3) and currently includes its commits. The reason is that Eagle3 and DFlash have many similarities. Please focus on the DFlash-specific commit(s), the EAGLE3 commits will disappear from the diff once #18039 merges.

This PR adds DFlash speculative decoding to llama.cpp, achieving up to 8x speedup (Qwen3) with full numerical equivalence to the reference original implementation.

Compared to EAGLE3 - which uses an autoregressive draft and generates one token per draft step, DFlash produces an entire block of candidates in a single draft forward pass, resulting in higher per-iteration draft throughput. However, DFlash relies on multiple transformer layers for its draft model, whereas EAGLE3 uses only a single transformer layer.

There is still quite meaningful headroom for further performance improvements with current implementation, summarized in the Future Performance Work section below.

Performance Evaluation (NVIDIA L40S 48GB)

Numbers below were collected with --draft-max 16, --temp 0 --top-k 1 --seed 42, n=256. Baseline is llama-cli running the target model alone with the same sampling parameters.

"Thinking on/off" toggles reasoning via the LLAMA_SPEC_NO_THINK env var. Turning off thinking generally yields a higher acceptance rate with DFlash, which may be due to the nature of its training data.

Qwen3-8B

Draft: z-lab/Qwen3-8B-DFlash (bf16), Target: Qwen/Qwen3-8B (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept Rate DFlash w/o thinking (t/s) Speedup Accept Rate
Write a quicksort algorithm in Python. Write code only. 16 51.9 92.0 1.77x 12.0% 419.3 8.08x 93.3%
Explain the Pythagorean theorem 16 51.6 95.7 1.85x 13.7% 133.8 2.59x 20.9%
Plan a 1 day trip to DC 16 51.6 56.5 1.09x 4.9% 76.7 1.49x 8.9%

Qwen3-4B

Draft: z-lab/Qwen3-4B-DFlash (bf16), Target: Qwen/Qwen3-4B (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept Rate DFlash w/o thinking (t/s) Speedup Accept Rate
Write a quicksort algorithm in Python. Write code only. 16 91.0 138.9 1.53x 11.3% 537.9 5.91x 93.3%
Explain the Pythagorean theorem 16 91.1 130.5 1.43x 11.3% 187.3 2.06x 18.0%
Plan a 1 day trip to DC 16 91.2 102.3 1.12x 6.9% 123.7 1.36x 9.0%

GPT-OSS-20B

Draft: z-lab/gpt-oss-20b-DFlash (bf16), Target: openai/gpt-oss-20b (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept DFlash w/o thinking (t/s) Speedup Accept
Write a quicksort algorithm in Python. Write code only. 8 171.0 167.9 0.98x 38.0% 216.8 1.27x 55.6%
Explain the Pythagorean theorem 8 172.0 147.0 0.85x 31.0% 178.0 1.03x 42.0%
Plan a 1 day trip to DC 8 171.9 105.2 0.61x 16.6% 120.5 0.70x 21.9%

For MoE targets (gpt-oss-20b), DFlash speedup is generally smaller than for dense attention targets because more experts get activated during the parallel verification step than during single-token autoregressive decoding (same observation as in #18039 for gpt-oss EAGLE3).

Qwen3.5-4B (With Performance Issue)

Draft: z-lab/Qwen3.5-4B-DFlash (bf16), Target: Qwen/Qwen3.5-4B (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept DFlash w/o thinking (t/s) Speedup Accept
Write a quicksort algorithm in Python. Write code only. 16 82.4 131.7 1.60x 36.3% 293.9 3.57x 85.7%
Explain the Pythagorean theorem 16 81.9 124.0 1.51x 34.1% 120.7 1.47x 38.2%
Plan a 1 day trip to DC 16 81.3 102.1 1.26x 26.7% 75.5 0.93x 17.7%

Speedup is intrinsically limited on hybrid target models:

  • For Hybrid targets (Qwen3.5, ...), when target verify draft tokens, llama.cpp writes KV / recurrent state for the full [id_last + draft block] before acceptance is known.
  • Pure-attention target models can drop rejected suffixes with seq_rm; hybrid targets cannot, because recurrent state is not decomposable by token position.
  • Current workaround in examples/speculative-simple/speculative-simple.cpp:
    • snapshot target state before verify
    • on rejection, restore + replay(rerun target model forward) only the accepted prefix to recover recurrent state
  • Cost: each rejected step requires one extra target forward, which is the main reason hybrid speedup lags pure-attention.
  • Thanks to speculative checkpointing server : speculative checkpointing #19493 and speculative-simple : add checkpoint support #22227, llama.cpp now supports fallback for hybrid model states.

Qwen3.5-9B

Draft: z-lab/Qwen3.5-9B-DFlash (bf16), Target: Qwen/Qwen3.5-9B (bf16)

Prompt block_size Baseline (t/s) DFlash w/ thinking (t/s) Speedup Accept DFlash w/o thinking (t/s) Speedup Accept
Write a quicksort algorithm in Python. Write code only. 16 47.6 64.0 1.34x 30.5% 131.8 2.77x 55.2%
Explain the Pythagorean theorem 16 48.0 91.2 1.90x 44.3% 88.6 1.85x 35.1%
Plan a 1 day trip to DC 16 48.0 69.6 1.45x 25.7% 52.5 1.10x 17.7%

How to run DFlash in llama.cpp

Step 1: Convert models to GGUF

TARGET_MODEL_HF="${MODELS_DIR}/Qwen3-8B"
TARGET_MODEL_GGUF="${MODELS_DIR}/Qwen3-8B.gguf"
DFLASH_MODEL_HF="${MODELS_DIR}/Qwen3-8B-DFlash-b16"
DFLASH_MODEL_GGUF="${MODELS_DIR}/Qwen3-8B-DFlash-b16.gguf"

python convert_hf_to_gguf.py \
    "${TARGET_MODEL_HF}" \
    --outtype bf16 \
    --outfile "${TARGET_MODEL_GGUF}"

python convert_hf_to_gguf.py \
    "${DFLASH_MODEL_HF}" \
    --outtype bf16 \
    --target-model-dir "${TARGET_MODEL_HF}" \
    --outfile "${DFLASH_MODEL_GGUF}"

[Optional] Step 2: Quantize GGUF models

TARGET_MODEL_GGUF="${MODELS_DIR}/Qwen3-8B.gguf"
DFLASH_MODEL_GGUF="${MODELS_DIR}/Qwen3-8B-DFlash-b16.gguf"

./build/bin/llama-quantize \
  ${TARGET_MODEL_GGUF} \
  ${TARGET_MODEL_GGUF}_Q4_K_M.gguf \
  Q4_K_M
 
./build/bin/llama-quantize \
  ${DFLASH_MODEL_GGUF} \
  ${DFLASH_MODEL_GGUF}_Q4_K_M.gguf \
  Q4_K_M

Step 3: Build llama.cpp

cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release -j

Step 4: Run DFlash speculative decoding

# thinking off: set LLAMA_SPEC_NO_THINK=1
# Omit it to test thinking-mode behavior
export LLAMA_SPEC_NO_THINK=1

for prompt in \
    "Write a quicksort algorithm in Python. Write code only." \
    "Explain the Pythagorean theorem" \
    "Plan a 1 day trip to DC"; do
  echo "=== Prompt: $prompt ==="
  ./build/bin/llama-speculative-simple \
    -m  "${TARGET_MODEL_GGUF}" \
    -md "${DFLASH_MODEL_GGUF}" \
    --dflash -p "$prompt" -n 256 \
    --draft-max 16 \
    -cd 512 -c 1024 \
    --temp 0 --top-k 1 --seed 42 \
    -ngl 99 -ngld 99
done

Future Performance Work

KV cache / graph reuse for the DFlash decoder

The DFlash decoder currently rebuilds its graph every iteration (graphs reused = 0). The main cause is that cross.n_enc (the length of accumulated_target_ctx) grows monotonically, which changes the shape of target_ctx and invalidates all downstream tensor shapes.

Possible improvements:

  • add a draft-side KV cache to the DFlash decoder.
    This would make the implementation closer to the original reference: committed target-context K/V would be materialized once and reused across iterations, instead of recomputing K/V from the full accumulated context every step. This reduces draft-side compute and also makes graph shapes much more stable, which should improve graph reuse. Since the DFlash decoder attention includes both cross-attention and self-attention, the current llama.cpp implementation does not support this pattern well.

  • keep the current no-cache design, but fix the target_ctx input shape.
    Instead of letting target_ctx grow every iteration, reserve a fixed-size buffer, track the active length separately, and mask out the padded region in attention. This preserves the current semantics while allowing the decoder graph to be reused. This method is not ideal compared to using a KV cache.

Hybrid target model performance improvement (For all speculative decoding methods)

Hybrid targets (e.g. Qwen3.5) are slower because the problem is no longer just draft-side graph reuse. During target verify, llama.cpp writes KV / recurrent state for the full draft block before acceptance is known. Pure-attention target models can discard rejected suffixes with seq_rm, but hybrid targets cannot, because their recurrent state is not decomposable by token position.

The current workaround is:

  • snapshot the target state before verify
  • on rejection, restore the snapshot
  • replay only the accepted prefix

This is correct, but each rejected step may require one extra target forward, which is the main reason hybrid speedup lags pure-attention.
A more fundamental future improvement would be target-side deferred commit (SGLang Implementation): verify would compute temporary recurrent states, and only the accepted-prefix state would be committed. That would remove replay from the hybrid path, but it requires deeper changes to llama.cpp’s recurrent-state update flow.
Note this applies to all hybrid models used as target models in speculative decoding methods, not just DFlash.

Updates: Thanks to #19493 and #22227, llama.cpp now supports fallback for hybrid model states.

More (Low Priority)

  • Draft-side sampling fast path: For greedy / no-grammar mode, batch argmax over the entire drafted block instead of invoking the sampler one token at a time.
  • CUDA graph for both draft model and target model
  • ....

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes, use Claude to help discuss and design the DFlash architecture, ask clarifying questions, and assist with writing tests. Everything remains under my control, and I reviewed every single line of AI-generated code.

ruixiang63 and others added 17 commits December 14, 2025 18:12
EAGLE3 is an encoder-decoder based speculative decoding method:
- Extracts features from target model at specific layers
- Uses feature fusion layer to compress target features
- Generates draft tokens with single-layer decoder
- Maps draft vocabulary to target vocabulary via d2t tensor

Key changes:
- Add LLM_ARCH_EAGLE3 architecture
- Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp)
- Add feature extraction from target model layers
- Add g_embeddings handling for decoder input
- Add GGML_TENSOR_FLAG_SYNC for GPU synchronization
- Add --eagle3 flag for speculative-simple example
- Add EAGLE3 model conversion in convert_hf_to_gguf.py
@ggml-gh-bot
Copy link
Copy Markdown

ggml-gh-bot Bot commented Apr 19, 2026

Hi @ruixiang63, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Large PR: Large changes require prior discussion (e.g. an issue or RFC) and maintainers may not be able to review this PR as-is. Consider splitting it into smaller, focused PRs.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@am17an
Copy link
Copy Markdown
Contributor

am17an commented Apr 19, 2026

I think the method of exposing the hidden states of the target model needs to be cleaner, as it's used in both eagle3 and dflash and I guess even MTP. Probably needs a refactoring to expose these endpoints

@ruixiang63
Copy link
Copy Markdown
Author

ruixiang63 commented Apr 19, 2026

I think the method of exposing the hidden states of the target model needs to be cleaner, as it's used in both eagle3 and dflash and I guess even MTP. Probably needs a refactoring to expose these endpoints

@ggerganov has already worked on this refactoring work. And you’re very welcome to contribute if you have any better ideas for this PR :)

@noonghunna
Copy link
Copy Markdown

Trying this against RedHatAI/gemma-4-31B-it-speculator.dflash on Ampere (2× RTX 3090, sm_86, CUDA 12.9) — ran into a gap worth flagging for speculators-format drafts.

Issue 1 (small, easy): d2t / t2d not handled in DFlashModel

The EAGLE3 path at convert_hf_to_gguf.py lines ~2923-2931 stashes d2t as int64 and drops t2d. DFlashModel (subclass of Qwen3Model) doesn't replicate that — fails with Can not map tensor 'model.d2t' on any speculators-format draft.

Fix that worked locally for us:

```python

in DFlashModel.modify_tensors, before the super() fallthrough

if name == "d2t":
if not hasattr(self, "_eagle3_int_tensors"):
self._eagle3_int_tensors = {}
self._eagle3_int_tensors[name] = data_torch
if not hasattr(self, "is_eagle3"):
self.is_eagle3 = True
return
if name == "t2d":
return
```

Piggy-backs on the EAGLE3 int-tensor emit in prepare_tensors.

Issue 2 (bigger): `gguf.MODEL_ARCH.DFLASH` tensor list is missing `TOKEN_EMBD` and `OUTPUT`

`gguf-py/gguf/constants.py` line 3578+ registers `DFLASH_FC` + `DFLASH_HIDDEN_NORM` plus the transformer layers, but not `TOKEN_EMBD` / `OUTPUT`. Works for z-lab-format drafts (they share the target's full vocab — we confirmed `z-lab/Qwen3.6-35B-A3B-DFlash` converts cleanly), but speculators-format drafts carry their own reduced-vocab embeddings — Red Hat's Gemma-4 draft has `draft_vocab_size=32000` vs Gemma's ~300K, with its own `embed_tokens.weight` + `lm_head.weight` + `d2t`/`t2d` remap.

After patching Issue 1, conversion now fails at `Can not map tensor 'model.embed_tokens.weight'`.

Fixing this end-to-end isn't just adding the arch constants — inference also needs to use the draft's own embeddings plus the `d2t` table to map draft-vocab logits back to target-vocab ids during verify. Flagging here rather than filing a separate issue since it's narrow.

AI-assisted (Claude) findings, human review & submit.

@ruixiang63
Copy link
Copy Markdown
Author

ruixiang63 commented Apr 24, 2026

Rebased onto the latest master. Hybrid target models (e.g. Qwen3.5) now benefit from the speculative checkpointing mechanism recently merged upstream and the DFlash performance gets better. PR description updated with the new performance numbers.

@ivanbaldo
Copy link
Copy Markdown

Have you also looked at DDTree perhaps?

@ruixiang63
Copy link
Copy Markdown
Author

Have you also looked at DDTree perhaps?

Not yet, but I’ll take a look. I’d expect it to come after this PR gets merged.

Comment thread common/arg.cpp Outdated
@kroaton
Copy link
Copy Markdown

kroaton commented Apr 24, 2026

Out of curiosity, have you tested quantizing the DFlash model to Q8?
I found this by mistake yesterday https://huggingface.co/spiritbuun/Qwen3.6-27B-DFlash-GGUF, maybe worth a shot at supporting .gguf for DFlash instead of just .safetensors?

https://huggingface.co/lym00/Qwen3.6-35B-A3B-DFlash-GGUF-Test

@JonhJonhD
Copy link
Copy Markdown

JonhJonhD commented Apr 26, 2026

I don’t know if this is useful, but I managed to get it working on AMD (though with poor performance)

Main GPU: R9700 AI PRO running Unsloth Q5 Qwen 27B 3.5 (Vulkan backend) + DFlash bf16 compiled GGUF

The acceptance rate works well with the current parameters; changing them does affect the rate.

The --device Vulkan0,Cuda0 and --tensor-split 1,0 parameters are necessary to make things run otherwise it fail because the GGML scheduler currently does not handle cross-buffer / cross-device operations cleanly for this feature.

It actually runs. Here’s the command and the result. If you’d like me to test something specific that might help, just let me know. I’m clearly out of my depth and can’t really suggest improvements.

Command + result + evaluation :

Click to expand
 ./llama-speculative-simple \
  -m ~/models/qwen3.5/Qwen3.5-27B-Q5_K_M.gguf \
  --model-draft ~/models/qwen3.5/Qwen3.5-27BDFlash.gguf \
  --device Vulkan0,Cuda0 \
  --tensor-split 1,0 \
  -cd 8192 \
  --draft-max 2 \
   --temp 0.8 \
   --top-p 0.9 \
   --top-k 40 \
   --presence_penalty 1.2 \
  -ngl 99 \
  --ctx-size 1572 \
  --flash-attn on \
  -ub 256 \
  -b 1024 \
  --cache-type-k q8_0 --cache-type-v q8_0 \
  --dflash
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 15949 MiB):
  Device 0: NVIDIA GeForce RTX 4060 Ti, compute capability 8.9, VMM: yes, VRAM: 15949 MiB
WARNING: radv is not a conformant Vulkan implementation, testing use only.
common_init_result: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on
common_params_fit_impl: getting device memory data for initial parameters:
common_memory_breakdown_print: | memory breakdown [MiB]                    | total    free     self   model   context   compute       unaccounted |
common_memory_breakdown_print: |   - Vulkan0 (AI PRO R9700 (RADV GFX1201)) | 32768 = 31424 + (18341 = 17856 +     209 +     275) + 17592186027418 |
common_memory_breakdown_print: |   - CUDA0 (RTX 4060 Ti)                   | 15949 = 15809 + (    0 =     0 +       0 +       0) +            140 |
common_memory_breakdown_print: |   - Host                                  |                    851 =   833 +       0 +      18                   |
common_params_fit_impl: projected memory use with initial parameters [MiB]:
common_params_fit_impl:   - Vulkan0 (AMD Radeon AI PRO R9700 (RADV GFX1201)):  32768 total,  18341 used,  13083 free vs. target of   1024
common_params_fit_impl:   - CUDA0 (NVIDIA GeForce RTX 4060 Ti)              :  15949 total,      0 used,  15809 free vs. target of   1024
common_params_fit_impl: projected to use 18341 MiB of device memory vs. 47233 MiB of free device memory
common_params_fit_impl: targets for free memory can be met on all devices, no changes needed
common_fit_params: successfully fit params to free device memory
common_fit_params: fitting params to free memory took 0.43 seconds
llama_model_load_from_file_impl: using device Vulkan0 (AMD Radeon AI PRO R9700 (RADV GFX1201)) (0000:08:00.0) - 31726 MiB free
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 4060 Ti) (0000:04:00.0) - 15809 MiB free
llama_model_loader: loaded meta data with 49 key-value pairs and 851 tensors from ~/models/qwen3.5/Qwen3.5-27B-Q5_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 0.600000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.5-27B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.5-27B
llama_model_loader: - kv   7:                       general.quantized_by str              = Unsloth
llama_model_loader: - kv   8:                         general.size_label str              = 27B
llama_model_loader: - kv   9:                            general.license str              = apache-2.0
llama_model_loader: - kv  10:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.5-2...
llama_model_loader: - kv  11:                           general.repo_url str              = https://huggingface.co/unsloth
llama_model_loader: - kv  12:                   general.base_model.count u32              = 1
llama_model_loader: - kv  13:                  general.base_model.0.name str              = Qwen3.5 27B
llama_model_loader: - kv  14:          general.base_model.0.organization str              = Qwen
llama_model_loader: - kv  15:              general.base_model.0.repo_url str              = https://huggingface.co/Qwen/Qwen3.5-27B
llama_model_loader: - kv  16:                               general.tags arr[str,3]       = ["qwen3_5_moe", "unsloth", "image-tex...
llama_model_loader: - kv  17:                         qwen35.block_count u32              = 64
llama_model_loader: - kv  18:                      qwen35.context_length u32              = 262144
llama_model_loader: - kv  19:                    qwen35.embedding_length u32              = 5120
llama_model_loader: - kv  20:                 qwen35.feed_forward_length u32              = 17408
llama_model_loader: - kv  21:                qwen35.attention.head_count u32              = 24
llama_model_loader: - kv  22:             qwen35.attention.head_count_kv u32              = 4
llama_model_loader: - kv  23:             qwen35.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  24:                      qwen35.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  25:    qwen35.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  26:                qwen35.attention.key_length u32              = 256
llama_model_loader: - kv  27:              qwen35.attention.value_length u32              = 256
llama_model_loader: - kv  28:                     qwen35.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  29:                      qwen35.ssm.state_size u32              = 128
llama_model_loader: - kv  30:                     qwen35.ssm.group_count u32              = 16
llama_model_loader: - kv  31:                  qwen35.ssm.time_step_rank u32              = 48
llama_model_loader: - kv  32:                      qwen35.ssm.inner_size u32              = 6144
llama_model_loader: - kv  33:             qwen35.full_attention_interval u32              = 4
llama_model_loader: - kv  34:                qwen35.rope.dimension_count u32              = 64
llama_model_loader: - kv  35:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  36:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  37:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  38:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  39:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  40:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  41:            tokenizer.ggml.padding_token_id u32              = 248055
llama_model_loader: - kv  42:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - kv  43:               general.quantization_version u32              = 2
llama_model_loader: - kv  44:                          general.file_type u32              = 17
llama_model_loader: - kv  45:                      quantize.imatrix.file str              = Qwen3.5-27B-GGUF/imatrix_unsloth.gguf
llama_model_loader: - kv  46:                   quantize.imatrix.dataset str              = unsloth_calibration_Qwen3.5-27B.txt
llama_model_loader: - kv  47:             quantize.imatrix.entries_count u32              = 496
llama_model_loader: - kv  48:              quantize.imatrix.chunks_count u32              = 80
llama_model_loader: - type  f32:  353 tensors
llama_model_loader: - type q8_0:   96 tensors
llama_model_loader: - type q5_K:  263 tensors
llama_model_loader: - type q6_K:  139 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q5_K - Medium
print_info: file size   = 18.25 GiB (5.83 BPW) 
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = qwen35
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 5120
print_info: n_embd_inp            = 5120
print_info: n_layer               = 64
print_info: n_head                = 24
print_info: n_head_kv             = 4
print_info: n_rot                 = 64
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 256
print_info: n_embd_head_v         = 256
print_info: n_gqa                 = 6
print_info: n_embd_k_gqa          = 1024
print_info: n_embd_v_gqa          = 1024
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: n_ff                  = 17408
print_info: n_expert              = 0
print_info: n_expert_used         = 0
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 40
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: mrope sections        = [11, 11, 10, 0]
print_info: ssm_d_conv            = 4
print_info: ssm_d_inner           = 6144
print_info: ssm_d_state           = 128
print_info: ssm_dt_rank           = 48
print_info: ssm_n_group           = 16
print_info: ssm_dt_b_c_rms        = 0
print_info: model type            = 27B
print_info: model params          = 26.90 B
print_info: general.name          = Qwen3.5-27B
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 11 ','
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248055 '<|vision_pad|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 63 repeating layers to GPU
load_tensors: offloaded 65/65 layers to GPU
load_tensors:   CPU_Mapped model buffer size =   833.59 MiB
load_tensors:      Vulkan0 model buffer size = 17856.52 MiB
.............................................................................................
common_init_result: added <|endoftext|> logit bias = -inf
common_init_result: added <|im_end|> logit bias = -inf
common_init_result: added <|fim_pad|> logit bias = -inf
common_init_result: added <|repo_name|> logit bias = -inf
common_init_result: added <|file_sep|> logit bias = -inf
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 1792
llama_context: n_ctx_seq     = 1792
llama_context: n_batch       = 1024
llama_context: n_ubatch      = 256
llama_context: causal_attn   = 1
llama_context: flash_attn    = enabled
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (1792) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context: Vulkan_Host  output buffer size =     0.95 MiB
llama_kv_cache:    Vulkan0 KV buffer size =    59.50 MiB
llama_kv_cache: size =   59.50 MiB (  1792 cells,  16 layers,  1/1 seqs), K (q8_0):   29.75 MiB, V (q8_0):   29.75 MiB
llama_kv_cache: attn_rot_k = 1, n_embd_head_k_all = 256
llama_kv_cache: attn_rot_v = 1, n_embd_head_k_all = 256
llama_memory_recurrent:    Vulkan0 RS buffer size =   149.62 MiB
llama_memory_recurrent: size =  149.62 MiB (     1 cells,  64 layers,  1 seqs), R (f32):    5.62 MiB, S (f32):  144.00 MiB
llama_context: pipeline parallelism enabled
sched_reserve: reserving ...
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:    Vulkan0 compute buffer size =   275.60 MiB
sched_reserve: Vulkan_Host compute buffer size =    18.14 MiB
sched_reserve: graph nodes  = 3849
sched_reserve: graph splits = 2
sched_reserve: reserve took 23.36 ms, sched copies = 4
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
common_context_can_seq_rm: the target context does not support partial sequence removal
speculative decoding will use checkpoints (context does not support partial sequence removal)
llama_model_load_from_file_impl: skipping device Vulkan1 (NVIDIA GeForce RTX 4060 Ti) with id 0000:04:00.0 - already using device CUDA0 (NVIDIA GeForce RTX 4060 Ti) with the same id
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 4060 Ti) (0000:04:00.0) - 15809 MiB free
llama_model_load_from_file_impl: using device Vulkan0 (AMD Radeon AI PRO R9700 (RADV GFX1201)) (0000:08:00.0) - 13380 MiB free
llama_model_loader: loaded meta data with 30 key-value pairs and 58 tensors from ~/models/qwen3.5/Qwen3.5-27BDFlash.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = dflash
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Qwen3.5 DFlash
llama_model_loader: - kv   3:                         general.size_label str              = 1.7B
llama_model_loader: - kv   4:                            general.license str              = mit
llama_model_loader: - kv   5:                               general.tags arr[str,8]       = ["dflash", "speculative-decoding", "d...
llama_model_loader: - kv   6:                         dflash.block_count u32              = 5
llama_model_loader: - kv   7:                      dflash.context_length u32              = 262144
llama_model_loader: - kv   8:                    dflash.embedding_length u32              = 5120
llama_model_loader: - kv   9:                 dflash.feed_forward_length u32              = 17408
llama_model_loader: - kv  10:                dflash.attention.head_count u32              = 32
llama_model_loader: - kv  11:             dflash.attention.head_count_kv u32              = 8
llama_model_loader: - kv  12:                      dflash.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  13:    dflash.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  14:                dflash.attention.key_length u32              = 128
llama_model_loader: - kv  15:              dflash.attention.value_length u32              = 128
llama_model_loader: - kv  16:                          general.file_type u32              = 32
llama_model_loader: - kv  17:                          dflash.block_size u32              = 16
llama_model_loader: - kv  18:                    dflash.target_layer_ids arr[i32,5]       = [2, 17, 32, 47, 62]
llama_model_loader: - kv  19:                       dflash.mask_token_id u32              = 248070
llama_model_loader: - kv  20:               general.quantization_version u32              = 2
llama_model_loader: - kv  21:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  22:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  23:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  24:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  25:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  27:            tokenizer.ggml.padding_token_id u32              = 248044
llama_model_loader: - kv  28:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  29:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - type  f32:   22 tensors
llama_model_loader: - type bf16:   36 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = BF16
print_info: file size   = 3.22 GiB (16.00 BPW) 
load_hparams: DFlash extract_layers = [2, 17, 32, 47, 62]
load_hparams: DFlash block_size = 16, mask_token_id = 248070
load: 0 unused tokens
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
print_info: arch                  = dflash
print_info: vocab_only            = 0
print_info: no_alloc              = 0
print_info: n_ctx_train           = 262144
print_info: n_embd                = 5120
print_info: n_embd_inp            = 5120
print_info: n_layer               = 5
print_info: n_head                = 32
print_info: n_head_kv             = 8
print_info: n_rot                 = 128
print_info: n_swa                 = 0
print_info: is_swa_any            = 0
print_info: n_embd_head_k         = 128
print_info: n_embd_head_v         = 128
print_info: n_gqa                 = 4
print_info: n_embd_k_gqa          = 1024
print_info: n_embd_v_gqa          = 1024
print_info: f_norm_eps            = 0.0e+00
print_info: f_norm_rms_eps        = 1.0e-06
print_info: f_clamp_kqv           = 0.0e+00
print_info: f_max_alibi_bias      = 0.0e+00
print_info: f_logit_scale         = 0.0e+00
print_info: f_attn_scale          = 0.0e+00
print_info: n_ff                  = 17408
print_info: n_expert              = 0
print_info: n_expert_used         = 0
print_info: n_expert_groups       = 0
print_info: n_group_used          = 0
print_info: causal attn           = 1
print_info: pooling type          = -1
print_info: rope type             = 2
print_info: rope scaling          = linear
print_info: freq_base_train       = 10000000.0
print_info: freq_scale_train      = 1
print_info: n_ctx_orig_yarn       = 262144
print_info: rope_yarn_log_mul     = 0.0000
print_info: rope_finetuned        = unknown
print_info: model type            = ?B
print_info: model params          = 1.73 B
print_info: general.name          = Qwen3.5 DFlash
print_info: vocab type            = BPE
print_info: n_vocab               = 248320
print_info: n_merges              = 247587
print_info: BOS token             = 11 ','
print_info: EOS token             = 248046 '<|im_end|>'
print_info: EOT token             = 248046 '<|im_end|>'
print_info: PAD token             = 248044 '<|endoftext|>'
print_info: LF token              = 198 'Ċ'
print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
print_info: FIM MID token         = 248061 '<|fim_middle|>'
print_info: FIM PAD token         = 248063 '<|fim_pad|>'
print_info: FIM REP token         = 248064 '<|repo_name|>'
print_info: FIM SEP token         = 248065 '<|file_sep|>'
print_info: EOG token             = 248044 '<|endoftext|>'
print_info: EOG token             = 248046 '<|im_end|>'
print_info: EOG token             = 248063 '<|fim_pad|>'
print_info: EOG token             = 248064 '<|repo_name|>'
print_info: EOG token             = 248065 '<|file_sep|>'
print_info: max token length      = 256
load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
load_tensors: offloading output layer to GPU
load_tensors: offloading 4 repeating layers to GPU
load_tensors: offloaded 6/6 layers to GPU
load_tensors:        CUDA0 model buffer size =  3300.24 MiB
.............................
set_dflash: DFlash extraction enabled for layers [2, 17, 32, 47, 62]
main: DFlash chat template applied


<|im_start|>user
<|im_end|>
<|im_start|>assistant
<think>
sched_reserve: reserving ...
sched_reserve:    Vulkan0 compute buffer size =   417.94 MiB
sched_reserve: Vulkan_Host compute buffer size =    18.14 MiB
sched_reserve: graph nodes  = 3849
sched_reserve: graph splits = 2
sched_reserve: reserve took 10.63 ms, sched copies = 4
Okayllama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 8192
llama_context: n_ctx_seq     = 8192
llama_context: n_batch       = 1792
llama_context: n_ubatch      = 256
llama_context: causal_attn   = 1
llama_context: flash_attn    = enabled
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (8192) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.97 MiB
llama_context: pipeline parallelism enabled
sched_reserve: reserving ...
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:      CUDA0 compute buffer size =   105.00 MiB
sched_reserve:  CUDA_Host compute buffer size =   100.00 MiB
sched_reserve: graph nodes  = 3
sched_reserve: graph splits = 1
sched_reserve: reserve took 53.57 ms, sched copies = 4
llama_init_from_model: DFlash auto-setup: using target model's embedding + lm_head layers
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 8192
llama_context: n_ctx_seq     = 8192
llama_context: n_batch       = 1792
llama_context: n_ubatch      = 256
llama_context: causal_attn   = 1
llama_context: flash_attn    = enabled
llama_context: kv_unified    = false
llama_context: freq_base     = 10000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (8192) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:  CUDA_Host  output buffer size =     0.97 MiB
llama_context: pipeline parallelism enabled
sched_reserve: reserving ...
sched_reserve: resolving fused Gated Delta Net support:
sched_reserve: fused Gated Delta Net (autoregressive) enabled
sched_reserve: fused Gated Delta Net (chunked) enabled
sched_reserve:      CUDA0 compute buffer size =   776.13 MiB
sched_reserve:    Vulkan0 compute buffer size =   262.50 MiB
sched_reserve:  CUDA_Host compute buffer size =   650.13 MiB
sched_reserve: graph nodes  = 180
sched_reserve: graph splits = 3
sched_reserve: reserve took 328.44 ms, sched copies = 4
, the user just sent a message with no content. Let me check if there's something missing or if they need assistance. Maybe they encountered an issue or want to ask a question but forgot to type it. I should respond politely to prompt them for more details. Let me make sure to keep it friendly and open-ended so they feel comfortable providing more information.
</think>

Hello! It seems like your message might have come through empty. How can I assist you today? Feel free to ask a question, share a topic, or let me know if you need help with anything specific! ��

encoded   10 tokens in    0.581 seconds, speed:   17.218 t/s
decoded  123 tokens in    7.577 seconds, speed:   16.234 t/s

n_draft   = 2
n_predict = 123
n_drafted = 61
n_accept  = 61
accept    = 100.000%

draft:


target:

common_perf_print:    sampling time =      31.88 ms
common_perf_print:    samplers time =      26.80 ms /    38 tokens
common_perf_print:        load time =    4482.16 ms
common_perf_print: prompt eval time =    3903.13 ms /   154 tokens (   25.35 ms per token,    39.46 tokens per second)
common_perf_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
common_perf_print:       total time =    9866.22 ms /   155 tokens
common_perf_print: unaccounted time =    5931.20 ms /  60.1 %      (total - sampling - prompt eval - eval) / (total)
common_perf_print:    graphs reused =         71
common_memory_breakdown_print: | memory breakdown [MiB]                    | total    free     self   model   context   compute    unaccounted |
common_memory_breakdown_print: |   - Vulkan0 (AI PRO R9700 (RADV GFX1201)) | 32768 = 12974 + (18483 = 17856 +     209 +     417) +        1310 |
common_memory_breakdown_print: |   - CUDA0 (RTX 4060 Ti)                   | 15949 = 11591 + (    0 =     0 +       0 +       0) +        4358 |
common_memory_breakdown_print: |   - Host                                  |                    851 =   833 +       0 +      18                |
llama_perf_context_print:        load time =    2228.10 ms
llama_perf_context_print: prompt eval time =    1077.04 ms /   122 tokens (    8.83 ms per token,   113.27 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =    9793.20 ms /   123 tokens
llama_perf_context_print:    graphs reused =          0

@SunYong0821
Copy link
Copy Markdown

--dflash

llama.cpp-b8941 does not have this parameter.

@rlex
Copy link
Copy Markdown

rlex commented Apr 27, 2026

llama.cpp-b8941 does not have this parameter.

Because it's not merged yet to master branch?

@SunYong0821
Copy link
Copy Markdown

llama.cpp-b8941 does not have this parameter.

Because it's not merged yet to master branch?

When is it expected to be merged into master?

@Raghuboi
Copy link
Copy Markdown

Raghuboi commented Apr 27, 2026

getting this startup error: /home/raghuboi/llama.cpp/src/llama-context.cpp:2509: GGML_ASSERT(tensor != nullptr && "DFlash extraction tensor is null") failed

tried these models:
Qwen 3 1.7B
Qwen 3.5 2b

meanwhile https://huggingface.co/spiritbuun/Qwen3.6-27B-DFlash-GGUF fails to load on startup: llama_model_load: error loading model: error loading model architecture: unknown model architecture: 'dflash-draft' llama_model_load_from_file_impl: failed to load model srv load_model: failed to load draft model, '/home/raghuboi/Desktop/models/qwen-3.6/dflash-draft-3.6-q8_0.gguf' srv operator(): operator(): cleaning up before exit...

set_dflash: DFlash extraction enabled for layers [0, 0, 0, 0, 0]
srv load_model: DFlash feature extraction enabled on target model
srv load_model: initializing slots, n_slots = 1
sched_reserve: reserving ...
sched_reserve: CUDA0 compute buffer size = 4981.19 MiB
sched_reserve: CUDA1 compute buffer size = 2373.19 MiB
sched_reserve: CUDA2 compute buffer size = 3119.20 MiB
sched_reserve: CUDA_Host compute buffer size = 4137.22 MiB
sched_reserve: graph nodes = 3849
sched_reserve: graph splits = 4
sched_reserve: reserve took 938.14 ms, sched copies = 4
/home/raghuboi/llama.cpp/src/llama-context.cpp:2509: GGML_ASSERT(tensor != nullptr && "DFlash extraction tensor is null") failed

exec "$LLAMA_SERVER"
-m "$MODEL"
-md "/home/raghuboi/Desktop/models/qwen-3.6/Qwen3-1.7B-Q8_0.gguf"
--dflash
--draft-max 16
--alias "qwen3.6"
--host 0.0.0.0 --port 8081
-ngl 99
-np 1
-kvu
--fit on
--fit-target 1024
--fit-ctx 262144
-c 262144
-cd 4096
--split-mode layer
-b 2048 -ub 1024
--main-gpu 1
--mlock
--no-mmap
--flash-attn on
--chat-template-kwargs '{"preserve_thinking":true}'
--cache-type-k q8_0 --cache-type-v q8_0
--cache-ram 38912
--ctx-checkpoints 128
--cont-batching
--jinja
--threads 8
--threads-batch 16
--temp 0.6 --top-p 0.95 --top-k 20 --min-p 0.00 --repeat-penalty 1.0 --presence-penalty 0.0 \

"$LOG" 2>&1

@mbednarek360
Copy link
Copy Markdown

I'm getting the following error trying to run this PR with the Vulkan backend on an R9700, only one token is generated before it crashes:

GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens") failed

Full Log
Apr 27 08:11:21 michael-server llama-server[888649]: main: starting router server, no model will be loaded in this process
Apr 27 08:11:21 michael-server llama-server[888649]: start: binding port with default address family
Apr 27 08:11:21 michael-server llama-server[888649]: main: router server is listening on http://127.0.0.1:8090
Apr 27 08:11:21 michael-server llama-server[888649]: main: NOTE: router mode is experimental
Apr 27 08:11:21 michael-server llama-server[888649]: main:       it is not recommended to use this mode in untrusted environments
Apr 27 08:11:44 michael-server llama-server[888649]: srv  ensure_model: model name=Qwen3.6 is not loaded, loading...
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load: spawning server instance with name=Qwen3.6 on port 39385
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load: spawning server instance with args:
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/bin/llama-server
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --cache-reuse
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   256
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --chat-template-file
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   /nix/store/wg2gx0sf317iisfy17ap6vmkyrk8qd4y-qwen-36.jinja
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --chat-template-kwargs
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   {"preserve_thinking": true}
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --no-context-shift
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --dflash
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --host
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   127.0.0.1
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --jinja
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --keep
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   -1
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --min-p
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   0.0
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --no-mmproj-auto
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --port
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   39385
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --presence-penalty
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   0.0
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --prio
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   3
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --repeat-penalty
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   1.0
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --temperature
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   0.6
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --top-k
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   20
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --top-p
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   0.95
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --alias
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   Qwen3.6
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --ctx-size-draft
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   2048
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --cache-type-k
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   q8_0
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --cache-type-v
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   q8_0
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --flash-attn
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   on
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --fit
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   off
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --hf-repo
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   bartowski/Qwen_Qwen3.6-27B-GGUF:Q6_K_L
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --model-draft
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   /media/Downloads/Qwen3.6-dflash.gguf
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --n-gpu-layers
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   99
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --n-gpu-layers-draft
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   99
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --parallel
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   1
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   --reasoning
Apr 27 08:11:44 michael-server llama-server[888649]: srv          load:   on
Apr 27 08:11:44 michael-server llama-server[888649]: srv  ensure_model: waiting until model name=Qwen3.6 is fully loaded...
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] WARNING: radv is not a conformant Vulkan implementation, testing use only.
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] common_download_file_single_online: HEAD failed, status: 404
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] no remote preset found, skipping
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] build_info: b0-unknown
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] system_info: n_threads = 16 (n_threads_batch = 16) / 32 | CPU : LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 |
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] Running without SSL
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] init: using 31 threads for HTTP server
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] start: binding port with default address family
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] main: loading model
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] srv    load_model: loading model '/var/cache/llama-cpp/models--bartowski--Qwen_Qwen3.6-27B-GGUF/snapshots/f73b625d7ceedbd05d14a93874387cd3bcd673b7/Qwen_Qwen3.6-27B-Q6_K_L.gguf'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_load_from_file_impl: using device Vulkan0 (AMD Radeon AI PRO R9700 (RADV GFX1201)) (0000:0a:00.0) - 32535 MiB free
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: loaded meta data with 45 key-value pairs and 851 tensors from /var/cache/llama-cpp/models--bartowski--Qwen_Qwen3.6-27B-GGUF/snapshots/f73b625d7ceedbd05d14a93874387cd3bcd673b7/Qwen_Qwen3.6-27B-Q6_K_L.gguf (version GGUF V3 (latest))
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   0:                       general.architecture str              = qwen35
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   1:                               general.type str              = model
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   4:                      general.sampling.temp f32              = 1.000000
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   5:                               general.name str              = Qwen3.6 27B
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   6:                           general.basename str              = Qwen3.6
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   7:                         general.size_label str              = 27B
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   8:                            general.license str              = apache-2.0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   9:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.6-2...
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  10:                               general.tags arr[str,1]       = ["image-text-to-text"]
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  11:                         qwen35.block_count u32              = 64
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  12:                      qwen35.context_length u32              = 262144
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  13:                    qwen35.embedding_length u32              = 5120
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  14:                 qwen35.feed_forward_length u32              = 17408
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  15:                qwen35.attention.head_count u32              = 24
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  16:             qwen35.attention.head_count_kv u32              = 4
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  17:             qwen35.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  18:                      qwen35.rope.freq_base f32              = 10000000.000000
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  19:    qwen35.attention.layer_norm_rms_epsilon f32              = 0.000001
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  20:                qwen35.attention.key_length u32              = 256
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  21:              qwen35.attention.value_length u32              = 256
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  22:                     qwen35.ssm.conv_kernel u32              = 4
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  23:                      qwen35.ssm.state_size u32              = 128
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  24:                     qwen35.ssm.group_count u32              = 16
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  25:                  qwen35.ssm.time_step_rank u32              = 48
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  26:                      qwen35.ssm.inner_size u32              = 6144
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  27:             qwen35.full_attention_interval u32              = 4
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  28:                qwen35.rope.dimension_count u32              = 64
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  29:                       tokenizer.ggml.model str              = gpt2
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  30:                         tokenizer.ggml.pre str              = qwen35
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  31:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  32:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  33:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  34:                tokenizer.ggml.eos_token_id u32              = 248046
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  35:            tokenizer.ggml.padding_token_id u32              = 248044
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  36:                tokenizer.ggml.bos_token_id u32              = 248044
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  37:               tokenizer.ggml.add_bos_token bool             = false
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  38:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  39:               general.quantization_version u32              = 2
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  40:                          general.file_type u32              = 18
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  41:                      quantize.imatrix.file str              = /models_out/Qwen3.6-27B-GGUF/Qwen_Qwe...
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  42:                   quantize.imatrix.dataset str              = /training_dir/calibration_datav5.txt
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  43:             quantize.imatrix.entries_count u32              = 496
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  44:              quantize.imatrix.chunks_count u32              = 802
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - type  f32:  449 tensors
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - type q8_0:  122 tensors
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] llama_model_loader: - type q6_K:  280 tensors
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: file format = GGUF V3 (latest)
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: file type   = Q6_K
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: file size   = 22.19 GiB (7.09 BPW)
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load: 0 unused tokens
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load: printing all EOG tokens:
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load:   - 248044 ('<|endoftext|>')
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load:   - 248046 ('<|im_end|>')
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load:   - 248063 ('<|fim_pad|>')
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load:   - 248064 ('<|repo_name|>')
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load:   - 248065 ('<|file_sep|>')
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load: special tokens cache size = 33
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load: token to piece cache size = 1.7581 MB
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: arch                  = qwen35
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: vocab_only            = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: no_alloc              = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_ctx_train           = 262144
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_embd                = 5120
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_embd_inp            = 5120
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_layer               = 64
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_head                = 24
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_head_kv             = 4
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_rot                 = 64
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_swa                 = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: is_swa_any            = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_embd_head_k         = 256
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_embd_head_v         = 256
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_gqa                 = 6
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_embd_k_gqa          = 1024
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_embd_v_gqa          = 1024
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: f_norm_eps            = 0.0e+00
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: f_norm_rms_eps        = 1.0e-06
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: f_clamp_kqv           = 0.0e+00
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: f_max_alibi_bias      = 0.0e+00
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: f_logit_scale         = 0.0e+00
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: f_attn_scale          = 0.0e+00
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_ff                  = 17408
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_expert              = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_expert_used         = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_expert_groups       = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_group_used          = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: causal attn           = 1
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: pooling type          = -1
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: rope type             = 40
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: rope scaling          = linear
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: freq_base_train       = 10000000.0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: freq_scale_train      = 1
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_ctx_orig_yarn       = 262144
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: rope_yarn_log_mul     = 0.0000
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: rope_finetuned        = unknown
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: mrope sections        = [11, 11, 10, 0]
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: ssm_d_conv            = 4
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: ssm_d_inner           = 6144
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: ssm_d_state           = 128
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: ssm_dt_rank           = 48
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: ssm_n_group           = 16
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: ssm_dt_b_c_rms        = 0
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: model type            = 27B
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: model params          = 26.90 B
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: general.name          = Qwen3.6 27B
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: vocab type            = BPE
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_vocab               = 248320
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: n_merges              = 247587
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: BOS token             = 248044 '<|endoftext|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: EOS token             = 248046 '<|im_end|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: EOT token             = 248046 '<|im_end|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: PAD token             = 248044 '<|endoftext|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: LF token              = 198 'Ċ'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: FIM MID token         = 248061 '<|fim_middle|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: FIM PAD token         = 248063 '<|fim_pad|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: FIM REP token         = 248064 '<|repo_name|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: FIM SEP token         = 248065 '<|file_sep|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248044 '<|endoftext|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248046 '<|im_end|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248063 '<|fim_pad|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248064 '<|repo_name|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248065 '<|file_sep|>'
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] print_info: max token length      = 256
Apr 27 08:11:44 michael-server llama-server[888649]: [39385] load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
Apr 27 08:11:59 michael-server llama-server[888649]: [39385] load_tensors: offloading output layer to GPU
Apr 27 08:11:59 michael-server llama-server[888649]: [39385] load_tensors: offloading 63 repeating layers to GPU
Apr 27 08:11:59 michael-server llama-server[888649]: [39385] load_tensors: offloaded 65/65 layers to GPU
Apr 27 08:11:59 michael-server llama-server[888649]: [39385] load_tensors:   CPU_Mapped model buffer size =  1288.28 MiB
Apr 27 08:11:59 michael-server llama-server[888649]: [39385] load_tensors:      Vulkan0 model buffer size = 21436.81 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] ...........................................................................................
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] common_init_result: added <|endoftext|> logit bias = -inf
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] common_init_result: added <|im_end|> logit bias = -inf
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] common_init_result: added <|fim_pad|> logit bias = -inf
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] common_init_result: added <|repo_name|> logit bias = -inf
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] common_init_result: added <|file_sep|> logit bias = -inf
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: constructing llama_context
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: n_seq_max     = 1
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: n_ctx         = 262144
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: n_ctx_seq     = 262144
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: n_batch       = 2048
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: n_ubatch      = 512
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: causal_attn   = 1
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: flash_attn    = enabled
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: kv_unified    = false
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: freq_base     = 10000000.0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: freq_scale    = 1
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_context: Vulkan_Host  output buffer size =     0.95 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_kv_cache:    Vulkan0 KV buffer size =  8704.00 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_kv_cache: size = 8704.00 MiB (262144 cells,  16 layers,  1/1 seqs), K (q8_0): 4352.00 MiB, V (q8_0): 4352.00 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_kv_cache: attn_rot_k = 1, n_embd_head_k_all = 256
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_kv_cache: attn_rot_v = 1, n_embd_head_k_all = 256
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_memory_recurrent:    Vulkan0 RS buffer size =   149.62 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_memory_recurrent: size =  149.62 MiB (     1 cells,  64 layers,  1 seqs), R (f32):    5.62 MiB, S (f32):  144.00 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: reserving ...
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: resolving fused Gated Delta Net support:
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: fused Gated Delta Net (autoregressive) enabled
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: fused Gated Delta Net (chunked) enabled
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve:    Vulkan0 compute buffer size =   840.28 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: Vulkan_Host compute buffer size =   532.29 MiB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: graph nodes  = 3849
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: graph splits = 2
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] sched_reserve: reserve took 60.19 ms, sched copies = 1
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] srv    load_model: loading draft model '/media/Downloads/Qwen3.6-dflash.gguf'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_load_from_file_impl: using device Vulkan0 (AMD Radeon AI PRO R9700 (RADV GFX1201)) (0000:0a:00.0) - 1388 MiB free
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: loaded meta data with 31 key-value pairs and 58 tensors from /media/Downloads/Qwen3.6-dflash.gguf (version GGUF V3 (latest))
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   0:                       general.architecture str              = dflash
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   1:                               general.type str              = model
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   2:                               general.name str              = Qwen3.6 27B DFlash
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   3:                           general.finetune str              = 27b-DFlash
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   4:                           general.basename str              = Qwen3.6
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   5:                         general.size_label str              = 1.7B
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   6:                         dflash.block_count u32              = 5
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   7:                      dflash.context_length u32              = 262144
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   8:                    dflash.embedding_length u32              = 5120
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv   9:                 dflash.feed_forward_length u32              = 17408
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  10:                dflash.attention.head_count u32              = 32
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  11:             dflash.attention.head_count_kv u32              = 8
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  12:                      dflash.rope.freq_base f32              = 10000000.000000
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  13:    dflash.attention.layer_norm_rms_epsilon f32              = 0.000001
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  14:                dflash.attention.key_length u32              = 128
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  15:              dflash.attention.value_length u32              = 128
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  16:                          general.file_type u32              = 7
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  17:                          dflash.block_size u32              = 16
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  18:                    dflash.target_layer_ids arr[i32,5]       = [2, 17, 32, 47, 62]
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  19:                       dflash.mask_token_id u32              = 248070
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  20:               general.quantization_version u32              = 2
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  21:                       tokenizer.ggml.model str              = gpt2
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  22:                         tokenizer.ggml.pre str              = qwen35
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  23:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  24:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  25:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 248046
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  27:            tokenizer.ggml.padding_token_id u32              = 248044
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  28:                tokenizer.ggml.bos_token_id u32              = 248044
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  29:               tokenizer.ggml.add_bos_token bool             = false
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - kv  30:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - type  f32:   22 tensors
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] llama_model_loader: - type q8_0:   36 tensors
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: file format = GGUF V3 (latest)
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: file type   = Q8_0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: file size   = 1.71 GiB (8.50 BPW)
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load_hparams: DFlash extract_layers = [2, 17, 32, 47, 62]
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load_hparams: DFlash block_size = 16, mask_token_id = 248070
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load: 0 unused tokens
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load: printing all EOG tokens:
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load:   - 248044 ('<|endoftext|>')
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load:   - 248046 ('<|im_end|>')
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load:   - 248063 ('<|fim_pad|>')
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load:   - 248064 ('<|repo_name|>')
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load:   - 248065 ('<|file_sep|>')
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load: special tokens cache size = 33
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load: token to piece cache size = 1.7581 MB
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: arch                  = dflash
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: vocab_only            = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: no_alloc              = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_ctx_train           = 262144
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_embd                = 5120
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_embd_inp            = 5120
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_layer               = 5
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_head                = 32
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_head_kv             = 8
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_rot                 = 128
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_swa                 = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: is_swa_any            = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_embd_head_k         = 128
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_embd_head_v         = 128
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_gqa                 = 4
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_embd_k_gqa          = 1024
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_embd_v_gqa          = 1024
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: f_norm_eps            = 0.0e+00
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: f_norm_rms_eps        = 1.0e-06
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: f_clamp_kqv           = 0.0e+00
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: f_max_alibi_bias      = 0.0e+00
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: f_logit_scale         = 0.0e+00
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: f_attn_scale          = 0.0e+00
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_ff                  = 17408
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_expert              = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_expert_used         = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_expert_groups       = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_group_used          = 0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: causal attn           = 1
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: pooling type          = -1
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: rope type             = 2
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: rope scaling          = linear
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: freq_base_train       = 10000000.0
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: freq_scale_train      = 1
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_ctx_orig_yarn       = 262144
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: rope_yarn_log_mul     = 0.0000
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: rope_finetuned        = unknown
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: model type            = ?B
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: model params          = 1.73 B
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: general.name          = Qwen3.6 27B DFlash
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: vocab type            = BPE
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_vocab               = 248320
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: n_merges              = 247587
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: BOS token             = 248044 '<|endoftext|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: EOS token             = 248046 '<|im_end|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: EOT token             = 248046 '<|im_end|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: PAD token             = 248044 '<|endoftext|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: LF token              = 198 'Ċ'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: FIM PRE token         = 248060 '<|fim_prefix|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: FIM SUF token         = 248062 '<|fim_suffix|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: FIM MID token         = 248061 '<|fim_middle|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: FIM PAD token         = 248063 '<|fim_pad|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: FIM REP token         = 248064 '<|repo_name|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: FIM SEP token         = 248065 '<|file_sep|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248044 '<|endoftext|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248046 '<|im_end|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248063 '<|fim_pad|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248064 '<|repo_name|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: EOG token             = 248065 '<|file_sep|>'
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] print_info: max token length      = 256
Apr 27 08:12:04 michael-server llama-server[888649]: [39385] load_tensors: loading model tensors, this can take a while... (mmap = true, direct_io = false)
Apr 27 08:12:08 michael-server llama-server[888649]: [39385] load_tensors: offloading output layer to GPU
Apr 27 08:12:08 michael-server llama-server[888649]: [39385] load_tensors: offloading 4 repeating layers to GPU
Apr 27 08:12:08 michael-server llama-server[888649]: [39385] load_tensors: offloaded 6/6 layers to GPU
Apr 27 08:12:08 michael-server llama-server[888649]: [39385] load_tensors:      Vulkan0 model buffer size =  1753.36 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] .............................
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] set_dflash: DFlash extraction enabled for layers [2, 17, 32, 47, 62]
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    load_model: DFlash feature extraction enabled on target model
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    load_model: cache_reuse is not supported by this context, it will be disabled
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    load_model: initializing slots, n_slots = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: reserving ...
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve:    Vulkan0 compute buffer size =  1335.28 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: Vulkan_Host compute buffer size =   532.29 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: graph nodes  = 3849
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: graph splits = 2
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: reserve took 121.81 ms, sched copies = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] common_context_can_seq_rm: the target context does not support partial sequence removal
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    load_model: speculative decoding will use checkpoints
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: constructing llama_context
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_seq_max     = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ctx         = 2048
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ctx_seq     = 2048
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_batch       = 2048
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ubatch      = 512
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: causal_attn   = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: flash_attn    = enabled
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: kv_unified    = false
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: freq_base     = 10000000.0
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: freq_scale    = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ctx_seq (2048) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: Vulkan_Host  output buffer size =     0.97 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: reserving ...
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: resolving fused Gated Delta Net support:
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: fused Gated Delta Net (autoregressive) enabled
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: fused Gated Delta Net (chunked) enabled
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve:    Vulkan0 compute buffer size =    60.00 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: Vulkan_Host compute buffer size =    50.00 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: graph nodes  = 3
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: graph splits = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: reserve took 17.69 ms, sched copies = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_init_from_model: DFlash auto-setup: using target model's embedding + lm_head layers
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: constructing llama_context
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_seq_max     = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ctx         = 2048
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ctx_seq     = 2048
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_batch       = 2048
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ubatch      = 512
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: causal_attn   = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: flash_attn    = enabled
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: kv_unified    = false
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: freq_base     = 10000000.0
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: freq_scale    = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: n_ctx_seq (2048) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] llama_context: Vulkan_Host  output buffer size =     0.97 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: reserving ...
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: resolving fused Gated Delta Net support:
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: fused Gated Delta Net (autoregressive) enabled
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: fused Gated Delta Net (chunked) enabled
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve:    Vulkan0 compute buffer size =   495.00 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: Vulkan_Host compute buffer size =    60.02 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: graph nodes  = 180
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: graph splits = 2
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] sched_reserve: reserve took 6.90 ms, sched copies = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot   load_model: id  0 | task -1 | speculative decoding context initialized
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot   load_model: id  0 | task -1 | new slot, n_ctx = 262144
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    load_model: prompt cache is enabled, size limit: 8192 MiB
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    load_model: use `--cache-ram 0` to disable the prompt cache
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    load_model: for more info see https://github.com/ggml-org/llama.cpp/pull/16391
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv          init: init: --cache-idle-slots requires --kv-unified, disabling
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] init: chat template, example_format: '<|im_start|>system
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] You are a helpful assistant<|im_end|>
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] <|im_start|>user
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] Hello<|im_end|>
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] <|im_start|>assistant
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] <think>
Apr 27 08:12:09 michael-server llama-server[888649]: [39385]
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] </think>
Apr 27 08:12:09 michael-server llama-server[888649]: [39385]
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] Hi there<|im_end|>
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] <|im_start|>user
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] How are you?<|im_end|>
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] <|im_start|>assistant
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] <think>
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] '
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv          init: init: chat template, thinking = 1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] main: model loaded
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] main: server is listening on http://127.0.0.1:39385
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] main: starting the main loop...
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] cmd_child_to_router:ready
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv    operator(): child server monitoring thread started, waiting for EOF on stdin...
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv  update_slots: all slots are idle
Apr 27 08:12:09 michael-server llama-server[888649]: srv  proxy_reques: proxying request to model Qwen3.6 on port 39385
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv  params_from_: Chat format: peg-native
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot get_availabl: id  0 | task -1 | selected slot by LRU, t_last = -1
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv  get_availabl: updating prompt cache
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv          load:  - looking for better prompt, base f_keep = -1.000, sim = 0.000
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv        update:  - cache state: 0 prompts, 0.000 MiB (limits: 8192.000 MiB, 262144 tokens, 8589934592 est)
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] srv  get_availabl: prompt cache update took 0.00 ms
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot launch_slot_: id  0 | task -1 | sampler chain: logits -> ?penalties -> ?dry -> ?top-n-sigma -> top-k -> ?typical -> top-p -> ?min-p -> ?xtc -> temp-ext -> dist
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot launch_slot_: id  0 | task 0 | processing task, is_child = 0
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 262144, n_keep = -1, task.n_tokens = 975
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot update_slots: id  0 | task 0 | n_tokens = 0, memory_seq_rm [0, end)
Apr 27 08:12:09 michael-server llama-server[888649]: [39385] slot update_slots: id  0 | task 0 | prompt processing progress, n_tokens = 459, batch.n_tokens = 459, progress = 0.470769
Apr 27 08:12:11 michael-server llama-server[888649]: [39385] slot update_slots: id  0 | task 0 | n_tokens = 459, memory_seq_rm [459, end)
Apr 27 08:12:11 michael-server llama-server[888649]: [39385] slot update_slots: id  0 | task 0 | prompt processing progress, n_tokens = 971, batch.n_tokens = 512, progress = 0.995897
Apr 27 08:12:11 michael-server llama-server[888649]: [39385] slot create_check: id  0 | task 0 | created context checkpoint 1 of 32 (pos_min = 458, pos_max = 458, n_tokens = 459, size = 149.626 MiB)
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] slot update_slots: id  0 | task 0 | n_tokens = 971, memory_seq_rm [971, end)
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] reasoning-budget: activated, budget=2147483647 tokens
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] slot init_sampler: id  0 | task 0 | init sampler, took 0.12 ms, tokens: text = 975, total = 975
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] slot update_slots: id  0 | task 0 | prompt processing done, n_tokens = 975, batch.n_tokens = 4
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] slot create_check: id  0 | task 0 | created context checkpoint 2 of 32 (pos_min = 970, pos_max = 970, n_tokens = 971, size = 149.626 MiB)
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] srv  log_server_r: done request: POST /v1/chat/completions 127.0.0.1 200
Apr 27 08:12:14 michael-server llama-server[888649]: srv  log_server_r: done request: POST /v1/chat/completions 127.0.0.1 200
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /build/source/src/llama-context.cpp:1393: GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens") failed
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/lib/libggml-base.so.0(+0x1979a) [0x7ffbc169879a]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/lib/libggml-base.so.0(ggml_print_backtrace+0x204) [0x7ffbc1698c64]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/lib/libggml-base.so.0(ggml_abort+0x159) [0x7ffbc1698e39]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/lib/libllama.so.0(_ZN13llama_context6encodeERK11llama_batch+0x1089) [0x7ffbc54f9459]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/lib/libllama.so.0(llama_encode+0x11) [0x7ffbc54f9481]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/lib/libllama-common.so.0(_ZN31common_speculative_state_dflash5draftERK25common_params_speculativeRKSt6vectorIiSaIiEEiRS5_+0xff) [0x7ffbc5af92bf]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/lib/libllama-common.so.0(_Z24common_speculative_draftP18common_speculativeRK25common_params_speculativeRKSt6vectorIiSaIiEEi+0xa4) [0x7ffbc5af1894]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/bin/llama-server(+0x1010ce) [0x564c8aff70ce]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/bin/llama-server(+0x10be68) [0x564c8b001e68]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/bin/llama-server(+0x1a8312) [0x564c8b09e312]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/bin/llama-server(+0x6e5e2) [0x564c8af645e2]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/jms7zxzm7w1whczwny5m3gkgdjghmi2r-glibc-2.42-51/lib/libc.so.6(+0x2b285) [0x7ffbc095d285]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/jms7zxzm7w1whczwny5m3gkgdjghmi2r-glibc-2.42-51/lib/libc.so.6(__libc_start_main+0x88) [0x7ffbc095d338]
Apr 27 08:12:14 michael-server llama-server[888649]: [39385] /nix/store/vh2367dlgknxjq9bgivlx314fxz91aw5-llama-cpp-vulkan-0.0.0/bin/llama-server(+0x6ecb5) [0x564c8af64cb5] 

@mbednarek360
Copy link
Copy Markdown

getting this startup error: /home/raghuboi/llama.cpp/src/llama-context.cpp:2509: GGML_ASSERT(tensor != nullptr && "DFlash extraction tensor is null") failed

tried these models: Qwen 3 1.7B Qwen 3.5 2b

meanwhile https://huggingface.co/spiritbuun/Qwen3.6-27B-DFlash-GGUF fails to load on startup: llama_model_load: error loading model: error loading model architecture: unknown model architecture: 'dflash-draft' llama_model_load_from_file_impl: failed to load model srv load_model: failed to load draft model, '/home/raghuboi/Desktop/models/qwen-3.6/dflash-draft-3.6-q8_0.gguf' srv operator(): operator(): cleaning up before exit...

set_dflash: DFlash extraction enabled for layers [0, 0, 0, 0, 0] srv load_model: DFlash feature extraction enabled on target model srv load_model: initializing slots, n_slots = 1 sched_reserve: reserving ... sched_reserve: CUDA0 compute buffer size = 4981.19 MiB sched_reserve: CUDA1 compute buffer size = 2373.19 MiB sched_reserve: CUDA2 compute buffer size = 3119.20 MiB sched_reserve: CUDA_Host compute buffer size = 4137.22 MiB sched_reserve: graph nodes = 3849 sched_reserve: graph splits = 4 sched_reserve: reserve took 938.14 ms, sched copies = 4 /home/raghuboi/llama.cpp/src/llama-context.cpp:2509: GGML_ASSERT(tensor != nullptr && "DFlash extraction tensor is null") failed

exec "$LLAMA_SERVER" -m "$MODEL" -md "/home/raghuboi/Desktop/models/qwen-3.6/Qwen3-1.7B-Q8_0.gguf" --dflash --draft-max 16 --alias "qwen3.6" --host 0.0.0.0 --port 8081 -ngl 99 -np 1 -kvu --fit on --fit-target 1024 --fit-ctx 262144 -c 262144 -cd 4096 --split-mode layer -b 2048 -ub 1024 --main-gpu 1 --mlock --no-mmap --flash-attn on --chat-template-kwargs '{"preserve_thinking":true}' --cache-type-k q8_0 --cache-type-v q8_0 --cache-ram 38912 --ctx-checkpoints 128 --cont-batching --jinja --threads 8 --threads-batch 16 --temp 0.6 --top-p 0.95 --top-k 20 --min-p 0.00 --repeat-penalty 1.0 --presence-penalty 0.0 \

"$LOG" 2>&1

The DFlash GGUF you referenced is meant for another fork of llama.cpp, not this PR.
See: https://github.com/spiritbuun/buun-llama-cpp

@ruixiang63
Copy link
Copy Markdown
Author

ruixiang63 commented Apr 27, 2026

My plan of next steps for this PR:

This PR currently supports llama-cli and llama-server only with n_parallel = 1 (multi-slot batching not yet implemented). Full functional support will land after the EAGLE3 PR is merged and the unified speculative API (EAGLE3, DFlash, MTP etc.) is finalized. (This PR only supports DFlash GGUF models converted with the converter included in this PR. )
I'll then rebase this PR onto that API and follow up with further performance optimizations and n_parallel > 1 support.

Current working commands for llama-cli and llama-server, e.g.

# llama-cli
./build/bin/llama-cli \
  -m "${TARGET_MODEL_GGUF}" \
  -md "${DFLASH_MODEL_GGUF}" \
  --dflash -p "Write a quicksort algorithm in Python. Write code only." -n 256 --draft-max 16 \
  -cd 512 -c 512 \
  --temp 0 --top-k 1 --seed 42 -ngl 99 -ngld 99 \
  --jinja -rea off

# llama-server
./build/bin/llama-server \
  -m "${TARGET_MODEL_GGUF}" \
  -md "${DFLASH_MODEL_GGUF}" \
  --dflash --draft-max 16 \
  -c 2048 -cd 512 \
  --temp 0 --top-k 1 --seed 42 \
  -ngl 99 -ngld 99 \
  --jinja -rea off \
  -np 1 \
  --host 0.0.0.0 --port 8088

@HH1162
Copy link
Copy Markdown

HH1162 commented Apr 27, 2026

Why isn’t there any speedup after enabling the dfloat parameter on this branch?Meanwhile, performance drops significantly when I switch to the official parameters. T_T
Anyone able to help me out?

200 tokens/s as normal:
/home/xxxx/llar/bin/llama-server
-m "/home/xxxx/models/Qwen3.6-35B-A3B-Q6_K.gguf"
-md "/home/xxxx/models/Qwen3.6-35B-A3B-DFlash-q8_0.gguf"
--dflash
--draft-max 16
--draft-p-min 0.9
--ctx-size 16384
--n-gpu-layers 99
-ngld 99
--host 127.0.0.1
--port 1234
-fa on
-ctk q4_0
-ctv q4_0
--verbose
--samplers "dry;top_k;typ_p;top_p;min_p;xtc;temperature"
--repeat-penalty 1.1
--temp 0
--top-k 1
--no-mmap
--mlock
--kv-unified
--parallel 1
--sleep-idle-seconds -1
--verbose
--batch-size 8192
--ubatch-size 2048
-n 8192
--threads 14
--threads-batch 28
--reasoning-format deepseek
--reasoning-budget 1024
--prio 2
--jinja

fallback to 40 tokens:
/home/xxxx/llar/bin/llama-server
-m "/home/xxxx/models/Qwen3.6-35B-A3B-Q6_K.gguf"
-md "/home/xxxx/models/Qwen3.6-35B-A3B-DFlash-q8_0.gguf"
-ngl 99
-ngld 99
--dflash
--draft-max 16
-c 2048
-cd 512
--temp 0
--top-k 1
--seed 42
--jinja -rea off
-np 1
--host 127.0.0.1
--port 1234

@aminya
Copy link
Copy Markdown

aminya commented Apr 28, 2026

I've tried many different patches and configurations over the weekend for my single 3090 setup. There's no benefit in Dflash I can see. I cannot reproduce any of the claimed speed ups in real workflows with Qwen 27B or Qwen 35B.

Originally posted by @aminya in TheTom#103 (comment)

@ukrospm
Copy link
Copy Markdown

ukrospm commented Apr 28, 2026

For me crashing after generating 1 token.

set CUDA_VISIBLE_DEVICES=0,1 a:\0_llama_server_d\build\bin\Release\llama-server ^ -m a:\0_LM_Studio\lmstudio-community\Qwen3.6-27B-GGUF\Qwen3.6-27B-Q4_K_M.gguf ^ -md a:\0_LM_Studio\lym00\Qwen3.6-27B-DFlash-bf16.gguf ^ --host 0.0.0.0 ^ --port 8088 ^ -ngl 99 ^ -ngld 99 ^ --dflash ^ --draft-max 16 ^ -c 2048 ^ -cd 512 ^ --temp 0 ^ --top-k 1 ^ --seed 42 ^ --jinja -rea off ^ -np 1 --verbose

�[0mdone_getting_tensors: tensor 'token_embd.weight' (q4_K) (and 0 others) cannot be used with preferred buffer type CUDA_Host, using CPU instead
�[0m
.
.
.
�[0msrv update_slots: decoding batch, n_tokens = 4
�[0mset_adapters_lora: adapters = 0000000000000000
�[0madapters_lora_are_same: adapters = 0000000000000000
�[0mset_embeddings: value = 0
�[0msrv operator (): http: streamed chunk: data: {"choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":null}}],"created":1777351446,"id":"chatcmpl-DbepysCBNfjQ2DyiF5glro4x77VGN159","model":"Qwen3.6-27B-Q4_K_M.gguf","system_fingerprint":"b0-unknown","object":"chat.completion.chunk","timings":{"cache_n":0,"prompt_n":9,"prompt_ms":0.0,"prompt_per_token_ms":0.0,"prompt_per_second":null,"predicted_n":0,"predicted_ms":0.0,"predicted_per_token_ms":null,"predicted_per_second":null},"prompt_progress":{"total":13,"cache":0,"processed":9,"time_ms":656}}

�[0mextract_dflash_features: Start to extract DFlash features: 5 layers, 4 tokens, 5120 embd
�[0mres send: sending result for task id = 0
�[0mres send: task id = 0 pushed to result queue
�[0msrv operator (): http: streamed chunk: data: {"choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":null}}],"created":1777351446,"id":"chatcmpl-DbepysCBNfjQ2DyiF5glro4x77VGN159","model":"Qwen3.6-27B-Q4_K_M.gguf","system_fingerprint":"b0-unknown","object":"chat.completion.chunk","timings":{"cache_n":0,"prompt_n":13,"prompt_ms":0.0,"prompt_per_token_ms":0.0,"prompt_per_second":null,"predicted_n":0,"predicted_ms":0.0,"predicted_per_token_ms":null,"predicted_per_second":null},"prompt_progress":{"total":13,"cache":0,"processed":13,"time_ms":728}}

�[0mres send: sending result for task id = 0
�[0mres send: task id = 0 pushed to result queue
�[0mslot process_toke: id 0 | task 0 | n_decoded = 1, n_remaining = -1, next token: 9419 'Hello'
�[0msrv update_slots: run slots completed
�[0mque start_loop: waiting for new tasks
�[0mque start_loop: processing new tasks
�[0mque start_loop: processing task, id = 2
�[0mque start_loop: update slots
�[0msrv update_slots: posting NEXT_RESPONSE
�[0mque post: new task, id = 3, front = 0
�[0mslot get_n_draft_: id 0 | task 0 | max possible draft: 16
�[0msrv operator (): http: streamed chunk: data: {"choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":null}}],"created":1777351446,"id":"chatcmpl-DbepysCBNfjQ2DyiF5glro4x77VGN159","model":"Qwen3.6-27B-Q4_K_M.gguf","system_fingerprint":"b0-unknown","object":"chat.completion.chunk"}

data: {"choices":[{"finish_reason":null,"index":0,"delta":{"content":"Hello"}}],"created":1777351446,"id":"chatcmpl-DbepysCBNfjQ2DyiF5glro4x77VGN159","model":"Qwen3.6-27B-Q4_K_M.gguf","system_fingerprint":"b0-unknown","object":"chat.completion.chunk","timings":{"cache_n":0,"prompt_n":13,"prompt_ms":731.354,"prompt_per_token_ms":56.258,"prompt_per_second":17.77524974225888,"predicted_n":1,"predicted_ms":0.001,"predicted_per_token_ms":0.001,"predicted_per_second":1000000.0}}

�[0m

@ruixiang63
Copy link
Copy Markdown
Author

I've tried many different patches and configurations over the weekend for my single 3090 setup. There's no benefit in Dflash I can see. I cannot reproduce any of the claimed speed ups in real workflows with Qwen 27B or Qwen 35B.

Did you follow the steps outlined in the PR description? Did you try the model mentioned there to check for speedups?

Please note that this is still a draft PR, and many parts will need to be refactored after the upstream changes. As mentioned in the PR description, the current performance is not yet optimal. However, you should still see speedups if you follow the correct steps in the PR description and use the correct models.

@ruixiang63
Copy link
Copy Markdown
Author

Thanks everyone for reporting the issues. I really appreciate your efforts in trying out this PR!

The most common issue I’ve seen so far is the use of GGUF models that were not converted with this PR.
Please note that this is still a draft PR, and there will be many refactoring changes in the near future. The current llama-cli and llama-server also have limited support for this PR, and the performance is not fully optimized yet.

If you run into any issues, the safest fallback is to use a GGUF model converted with this PR, run it with llama-speculative-simple, following the steps in the PR description.

Once the upstream unified API has been finalized, this PR will continue moving forward and be polished further. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples model Model specific python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.