Skip to content

model : add PLaMo-2 model #14560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
271104c
wip: llama : separate recurrent states from the KV cache
compilade Apr 3, 2024
8db1e4d
llama : use std::find for seq_nodes in llama_rs_cache
compilade Apr 4, 2024
0028010
llama : state checkpoints for recurrent models
compilade Apr 8, 2024
0c8b3b2
llama : correctly handle more edge cases for the rs cache
compilade Apr 9, 2024
d66849f
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 10, 2024
a09db95
llama : rename many llama_kv_cache_* functions
compilade Apr 29, 2024
c460ff1
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 29, 2024
b6fafd1
llama : remove useless return value for some llama_cache_* functions
compilade Apr 29, 2024
b7ec12e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 12, 2024
3b57b55
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 22, 2024
7e13f19
llama : rethink recurrent state cell counts
compilade May 24, 2024
cbc743e
llama : support Jamba
compilade May 24, 2024
0fd13e9
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 24, 2024
61a88a1
llama : fix BERT inference without KV cache
compilade May 25, 2024
ea2e63e
convert-hf : check for unprocessed Jamba experts
compilade May 25, 2024
fc59407
convert-hf : support Mini-Jamba conversion
compilade May 25, 2024
181dadf
llama : fix Jamba quantization sanity checks
compilade May 28, 2024
3a414b0
llama : sequence-length-aware batch splitting
compilade May 28, 2024
4e4c41e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 28, 2024
3587a94
llama : use equal-sequence-length sub-batches for recurrent models
compilade Jun 1, 2024
5d3c7b9
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 1, 2024
72eea49
llama : fix batch split output count for embeddings
compilade Jun 1, 2024
18d1c14
llama : minimize swaps when reordering logits
compilade Jun 1, 2024
61200ef
llama : fix edge case finding batch seq_id of split recurrent cell
compilade Jun 1, 2024
eb589d5
llama : avoid copies for simple batch splits
compilade Jun 2, 2024
8fb57ac
llama : use im2col and mul_mat to perform convolution for Mamba
compilade Jun 3, 2024
17f6c1e
llama : fix .base() compilation error on Windows
compilade Jun 3, 2024
fee3c1d
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
compilade Jun 3, 2024
6840ac0
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 8, 2024
372482d
llama : rename llama_cache to llama_past
compilade Jun 8, 2024
43d8d4b
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
compilade Jun 10, 2024
ff794f5
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 12, 2024
33425a7
mamba : fix non-contiguous usage of ggml_silu
compilade Jun 12, 2024
10c3c41
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 30, 2024
9b38f8b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 4, 2024
bc320ef
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 1, 2024
fcb889c
llama : session saving and reloading for hybrid models
compilade Sep 2, 2024
a03e32a
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 2, 2024
9d3f44d
convert_hf : fix Jamba conversion
compilade Sep 2, 2024
5f62db7
llama : fix mixed signedness comparison
compilade Sep 2, 2024
375de5b
llama : use unused n_embd_k_gqa in k_shift
compilade Sep 2, 2024
4bb4b22
llama : begin renaming llama_past back to llama_kv_cache
compilade Sep 14, 2024
63ac36b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 14, 2024
124c222
Merge branch 'master' into compilade/refactor-kv-cache
compilade Oct 12, 2024
8006f3b
llama : remove implicit recurrent state rollbacks
compilade Nov 25, 2024
691698e
Merge branch 'master' into compilade/refactor-kv-cache
compilade Nov 25, 2024
e3fe612
llama : partially apply clang-format style
compilade Nov 25, 2024
2bcaf64
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
908e655
convert : fix jamba conv1d shape squeezing
compilade Jul 3, 2025
4682e21
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
20f8e43
graph : add back hybrid memory graph input
compilade Jul 3, 2025
07c252f
model : add Jamba to Mamba-specific hparams printing
compilade Jul 3, 2025
f716358
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 7, 2025
f656712
Add PLaMo-2 model using hybrid memory module
mitmul Jul 7, 2025
4728e42
Fix z shape
mitmul Jul 7, 2025
6acaf3c
Add cmath to include from llama-vocab.h
mitmul Jul 8, 2025
7e4c5ec
Explicitly dequantize normalization weights before RoPE apply
mitmul Jul 8, 2025
149b98c
Revert unnecessary cast because the problem can be solved by excludin…
mitmul Jul 8, 2025
7786520
Use ATTN_K/Q_NORM for k,q weights to prevent quantization
mitmul Jul 8, 2025
0424a76
Remove SSM_BCDT that is not used from anywhere
mitmul Jul 9, 2025
ea95a1d
Do not duplicate embedding weights for output.weight
mitmul Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3476,6 +3476,172 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(new_name, data_torch)]


@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
class Plamo2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PLAMO2

def set_vocab(self):
# PLaMo 2 uses a custom tokenizer with a .jsonl file
# We need to handle this specially
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
tokenizer_config_path = self.dir_model / "tokenizer_config.json"

if not tokenizer_jsonl_path.is_file():
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")

# Load tokenizer config
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
tokenizer_config = json.load(f)

# Load tokens from JSONL file (actually a list format)
tokens = []
scores = []
toktypes = []

with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
if line.strip():
token_data = json.loads(line)
# Format: [token, score, type, ?, ?, ?, ?]
token = token_data[0].encode("utf-8")
score = float(token_data[1])
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"

tokens.append(token)
scores.append(score)

# Map token type strings to GGUF token types
if token_type_str == "UNKNOWN":
toktypes.append(gguf.TokenType.UNKNOWN)
elif token_type_str == "CONTROL":
toktypes.append(gguf.TokenType.CONTROL)
elif token_type_str == "BYTE":
toktypes.append(gguf.TokenType.BYTE)
else:
# Check for PLaMo-2 special tokens
token_str = token_data[0]
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

# Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
self.gguf_writer.add_tokenizer_model("plamo2")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

# Add special tokens from config
if "bos_token_id" in tokenizer_config:
self.gguf_writer.add_bos_token_id(tokenizer_config["bos_token_id"])
if "eos_token_id" in tokenizer_config:
self.gguf_writer.add_eos_token_id(tokenizer_config["eos_token_id"])
if "pad_token_id" in tokenizer_config:
self.gguf_writer.add_pad_token_id(tokenizer_config["pad_token_id"])
if "unk_token_id" in tokenizer_config:
self.gguf_writer.add_unk_token_id(tokenizer_config["unk_token_id"])

self.gguf_writer.add_add_space_prefix(False)

def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

# Which layers are Mamba layers
# PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
# This logic matches modeling_plamo.py's is_mamba function
mamba_step = hparams.get("mamba_step", 2)
mamba_enabled = hparams.get("mamba_enabled", True)
mamba_layers = []

if mamba_enabled:
for i in range(block_count):
if block_count <= (mamba_step // 2):
# use attention in last layer
is_mamba = (i != block_count - 1)
else:
is_mamba = (i % mamba_step) != (mamba_step // 2)
if is_mamba:
mamba_layers.append(0)
else:
mamba_layers.append(hparams.get("num_key_value_heads", 4))

if mamba_layers:
self.gguf_writer.add_head_count_kv(mamba_layers)

self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_group_norm_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_layer_norm_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))

# Mamba parameters
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64))
intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128)
self.gguf_writer.add_ssm_inner_size(intermediate_size)
self.gguf_writer.add_ssm_group_count(0)

# MLP feed forward parameters (for attention layers)
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif name.endswith(".dt_norm_weight"):
name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
elif name.endswith(".B_norm_weight"):
name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
elif name.endswith(".C_norm_weight"):
name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
elif name.endswith(".k_weight"):
name = name.rpartition(".k_weight")[0] + ".k.weight"
elif name.endswith(".q_weight"):
name = name.rpartition(".q_weight")[0] + ".q.weight"
elif name.endswith(".conv1d.weight"):
data_torch = torch.squeeze(data_torch) # remove (, 1, )
assert data_torch.ndim == 2
elif name.endswith(".pre_mixer_norm.weight"):
data_torch += 1.0
elif name.endswith(".post_mixer_norm.weight"):
data_torch += 1.0 / 5
elif name.endswith(".pre_mlp_norm.weight"):
data_torch += 1.0
elif name.endswith(".post_mlp_norm.weight"):
data_torch += 1.0 / (5**1.5)
elif name.endswith(".norm.weight"):
data_torch += 1.0
elif name.endswith(".gate_up_proj.weight"):
# Split the combined gate_up tensor
split_size = data_torch.shape[0] // 2
gate_tensor = data_torch[:split_size, :]
up_tensor = data_torch[split_size:, :]

# Return both tensors - remove .weight suffix if present
name_base = name.replace(".gate_up_proj.weight", "")
gate_name = name_base + ".ffn_gate.weight"
up_name = name_base + ".ffn_up.weight"

gate_mapped = self.map_tensor_name(gate_name)
up_mapped = self.map_tensor_name(up_name)

return [(gate_mapped, gate_tensor), (up_mapped, up_tensor)]

new_name = self.map_tensor_name(name)

return [(new_name, data_torch)]


@ModelBase.register("CodeShellForCausalLM")
class CodeShellModel(TextModel):
model_arch = gguf.MODEL_ARCH.CODESHELL
Expand Down Expand Up @@ -4954,6 +5120,123 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield (new_name, data_torch)


@ModelBase.register("JambaForCausalLM")
class JambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.JAMBA

def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused

return "gpt-2"

def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
# Using Jamba's tokenizer.json causes errors on model load
# (something about "byte not found in vocab"),
# but there's a working tokenizer.model
self._set_vocab_sentencepiece()
else:
# Some Jamba models only have a tokenizer.json, which works.
self._set_vocab_gpt2()

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]

self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

# Mini-Jamba
name = name.replace(".moe.", ".feed_forward.")
if bid is not None:
moe_offset = self.hparams["expert_layer_offset"]
moe_period = self.hparams["expert_layer_period"]

if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
name = name.replace(".experts.0.", ".")

# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:

# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"

new_name = self.map_tensor_name(merged_name)

yield new_name, data_torch
return

new_name = self.map_tensor_name(name)

if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
data_torch = data_torch.squeeze()

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

yield (new_name, data_torch)

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("CohereForCausalLM")
class CommandR2Model(TextModel):
model_arch = gguf.MODEL_ARCH.COMMAND_R
Expand Down
Loading
Loading