diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dd80a4a05d596..2d332d35b1c01 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3476,6 +3476,184 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM") +class Plamo2Model(TextModel): + model_arch = gguf.MODEL_ARCH.PLAMO2 + + def set_vocab(self): + # PLaMo 2 uses a custom tokenizer with a .jsonl file + # We need to handle this specially + tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl" + tokenizer_config_path = self.dir_model / "tokenizer_config.json" + + if not tokenizer_jsonl_path.is_file(): + raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}") + + # Load tokenizer config + with open(tokenizer_config_path, 'r', encoding='utf-8') as f: + tokenizer_config = json.load(f) + + # Load tokens from JSONL file (actually a list format) + tokens = [] + scores = [] + toktypes = [] + + with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f): + if line.strip(): + token_data = json.loads(line) + # Format: [token, score, type, ?, ?, ?, ?] + token = token_data[0].encode("utf-8") + score = float(token_data[1]) + token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL" + + tokens.append(token) + scores.append(score) + + # Map token type strings to GGUF token types + if token_type_str == "UNKNOWN": + toktypes.append(gguf.TokenType.UNKNOWN) + elif token_type_str == "CONTROL": + toktypes.append(gguf.TokenType.CONTROL) + elif token_type_str == "BYTE": + toktypes.append(gguf.TokenType.BYTE) + else: + # Check for PLaMo-2 special tokens + token_str = token_data[0] + if token_str.startswith("<|plamo:") and token_str.endswith("|>"): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + # Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer + self.gguf_writer.add_tokenizer_model("plamo2") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + # Add special tokens from config + if "bos_token_id" in tokenizer_config: + self.gguf_writer.add_bos_token_id(tokenizer_config["bos_token_id"]) + if "eos_token_id" in tokenizer_config: + self.gguf_writer.add_eos_token_id(tokenizer_config["eos_token_id"]) + if "pad_token_id" in tokenizer_config: + self.gguf_writer.add_pad_token_id(tokenizer_config["pad_token_id"]) + if "unk_token_id" in tokenizer_config: + self.gguf_writer.add_unk_token_id(tokenizer_config["unk_token_id"]) + + self.gguf_writer.add_add_space_prefix(False) + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048)) + self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32)) + self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) + self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) + self.gguf_writer.add_group_norm_eps(hparams.get("rms_norm_eps", 1e-06)) + self.gguf_writer.add_layer_norm_eps(hparams.get("rms_norm_eps", 1e-06)) + self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0)) + + # Mamba parameters + self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64)) + self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4)) + self.gguf_writer.add_ssm_num_heads(hparams.get("mamba_num_heads", 64)) + self.gguf_writer.add_ssm_head_dim(hparams.get("hidden_size_per_head", 128)) + self.gguf_writer.add_ssm_inner_size(hparams.get("hidden_size_per_head", 128) * hparams.get("mamba_num_heads", 64)) + self.gguf_writer.add_ssm_time_step_rank(hparams.get("time_step_limit", 192)) + self.gguf_writer.add_ssm_dt_min(hparams.get("time_step_min", 0.001)) + self.gguf_writer.add_ssm_dt_max(hparams.get("time_step_max", 0.1)) + self.gguf_writer.add_hybrid_mamba_step(hparams.get("mamba_step", 2)) + + # MLP feed forward parameters (for attention layers) + self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384)) + + # Which layers are Mamba layers + # PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer) + # This logic matches modeling_plamo.py's is_mamba function + mamba_step = hparams.get("mamba_step", 2) + mamba_enabled = hparams.get("mamba_enabled", True) + mamba_layers = [] + + if mamba_enabled: + for i in range(block_count): + if block_count <= (mamba_step // 2): + # use attention in last layer + is_mamba = (i != block_count - 1) + else: + is_mamba = (i % mamba_step) != (mamba_step // 2) + if is_mamba: + mamba_layers.append(i) + + if mamba_layers: + self.gguf_writer.add_hybrid_mamba_layers(mamba_layers) + + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".embed_tokens.weight"): + # If there is no lm_head, we need to map the token embedding to the output layer + assert self.tensor_names is not None + if all(['lm_head' not in name for name in self.tensor_names]): + name_base = name.replace(".embed_tokens.weight", "") + output_name = "lm_head" + + embed_tokens_mapped = self.map_tensor_name(name) + output_mapped = self.map_tensor_name(output_name) + ".weight" + + return [(embed_tokens_mapped, data_torch), (output_mapped, data_torch)] + elif name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + elif name.endswith(".dt_norm_weight"): + name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight" + elif name.endswith(".B_norm_weight"): + name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight" + elif name.endswith(".C_norm_weight"): + name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight" + elif name.endswith(".k_weight"): + name = name.rpartition(".k_weight")[0] + ".k.weight" + elif name.endswith(".q_weight"): + name = name.rpartition(".q_weight")[0] + ".q.weight" + elif name.endswith(".conv1d.weight"): + data_torch = torch.squeeze(data_torch) # remove (, 1, ) + assert data_torch.ndim == 2 + elif name.endswith(".pre_mixer_norm.weight"): + data_torch += 1.0 + elif name.endswith(".post_mixer_norm.weight"): + data_torch += 1.0 / 5 + elif name.endswith(".pre_mlp_norm.weight"): + data_torch += 1.0 + elif name.endswith(".post_mlp_norm.weight"): + data_torch += 1.0 / (5**1.5) + elif name.endswith(".norm.weight"): + data_torch += 1.0 + elif name.endswith(".gate_up_proj.weight"): + # Split the combined gate_up tensor + split_size = data_torch.shape[0] // 2 + gate_tensor = data_torch[:split_size, :] + up_tensor = data_torch[split_size:, :] + + # Return both tensors - remove .weight suffix if present + name_base = name.replace(".gate_up_proj.weight", "") + gate_name = name_base + ".ffn_gate.weight" + up_name = name_base + ".ffn_up.weight" + + gate_mapped = self.map_tensor_name(gate_name) + up_mapped = self.map_tensor_name(up_name) + + return [(gate_mapped, gate_tensor), (up_mapped, up_tensor)] + + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] + + @ModelBase.register("CodeShellForCausalLM") class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index bbbec6a01a175..5faec5d6a416d 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -122,7 +122,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { if (!ggml_is_quantized(t->type)) { uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); - ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + ggml_print_tensor(data, t->type, t->ne, t->nb, 256); } return true; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c12609c6d9f99..0e4b2754f45d6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -172,6 +172,14 @@ class SSM: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + DT_MIN = "{arch}.ssm.dt_min" + DT_MAX = "{arch}.ssm.dt_max" + NUM_HEADS = "{arch}.ssm.num_heads" + HEAD_DIM = "{arch}.ssm.head_dim" + + class Hybrid: + MAMBA_LAYERS = "{arch}.hybrid.mamba_layers" + MAMBA_STEP = "{arch}.hybrid.mamba_step" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" @@ -313,6 +321,7 @@ class MODEL_ARCH(IntEnum): PHI3 = auto() PHIMOE = auto() PLAMO = auto() + PLAMO2 = auto() CODESHELL = auto() ORION = auto() INTERNLM2 = auto() @@ -433,6 +442,12 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_CONV1D_BIAS = auto() + SSM_DT_BIAS = auto() + SSM_BCDT = auto() + SSM_DT_NORM = auto() + SSM_B_NORM = auto() + SSM_C_NORM = auto() TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -616,6 +631,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PHIMOE: "phimoe", MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.PLAMO2: "plamo2", MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.ORION: "orion", MODEL_ARCH.INTERNLM2: "internlm2", @@ -736,6 +752,12 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_CONV1D_BIAS: "blk.{bid}.ssm_conv1d_bias", + MODEL_TENSOR.SSM_DT_BIAS: "blk.{bid}.ssm_dt_bias", + MODEL_TENSOR.SSM_BCDT: "blk.{bid}.ssm_bcdt", + MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm", + MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm", + MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -1342,6 +1364,39 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PLAMO2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_DT_BIAS, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.SSM_BCDT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + ], MODEL_ARCH.GPT2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 697e057c090da..97db1e8615c23 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -867,6 +867,24 @@ def add_ssm_group_count(self, value: int) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_ssm_dt_min(self, value: float) -> None: + self.add_float32(Keys.SSM.DT_MIN.format(arch=self.arch), value) + + def add_ssm_dt_max(self, value: float) -> None: + self.add_float32(Keys.SSM.DT_MAX.format(arch=self.arch), value) + + def add_ssm_num_heads(self, value: int) -> None: + self.add_uint32(Keys.SSM.NUM_HEADS.format(arch=self.arch), value) + + def add_ssm_head_dim(self, value: int) -> None: + self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value) + + def add_hybrid_mamba_layers(self, layers: list[int]) -> None: + self.add_array(Keys.Hybrid.MAMBA_LAYERS.format(arch=self.arch), layers) + + def add_hybrid_mamba_step(self, step: int) -> None: + self.add_uint32(Keys.Hybrid.MAMBA_STEP.format(arch=self.arch), step) + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 51634ef6bdd2e..89bddeec997a9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -62,7 +62,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe plamo2 "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -76,7 +76,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe plamo2 "norm", # llama-pth "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 @@ -125,6 +125,7 @@ class TensorNameMap: "h.{bid}.ln_1", # gpt2 "transformer.h.{bid}.ln", # phi2 "model.layers.layers.{bid}.norm", # plamo + "model.layers.layers.{bid}.pre_mixer_norm", # plamo2 "model.layers.{bid}.attention_norm", # internlm2 "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba @@ -161,6 +162,7 @@ class TensorNameMap: "encoder.layers.{bid}.attn.Wqkv", # nomic-bert "encoder.layers.{bid}.mixer.Wqkv", # jina "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "model.layers.layers.{bid}.mixer.qkv_proj", # plamo2 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer_encoder.{bid}.qkv", # neobert @@ -175,6 +177,7 @@ class TensorNameMap: "transformer.layer.{bid}.attention.q_lin", # distillbert "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo + "model.layers.layers.{bid}.mixer.q", # plamo2 "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.h.{bid}.attn.attention.q_proj", # exaone @@ -191,6 +194,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k", # refact "model.layers.layers.{bid}.self_attn.k_proj", # plamo + "model.layers.layers.{bid}.mixer.k", # plamo2 "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.h.{bid}.attn.attention.k_proj", # exaone @@ -230,6 +234,7 @@ class TensorNameMap: "h.{bid}.attn.c_proj", # gpt2 "transformer.h.{bid}.mixer.out_proj", # phi2 "model.layers.layers.{bid}.self_attn.o_proj", # plamo + "model.layers.layers.{bid}.mixer.o_proj", # plamo2 "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert "encoder.layers.{bid}.mixer.out_proj", # jina @@ -252,8 +257,9 @@ class TensorNameMap: ), MODEL_TENSOR.ATTN_POST_NORM: ( - "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge - "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + "model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2 ), # Rotary embeddings @@ -281,6 +287,7 @@ class TensorNameMap: "transformer.layers.{bid}.ffn_norm", # openelm "model.layers.{bid}.post_attention_layernorm", # llama4 "transformer_encoder.{bid}.ffn_norm", # neobert + "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 ), # Post feed-forward norm @@ -292,6 +299,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 + "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2 ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -334,6 +342,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.fc1", # phi2 "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 "model.layers.layers.{bid}.mlp.up_proj", # plamo + "model.layers.layers.{bid}.mlp.ffn_up", # plamo2 "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe @@ -376,6 +385,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.c_fc2", # jais "model.layers.layers.{bid}.mlp.gate_proj", # plamo + "model.layers.layers.{bid}.mlp.ffn_gate" , # plamo2 "model.layers.{bid}.feed_forward.w1", # internlm2 "encoder.layers.{bid}.mlp.fc12", # nomic-bert "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used) @@ -547,31 +557,49 @@ class TensorNameMap: MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", + "model.layers.layers.{bid}.mixer.in_proj", # plamo2 ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", + "model.layers.layers.{bid}.mixer.conv1d", # plamo2 ), MODEL_TENSOR.SSM_X: ( "model.layers.{bid}.x_proj", "backbone.layers.{bid}.mixer.x_proj", + "model.layers.layers.{bid}.mixer.bcdt_proj", # plamo2 ), MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 ), MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", + "model.layers.layers.{bid}.mixer.A_log", # plamo2 ), MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", + "model.layers.layers.{bid}.mixer.D", # plamo2 + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2 + ), + + MODEL_TENSOR.SSM_B_NORM: ( + "model.layers.layers.{bid}.mixer.B_norm.weight", # plamo2 + ), + + MODEL_TENSOR.SSM_C_NORM: ( + "model.layers.layers.{bid}.mixer.C_norm.weight", # plamo2 ), MODEL_TENSOR.SSM_NORM: ( @@ -581,6 +609,17 @@ class TensorNameMap: MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", + "model.layers.layers.{bid}.mixer.out_proj", # plamo2 + ), + + MODEL_TENSOR.SSM_CONV1D_BIAS: ( + "model.layers.{bid}.conv1d_bias", + "backbone.layers.{bid}.mixer.conv1d.bias", + ), + + MODEL_TENSOR.SSM_DT_BIAS: ( + "model.layers.{bid}.dt_bias", + "backbone.layers.{bid}.mixer.dt_bias", ), MODEL_TENSOR.TIME_MIX_W0: ( diff --git a/include/llama.h b/include/llama.h index 3eda9bc68608c..cab66d70fa2af 100644 --- a/include/llama.h +++ b/include/llama.h @@ -71,12 +71,13 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback - LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE - LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram - LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming }; // pre-tokenization types diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ab24054305857..22de2782d2c8d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -34,6 +34,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_PLAMO2, "plamo2" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -173,6 +174,13 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_DT_MIN, "%s.ssm.dt_min" }, + { LLM_KV_SSM_DT_MAX, "%s.ssm.dt_max" }, + { LLM_KV_SSM_NUM_HEADS, "%s.ssm.num_heads" }, + { LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" }, + + { LLM_KV_HYBRID_MAMBA_LAYERS, "%s.hybrid.mamba_layers" }, + { LLM_KV_HYBRID_MAMBA_STEP, "%s.hybrid.mamba_step" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -777,6 +785,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PLAMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_CODESHELL, { @@ -1779,7 +1819,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1925,9 +1967,9 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { } bool llm_arch_is_hybrid(const llm_arch & arch) { - // TODO: There are currently no hybrid models! Once there are, this will be - // the place to identify them switch (arch) { + case LLM_ARCH_PLAMO2: + return true; default: return false; } diff --git a/src/llama-arch.h b/src/llama-arch.h index b769831dff5ec..cd62f3f9be57a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -38,6 +38,7 @@ enum llm_arch { LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, + LLM_ARCH_PLAMO2, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -177,6 +178,14 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_SSM_DT_MIN, + LLM_KV_SSM_DT_MAX, + LLM_KV_SSM_NUM_HEADS, + LLM_KV_SSM_HEAD_DIM, + LLM_KV_SSM_STEP, + + LLM_KV_HYBRID_MAMBA_LAYERS, + LLM_KV_HYBRID_MAMBA_STEP, LLM_KV_WKV_HEAD_SIZE, @@ -297,6 +306,9 @@ enum llm_tensor { LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_SSM_DT_NORM, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19cbf40..900be758a1d6b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -7,10 +7,18 @@ #include "llama-mmap.h" #include "llama-model.h" +#include "ggml.h" + +#include #include #include #include #include +#ifdef _WIN32 +#include +#else +#include +#endif // // llama_context @@ -693,6 +701,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, } auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx); + res->graph = gf; // Store the graph pointer in the result object if (!res) { LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); ret = GGML_STATUS_FAILED; @@ -701,6 +710,9 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + // Dump computation graph for visualization + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; @@ -1042,11 +1054,6 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // plot the computation graph in dot format (for debugging purposes) - //if (n_past%100 == 0) { - // ggml_graph_dump_dot(gf, NULL, "llama.dot"); - //} - auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f2fae6d1b71aa..686d3805cecf3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1,5 +1,6 @@ #include "llama-graph.h" +#include "ggml.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-cparams.h" @@ -802,7 +803,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); + // cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); res->t_tokens = inp->tokens; @@ -1226,7 +1227,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * mctx_cur = static_cast(mctx); + const auto * mctx_cur = inp->mctx; // store to KV cache { diff --git a/src/llama-graph.h b/src/llama-graph.h index db4e14805caa3..9ce6a19d30784 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -358,9 +358,11 @@ class llm_graph_result_i { virtual ggml_tensor * get_tokens() = 0; virtual ggml_tensor * get_logits() = 0; virtual ggml_tensor * get_embd() = 0; + virtual ggml_cgraph * get_graph() = 0; virtual ggml_tensor * get_embd_pooled() = 0; virtual void set_inputs(const llama_ubatch * ubatch) = 0; + ggml_cgraph * graph = nullptr; // Store the graph here }; using llm_graph_result_ptr = std::unique_ptr; @@ -373,6 +375,7 @@ class llm_graph_result : public llm_graph_result_i { ggml_tensor * get_tokens() override { return t_tokens; } ggml_tensor * get_logits() override { return t_logits; } ggml_tensor * get_embd() override { return t_embd; } + ggml_cgraph * get_graph() override { return graph; } ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } void set_inputs(const llama_ubatch * ubatch) override { diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 86c814d51b901..ac641d2544f27 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -71,6 +71,12 @@ uint32_t llama_hparams::n_embd_r() const { return token_shift_count * n_embd; } + // for PLaMo-2 hybrid models that use mamba_step + if (mamba_step > 0 && ssm_num_heads > 0 && ssm_head_dim > 0) { + // PLaMo-2 uses mamba_num_heads * hidden_size_per_head for Mamba layers + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_num_heads * ssm_head_dim); + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -83,6 +89,12 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + // for PLaMo-2 hybrid models that use mamba_step + if (mamba_step > 0 && ssm_num_heads > 0 && ssm_head_dim > 0) { + // PLaMo-2 uses mamba_num_heads * hidden_size_per_head for Mamba layers + return ssm_d_state * (ssm_num_heads * ssm_head_dim); + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 476d0a5eade28..52f18d415d291 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -114,6 +115,8 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_num_heads = 0; + uint32_t ssm_head_dim = 0; uint32_t ssm_n_group = 0; // for hybrid state space models @@ -121,6 +124,11 @@ struct llama_hparams { bool ssm_dt_b_c_rms = false; + // for PLaMo 2 hybrid architecture + float ssm_dt_min = 0.001f; + float ssm_dt_max = 0.1f; + uint32_t mamba_step = 2; + float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 7f7b162ffd7ce..80c6603e97f87 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -779,11 +779,30 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ const int64_t n_tokens = k_cur->ne[2]; - ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + // Handle PLaMo-2 GQA: k_cur might be expanded (128, 32, 512) but cache expects (128, 4, 512) + ggml_tensor * k_view; + if (k_cur->ne[1] != hparams.n_head_kv(il)) { + // k_cur is GQA-expanded, need to take only the first n_head_kv heads + ggml_tensor * k_cur_orig = ggml_view_3d(ctx, k_cur, + k_cur->ne[0], hparams.n_head_kv(il), k_cur->ne[2], + k_cur->nb[1], k_cur->nb[2], 0); + + k_view = ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_tokens, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + + return ggml_cpy(ctx, k_cur_orig, k_view); + } else { + k_view = ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_tokens, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); - return ggml_cpy(ctx, k_cur, k_view); + return ggml_cpy(ctx, k_cur, k_view); + } } ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { @@ -793,6 +812,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ const int64_t n_tokens = v_cur->ne[2]; + // Handle PLaMo-2 GQA: v_cur might be expanded (128, 32, 512) but cache expects (128, 4, 512) + if (v_cur->ne[1] != hparams.n_head_kv(il)) { + // v_cur is GQA-expanded, need to take only the first n_head_kv heads + v_cur = ggml_view_3d(ctx, v_cur, + v_cur->ne[0], hparams.n_head_kv(il), v_cur->ne[2], + v_cur->nb[1], v_cur->nb[2], 0); + } + v_cur = ggml_cont(ctx, v_cur); v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); ggml_tensor * v_view = nullptr; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 4c53f1273ab88..36956b3cbd47c 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -236,6 +236,14 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { virtual ~llama_kv_cache_unified_context(); + // Delete copy constructor and copy assignment to prevent shallow copies + llama_kv_cache_unified_context(const llama_kv_cache_unified_context&) = delete; + llama_kv_cache_unified_context& operator=(const llama_kv_cache_unified_context&) = delete; + + // Delete move constructor and move assignment to prevent issues + llama_kv_cache_unified_context(llama_kv_cache_unified_context&&) = delete; + llama_kv_cache_unified_context& operator=(llama_kv_cache_unified_context&&) = delete; + // // llama_memory_context_i // diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index bd9e6da8832b7..ffe7a999dbb5a 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -465,6 +465,9 @@ namespace GGUFMeta { template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + // Template instantiation for PLaMo 2 hybrid architecture + template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); + llama_model_loader::llama_model_loader( const std::string & fname, std::vector & splits, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0573c5bcea0a4..73fef028c91f9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1,7 +1,8 @@ #include "llama-model.h" +#include "ggml.h" +#include "llama-arch.h" #include "llama-impl.h" -#include "llama-mmap.h" #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" @@ -18,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -929,6 +929,38 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLAMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_NUM_HEADS, hparams.ssm_num_heads); + ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); + ml.get_key(LLM_KV_SSM_DT_MIN, hparams.ssm_dt_min, false); + ml.get_key(LLM_KV_SSM_DT_MAX, hparams.ssm_dt_max, false); + ml.get_key(LLM_KV_HYBRID_MAMBA_STEP, hparams.mamba_step, true); + + // Load which layers are Mamba layers + std::vector mamba_layer_indices; + if (ml.get_arr(LLM_KV_HYBRID_MAMBA_LAYERS, mamba_layer_indices, false)) { + // Mark specified layers as Mamba + for (uint32_t idx : mamba_layer_indices) { + if (idx % hparams.mamba_step == 0) { + hparams.recurrent_layer_arr[idx] = true; + } + } + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2807,6 +2839,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_PLAMO2: + { + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_num_heads; + const uint32_t intermediate_size = num_heads * hparams.ssm_head_dim; + const uint32_t head_dim = hparams.ssm_head_dim; + const uint32_t qk_dim = head_dim; + const uint32_t v_dim = head_dim; + const int64_t num_attention_heads = hparams.n_head(); + const int64_t q_num_heads = num_attention_heads; + const int64_t num_key_value_heads = hparams.n_head_kv(); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + // Mamba layer tensors + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + // PLaMo-2 SSM norm tensors + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + // Attention layer tensors + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {head_dim, num_attention_heads}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {head_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8111,6 +8209,357 @@ struct llm_build_plamo : public llm_graph_context { } }; +struct llm_build_plamo2 : public llm_graph_context { + llm_build_plamo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + + // key variables used in PLaMo-2 attention + // const int64_t n_embd_head = hparams.n_embd_head_v; + // ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: Cast to f32 is currently required for ggml_get_rows in build_inp_embd + ggml_tensor * embed_tokens = ggml_cast(ctx0, model.tok_embd, GGML_TYPE_F32); + + // {n_embd, n_tokens} + ggml_tensor * inpL = build_inp_embd(embed_tokens); + cb(inpL, "embedding_output", -1); + + // ensure the memory context is hybrid + const auto * mctx_hybrid = dynamic_cast(mctx); + GGML_ASSERT(mctx_hybrid != nullptr); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + // ggml_graph_add_node(gf, model.layers[il].attn_norm); + // cb(model.layers[il].attn_norm, "attn_norm", il); + + // pre_mixer_norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + + // check if this layer is Mamba or Attention + bool is_mamba_layer = hparams.is_recurrent(il); + + if (is_mamba_layer) { + // PLaMo-2 Mamba layer + cur = build_plamo2_mamba_layer(model, gf, cur, il, mctx_hybrid->get_recr(), ubatch); + } else { + // PLaMo-2 Attention layer + cur = build_plamo2_attn_layer(model, gf, cur, il, mctx_hybrid->get_attn()); + } + + // post_mixer_norm + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + + // pre-ffn norm + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_pre_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + } + cb(cur, "ffn_out", il); + + // post ffn norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, residual); + + inpL = cur; + + if (il >= 2) { + break; + } + } + + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + // Explicitly mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +private: + ggml_tensor * build_plamo2_attn_layer( + const llama_model & model, + ggml_cgraph * gf, + ggml_tensor * inpL, + int il, + const llama_kv_cache_unified_context * attn_ctx) { + + ggml_tensor * cur = inpL; + ggml_tensor * inpSA = inpL; + + // attention layer specific variables + // const int64_t n_embd_head = hparams.n_embd_head_v; + ggml_tensor * inp_pos = build_inp_pos(); + + // self-attention + { + // For PLaMo-2 hybrid architecture, get the correct attention context + // const auto * mctx_hybrid = static_cast(mctx); + // const auto * unified_ctx = mctx_hybrid->get_attn(); + auto inp = std::make_unique(hparams, cparams, attn_ctx); + auto * inp_attn = inp.release(); + + // PLaMo-2 uses combined QKV tensor + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "qkv", il); + + // split QKV tensor into Q, K, V + const int64_t n_embd_head_q = hparams.n_embd_head_k; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + + const int64_t q_offset = 0; + const int64_t k_offset = n_embd_head_q * n_head; + const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; + + ggml_tensor * Qcur = ggml_view_2d(ctx0, qkv, n_embd_head_q * n_head, n_tokens, qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, qkv, n_embd_head_k * n_head_kv, n_tokens, qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)); + + // make tensors contiguous before reshape + Qcur = ggml_cont(ctx0, Qcur); + Kcur = ggml_cont(ctx0, Kcur); + Vcur = ggml_cont(ctx0, Vcur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_q, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + // Store original K and V for KV cache (before GQA expansion) + ggml_tensor * Kcur_cache = Kcur; + ggml_tensor * Vcur_cache = Vcur; + + // PLaMo-2 GQA: expand K and V heads to match Q heads (equivalent to _expand_kv) + if (n_head_kv < n_head) { + // const int n_group = n_head / n_head_kv; + + // manually expand K and V tensors to repeat each head n_group times + // create expanded tensors with target dimensions + ggml_tensor * Kcur_expanded = ggml_new_tensor_3d(ctx0, Kcur->type, n_embd_head_k, n_head, n_tokens); + ggml_tensor * Vcur_expanded = ggml_new_tensor_3d(ctx0, Vcur->type, n_embd_head_v, n_head, n_tokens); + + // repeat each head n_group times + Kcur = ggml_repeat(ctx0, Kcur, Kcur_expanded); + Vcur = ggml_repeat(ctx0, Vcur, Vcur_expanded); + + cb(Kcur, "Kcur_expanded", il); + cb(Vcur, "Vcur_expanded", il); + } + + cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur_cache, Vcur_cache, NULL, NULL, 1.0f, il); + } + + // add the input + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_out", il); + + return cur; + } + + ggml_tensor * build_plamo2_mamba_layer( + const llama_model & model, + ggml_cgraph * gf, + ggml_tensor * inpL, + int il, + const llama_memory_recurrent_context * recr_ctx, + const llama_ubatch & ubatch) { + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Get conv and ssm states + ggml_tensor * conv_states_all = recr_ctx->get_r_l(il); + ggml_tensor * ssm_states_all = recr_ctx->get_s_l(il); + + const auto kv_head = recr_ctx->get_head(); + + ggml_tensor * conv = ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1) * d_inner * n_seqs, + kv_head * (d_conv - 1) * d_inner * ggml_element_size(conv_states_all)); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); + + ggml_tensor * ssm = ggml_view_1d(ctx0, ssm_states_all, + d_state * d_inner * n_seqs, + kv_head * d_state * d_inner * ggml_element_size(ssm_states_all)); + ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + ggml_tensor * cur = ggml_reshape_3d(ctx0, inpL, inpL->ne[0], n_seq_tokens, n_seqs); + cb(cur, "mamba_input", il); + + // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); + cb(zx, "mamba_in_proj", il); + + zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); + zx = ggml_reshape_4d(ctx0, zx, 2 * hparams.ssm_head_dim, hparams.ssm_num_heads, n_seq_tokens, n_seqs); + cb(zx, "mamba_in_proj_out", il); + + // split into z and x + // => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], hparams.ssm_head_dim*ggml_element_size(zx)); + x = ggml_cont(ctx0, x); + x = ggml_reshape_4d(ctx0, x, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs); + x = ggml_permute(ctx0, x, 0, 2, 1, 3); + cb(x, "mamba_x_split", il); + ggml_tensor * z = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], 0); + z = ggml_cont(ctx0, z); + z = ggml_reshape_4d(ctx0, z, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs); + z = ggml_permute(ctx0, z, 0, 2, 1, 3); + cb(z, "mamba_z_split", il); + + // conv1d + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + cb(conv_x, "mamba_conv1d_input", il); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, + conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + cb(x, "mamba_conv1d", il); + + x = ggml_silu(ctx0, x); + cb(x, "mamba_conv1d_silu", il); + } + + // SSM + { + // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); + cb(x_bcdt, "mamba_bcdt_proj", il); + + // split into dt, B, C + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); + ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state); + ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state)); + cb(B, "mamba_B_raw", il); + cb(C, "mamba_C_raw", il); + cb(dt, "mamba_dt_raw", il); + + // Apply RMS norm to dt, B, C (PLaMo-2 specific) + B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); + dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); + cb(B, "mamba_B_normed", il); + cb(C, "mamba_C_normed", il); + cb(dt, "mamba_dt_normed", il); + + // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + cb(dt, "mamba_dt_proj", il); + + // This is corresponding to the broadcast_to operation in ssd_update_state() of the originall code + ggml_tensor * dt_expanded = ggml_new_tensor_2d(ctx0, dt->type, dt_dim * hparams.ssm_num_heads, dt->ne[1]); + dt = ggml_repeat(ctx0, dt, dt_expanded); + cb(dt, "mamba_dt_expanded", il); + + ggml_tensor * A_expanded = ggml_new_tensor_2d(ctx0, model.layers[il].ssm_a->type, d_state, d_inner); + A_expanded = ggml_repeat(ctx0, model.layers[il].ssm_a, A_expanded); + A_expanded = ggml_exp(ctx0, A_expanded); + A_expanded = ggml_scale(ctx0, A_expanded, -1.0f); + cb(A_expanded, "mamba_A_expanded", il); + + // SSM scan operation + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, A_expanded, B, C); + cb(y_ssm, "mamba_ssm_scan", il); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, + kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + cb(y, "mamba_y_view", il); + + // Add D parameter and apply gating with z + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba_y_with_D", il); + ggml_tensor * z_silu = ggml_silu(ctx0, ggml_cont(ctx0, z)); + cb(z_silu, "mamba_z_silu", il); + y = ggml_mul(ctx0, y, z_silu); + cb(y, "mamba_y_gated", il); + + // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + cb(cur, "mamba_out_proj", il); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; + } +}; + struct llm_build_gpt2 : public llm_graph_context { llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14693,6 +15142,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + // PLaMo-2 layer filters: attention vs Mamba layers + auto filter_attn = [this](int32_t il) -> bool { + return il < (int32_t)hparams.n_layer && !hparams.is_recurrent(il); // attention layers + }; + auto filter_recr = [this](int32_t il) -> bool { + return il < (int32_t)hparams.n_layer && hparams.is_recurrent(il); // Mamba layers + }; + res = new llama_memory_hybrid( /* model */ *this, /* attn_type_k */ params.type_k, @@ -14706,7 +15163,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv); + /* offload */ cparams.offload_kqv, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); } else { const auto padding = llama_kv_cache_unified::get_padding(cparams); @@ -14854,6 +15313,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_PLAMO2: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_GPT2: { llm = std::make_unique(*this, params, gf); @@ -15174,6 +15637,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: + case LLM_ARCH_PLAMO2: case LLM_ARCH_INTERNLM2: case LLM_ARCH_MINICPM: case LLM_ARCH_XVERSE: diff --git a/src/llama-model.h b/src/llama-model.h index 979fff62045f9..97c1bd72cb073 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -258,6 +258,11 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + // mamba norm (PLaMo-2) + struct ggml_tensor * ssm_dt_norm = nullptr; + struct ggml_tensor * ssm_b_norm = nullptr; + struct ggml_tensor * ssm_c_norm = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 5c9eb87566dde..c03b4d696e3b5 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -1159,6 +1160,40 @@ struct llm_tokenizer_rwkv : llm_tokenizer { struct naive_trie token_matcher; }; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + // Build vocabulary entries for PLaMo-2 tokenizer + std::vector vocab_entries; + vocab_entries.reserve(vocab.n_tokens()); + + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & data = vocab.get_token_data(id); + llama_vocab_plamo2::vocab_entry entry; + entry.text = data.text; + entry.score = data.score; + + // Check if this is a byte token + if (vocab.is_byte(id)) { + entry.type = "BYTE"; + } else { + entry.type = ""; + } + + vocab_entries.push_back(entry); + } + + // Build the Aho-Corasick automaton + plamo2_tokenizer.build(vocab_entries); + } + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = plamo2_tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); + } + + llama_vocab_plamo2 plamo2_tokenizer; +}; + struct llm_tokenizer_rwkv_session { llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} @@ -1498,6 +1533,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_unk_id = LLAMA_TOKEN_NULL; special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -2134,12 +2179,13 @@ enum llama_vocab_type llama_vocab::impl::get_type() const { std::string llama_vocab::impl::type_name() const{ switch (type) { - case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - case LLAMA_VOCAB_TYPE_UGM: return "UGM"; - case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; + case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; + case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2"; default: return "unknown"; } } @@ -2223,6 +2269,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) { case LLAMA_VOCAB_TYPE_RWKV: tokenizer = std::make_unique(vocab); break; + case LLAMA_VOCAB_TYPE_PLAMO2: + tokenizer = std::make_unique(vocab); + break; default: GGML_ABORT("unsupported vocab type"); } @@ -2565,6 +2614,23 @@ std::vector llama_vocab::impl::tokenize( } } } break; + case LLAMA_VOCAB_TYPE_PLAMO2: + { + auto * plamo2_tokenizer = static_cast(tokenizer.get()); + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); +#endif + + plamo2_tokenizer->tokenize(text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ABORT("fatal error"); } @@ -2653,6 +2719,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t memcpy(buf, result.data(), result.size()); return (int)result.size(); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses similar token handling as BPE/SPM + if (vocab.is_byte(token)) { + // Handle byte tokens like <0xXX> + if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') { + int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16); + if (length < 1) { + return -1; + } + buf[0] = static_cast(hex_val); + return 1; + } + } + + // Normal token - just copy the text + std::string result = token_text; + return _try_copy(result.data(), result.size()); + } default: GGML_ABORT("fatal error"); } @@ -2897,6 +2981,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const { case LLAMA_VOCAB_TYPE_BPE: { return pimpl->token_to_id.at(unicode_byte_to_utf8(ch)); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses byte tokens in format <0xXX> + char hex_str[8]; + snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch); + return pimpl->token_to_id.at(hex_str); + } default: GGML_ABORT("fatal error"); } @@ -3375,3 +3465,287 @@ int32_t llama_detokenize( return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special); } +// +// PLaMo-2 Aho-Corasick tokenizer implementation +// + +llama_vocab_plamo2::llama_vocab_plamo2() : bytes_(256, 0) { +} + +llama_vocab_plamo2::~llama_vocab_plamo2() { +} + +void llama_vocab_plamo2::build(const std::vector & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.size(); ++token_id) { + const auto & entry = vocab[token_id]; + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (entry.type == "BYTE") { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } + + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; + for (size_t i = 1; i < entry.text.length(); ++i) { + std::string suffix = entry.text.substr(i); + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } + + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } + + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix + + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); + + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; + + for (const auto & s : suffixes) { + suffix_to_id[s] = num_pieces; + if (!s.empty()) { + // Convert first character to Unicode code point + std::vector unicode_chars = utf8_to_unicode(s); + if (!unicode_chars.empty()) { + int64_t piece_code = (static_cast(unicode_chars[0]) << 32) | suffix_to_id[s.substr(1)]; + to_suffix_id_[piece_code] = num_pieces; + } + } + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (size_t i = 1; i <= s.length(); ++i) { + std::string prefix = s.substr(0, i); + if (suffix_to_score.find(prefix) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } + + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; + + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + for (int32_t piece_length = static_cast(suffix.length()); piece_length > 0; --piece_length) { + std::string piece = suffix.substr(0, piece_length); + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } + + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; + + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; + + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } +} + +std::vector llama_vocab_plamo2::utf8_to_unicode(const std::string & text) const { + std::vector result; + const char * ptr = text.c_str(); + const char * end = ptr + text.length(); + + while (ptr < end) { + int32_t codepoint = 0; + int bytes_read = 0; + + if ((*ptr & 0x80) == 0) { + // ASCII + codepoint = *ptr; + bytes_read = 1; + } else if ((*ptr & 0xE0) == 0xC0) { + // 2-byte UTF-8 + codepoint = (*ptr & 0x1F) << 6; + codepoint |= (*(ptr + 1) & 0x3F); + bytes_read = 2; + } else if ((*ptr & 0xF0) == 0xE0) { + // 3-byte UTF-8 + codepoint = (*ptr & 0x0F) << 12; + codepoint |= (*(ptr + 1) & 0x3F) << 6; + codepoint |= (*(ptr + 2) & 0x3F); + bytes_read = 3; + } else if ((*ptr & 0xF8) == 0xF0) { + // 4-byte UTF-8 + codepoint = (*ptr & 0x07) << 18; + codepoint |= (*(ptr + 1) & 0x3F) << 12; + codepoint |= (*(ptr + 2) & 0x3F) << 6; + codepoint |= (*(ptr + 3) & 0x3F); + bytes_read = 4; + } else { + // Invalid UTF-8, skip this byte + ptr++; + continue; + } + + result.push_back(codepoint); + ptr += bytes_read; + } + + return result; +} + +std::vector llama_vocab_plamo2::encode_unicode(const std::vector & unicode_data) const { + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + int32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + } + } + } + + if (score == UNKNOWN_SCORE) { + break; + } + } + } + + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + int32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int j = 0; j < s; ++j) { + uint8_t b = (s == 1) ? c : + (j == 0) ? (0xF00 >> s) & 0xFF : + 0x80 | ((c >> ((s - j - 1) * 6)) & 0x3F); + token_ids.push_back(bytes_[b]); + } + } + + pos += path[pos][PATH_TOKEN_LENGTH]; + } + + return token_ids; +} + +std::vector llama_vocab_plamo2::encode(const std::string & text) const { + std::vector unicode_data = utf8_to_unicode(text); + return encode_unicode(unicode_data); +} + +std::vector llama_vocab_plamo2::encode_as_tokens(const std::string & text) const { + std::vector token_ids = encode(text); + std::vector result; + result.reserve(token_ids.size()); + + for (llama_token id : token_ids) { + if (id >= 0 && id < static_cast(tokens_.size())) { + result.push_back(tokens_[id]); + } + } + + return result; +} + +const std::string & llama_vocab_plamo2::get_token_text(llama_token id) const { + static const std::string empty_string; + if (id >= 0 && id < static_cast(tokens_.size())) { + return tokens_[id]; + } + return empty_string; +} + +size_t llama_vocab_plamo2::vocab_size() const { + return tokens_.size(); +} + diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 40e4d1c05b18e..7c8ed6ca14da5 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -5,6 +5,7 @@ #include #include #include +#include struct LLM_KV; struct llama_model_loader; @@ -130,3 +131,68 @@ struct llama_vocab { struct impl; std::unique_ptr pimpl; }; + +// PLaMo-2 Aho-Corasick tokenizer +class llama_vocab_plamo2 { +public: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; + + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; + + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + struct vocab_entry { + std::string text; + float score; + std::string type; // "BYTE" for byte tokens, empty for normal tokens + }; + + llama_vocab_plamo2(); + ~llama_vocab_plamo2(); + + // Build the Aho-Corasick automaton from vocabulary + void build(const std::vector & vocab); + + // Encode text to token IDs + std::vector encode(const std::string & text) const; + + // Encode text to token strings + std::vector encode_as_tokens(const std::string & text) const; + + // Get token text by ID + const std::string & get_token_text(llama_token id) const; + + // Get vocabulary size + size_t vocab_size() const; + +private: + // Internal structures for Aho-Corasick automaton + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; + + // Helper functions + void build_suffix_map(const std::vector & vocab); + void build_trie_table(const std::vector & vocab); + std::vector utf8_to_unicode(const std::string & text) const; + std::vector encode_unicode(const std::vector & unicode_data) const; +};