Skip to content

Commit 84a093e

Browse files
committed
Correctly set vocab size in Gemma3 merges
1 parent 48b0d48 commit 84a093e

File tree

6 files changed

+18
-5
lines changed

6 files changed

+18
-5
lines changed

mergekit/_data/architectures/gemma3vl.json

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"preprocessor_config.json",
99
"processor_config.json"
1010
],
11+
"vocab_size_config_key": "text_config.vocab_size",
1112
"modules": {
1213
"text_decoder": {
1314
"weight_prefix": "language_model.",

mergekit/architecture/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class ModelArchitecture(BaseModel, frozen=True):
128128
architectures: List[str]
129129
expected_model_type: str = Field(alias="model_type")
130130
tagalong_files: Optional[List[str]] = None
131+
vocab_size_config_key: Optional[str] = None
131132

132133
def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
133134
res = []

mergekit/architecture/json_definitions.py

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class JsonModularArchitectureDefinition(BaseModel, frozen=True):
105105
architectures: List[str]
106106
expected_model_type: str = Field(alias="model_type")
107107
tagalong_files: Optional[List[str]] = None
108+
vocab_size_config_key: Optional[str] = None
108109

109110

110111
class TemplateWithArithmetic(string.Template):
@@ -154,6 +155,7 @@ def _load_architecture_json(name: str) -> ModelArchitecture:
154155
architectures=parsed.architectures,
155156
model_type=parsed.expected_model_type,
156157
tagalong_files=parsed.tagalong_files,
158+
vocab_size_config_key=parsed.vocab_size_config_key,
157159
)
158160
elif data.get("kind", "module") == "module":
159161
module = JsonModuleArchitecture(

mergekit/merge.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def run_merge(
9090
pad_to_multiple_of = None
9191
if merge_config.tokenizer and merge_config.tokenizer.pad_to_multiple_of:
9292
pad_to_multiple_of = merge_config.tokenizer.pad_to_multiple_of
93-
_update_config_vocab(cfg_out, tokenizer, pad_to_multiple_of=pad_to_multiple_of)
93+
_update_config_vocab(
94+
cfg_out, arch_info, tokenizer, pad_to_multiple_of=pad_to_multiple_of
95+
)
9496

9597
logger.info("Saving config")
9698
cfg_out.save_pretrained(out_path)
@@ -308,14 +310,15 @@ def _model_out_config(
308310

309311
def _update_config_vocab(
310312
config: transformers.PretrainedConfig,
313+
arch_info: ModelArchitecture,
311314
tokenizer: transformers.PreTrainedTokenizerBase,
312315
pad_to_multiple_of: Optional[int] = None,
313316
):
314317
vocab_size = len(tokenizer.get_vocab())
315318
if pad_to_multiple_of and vocab_size % pad_to_multiple_of:
316319
vocab_size = vocab_size + pad_to_multiple_of - (vocab_size % pad_to_multiple_of)
317320
try:
318-
config.vocab_size = vocab_size
321+
setattr(config, arch_info.vocab_size_config_key or "vocab_size", vocab_size)
319322
except Exception as e:
320323
logger.warning(
321324
"Unable to set vocabulary size in output config - you may need to manually correct it.",

mergekit/scripts/tokensurgeon.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,11 @@ def main(
199199
tokenizer.save_pretrained(out_path)
200200
cfg_out = arch_info.config
201201
try:
202-
cfg_out.vocab_size = new_embed.shape[0]
202+
setattr(
203+
cfg_out,
204+
arch_info.info.vocab_size_config_key or "vocab_size",
205+
new_embed.shape[0],
206+
)
203207
except AttributeError:
204208
LOG.error(
205209
"Could not set vocab size in config.json - you may need to update it manually."

mergekit/tokenizer/build.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from pydantic import BaseModel
1414
from typing_extensions import Literal
1515

16-
from mergekit.common import ModelPath, ModelReference
16+
from mergekit.architecture import arch_info_for_config
17+
from mergekit.common import ModelPath, ModelReference, get_config_value
1718
from mergekit.graph import Task
1819

1920
logger = logging.getLogger(__name__)
@@ -26,7 +27,8 @@ def get_vocab_size(model_path: ModelPath, trust_remote_code: bool) -> Optional[i
2627
revision=model_path.revision,
2728
trust_remote_code=trust_remote_code,
2829
)
29-
return cfg.vocab_size
30+
arch_info = arch_info_for_config(cfg)
31+
return get_config_value(cfg, arch_info.vocab_size_config_key or "vocab_size")
3032
except Exception as e:
3133
logger.warning(f"Unable to get vocab size for {model_path}", exc_info=e)
3234

0 commit comments

Comments
 (0)