@@ -90,7 +90,9 @@ def run_merge(
90
90
pad_to_multiple_of = None
91
91
if merge_config .tokenizer and merge_config .tokenizer .pad_to_multiple_of :
92
92
pad_to_multiple_of = merge_config .tokenizer .pad_to_multiple_of
93
- _update_config_vocab (cfg_out , tokenizer , pad_to_multiple_of = pad_to_multiple_of )
93
+ _update_config_vocab (
94
+ cfg_out , arch_info , tokenizer , pad_to_multiple_of = pad_to_multiple_of
95
+ )
94
96
95
97
logger .info ("Saving config" )
96
98
cfg_out .save_pretrained (out_path )
@@ -308,14 +310,15 @@ def _model_out_config(
308
310
309
311
def _update_config_vocab (
310
312
config : transformers .PretrainedConfig ,
313
+ arch_info : ModelArchitecture ,
311
314
tokenizer : transformers .PreTrainedTokenizerBase ,
312
315
pad_to_multiple_of : Optional [int ] = None ,
313
316
):
314
317
vocab_size = len (tokenizer .get_vocab ())
315
318
if pad_to_multiple_of and vocab_size % pad_to_multiple_of :
316
319
vocab_size = vocab_size + pad_to_multiple_of - (vocab_size % pad_to_multiple_of )
317
320
try :
318
- config . vocab_size = vocab_size
321
+ setattr ( config , arch_info . vocab_size_config_key or " vocab_size" , vocab_size )
319
322
except Exception as e :
320
323
logger .warning (
321
324
"Unable to set vocabulary size in output config - you may need to manually correct it." ,
0 commit comments