Skip to content

Commit 2b263e6

Browse files
committed
Merge branch 'mamba2-sync' into GraniteFour
* mamba2-sync: (22 commits) recurrent : call balloc split_reset() in init_batch() (ggml-org#14414) ggml : add ggml_set_rows (ggml-org#14274) convert : fix broken sentencepiece vocab (ggml-org#14416) mamba : fix mismatched new and delete size for llm_build_mamba model : gemma3n text-only (ggml-org#14400) cmake: regen vulkan shaders when shaders-gen sources change (ggml-org#14398) llama : return mistral-v7-tekken as default template only (ggml-org#14390) metal : add special-case mat-vec mul for ne00 == 4 (ggml-org#14385) metal : batch rows copy in a single threadgroup (ggml-org#14384) docs: update s390x documentation + add faq (ggml-org#14389) musa: enable fp16 mma (all) and cublas on qy2 (ggml-org#13842) ggml-cpu: enable IBM NNPA Vector Intrinsics (ggml-org#14317) ggml : do not output unprintable characters on GGUF load failure (ggml-org#14381) sycl: GGML_SYCL_DISABLE_OPT on by default for all Intel Devices (ggml-org#13973) opencl: ref count `ggml_backend_opencl_context` and refactor profiling (ggml-org#14254) batch : fix check for empty sequences in memory (ggml-org#14364) cmake : use LLAMA_BUILD_NUMBER when defining LLAMA_INSTALL_VERSION (ggml-org#14362) server : move no API key doc to /health (ggml-org#14352) main : honor --verbose-prompt on interactive prompts (ggml-org#14350) jinja : Add Mistral-Small-3.2-24B-Instruct-2506.jinja (ggml-org#14349) ...
2 parents a9dcc84 + fdc9a8d commit 2b263e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3440
-1747
lines changed

.github/workflows/build-cmake-pkg.yml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
name: Build relocatable cmake package
2+
on:
3+
workflow_dispatch:
4+
workflow_call:
5+
6+
jobs:
7+
linux:
8+
runs-on: ubuntu-24.04
9+
steps:
10+
- uses: actions/checkout@v4
11+
with:
12+
fetch-depth: 0
13+
14+
- name: Install dependencies
15+
run: |
16+
sudo apt update
17+
sudo apt install -y build-essential tcl
18+
19+
- name: Build
20+
run: |
21+
PREFIX="$(pwd)"/inst
22+
cmake -S . -B build -DCMAKE_PREFIX_PATH="$PREFIX" \
23+
-DLLAMA_CURL=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_TOOLS=OFF \
24+
-DLLAMA_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=Release
25+
cmake --build build --config Release
26+
cmake --install build --prefix "$PREFIX" --config Release
27+
28+
export LLAMA_CONFIG="$PREFIX"/lib/cmake/llama/llama-config.cmake
29+
tclsh <<'EOF'
30+
set build(commit) [string trim [exec git rev-parse --short HEAD]]
31+
set build(number) [string trim [exec git rev-list --count HEAD]]
32+
set build(version) "0.0.$build(number)"
33+
34+
set llamaconfig [read [open "$env(LLAMA_CONFIG)" r]]
35+
set checks [list "set\\(LLAMA_VERSION \\s+$build(version)\\)" \
36+
"set\\(LLAMA_BUILD_COMMIT\\s+$build(commit)\\)" \
37+
"set\\(LLAMA_BUILD_NUMBER\\s+$build(number)\\)"]
38+
39+
puts -nonewline "Checking llama-config.cmake version... "
40+
foreach check $checks {
41+
if {![regexp -expanded -- $check $llamaconfig]} {
42+
puts "\"$check\" failed!"
43+
exit 1
44+
}
45+
}
46+
puts "success."
47+
EOF
48+
49+
cd examples/simple-cmake-pkg
50+
cmake -S . -B build -DCMAKE_PREFIX_PATH="$PREFIX"/lib/cmake
51+
cmake --build build

.github/workflows/build.yml

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,43 @@ on:
55
push:
66
branches:
77
- master
8-
paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp']
8+
paths: [
9+
'.github/workflows/build.yml',
10+
'.github/workflows/build-linux-cross.yml',
11+
'.github/workflows/build-cmake-pkg.yml',
12+
'**/CMakeLists.txt',
13+
'**/.cmake',
14+
'**/*.h',
15+
'**/*.hpp',
16+
'**/*.c',
17+
'**/*.cpp',
18+
'**/*.cu',
19+
'**/*.cuh',
20+
'**/*.swift',
21+
'**/*.m',
22+
'**/*.metal',
23+
'**/*.comp'
24+
]
25+
926
pull_request:
1027
types: [opened, synchronize, reopened]
11-
paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp']
28+
paths: [
29+
'.github/workflows/build.yml',
30+
'.github/workflows/build-linux-cross.yml',
31+
'.github/workflows/build-cmake-pkg.yml',
32+
'**/CMakeLists.txt',
33+
'**/.cmake',
34+
'**/*.h',
35+
'**/*.hpp',
36+
'**/*.c',
37+
'**/*.cpp',
38+
'**/*.cu',
39+
'**/*.cuh',
40+
'**/*.swift',
41+
'**/*.m',
42+
'**/*.metal',
43+
'**/*.comp'
44+
]
1245

1346
concurrency:
1447
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@@ -478,6 +511,9 @@ jobs:
478511
build-linux-cross:
479512
uses: ./.github/workflows/build-linux-cross.yml
480513

514+
build-cmake-pkg:
515+
uses: ./.github/workflows/build-cmake-pkg.yml
516+
481517
macOS-latest-cmake-ios:
482518
runs-on: macos-latest
483519

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ endif()
9595
if (NOT DEFINED LLAMA_BUILD_COMMIT)
9696
set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
9797
endif()
98-
set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER})
98+
set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER})
9999

100100
# override ggml options
101101
set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})

convert_hf_to_gguf.py

Lines changed: 117 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def prepare_tensors(self):
310310
gguf.MODEL_TENSOR.POSNET_NORM2,
311311
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
312312
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
313+
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
314+
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
313315
)
314316
)
315317
or not new_name.endswith(".weight")
@@ -320,7 +322,11 @@ def prepare_tensors(self):
320322
self.match_model_tensor_name(new_name, key, bid)
321323
for key in (
322324
gguf.MODEL_TENSOR.TOKEN_EMBD,
325+
gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
323326
gguf.MODEL_TENSOR.OUTPUT,
327+
gguf.MODEL_TENSOR.ALTUP_ROUTER,
328+
gguf.MODEL_TENSOR.LAUREL_L,
329+
gguf.MODEL_TENSOR.LAUREL_R,
324330
)
325331
):
326332
if self.ftype in (
@@ -921,13 +927,20 @@ def _create_vocab_sentencepiece(self):
921927
tokenizer = SentencePieceProcessor()
922928
tokenizer.LoadFromFile(str(tokenizer_path))
923929

924-
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
930+
vocab_size = self.find_hparam([
931+
"vocab_size_per_layer_input", # gemma3n
932+
"vocab_size",
933+
], optional=True) or tokenizer.vocab_size()
925934

926935
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
927936
scores: list[float] = [-10000.0] * vocab_size
928937
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
929938

930939
for token_id in range(tokenizer.vocab_size()):
940+
if token_id >= vocab_size:
941+
logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
942+
break
943+
931944
piece = tokenizer.IdToPiece(token_id)
932945
text = piece.encode("utf-8")
933946
score = tokenizer.GetScore(token_id)
@@ -4217,6 +4230,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42174230
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
42184231
class Gemma3Model(TextModel):
42194232
model_arch = gguf.MODEL_ARCH.GEMMA3
4233+
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
42204234

42214235
def set_vocab(self):
42224236
self._set_vocab_sentencepiece()
@@ -4238,9 +4252,8 @@ def set_gguf_parameters(self):
42384252
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
42394253
self.gguf_writer.add_file_type(self.ftype)
42404254
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
4241-
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
4255+
# attn_logit_softcapping is removed in Gemma3
42424256
assert hparams.get("attn_logit_softcapping") is None
4243-
assert hparams.get("final_logit_softcapping") is None
42444257
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
42454258
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
42464259
if hparams.get("rope_scaling") is not None:
@@ -4252,7 +4265,7 @@ def set_gguf_parameters(self):
42524265
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
42534266
del bid # unused
42544267

4255-
if name.startswith("language_model."):
4268+
if "language_model." in name:
42564269
name = name.replace("language_model.", "")
42574270

42584271
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
@@ -4267,8 +4280,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42674280

42684281
# ref code in Gemma3RMSNorm
42694282
# output = output * (1.0 + self.weight.float())
4283+
# note: this is not the case on gemma3n
42704284
if name.endswith("norm.weight"):
4271-
data_torch = data_torch + 1
4285+
data_torch = data_torch + self.norm_shift
42724286

42734287
return [(self.map_tensor_name(name), data_torch)]
42744288

@@ -4325,6 +4339,104 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43254339
return [] # skip other tensors
43264340

43274341

4342+
@ModelBase.register("Gemma3nForConditionalGeneration")
4343+
class Gemma3NModel(Gemma3Model):
4344+
model_arch = gguf.MODEL_ARCH.GEMMA3N
4345+
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
4346+
4347+
_altup_proj: list[Tensor] = []
4348+
_altup_unembd: list[Tensor] = []
4349+
4350+
def __init__(self, *args, **kwargs):
4351+
super().__init__(*args, **kwargs)
4352+
assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
4353+
self._altup_proj = [
4354+
torch.Tensor(), # to be replaced
4355+
torch.Tensor(), # to be replaced
4356+
torch.Tensor(), # to be replaced
4357+
]
4358+
self._altup_unembd = [
4359+
torch.Tensor(), # to be replaced
4360+
torch.Tensor(), # to be replaced
4361+
torch.Tensor(), # to be replaced
4362+
]
4363+
4364+
def set_vocab(self):
4365+
with open(self.dir_model / "chat_template.jinja") as f:
4366+
# quick hack to make sure chat template is added
4367+
self.gguf_writer.add_chat_template(f.read())
4368+
super().set_vocab()
4369+
4370+
def set_gguf_parameters(self):
4371+
super().set_gguf_parameters()
4372+
self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
4373+
self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
4374+
self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
4375+
self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])
4376+
4377+
activation_sparsity_scale = []
4378+
for s in self.hparams["activation_sparsity_pattern"]:
4379+
normal_dist = torch.distributions.normal.Normal(0, 1)
4380+
std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
4381+
activation_sparsity_scale.append(std_multiplier.item())
4382+
self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)
4383+
4384+
sliding_window_pattern = []
4385+
for t in self.hparams["layer_types"]:
4386+
sliding_window_pattern.append(t == "sliding_attention")
4387+
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
4388+
4389+
def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
4390+
has_all = all(m.numel() > 0 for m in matrices)
4391+
if not has_all:
4392+
return None
4393+
else:
4394+
return torch.stack(matrices, dim=0)
4395+
4396+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4397+
if name.endswith("_scale"):
4398+
name = name + ".weight"
4399+
4400+
# TODO: implement self.prediction_coefs.weight.clamp_(...)
4401+
4402+
if "language_model." not in name:
4403+
return [] # skip non-language model tensors
4404+
4405+
if "altup_unembed_projections" in name:
4406+
data_torch = data_torch.to(device="cpu")
4407+
if ".0." in name:
4408+
self._altup_unembd[0] = data_torch
4409+
elif ".1." in name:
4410+
self._altup_unembd[1] = data_torch
4411+
elif ".2." in name:
4412+
self._altup_unembd[2] = data_torch
4413+
else:
4414+
raise ValueError(f"Unknown name: {name}")
4415+
out = self._stack_matrices(self._altup_unembd)
4416+
if out is not None:
4417+
return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
4418+
else:
4419+
return []
4420+
4421+
if "altup_projections" in name:
4422+
data_torch = data_torch.to(device="cpu")
4423+
if ".0." in name:
4424+
self._altup_proj[0] = data_torch
4425+
elif ".1." in name:
4426+
self._altup_proj[1] = data_torch
4427+
elif ".2." in name:
4428+
self._altup_proj[2] = data_torch
4429+
else:
4430+
raise ValueError(f"Unknown name: {name}")
4431+
out = self._stack_matrices(self._altup_proj)
4432+
if out is not None:
4433+
return [(self.map_tensor_name("model.altup_projections.weight"), out)]
4434+
else:
4435+
return []
4436+
4437+
return super().modify_tensors(data_torch, name, bid)
4438+
4439+
43284440
@ModelBase.register("Starcoder2ForCausalLM")
43294441
class StarCoder2Model(TextModel):
43304442
model_arch = gguf.MODEL_ARCH.STARCODER2

docs/backend/SYCL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
757757
| Name | Value | Function |
758758
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
759759
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
760-
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
760+
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) |
761761
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
762762
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
763763
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |

0 commit comments

Comments
 (0)