Skip to content

Commit 9741405

Browse files
committed
Add PLaMo-2
1 parent bb16041 commit 9741405

16 files changed

+823
-46
lines changed

convert_hf_to_gguf.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3417,6 +3417,167 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
34173417
return [(new_name, data_torch)]
34183418

34193419

3420+
@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
3421+
class Plamo2Model(TextModel):
3422+
model_arch = gguf.MODEL_ARCH.PLAMO2
3423+
3424+
def set_vocab(self):
3425+
# PLaMo 2 uses a custom tokenizer with a .jsonl file
3426+
# We need to handle this specially
3427+
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
3428+
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
3429+
3430+
if not tokenizer_jsonl_path.is_file():
3431+
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
3432+
3433+
# Load tokenizer config
3434+
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
3435+
tokenizer_config = json.load(f)
3436+
3437+
# Load tokens from JSONL file (actually a list format)
3438+
tokens = []
3439+
scores = []
3440+
toktypes = []
3441+
3442+
with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
3443+
for line_num, line in enumerate(f):
3444+
if line.strip():
3445+
token_data = json.loads(line)
3446+
# Format: [token, score, type, ?, ?, ?, ?]
3447+
token = token_data[0].encode("utf-8")
3448+
score = float(token_data[1])
3449+
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
3450+
3451+
tokens.append(token)
3452+
scores.append(score)
3453+
3454+
# Map token type strings to GGUF token types
3455+
if token_type_str == "UNKNOWN":
3456+
toktypes.append(gguf.TokenType.UNKNOWN)
3457+
elif token_type_str == "CONTROL":
3458+
toktypes.append(gguf.TokenType.CONTROL)
3459+
elif token_type_str == "BYTE":
3460+
toktypes.append(gguf.TokenType.BYTE)
3461+
else:
3462+
toktypes.append(gguf.TokenType.NORMAL)
3463+
3464+
# Use "llama" (SPM) tokenizer type which doesn't require merges
3465+
# PLaMo 2's tokenizer is more similar to SPM than GPT2
3466+
self.gguf_writer.add_tokenizer_model("llama")
3467+
self.gguf_writer.add_tokenizer_pre("default")
3468+
self.gguf_writer.add_token_list(tokens)
3469+
self.gguf_writer.add_token_scores(scores)
3470+
self.gguf_writer.add_token_types(toktypes)
3471+
3472+
# Add special tokens from config
3473+
if "bos_token_id" in tokenizer_config:
3474+
self.gguf_writer.add_bos_token_id(tokenizer_config["bos_token_id"])
3475+
if "eos_token_id" in tokenizer_config:
3476+
self.gguf_writer.add_eos_token_id(tokenizer_config["eos_token_id"])
3477+
if "pad_token_id" in tokenizer_config:
3478+
self.gguf_writer.add_pad_token_id(tokenizer_config["pad_token_id"])
3479+
if "unk_token_id" in tokenizer_config:
3480+
self.gguf_writer.add_unk_token_id(tokenizer_config["unk_token_id"])
3481+
3482+
self.gguf_writer.add_add_space_prefix(False)
3483+
3484+
def set_gguf_parameters(self):
3485+
hparams = self.hparams
3486+
block_count = hparams["num_hidden_layers"]
3487+
3488+
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
3489+
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
3490+
self.gguf_writer.add_block_count(block_count)
3491+
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
3492+
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
3493+
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
3494+
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
3495+
3496+
# Mamba parameters
3497+
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
3498+
self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
3499+
self.gguf_writer.add_ssm_num_heads(hparams.get("mamba_num_heads", 64))
3500+
self.gguf_writer.add_ssm_head_dim(hparams.get("hidden_size_per_head", 128))
3501+
self.gguf_writer.add_ssm_inner_size(hparams.get("hidden_size_per_head", 128) * hparams.get("mamba_num_heads", 64))
3502+
self.gguf_writer.add_ssm_time_step_rank(hparams.get("time_step_limit", 192))
3503+
self.gguf_writer.add_ssm_dt_min(hparams.get("time_step_min", 0.001))
3504+
self.gguf_writer.add_ssm_dt_max(hparams.get("time_step_max", 0.1))
3505+
self.gguf_writer.add_hybrid_mamba_step(hparams.get("mamba_step", 2))
3506+
3507+
# MLP feed forward parameters (for attention layers)
3508+
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
3509+
3510+
# Which layers are Mamba layers
3511+
# PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
3512+
# This logic matches modeling_plamo.py's is_mamba function
3513+
mamba_step = hparams.get("mamba_step", 2)
3514+
mamba_enabled = hparams.get("mamba_enabled", True)
3515+
mamba_layers = []
3516+
3517+
if mamba_enabled:
3518+
for i in range(block_count):
3519+
if block_count <= (mamba_step // 2):
3520+
# use attention in last layer
3521+
is_mamba = (i != block_count - 1)
3522+
else:
3523+
is_mamba = (i % mamba_step) != (mamba_step // 2)
3524+
if is_mamba:
3525+
mamba_layers.append(i)
3526+
3527+
if mamba_layers:
3528+
self.gguf_writer.add_hybrid_mamba_layers(mamba_layers)
3529+
3530+
self.gguf_writer.add_file_type(self.ftype)
3531+
3532+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3533+
del bid # unused
3534+
3535+
if name.endswith(".dt_bias"):
3536+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
3537+
elif name.endswith(".dt_norm_weight"):
3538+
name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
3539+
elif name.endswith(".B_norm_weight"):
3540+
name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
3541+
elif name.endswith(".C_norm_weight"):
3542+
name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
3543+
elif name.endswith(".k_weight"):
3544+
name = name.rpartition(".k_weight")[0] + ".k.weight"
3545+
elif name.endswith(".q_weight"):
3546+
name = name.rpartition(".q_weight")[0] + ".q.weight"
3547+
elif name.endswith(".conv1d.weight"):
3548+
data_torch = torch.squeeze(data_torch) # remove (, 1, )
3549+
assert data_torch.ndim == 2
3550+
elif name.endswith(".pre_mixer_norm.weight"):
3551+
data_torch += 1.0
3552+
elif name.endswith(".post_mixer_norm.weight"):
3553+
data_torch += 1.0 / 5
3554+
elif name.endswith(".pre_mlp_norm.weight"):
3555+
data_torch += 1.0
3556+
elif name.endswith(".post_mlp_norm.weight"):
3557+
data_torch += 1.0 / (5**1.5)
3558+
elif name.endswith(".gate_up_proj.weight"):
3559+
# Split the combined gate_up tensor
3560+
split_size = data_torch.shape[0] // 2
3561+
gate_tensor = data_torch[:split_size, :]
3562+
up_tensor = data_torch[split_size:, :]
3563+
3564+
# Return both tensors - remove .weight suffix if present
3565+
name_base = name.replace(".gate_up_proj.weight", "")
3566+
gate_name = name_base + ".ffn_gate.weight"
3567+
up_name = name_base + ".ffn_up.weight"
3568+
3569+
gate_mapped = self.map_tensor_name(gate_name)
3570+
up_mapped = self.map_tensor_name(up_name)
3571+
3572+
return [(gate_mapped, gate_tensor), (up_mapped, up_tensor)]
3573+
3574+
new_name = self.map_tensor_name(name)
3575+
3576+
print(f"Plamo2Model: {name} -> {new_name}, shape={data_torch.shape}")
3577+
3578+
return [(new_name, data_torch)]
3579+
3580+
34203581
@ModelBase.register("CodeShellForCausalLM")
34213582
class CodeShellModel(TextModel):
34223583
model_arch = gguf.MODEL_ARCH.CODESHELL

ggml/src/ggml-backend.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
243243
}
244244

245245
void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
246+
if (tensor->data == NULL) {
247+
fprintf(stderr, "ERROR: Tensor '%s' data is NULL - cannot read tensor\n",
248+
tensor->name ? tensor->name : "unnamed");
249+
250+
// For output tensors that may not have been properly allocated
251+
if (tensor->flags & GGML_TENSOR_FLAG_OUTPUT) {
252+
fprintf(stderr, " Output tensor detected - this may indicate scheduling issue\n");
253+
// Return zeros for now to prevent crash
254+
memset(data, 0, size);
255+
return;
256+
}
257+
}
246258
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
247259
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
248260

@@ -261,6 +273,24 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz
261273
return;
262274
}
263275

276+
if (buf == NULL) {
277+
// For input tensors, buffer allocation may happen later by the scheduler
278+
if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
279+
// fprintf(stderr, "WARNING: Skipping tensor_set for input tensor '%s' - buffer will be allocated by scheduler\n",
280+
// tensor->name ? tensor->name : "unnamed");
281+
return;
282+
}
283+
284+
// Enhanced error message with tensor information
285+
fprintf(stderr, "ERROR: Tensor buffer not set for tensor '%s' (op: %s, type: %s)\n",
286+
tensor->name ? tensor->name : "unnamed",
287+
ggml_op_name(tensor->op),
288+
ggml_type_name(tensor->type));
289+
if (tensor->view_src) {
290+
fprintf(stderr, " This is a view tensor with view_src: '%s'\n",
291+
tensor->view_src->name ? tensor->view_src->name : "unnamed");
292+
}
293+
}
264294
GGML_ASSERT(buf != NULL && "tensor buffer not set");
265295
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
266296
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
@@ -1648,7 +1678,23 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg
16481678
ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
16491679
int backend_index = tensor_backend_id(node);
16501680
if (backend_index == -1) {
1651-
return NULL;
1681+
// Enhanced debugging for unassigned tensors
1682+
fprintf(stderr, "ERROR: Tensor '%s' (op: %s, flags: 0x%x) has no backend assigned (backend_id = -1)\n",
1683+
node->name ? node->name : "unnamed",
1684+
ggml_op_name(node->op),
1685+
node->flags);
1686+
1687+
// Try to assign to CPU backend as fallback for output tensors
1688+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
1689+
fprintf(stderr, " Attempting to assign output tensor to CPU backend\n");
1690+
backend_index = sched->n_backends - 1; // CPU backend
1691+
tensor_backend_id(node) = backend_index;
1692+
SET_CAUSE(node, "out.cpu");
1693+
}
1694+
1695+
if (backend_index == -1) {
1696+
return NULL;
1697+
}
16521698
}
16531699
return sched->backends[backend_index];
16541700
}

gguf-py/gguf/constants.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ class SSM:
165165
STATE_SIZE = "{arch}.ssm.state_size"
166166
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
167167
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
168+
DT_MIN = "{arch}.ssm.dt_min"
169+
DT_MAX = "{arch}.ssm.dt_max"
170+
NUM_HEADS = "{arch}.ssm.num_heads"
171+
HEAD_DIM = "{arch}.ssm.head_dim"
172+
173+
class Hybrid:
174+
MAMBA_LAYERS = "{arch}.hybrid.mamba_layers"
175+
MAMBA_STEP = "{arch}.hybrid.mamba_step"
168176

169177
class WKV:
170178
HEAD_SIZE = "{arch}.wkv.head_size"
@@ -306,6 +314,7 @@ class MODEL_ARCH(IntEnum):
306314
PHI3 = auto()
307315
PHIMOE = auto()
308316
PLAMO = auto()
317+
PLAMO2 = auto()
309318
CODESHELL = auto()
310319
ORION = auto()
311320
INTERNLM2 = auto()
@@ -406,6 +415,12 @@ class MODEL_TENSOR(IntEnum):
406415
SSM_A = auto()
407416
SSM_D = auto()
408417
SSM_OUT = auto()
418+
SSM_CONV1D_BIAS = auto()
419+
SSM_DT_BIAS = auto()
420+
SSM_BCDT = auto()
421+
SSM_DT_NORM = auto()
422+
SSM_B_NORM = auto()
423+
SSM_C_NORM = auto()
409424
TIME_MIX_W0 = auto()
410425
TIME_MIX_W1 = auto()
411426
TIME_MIX_W2 = auto()
@@ -589,6 +604,7 @@ class MODEL_TENSOR(IntEnum):
589604
MODEL_ARCH.PHI3: "phi3",
590605
MODEL_ARCH.PHIMOE: "phimoe",
591606
MODEL_ARCH.PLAMO: "plamo",
607+
MODEL_ARCH.PLAMO2: "plamo2",
592608
MODEL_ARCH.CODESHELL: "codeshell",
593609
MODEL_ARCH.ORION: "orion",
594610
MODEL_ARCH.INTERNLM2: "internlm2",
@@ -689,6 +705,12 @@ class MODEL_TENSOR(IntEnum):
689705
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
690706
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
691707
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
708+
MODEL_TENSOR.SSM_CONV1D_BIAS: "blk.{bid}.ssm_conv1d_bias",
709+
MODEL_TENSOR.SSM_DT_BIAS: "blk.{bid}.ssm_dt_bias",
710+
MODEL_TENSOR.SSM_BCDT: "blk.{bid}.ssm_bcdt",
711+
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
712+
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
713+
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
692714
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
693715
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
694716
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
@@ -1295,6 +1317,39 @@ class MODEL_TENSOR(IntEnum):
12951317
MODEL_TENSOR.FFN_DOWN,
12961318
MODEL_TENSOR.FFN_UP,
12971319
],
1320+
MODEL_ARCH.PLAMO2: [
1321+
MODEL_TENSOR.TOKEN_EMBD,
1322+
MODEL_TENSOR.OUTPUT_NORM,
1323+
MODEL_TENSOR.OUTPUT,
1324+
MODEL_TENSOR.ROPE_FREQS,
1325+
MODEL_TENSOR.ATTN_NORM,
1326+
MODEL_TENSOR.ATTN_Q,
1327+
MODEL_TENSOR.ATTN_K,
1328+
MODEL_TENSOR.ATTN_V,
1329+
MODEL_TENSOR.ATTN_QKV,
1330+
MODEL_TENSOR.ATTN_OUT,
1331+
MODEL_TENSOR.ATTN_ROT_EMBD,
1332+
MODEL_TENSOR.ATTN_Q_NORM,
1333+
MODEL_TENSOR.ATTN_K_NORM,
1334+
MODEL_TENSOR.ATTN_POST_NORM,
1335+
MODEL_TENSOR.FFN_NORM,
1336+
MODEL_TENSOR.FFN_GATE,
1337+
MODEL_TENSOR.FFN_DOWN,
1338+
MODEL_TENSOR.FFN_UP,
1339+
MODEL_TENSOR.FFN_POST_NORM,
1340+
MODEL_TENSOR.SSM_IN,
1341+
MODEL_TENSOR.SSM_CONV1D,
1342+
MODEL_TENSOR.SSM_X,
1343+
MODEL_TENSOR.SSM_DT,
1344+
MODEL_TENSOR.SSM_DT_BIAS,
1345+
MODEL_TENSOR.SSM_A,
1346+
MODEL_TENSOR.SSM_D,
1347+
MODEL_TENSOR.SSM_OUT,
1348+
MODEL_TENSOR.SSM_BCDT,
1349+
MODEL_TENSOR.SSM_DT_NORM,
1350+
MODEL_TENSOR.SSM_B_NORM,
1351+
MODEL_TENSOR.SSM_C_NORM,
1352+
],
12981353
MODEL_ARCH.GPT2: [
12991354
MODEL_TENSOR.TOKEN_EMBD,
13001355
MODEL_TENSOR.POS_EMBD,

gguf-py/gguf/gguf_writer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,24 @@ def add_ssm_time_step_rank(self, value: int) -> None:
846846
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
847847
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
848848

849+
def add_ssm_dt_min(self, value: float) -> None:
850+
self.add_float32(Keys.SSM.DT_MIN.format(arch=self.arch), value)
851+
852+
def add_ssm_dt_max(self, value: float) -> None:
853+
self.add_float32(Keys.SSM.DT_MAX.format(arch=self.arch), value)
854+
855+
def add_ssm_num_heads(self, value: int) -> None:
856+
self.add_uint32(Keys.SSM.NUM_HEADS.format(arch=self.arch), value)
857+
858+
def add_ssm_head_dim(self, value: int) -> None:
859+
self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value)
860+
861+
def add_hybrid_mamba_layers(self, layers: list[int]) -> None:
862+
self.add_array(Keys.Hybrid.MAMBA_LAYERS.format(arch=self.arch), layers)
863+
864+
def add_hybrid_mamba_step(self, step: int) -> None:
865+
self.add_uint32(Keys.Hybrid.MAMBA_STEP.format(arch=self.arch), step)
866+
849867
def add_tokenizer_model(self, model: str) -> None:
850868
self.add_string(Keys.Tokenizer.MODEL, model)
851869

0 commit comments

Comments
 (0)