Skip to content

Commit dbce4e2

Browse files
committed
Add PLaMo-2 model using hybrid memory module
1 parent 07c252f commit dbce4e2

File tree

12 files changed

+1203
-48
lines changed

12 files changed

+1203
-48
lines changed

convert_hf_to_gguf.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3476,6 +3476,183 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
34763476
return [(new_name, data_torch)]
34773477

34783478

3479+
@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
3480+
class Plamo2Model(TextModel):
3481+
model_arch = gguf.MODEL_ARCH.PLAMO2
3482+
3483+
def set_vocab(self):
3484+
# PLaMo 2 uses a custom tokenizer with a .jsonl file
3485+
# We need to handle this specially
3486+
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
3487+
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
3488+
3489+
if not tokenizer_jsonl_path.is_file():
3490+
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
3491+
3492+
# Load tokenizer config
3493+
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
3494+
tokenizer_config = json.load(f)
3495+
3496+
# Load tokens from JSONL file (actually a list format)
3497+
tokens = []
3498+
scores = []
3499+
toktypes = []
3500+
3501+
with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
3502+
for line_num, line in enumerate(f):
3503+
if line.strip():
3504+
token_data = json.loads(line)
3505+
# Format: [token, score, type, ?, ?, ?, ?]
3506+
token = token_data[0].encode("utf-8")
3507+
score = float(token_data[1])
3508+
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
3509+
3510+
tokens.append(token)
3511+
scores.append(score)
3512+
3513+
# Map token type strings to GGUF token types
3514+
if token_type_str == "UNKNOWN":
3515+
toktypes.append(gguf.TokenType.UNKNOWN)
3516+
elif token_type_str == "CONTROL":
3517+
toktypes.append(gguf.TokenType.CONTROL)
3518+
elif token_type_str == "BYTE":
3519+
toktypes.append(gguf.TokenType.BYTE)
3520+
else:
3521+
# Check for PLaMo-2 special tokens
3522+
token_str = token_data[0]
3523+
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
3524+
toktypes.append(gguf.TokenType.CONTROL)
3525+
else:
3526+
toktypes.append(gguf.TokenType.NORMAL)
3527+
3528+
# Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
3529+
self.gguf_writer.add_tokenizer_model("plamo2")
3530+
self.gguf_writer.add_tokenizer_pre("default")
3531+
self.gguf_writer.add_token_list(tokens)
3532+
self.gguf_writer.add_token_scores(scores)
3533+
self.gguf_writer.add_token_types(toktypes)
3534+
3535+
# Add special tokens from config
3536+
if "bos_token_id" in tokenizer_config:
3537+
self.gguf_writer.add_bos_token_id(tokenizer_config["bos_token_id"])
3538+
if "eos_token_id" in tokenizer_config:
3539+
self.gguf_writer.add_eos_token_id(tokenizer_config["eos_token_id"])
3540+
if "pad_token_id" in tokenizer_config:
3541+
self.gguf_writer.add_pad_token_id(tokenizer_config["pad_token_id"])
3542+
if "unk_token_id" in tokenizer_config:
3543+
self.gguf_writer.add_unk_token_id(tokenizer_config["unk_token_id"])
3544+
3545+
self.gguf_writer.add_add_space_prefix(False)
3546+
3547+
def set_gguf_parameters(self):
3548+
hparams = self.hparams
3549+
block_count = hparams["num_hidden_layers"]
3550+
3551+
# Which layers are Mamba layers
3552+
# PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
3553+
# This logic matches modeling_plamo.py's is_mamba function
3554+
mamba_step = hparams.get("mamba_step", 2)
3555+
mamba_enabled = hparams.get("mamba_enabled", True)
3556+
mamba_layers = []
3557+
3558+
if mamba_enabled:
3559+
for i in range(block_count):
3560+
if block_count <= (mamba_step // 2):
3561+
# use attention in last layer
3562+
is_mamba = (i != block_count - 1)
3563+
else:
3564+
is_mamba = (i % mamba_step) != (mamba_step // 2)
3565+
if is_mamba:
3566+
mamba_layers.append(0)
3567+
else:
3568+
mamba_layers.append(hparams.get("num_key_value_heads", 4))
3569+
3570+
if mamba_layers:
3571+
self.gguf_writer.add_head_count_kv(mamba_layers)
3572+
3573+
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
3574+
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
3575+
self.gguf_writer.add_block_count(block_count)
3576+
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
3577+
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
3578+
self.gguf_writer.add_group_norm_eps(hparams.get("rms_norm_eps", 1e-06))
3579+
self.gguf_writer.add_layer_norm_eps(hparams.get("rms_norm_eps", 1e-06))
3580+
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
3581+
3582+
# Mamba parameters
3583+
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
3584+
self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
3585+
self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64))
3586+
intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128)
3587+
self.gguf_writer.add_ssm_inner_size(intermediate_size)
3588+
self.gguf_writer.add_ssm_group_count(0)
3589+
3590+
# MLP feed forward parameters (for attention layers)
3591+
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
3592+
self.gguf_writer.add_file_type(self.ftype)
3593+
3594+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3595+
del bid # unused
3596+
3597+
if name.endswith(".embed_tokens.weight"):
3598+
# If there is no lm_head, we need to map the token embedding to the output layer
3599+
assert self.tensor_names is not None
3600+
if all(['lm_head' not in name for name in self.tensor_names]):
3601+
name_base = name.replace(".embed_tokens.weight", "")
3602+
output_name = "lm_head"
3603+
3604+
embed_tokens_mapped = self.map_tensor_name(name)
3605+
output_mapped = self.map_tensor_name(output_name) + ".weight"
3606+
3607+
return [(embed_tokens_mapped, data_torch), (output_mapped, data_torch)]
3608+
elif name.endswith(".A_log"):
3609+
data_torch = -torch.exp(data_torch)
3610+
elif name.endswith(".dt_bias"):
3611+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
3612+
elif name.endswith(".dt_norm_weight"):
3613+
name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
3614+
elif name.endswith(".B_norm_weight"):
3615+
name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
3616+
elif name.endswith(".C_norm_weight"):
3617+
name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
3618+
elif name.endswith(".k_weight"):
3619+
name = name.rpartition(".k_weight")[0] + ".k.weight"
3620+
elif name.endswith(".q_weight"):
3621+
name = name.rpartition(".q_weight")[0] + ".q.weight"
3622+
elif name.endswith(".conv1d.weight"):
3623+
data_torch = torch.squeeze(data_torch) # remove (, 1, )
3624+
assert data_torch.ndim == 2
3625+
elif name.endswith(".pre_mixer_norm.weight"):
3626+
data_torch += 1.0
3627+
elif name.endswith(".post_mixer_norm.weight"):
3628+
data_torch += 1.0 / 5
3629+
elif name.endswith(".pre_mlp_norm.weight"):
3630+
data_torch += 1.0
3631+
elif name.endswith(".post_mlp_norm.weight"):
3632+
data_torch += 1.0 / (5**1.5)
3633+
elif name.endswith(".norm.weight"):
3634+
data_torch += 1.0
3635+
elif name.endswith(".gate_up_proj.weight"):
3636+
# Split the combined gate_up tensor
3637+
split_size = data_torch.shape[0] // 2
3638+
gate_tensor = data_torch[:split_size, :]
3639+
up_tensor = data_torch[split_size:, :]
3640+
3641+
# Return both tensors - remove .weight suffix if present
3642+
name_base = name.replace(".gate_up_proj.weight", "")
3643+
gate_name = name_base + ".ffn_gate.weight"
3644+
up_name = name_base + ".ffn_up.weight"
3645+
3646+
gate_mapped = self.map_tensor_name(gate_name)
3647+
up_mapped = self.map_tensor_name(up_name)
3648+
3649+
return [(gate_mapped, gate_tensor), (up_mapped, up_tensor)]
3650+
3651+
new_name = self.map_tensor_name(name)
3652+
3653+
return [(new_name, data_torch)]
3654+
3655+
34793656
@ModelBase.register("CodeShellForCausalLM")
34803657
class CodeShellModel(TextModel):
34813658
model_arch = gguf.MODEL_ARCH.CODESHELL

examples/eval-callback/eval-callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
122122

123123
if (!ggml_is_quantized(t->type)) {
124124
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
125-
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
125+
ggml_print_tensor(data, t->type, t->ne, t->nb, 256);
126126
}
127127

128128
return true;

gguf-py/gguf/constants.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ class MODEL_ARCH(IntEnum):
313313
PHI3 = auto()
314314
PHIMOE = auto()
315315
PLAMO = auto()
316+
PLAMO2 = auto()
316317
CODESHELL = auto()
317318
ORION = auto()
318319
INTERNLM2 = auto()
@@ -435,6 +436,7 @@ class MODEL_TENSOR(IntEnum):
435436
SSM_B_NORM = auto()
436437
SSM_C_NORM = auto()
437438
SSM_D = auto()
439+
SSM_BCDT = auto() # PLaMo-2
438440
SSM_NORM = auto()
439441
SSM_OUT = auto()
440442
TIME_MIX_W0 = auto()
@@ -620,6 +622,7 @@ class MODEL_TENSOR(IntEnum):
620622
MODEL_ARCH.PHI3: "phi3",
621623
MODEL_ARCH.PHIMOE: "phimoe",
622624
MODEL_ARCH.PLAMO: "plamo",
625+
MODEL_ARCH.PLAMO2: "plamo2",
623626
MODEL_ARCH.CODESHELL: "codeshell",
624627
MODEL_ARCH.ORION: "orion",
625628
MODEL_ARCH.INTERNLM2: "internlm2",
@@ -1350,6 +1353,37 @@ class MODEL_TENSOR(IntEnum):
13501353
MODEL_TENSOR.FFN_DOWN,
13511354
MODEL_TENSOR.FFN_UP,
13521355
],
1356+
MODEL_ARCH.PLAMO2: [
1357+
MODEL_TENSOR.TOKEN_EMBD,
1358+
MODEL_TENSOR.OUTPUT_NORM,
1359+
MODEL_TENSOR.OUTPUT,
1360+
MODEL_TENSOR.ROPE_FREQS,
1361+
MODEL_TENSOR.ATTN_NORM,
1362+
MODEL_TENSOR.ATTN_QKV,
1363+
MODEL_TENSOR.ATTN_Q,
1364+
MODEL_TENSOR.ATTN_K,
1365+
MODEL_TENSOR.ATTN_OUT,
1366+
MODEL_TENSOR.ATTN_ROT_EMBD,
1367+
MODEL_TENSOR.ATTN_Q_NORM,
1368+
MODEL_TENSOR.ATTN_K_NORM,
1369+
MODEL_TENSOR.ATTN_POST_NORM,
1370+
MODEL_TENSOR.FFN_NORM,
1371+
MODEL_TENSOR.FFN_GATE,
1372+
MODEL_TENSOR.FFN_DOWN,
1373+
MODEL_TENSOR.FFN_UP,
1374+
MODEL_TENSOR.FFN_POST_NORM,
1375+
MODEL_TENSOR.SSM_IN,
1376+
MODEL_TENSOR.SSM_CONV1D,
1377+
MODEL_TENSOR.SSM_X,
1378+
MODEL_TENSOR.SSM_DT,
1379+
MODEL_TENSOR.SSM_A,
1380+
MODEL_TENSOR.SSM_D,
1381+
MODEL_TENSOR.SSM_OUT,
1382+
MODEL_TENSOR.SSM_BCDT,
1383+
MODEL_TENSOR.SSM_DT_NORM,
1384+
MODEL_TENSOR.SSM_B_NORM,
1385+
MODEL_TENSOR.SSM_C_NORM,
1386+
],
13531387
MODEL_ARCH.GPT2: [
13541388
MODEL_TENSOR.TOKEN_EMBD,
13551389
MODEL_TENSOR.POS_EMBD,

0 commit comments

Comments
 (0)