diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dd80a4a05d596..1743f4e867cfd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3476,6 +3476,172 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM") +class Plamo2Model(TextModel): + model_arch = gguf.MODEL_ARCH.PLAMO2 + + def set_vocab(self): + # PLaMo 2 uses a custom tokenizer with a .jsonl file + # We need to handle this specially + tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl" + tokenizer_config_path = self.dir_model / "tokenizer_config.json" + + if not tokenizer_jsonl_path.is_file(): + raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}") + + # Load tokenizer config + with open(tokenizer_config_path, 'r', encoding='utf-8') as f: + tokenizer_config = json.load(f) + + # Load tokens from JSONL file (actually a list format) + tokens = [] + scores = [] + toktypes = [] + + with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f): + if line.strip(): + token_data = json.loads(line) + # Format: [token, score, type, ?, ?, ?, ?] + token = token_data[0].encode("utf-8") + score = float(token_data[1]) + token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL" + + tokens.append(token) + scores.append(score) + + # Map token type strings to GGUF token types + if token_type_str == "UNKNOWN": + toktypes.append(gguf.TokenType.UNKNOWN) + elif token_type_str == "CONTROL": + toktypes.append(gguf.TokenType.CONTROL) + elif token_type_str == "BYTE": + toktypes.append(gguf.TokenType.BYTE) + else: + # Check for PLaMo-2 special tokens + token_str = token_data[0] + if token_str.startswith("<|plamo:") and token_str.endswith("|>"): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + # Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer + self.gguf_writer.add_tokenizer_model("plamo2") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + # Add special tokens from config + if "bos_token_id" in tokenizer_config: + self.gguf_writer.add_bos_token_id(tokenizer_config["bos_token_id"]) + if "eos_token_id" in tokenizer_config: + self.gguf_writer.add_eos_token_id(tokenizer_config["eos_token_id"]) + if "pad_token_id" in tokenizer_config: + self.gguf_writer.add_pad_token_id(tokenizer_config["pad_token_id"]) + if "unk_token_id" in tokenizer_config: + self.gguf_writer.add_unk_token_id(tokenizer_config["unk_token_id"]) + + self.gguf_writer.add_add_space_prefix(False) + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + # Which layers are Mamba layers + # PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer) + # This logic matches modeling_plamo.py's is_mamba function + mamba_step = hparams.get("mamba_step", 2) + mamba_enabled = hparams.get("mamba_enabled", True) + mamba_layers = [] + + if mamba_enabled: + for i in range(block_count): + if block_count <= (mamba_step // 2): + # use attention in last layer + is_mamba = (i != block_count - 1) + else: + is_mamba = (i % mamba_step) != (mamba_step // 2) + if is_mamba: + mamba_layers.append(0) + else: + mamba_layers.append(hparams.get("num_key_value_heads", 4)) + + if mamba_layers: + self.gguf_writer.add_head_count_kv(mamba_layers) + + self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048)) + self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32)) + self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) + self.gguf_writer.add_group_norm_eps(hparams.get("rms_norm_eps", 1e-06)) + self.gguf_writer.add_layer_norm_eps(hparams.get("rms_norm_eps", 1e-06)) + self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0)) + + # Mamba parameters + self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64)) + self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4)) + self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64)) + intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128) + self.gguf_writer.add_ssm_inner_size(intermediate_size) + self.gguf_writer.add_ssm_group_count(0) + + # MLP feed forward parameters (for attention layers) + self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384)) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + elif name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + elif name.endswith(".dt_norm_weight"): + name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight" + elif name.endswith(".B_norm_weight"): + name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight" + elif name.endswith(".C_norm_weight"): + name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight" + elif name.endswith(".k_weight"): + name = name.rpartition(".k_weight")[0] + ".k.weight" + elif name.endswith(".q_weight"): + name = name.rpartition(".q_weight")[0] + ".q.weight" + elif name.endswith(".conv1d.weight"): + data_torch = torch.squeeze(data_torch) # remove (, 1, ) + assert data_torch.ndim == 2 + elif name.endswith(".pre_mixer_norm.weight"): + data_torch += 1.0 + elif name.endswith(".post_mixer_norm.weight"): + data_torch += 1.0 / 5 + elif name.endswith(".pre_mlp_norm.weight"): + data_torch += 1.0 + elif name.endswith(".post_mlp_norm.weight"): + data_torch += 1.0 / (5**1.5) + elif name.endswith(".norm.weight"): + data_torch += 1.0 + elif name.endswith(".gate_up_proj.weight"): + # Split the combined gate_up tensor + split_size = data_torch.shape[0] // 2 + gate_tensor = data_torch[:split_size, :] + up_tensor = data_torch[split_size:, :] + + # Return both tensors - remove .weight suffix if present + name_base = name.replace(".gate_up_proj.weight", "") + gate_name = name_base + ".ffn_gate.weight" + up_name = name_base + ".ffn_up.weight" + + gate_mapped = self.map_tensor_name(gate_name) + up_mapped = self.map_tensor_name(up_name) + + return [(gate_mapped, gate_tensor), (up_mapped, up_tensor)] + + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] + + @ModelBase.register("CodeShellForCausalLM") class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL @@ -4954,6 +5120,123 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) +@ModelBase.register("JambaForCausalLM") +class JambaModel(TextModel): + model_arch = gguf.MODEL_ARCH.JAMBA + + def get_vocab_base_pre(self, tokenizer) -> str: + del tokenizer # unused + + return "gpt-2" + + def set_vocab(self): + if (self.dir_model / "tokenizer.model").is_file(): + # Using Jamba's tokenizer.json causes errors on model load + # (something about "byte not found in vocab"), + # but there's a working tokenizer.model + self._set_vocab_sentencepiece() + else: + # Some Jamba models only have a tokenizer.json, which works. + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) + d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 + d_inner = self.hparams["mamba_expand"] * d_model + d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6 + n_kv_head = self.hparams["num_key_value_heads"] + attn_offset = self.hparams["attn_layer_offset"] + attn_period = self.hparams["attn_layer_period"] + n_kv_vec = [0 for _ in range(attn_offset)] + [ + n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) + ] + + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(n_kv_vec) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_file_type(self.ftype) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # Mini-Jamba + name = name.replace(".moe.", ".feed_forward.") + if bid is not None: + moe_offset = self.hparams["expert_layer_offset"] + moe_period = self.hparams["expert_layer_period"] + + if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0): + name = name.replace(".experts.0.", ".") + + # process the experts separately + if ".feed_forward.experts." in name: + n_experts = self.hparams["num_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + + # merge the experts into a single 3d tensor + for wid in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + # using the same merged name as qwen2moe + merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + yield new_name, data_torch + return + + new_name = self.map_tensor_name(name) + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c12609c6d9f99..3717dc728499f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -313,6 +313,7 @@ class MODEL_ARCH(IntEnum): PHI3 = auto() PHIMOE = auto() PLAMO = auto() + PLAMO2 = auto() CODESHELL = auto() ORION = auto() INTERNLM2 = auto() @@ -329,6 +330,7 @@ class MODEL_ARCH(IntEnum): ARWKV7 = auto() MAMBA = auto() MAMBA2 = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() COHERE2 = auto() @@ -429,7 +431,10 @@ class MODEL_TENSOR(IntEnum): SSM_CONV1D = auto() SSM_X = auto() SSM_DT = auto() + SSM_DT_NORM = auto() SSM_A = auto() + SSM_B_NORM = auto() + SSM_C_NORM = auto() SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() @@ -616,6 +621,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PHIMOE: "phimoe", MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.PLAMO2: "plamo2", MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.ORION: "orion", MODEL_ARCH.INTERNLM2: "internlm2", @@ -632,6 +638,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA2: "mamba2", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COHERE2: "cohere2", @@ -732,7 +739,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm", + MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", @@ -1342,6 +1352,36 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PLAMO2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + ], MODEL_ARCH.GPT2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, @@ -1732,6 +1772,34 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 51634ef6bdd2e..8854d33b92135 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -62,7 +62,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe plamo2 "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -76,7 +76,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe plamo2 "norm", # llama-pth "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 @@ -125,6 +125,7 @@ class TensorNameMap: "h.{bid}.ln_1", # gpt2 "transformer.h.{bid}.ln", # phi2 "model.layers.layers.{bid}.norm", # plamo + "model.layers.layers.{bid}.pre_mixer_norm", # plamo2 "model.layers.{bid}.attention_norm", # internlm2 "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba @@ -161,6 +162,7 @@ class TensorNameMap: "encoder.layers.{bid}.attn.Wqkv", # nomic-bert "encoder.layers.{bid}.mixer.Wqkv", # jina "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "model.layers.layers.{bid}.mixer.qkv_proj", # plamo2 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer_encoder.{bid}.qkv", # neobert @@ -230,6 +232,7 @@ class TensorNameMap: "h.{bid}.attn.c_proj", # gpt2 "transformer.h.{bid}.mixer.out_proj", # phi2 "model.layers.layers.{bid}.self_attn.o_proj", # plamo + "model.layers.layers.{bid}.mixer.o_proj", # plamo2 "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert "encoder.layers.{bid}.mixer.out_proj", # jina @@ -252,8 +255,9 @@ class TensorNameMap: ), MODEL_TENSOR.ATTN_POST_NORM: ( - "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge - "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + "model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2 ), # Rotary embeddings @@ -279,8 +283,11 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.pre_ff_layernorm", # jamba + "model.layers.{bid}.pre_moe_layernorm", # mini-jamba "model.layers.{bid}.post_attention_layernorm", # llama4 "transformer_encoder.{bid}.ffn_norm", # neobert + "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 ), # Post feed-forward norm @@ -292,6 +299,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 + "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2 ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -300,6 +308,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate", # qwen2moe olmoe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.feed_forward.router", # jamba "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe "model.layers.{bid}.feed_forward.router", # llama4 "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe @@ -334,6 +343,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.fc1", # phi2 "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 "model.layers.layers.{bid}.mlp.up_proj", # plamo + "model.layers.layers.{bid}.mlp.ffn_up", # plamo2 "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe @@ -342,6 +352,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU) "encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU) "model.layers.{bid}.residual_mlp.w3", # arctic + "model.layers.{bid}.feed_forward.up_proj", # jamba "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone "model.layers.{bid}.feed_forward.up_proj", # llama4 @@ -376,11 +387,13 @@ class TensorNameMap: "transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.c_fc2", # jais "model.layers.layers.{bid}.mlp.gate_proj", # plamo + "model.layers.layers.{bid}.mlp.ffn_gate" , # plamo2 "model.layers.{bid}.feed_forward.w1", # internlm2 "encoder.layers.{bid}.mlp.fc12", # nomic-bert "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used) "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic + "model.layers.{bid}.feed_forward.gate_proj", # jamba "transformer.h.{bid}.mlp.c_fc_0", # exaone "model.layers.{bid}.feed_forward.gate_proj", # llama4 ), @@ -425,6 +438,7 @@ class TensorNameMap: "transformer.layers.{bid}.ffn.proj_2", # openelm "model.layers.{bid}.residual_mlp.w2", # arctic "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 + "model.layers.{bid}.feed_forward.down_proj", # jamba "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.{bid}.feed_forward.down_proj", # llama4 @@ -456,6 +470,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm + "model.layers.layers.{bid}.mixer.q", # plamo2 ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -465,6 +480,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm + "model.layers.layers.{bid}.mixer.k", # plamo2 ), MODEL_TENSOR.ROPE_FREQS: ( @@ -545,33 +561,65 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", - "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba + "model.layers.layers.{bid}.mixer.in_proj", # plamo2 ), MODEL_TENSOR.SSM_CONV1D: ( - "model.layers.{bid}.conv1d", - "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.conv1d", # mamba-hf + "backbone.layers.{bid}.mixer.conv1d", # mamba + "model.layers.{bid}.mamba.conv1d", # jamba + "model.layers.layers.{bid}.mixer.conv1d", # plamo2 ), MODEL_TENSOR.SSM_X: ( - "model.layers.{bid}.x_proj", - "backbone.layers.{bid}.mixer.x_proj", + "model.layers.{bid}.x_proj", # mamba-hf + "backbone.layers.{bid}.mixer.x_proj", # mamba + "model.layers.{bid}.mamba.x_proj", # jamba + "model.layers.layers.{bid}.mixer.bcdt_proj", # plamo2 ), MODEL_TENSOR.SSM_DT: ( - "model.layers.{bid}.dt_proj", - "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.dt_proj", # mamba-hf + "backbone.layers.{bid}.mixer.dt_proj", # mamba + "model.layers.{bid}.mamba.dt_proj", # jamba + "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.{bid}.mamba.dt_layernorm", # jamba ), MODEL_TENSOR.SSM_A: ( - "model.layers.{bid}.A_log", - "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.A_log", # mamba-hf + "backbone.layers.{bid}.mixer.A_log", # mamba + "model.layers.{bid}.mamba.A_log", # jamba + "model.layers.layers.{bid}.mixer.A_log", # plamo2 + ), + + MODEL_TENSOR.SSM_B_NORM: ( + "model.layers.{bid}.mamba.b_layernorm", # jamba + "model.layers.{bid}.mamba.B_layernorm", # mini-jamba + "model.layers.layers.{bid}.mixer.B_norm.weight", # plamo2 + ), + + MODEL_TENSOR.SSM_C_NORM: ( + "model.layers.{bid}.mamba.c_layernorm", # jamba + "model.layers.{bid}.mamba.C_layernorm", # mini-jamba + "model.layers.layers.{bid}.mixer.C_norm.weight", # plamo2 ), MODEL_TENSOR.SSM_D: ( - "model.layers.{bid}.D", - "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.D", # mamba-hf + "backbone.layers.{bid}.mixer.D", # mamba + "model.layers.{bid}.mamba.D", # jamba + "model.layers.layers.{bid}.mixer.D", # plamo2 + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2 ), MODEL_TENSOR.SSM_NORM: ( @@ -579,8 +627,10 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_OUT: ( - "model.layers.{bid}.out_proj", - "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.out_proj", # mamba-hf + "backbone.layers.{bid}.mixer.out_proj", # mamba + "model.layers.{bid}.mamba.out_proj", # jamba + "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), MODEL_TENSOR.TIME_MIX_W0: ( diff --git a/include/llama.h b/include/llama.h index 3eda9bc68608c..cab66d70fa2af 100644 --- a/include/llama.h +++ b/include/llama.h @@ -71,12 +71,13 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback - LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE - LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram - LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming }; // pre-tokenization types diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ab24054305857..9fc568c348e76 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -34,6 +34,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_PLAMO2, "plamo2" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -46,6 +47,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, + { LLM_ARCH_JAMBA, "jamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COHERE2, "cohere2" }, @@ -777,6 +779,37 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PLAMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_CODESHELL, { @@ -1022,6 +1055,37 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_JAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1778,6 +1842,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1928,6 +1995,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { // TODO: There are currently no hybrid models! Once there are, this will be // the place to identify them switch (arch) { + case LLM_ARCH_JAMBA: + case LLM_ARCH_PLAMO2: + return true; default: return false; } diff --git a/src/llama-arch.h b/src/llama-arch.h index b769831dff5ec..9abc66856ff90 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -38,6 +38,7 @@ enum llm_arch { LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, + LLM_ARCH_PLAMO2, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -50,6 +51,7 @@ enum llm_arch { LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, + LLM_ARCH_JAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -293,7 +295,10 @@ enum llm_tensor { LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7f0e8c67f1325..55a059d0975d2 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -336,22 +336,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch); - - mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - - const int64_t n_rs = mctx->get_recr()->get_n_rs(); - - if (s_copy) { - GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); - int32_t * data = (int32_t *) s_copy->data; - - // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n - for (uint32_t i = 0; i < n_rs; ++i) { - data[i] = mctx->get_recr()->s_copy(i); - } - } + inp_attn->set_input(ubatch); + inp_rs->set_input(ubatch); } void llm_graph_input_one::set_input(const llama_ubatch * ubatch) { @@ -992,35 +978,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t return pos_bias; } -llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { - const auto * mctx_cur = static_cast(mctx); - - auto inp = std::make_unique(hparams, cparams, mctx_cur); - - { - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - - const auto n_kv = inp->mctx->get_attn()->get_n_kv(); - - inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch); - inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch); - - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); - ggml_set_input(inp->self_kq_mask); - - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - } - - { - const auto n_rs = mctx_cur->get_recr()->get_n_rs(); - - inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); - ggml_set_input(inp->s_copy); - } - - return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); -} - ggml_tensor * llm_graph_context::build_attn_mha( ggml_cgraph * gf, ggml_tensor * q, @@ -1194,8 +1151,12 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const auto * mctx_cur = static_cast(mctx); +static std::unique_ptr build_attn_inp_kv_unified_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_context * mctx_cur) { auto inp = std::make_unique(hparams, cparams, mctx_cur); @@ -1203,6 +1164,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); @@ -1213,6 +1175,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + return inp; +} + +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); } @@ -1234,7 +1204,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * mctx_cur = static_cast(mctx); + const auto * mctx_cur = inp->mctx; // store to KV cache { @@ -1293,7 +1263,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, v_cur); } - const auto * mctx_iswa = static_cast(mctx); + const auto * mctx_iswa = inp->mctx; const bool is_swa = hparams.is_swa(il); @@ -1391,59 +1361,9 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - // these nodes are added to the graph together so that they are not reordered - // by doing so, the number of splits in the graph is reduced - ggml_build_forward_expand(gf, q_cur); - ggml_build_forward_expand(gf, k_cur); - ggml_build_forward_expand(gf, v_cur); - - const auto * mctx_cur = static_cast(mctx)->get_attn(); - - // store to KV cache - { - const auto & k_idxs = inp->get_k_idxs(); - const auto & v_idxs = inp->get_v_idxs(); - - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); - } - - const auto & kq_mask = inp->get_kq_mask(); - - ggml_tensor * q = q_cur; - ggml_tensor * k = mctx_cur->get_k(ctx0, il); - ggml_tensor * v = mctx_cur->get_v(ctx0, il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); - cb(cur, "kqv_out", il); - - if (wo) { - cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators - ggml_mul_mat_set_prec(cur, GGML_PREC_F32); - } - } - - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } - - return cur; -} - +// TODO: maybe separate the inner implementation into a separate function +// like with the non-sliding window equivalent +// once sliding-window hybrid caches are a thing. llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { const auto * mctx_cur = static_cast(mctx); @@ -1513,8 +1433,9 @@ ggml_tensor * llm_graph_context::build_rs( return output_states; } -llm_graph_input_rs * llm_graph_context::build_rs_inp() const { - const auto * mctx_cur = static_cast(mctx); +static std::unique_ptr build_rs_inp_impl( + ggml_context * ctx0, + const llama_memory_recurrent_context * mctx_cur) { auto inp = std::make_unique(mctx_cur); @@ -1523,29 +1444,25 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const { inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); - return (llm_graph_input_rs *) res->add_input(std::move(inp)); + return inp; } -ggml_tensor * llm_graph_context::build_rs( - llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * s, - int32_t state_size, - int32_t n_seqs, - const llm_graph_get_rows_fn & get_state_rows) const { - const auto * kv_state = static_cast(mctx); +llm_graph_input_rs * llm_graph_context::build_rs_inp() const { + const auto * mctx_cur = static_cast(mctx); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); + auto inp = build_rs_inp_impl(ctx0, mctx_cur); + + return (llm_graph_input_rs *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_rs( - llm_graph_input_mem_hybrid * inp, + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, const llm_graph_get_rows_fn & get_state_rows) const { - const auto * kv_state = static_cast(mctx)->get_recr(); + const auto * kv_state = inp->mctx; return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); } @@ -1592,6 +1509,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ); } +llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_pooling( ggml_cgraph * gf, ggml_tensor * cls, diff --git a/src/llama-graph.h b/src/llama-graph.h index 7bdf656768a0c..54eaaac02b99e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -322,32 +322,21 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_memory_hybrid_context * mctx) : - hparams(hparams), - cparams(cparams), - mctx(mctx) { - } + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * s_copy; // I32 [kv_size] - - ggml_tensor * get_k_idxs() const { return self_k_idxs; } - ggml_tensor * get_v_idxs() const { return self_v_idxs; } - - ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } - - ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] - - const llama_hparams & hparams; - const llama_cparams & cparams; + llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } const llama_memory_hybrid_context * mctx; }; @@ -579,8 +568,6 @@ struct llm_graph_context { ggml_tensor * build_inp_pos_bucket_dec() const; ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; - llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; - // // attention // @@ -656,18 +643,6 @@ struct llm_graph_context { float kq_scale, int il) const; - ggml_tensor * build_attn( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] - ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il) const; // // recurrent // @@ -700,14 +675,6 @@ struct llm_graph_context { int32_t n_seqs, const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; - ggml_tensor * build_rs( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * s, - int32_t state_size, - int32_t n_seqs, - const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; - ggml_tensor * build_rwkv_token_shift_load( llm_graph_input_rs * inp, ggml_cgraph * gf, @@ -718,6 +685,11 @@ struct llm_graph_context { ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const; + // + // hybrid + // + + llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; // // pooling diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 86c814d51b901..1d926d8016b75 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -1,6 +1,7 @@ #include "llama-hparams.h" #include "ggml.h" +#include "llama-arch.h" void llama_hparams::set_swa_pattern(uint32_t n_pattern) { for (uint32_t il = 0; il < n_layer; ++il) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0573c5bcea0a4..bc9f19536e41a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -929,6 +929,27 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLAMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1117,6 +1138,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2807,6 +2848,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_PLAMO2: + { + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const uint32_t head_dim = intermediate_size / num_heads; + const uint32_t qk_dim = head_dim; + const uint32_t v_dim = head_dim; + const int64_t num_attention_heads = hparams.n_head(); + const int64_t q_num_heads = num_attention_heads; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3208,6 +3317,114 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } } break; + case LLM_ARCH_JAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + + layer.wq = nullptr; + layer.wk = nullptr; + layer.wv = nullptr; + layer.wo = nullptr; + + } else { + // Attention layers + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ssm_in = nullptr; + layer.ssm_conv1d = nullptr; + layer.ssm_conv1d_b = nullptr; + layer.ssm_x = nullptr; + layer.ssm_dt_norm = nullptr; + layer.ssm_dt = nullptr; + layer.ssm_dt_b = nullptr; + layer.ssm_b_norm = nullptr; + layer.ssm_c_norm = nullptr; + layer.ssm_a = nullptr; + layer.ssm_d = nullptr; + layer.ssm_out = nullptr; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.ffn_gate = nullptr; + layer.ffn_down = nullptr; + layer.ffn_up = nullptr; + } else { + // FFN (no MoE) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.ffn_gate_exps = nullptr; + layer.ffn_down_exps = nullptr; + layer.ffn_up_exps = nullptr; + } + } + } break; case LLM_ARCH_XVERSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4714,16 +4931,6 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); - } - - if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - if (!classifier_labels.empty()) { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); @@ -4734,6 +4941,15 @@ void llama_model::print_info() const { } } + if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_PLAMO2) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); @@ -8111,6 +8327,8 @@ struct llm_build_plamo : public llm_graph_context { } }; + + struct llm_build_gpt2 : public llm_graph_context { llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -9752,62 +9970,8 @@ struct llm_build_starcoder2 : public llm_graph_context { } }; -struct llm_build_mamba : public llm_graph_context { - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { - ggml_tensor * cur; - ggml_tensor * inpL; - - // {n_embd, n_tokens} - inpL = build_inp_embd(model.tok_embd); - - auto * rs_inp = build_rs_inp(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - if (model.arch == LLM_ARCH_MAMBA2) { - cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); - } else { - cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - - // residual - cur = ggml_add(ctx0, cur, inpL); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - // final rmsnorm - cur = build_norm(inpL, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); - } +struct llm_graph_context_mamba : public llm_graph_context { + llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} ggml_tensor * build_mamba_layer( llm_graph_input_rs * inp, @@ -9815,11 +9979,14 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, - int il) const { - const auto * mctx_cur = static_cast(mctx); + int il) { + + const auto * mctx_cur = inp->mctx; const auto kv_head = mctx_cur->get_head(); + const auto & layer = model.layers[il]; + const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; @@ -9829,8 +9996,6 @@ struct llm_build_mamba : public llm_graph_context { const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; - // Use the same RMS norm as the final layer norm - const float norm_rms_eps = hparams.f_norm_rms_eps; const int64_t n_seq_tokens = ubatch.n_seq_tokens; @@ -9848,7 +10013,7 @@ struct llm_build_mamba : public llm_graph_context { cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -9877,10 +10042,10 @@ struct llm_build_mamba : public llm_graph_context { // then permute away the ne[0] dimension, // and then you're left with the resulting x tensor. // For simultaneous sequences, all sequences need to have the same length. - x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d); // bias - x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); + x = ggml_add(ctx0, x, layer.ssm_conv1d_b); x = ggml_silu(ctx0, x); } @@ -9888,27 +10053,27 @@ struct llm_build_mamba : public llm_graph_context { // ssm { // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} - ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); + ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); - // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers - if (ssm_dt_b_c_rms) { - dt = ggml_rms_norm(ctx0, dt, norm_rms_eps); - B = ggml_rms_norm(ctx0, B, norm_rms_eps); - C = ggml_rms_norm(ctx0, C, norm_rms_eps); + // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers + if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) { + dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il); + B = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il); } // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} - dt = build_lora_mm(model.layers[il].ssm_dt, dt); - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + dt = build_lora_mm(layer.ssm_dt, dt); + dt = ggml_add(ctx0, dt, layer.ssm_dt_b); cur = x; x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * A = model.layers[il].ssm_a; + ggml_tensor * A = layer.ssm_a; // use the states and the indices provided by build_recurrent_state // (this is necessary in order to properly use the states before they are overwritten, @@ -9934,16 +10099,15 @@ struct llm_build_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens - y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d)); + y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(layer.ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - // cb(cur, "mamba_out", il); return cur; } @@ -9955,7 +10119,8 @@ struct llm_build_mamba : public llm_graph_context { const llama_model & model, const llama_ubatch & ubatch, int il) const { - const auto * mctx_cur = static_cast(mctx); + + const auto * mctx_cur = inp->mctx; const auto kv_head = mctx_cur->get_head(); @@ -10078,6 +10243,515 @@ struct llm_build_mamba : public llm_graph_context { } }; +struct llm_build_mamba : public llm_graph_context_mamba { + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + auto * rs_inp = build_rs_inp(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (model.arch == LLM_ARCH_MAMBA2) { + cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); + } else { + cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +}; + +struct llm_build_jamba : public llm_graph_context_mamba { + llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + auto * inp_hybrid = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv(il); + + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (n_head_kv == 0) { + cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); + } else { + // Attention + + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // No RoPE :) + cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur); + cb(cur, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // FFN + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + // residual + cur = ggml_add(ctx0, ffn_inp, cur); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_plamo2 : public llm_graph_context_mamba { + llm_build_plamo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "embedding_output", -1); + + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_hybrid = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + // ggml_graph_add_node(gf, model.layers[il].attn_norm); + // cb(model.layers[il].attn_norm, "attn_norm", il); + + // pre_mixer_norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + + // check if this layer is Mamba or Attention + bool is_mamba_layer = hparams.is_recurrent(il); + + if (is_mamba_layer) { + // PLaMo-2 Mamba layer + cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); + } else { + // PLaMo-2 Attention layer + cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, gf, cur, model, il); + } + + // post_mixer_norm + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + residual = cur; + + // pre-ffn norm + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_pre_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + } + cb(cur, "ffn_out", il); + + // post ffn norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_residual", il); + + inpL = cur; + } + + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + // Explicitly mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +private: + ggml_tensor * build_plamo2_attn_layer( + llm_graph_input_attn_kv_unified * inp, + ggml_tensor * inp_pos, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + int il) { + + // self-attention + { + // PLaMo-2 uses combined QKV tensor + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "qkv", il); + + // split QKV tensor into Q, K, V + const int64_t n_embd_head_q = hparams.n_embd_head_k; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head_kv = hparams.n_head_kv(il); + + const int64_t q_offset = 0; + const int64_t k_offset = n_embd_head_q * n_head; + const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; + + ggml_tensor * Qcur = ggml_view_2d(ctx0, qkv, n_embd_head_q * n_head, n_tokens, qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, qkv, n_embd_head_k * n_head_kv, n_tokens, qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)); + + // make tensors contiguous before reshape + Qcur = ggml_cont(ctx0, Qcur); + Kcur = ggml_cont(ctx0, Kcur); + Vcur = ggml_cont(ctx0, Vcur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_q, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].wq, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].wk, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + // Store original K and V for KV cache (before GQA expansion) + ggml_tensor * Kcur_cache = Kcur; + ggml_tensor * Vcur_cache = Vcur; + + // PLaMo-2 GQA: expand K and V heads to match Q heads (equivalent to _expand_kv) + if (n_head_kv < n_head) { + // const int n_group = n_head / n_head_kv; + + // manually expand K and V tensors to repeat each head n_group times + // create expanded tensors with target dimensions + ggml_tensor * Kcur_expanded = ggml_new_tensor_3d(ctx0, Kcur->type, n_embd_head_k, n_head, n_tokens); + ggml_tensor * Vcur_expanded = ggml_new_tensor_3d(ctx0, Vcur->type, n_embd_head_v, n_head, n_tokens); + + // repeat each head n_group times + Kcur = ggml_repeat(ctx0, Kcur, Kcur_expanded); + Vcur = ggml_repeat(ctx0, Vcur, Vcur_expanded); + + cb(Kcur, "Kcur_expanded", il); + cb(Vcur, "Vcur_expanded", il); + } + + cur = build_attn(inp, gf, model.layers[il].wo, NULL, Qcur, Kcur_cache, Vcur_cache, NULL, NULL, 1.0f, il); + } + + cb(cur, "attn_out", il); + + return cur; + } + + ggml_tensor * build_plamo2_mamba_layer( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) { + + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_heads = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_heads; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); + cb(zx, "mamba_in_proj", il); + // {8192, 5, 1, 1} -> {8192, 1, 5, 1} + zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); + zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + cb(zx, "mamba_in_proj_out", il); + + // split into z and x + // => {head_dim * n_heads, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); + x = ggml_cont(ctx0, x); + x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + // x = ggml_permute(ctx0, x, 0, 2, 1, 3); + cb(x, "mamba_x_split", il); + + ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); + cb(z, "mamba_z_split", il); + + // conv1d + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + x = ggml_view_2d(ctx0, x, d_inner, n_seq_tokens * n_seqs, d_inner * x->nb[0], 0); + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + cb(conv_x, "mamba_conv1d_input", il); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, + conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + cb(x, "mamba_conv1d", il); + + x = ggml_silu(ctx0, x); + cb(x, "mamba_conv1d_silu", il); + } + + // SSM + { + // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); + cb(x_bcdt, "mamba_bcdt_proj", il); + + // split into dt, B, C + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); + ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state); + ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state)); + cb(B, "mamba_B_raw", il); + cb(C, "mamba_C_raw", il); + cb(dt, "mamba_dt_raw", il); + + // Apply RMS norm to dt, B, C (PLaMo-2 specific) + B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); + dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); + cb(B, "mamba_B_normed", il); + cb(C, "mamba_C_normed", il); + cb(dt, "mamba_dt_normed", il); + + // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + cb(dt, "mamba_dt_proj", il); + + ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); + cb(A, "mamba_A", il); + + x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); + C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + cb(y_ssm, "mamba_ssm_scan", il); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, + kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + cb(y, "mamba_y_view", il); + + // Add D parameter and apply gating with z + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); + cb(y, "mamba_y_add_d", il); + + ggml_tensor * z_silu = ggml_silu(ctx0, ggml_cont(ctx0, z)); + cb(z_silu, "mamba_z_silu", il); + + y = ggml_mul(ctx0, y, z_silu); + cb(y, "mamba_y_gated", il); + + // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); + cur = build_lora_mm(model.layers[il].ssm_out, y); + cb(cur, "mamba_out_proj", il); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; + } +}; + struct llm_build_command_r : public llm_graph_context { llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14854,6 +15528,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_PLAMO2: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_GPT2: { llm = std::make_unique(*this, params, gf); @@ -14899,6 +15577,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_JAMBA: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_XVERSE: { llm = std::make_unique(*this, params, gf); @@ -15157,6 +15839,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA2: + case LLM_ARCH_JAMBA: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -15215,6 +15898,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: diff --git a/src/llama-model.h b/src/llama-model.h index 979fff62045f9..2c14627e459ac 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -173,6 +173,9 @@ struct llama_layer { struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; struct ggml_tensor * ssm_norm = nullptr; + struct ggml_tensor * ssm_dt_norm = nullptr; + struct ggml_tensor * ssm_b_norm = nullptr; + struct ggml_tensor * ssm_c_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; @@ -244,10 +247,10 @@ struct llama_layer { struct ggml_tensor * ffn_exp_probs_b = nullptr; // mamba proj - struct ggml_tensor * ssm_in = nullptr; - struct ggml_tensor * ssm_x = nullptr; - struct ggml_tensor * ssm_dt = nullptr; - struct ggml_tensor * ssm_out = nullptr; + struct ggml_tensor * ssm_in = nullptr; + struct ggml_tensor * ssm_x = nullptr; + struct ggml_tensor * ssm_dt = nullptr; + struct ggml_tensor * ssm_out = nullptr; // mamba struct ggml_tensor * ssm_conv1d = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 5c9eb87566dde..ea76b00990109 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -1159,6 +1160,40 @@ struct llm_tokenizer_rwkv : llm_tokenizer { struct naive_trie token_matcher; }; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + // Build vocabulary entries for PLaMo-2 tokenizer + std::vector vocab_entries; + vocab_entries.reserve(vocab.n_tokens()); + + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & data = vocab.get_token_data(id); + llama_vocab_plamo2::vocab_entry entry; + entry.text = data.text; + entry.score = data.score; + + // Check if this is a byte token + if (vocab.is_byte(id)) { + entry.type = "BYTE"; + } else { + entry.type = ""; + } + + vocab_entries.push_back(entry); + } + + // Build the Aho-Corasick automaton + plamo2_tokenizer.build(vocab_entries); + } + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = plamo2_tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); + } + + llama_vocab_plamo2 plamo2_tokenizer; +}; + struct llm_tokenizer_rwkv_session { llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} @@ -1498,6 +1533,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_unk_id = LLAMA_TOKEN_NULL; special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -2134,12 +2179,13 @@ enum llama_vocab_type llama_vocab::impl::get_type() const { std::string llama_vocab::impl::type_name() const{ switch (type) { - case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - case LLAMA_VOCAB_TYPE_UGM: return "UGM"; - case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; + case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; + case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2"; default: return "unknown"; } } @@ -2223,6 +2269,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) { case LLAMA_VOCAB_TYPE_RWKV: tokenizer = std::make_unique(vocab); break; + case LLAMA_VOCAB_TYPE_PLAMO2: + tokenizer = std::make_unique(vocab); + break; default: GGML_ABORT("unsupported vocab type"); } @@ -2565,6 +2614,23 @@ std::vector llama_vocab::impl::tokenize( } } } break; + case LLAMA_VOCAB_TYPE_PLAMO2: + { + auto * plamo2_tokenizer = static_cast(tokenizer.get()); + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); +#endif + + plamo2_tokenizer->tokenize(text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ABORT("fatal error"); } @@ -2653,6 +2719,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t memcpy(buf, result.data(), result.size()); return (int)result.size(); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses similar token handling as BPE/SPM + if (vocab.is_byte(token)) { + // Handle byte tokens like <0xXX> + if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') { + int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16); + if (length < 1) { + return -1; + } + buf[0] = static_cast(hex_val); + return 1; + } + } + + // Normal token - just copy the text + std::string result = token_text; + return _try_copy(result.data(), result.size()); + } default: GGML_ABORT("fatal error"); } @@ -2897,6 +2981,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const { case LLAMA_VOCAB_TYPE_BPE: { return pimpl->token_to_id.at(unicode_byte_to_utf8(ch)); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses byte tokens in format <0xXX> + char hex_str[8]; + snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch); + return pimpl->token_to_id.at(hex_str); + } default: GGML_ABORT("fatal error"); } @@ -3375,3 +3465,286 @@ int32_t llama_detokenize( return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special); } +// +// PLaMo-2 Aho-Corasick tokenizer implementation +// + +llama_vocab_plamo2::llama_vocab_plamo2() : bytes_(256, 0) { +} + +llama_vocab_plamo2::~llama_vocab_plamo2() { +} + +void llama_vocab_plamo2::build(const std::vector & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.size(); ++token_id) { + const auto & entry = vocab[token_id]; + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (entry.type == "BYTE") { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } + + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; + for (size_t i = 1; i < entry.text.length(); ++i) { + std::string suffix = entry.text.substr(i); + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } + + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } + + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix + + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); + + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; + + for (const auto & s : suffixes) { + suffix_to_id[s] = num_pieces; + if (!s.empty()) { + // Convert first character to Unicode code point + std::vector unicode_chars = utf8_to_unicode(s); + if (!unicode_chars.empty()) { + int64_t piece_code = (static_cast(unicode_chars[0]) << 32) | suffix_to_id[s.substr(1)]; + to_suffix_id_[piece_code] = num_pieces; + } + } + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (size_t i = 1; i <= s.length(); ++i) { + std::string prefix = s.substr(0, i); + if (suffix_to_score.find(prefix) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } + + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; + + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + for (int32_t piece_length = static_cast(suffix.length()); piece_length > 0; --piece_length) { + std::string piece = suffix.substr(0, piece_length); + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } + + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; + + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; + + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } +} + +std::vector llama_vocab_plamo2::utf8_to_unicode(const std::string & text) const { + std::vector result; + const char * ptr = text.c_str(); + const char * end = ptr + text.length(); + + while (ptr < end) { + int32_t codepoint = 0; + int bytes_read = 0; + + if ((*ptr & 0x80) == 0) { + // ASCII + codepoint = *ptr; + bytes_read = 1; + } else if ((*ptr & 0xE0) == 0xC0) { + // 2-byte UTF-8 + codepoint = (*ptr & 0x1F) << 6; + codepoint |= (*(ptr + 1) & 0x3F); + bytes_read = 2; + } else if ((*ptr & 0xF0) == 0xE0) { + // 3-byte UTF-8 + codepoint = (*ptr & 0x0F) << 12; + codepoint |= (*(ptr + 1) & 0x3F) << 6; + codepoint |= (*(ptr + 2) & 0x3F); + bytes_read = 3; + } else if ((*ptr & 0xF8) == 0xF0) { + // 4-byte UTF-8 + codepoint = (*ptr & 0x07) << 18; + codepoint |= (*(ptr + 1) & 0x3F) << 12; + codepoint |= (*(ptr + 2) & 0x3F) << 6; + codepoint |= (*(ptr + 3) & 0x3F); + bytes_read = 4; + } else { + // Invalid UTF-8, skip this byte + ptr++; + continue; + } + + result.push_back(codepoint); + ptr += bytes_read; + } + + return result; +} + +std::vector llama_vocab_plamo2::encode_unicode(const std::vector & unicode_data) const { + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + int32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + } + } + } + + if (score == UNKNOWN_SCORE) { + break; + } + } + } + + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + int32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int j = 0; j < s; ++j) { + uint8_t b = (s == 1) ? c : + (j == 0) ? (0xF00 >> s) & 0xFF : + 0x80 | ((c >> ((s - j - 1) * 6)) & 0x3F); + token_ids.push_back(bytes_[b]); + } + } + + pos += path[pos][PATH_TOKEN_LENGTH]; + } + + return token_ids; +} + +std::vector llama_vocab_plamo2::encode(const std::string & text) const { + std::vector unicode_data = utf8_to_unicode(text); + return encode_unicode(unicode_data); +} + +std::vector llama_vocab_plamo2::encode_as_tokens(const std::string & text) const { + std::vector token_ids = encode(text); + std::vector result; + result.reserve(token_ids.size()); + + for (llama_token id : token_ids) { + if (id >= 0 && id < static_cast(tokens_.size())) { + result.push_back(tokens_[id]); + } + } + + return result; +} + +const std::string & llama_vocab_plamo2::get_token_text(llama_token id) const { + static const std::string empty_string; + if (id >= 0 && id < static_cast(tokens_.size())) { + return tokens_[id]; + } + return empty_string; +} + +size_t llama_vocab_plamo2::vocab_size() const { + return tokens_.size(); +} diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 40e4d1c05b18e..17314aac8e166 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -5,6 +5,7 @@ #include #include #include +#include struct LLM_KV; struct llama_model_loader; @@ -130,3 +131,68 @@ struct llama_vocab { struct impl; std::unique_ptr pimpl; }; + +// PLaMo-2 Aho-Corasick tokenizer +class llama_vocab_plamo2 { +public: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; + + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; + + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + struct vocab_entry { + std::string text; + float score; + std::string type; // "BYTE" for byte tokens, empty for normal tokens + }; + + llama_vocab_plamo2(); + ~llama_vocab_plamo2(); + + // Build the Aho-Corasick automaton from vocabulary + void build(const std::vector & vocab); + + // Encode text to token IDs + std::vector encode(const std::string & text) const; + + // Encode text to token strings + std::vector encode_as_tokens(const std::string & text) const; + + // Get token text by ID + const std::string & get_token_text(llama_token id) const; + + // Get vocabulary size + size_t vocab_size() const; + +private: + // Internal structures for Aho-Corasick automaton + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; + + // Helper functions + void build_suffix_map(const std::vector & vocab); + void build_trie_table(const std::vector & vocab); + std::vector utf8_to_unicode(const std::string & text) const; + std::vector encode_unicode(const std::vector & unicode_data) const; +};