diff --git a/examples/bert/build.py b/examples/bert/build.py index 300740951..f9a363880 100644 --- a/examples/bert/build.py +++ b/examples/bert/build.py @@ -221,8 +221,12 @@ def prepare_inputs(): output_name = 'hidden_states' if args.model == 'BertModel' or args.model == 'RobertaModel': - hf_bert = globals()[f'{model_type}Model'](bert_config, - add_pooling_layer=False) + if args.model_dir: + # Shall the same be done for other models? + hf_bert = globals()[f'{model_type}Model'].from_pretrained(args.model_dir, config=bert_config, add_pooling_layer=False) + else: + hf_bert = globals()[f'{model_type}Model'](bert_config, + add_pooling_layer=False) tensorrt_llm_bert = tensorrt_llm.models.BertModel( num_layers=bert_config.num_hidden_layers, num_heads=bert_config.num_attention_heads,