From 5e21192c5d61cd24ba54e0d37f96899207ea0363 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A2=D0=B8=D0=BC=D1=83=D1=80=20=D0=A5=D0=B0=D0=BD=D0=B8?= =?UTF-8?q?=D0=BF=D0=BE=D0=B2?= Date: Tue, 3 Sep 2024 14:49:18 +0300 Subject: [PATCH 1/2] Read weights for BertModel and RobertaModel if --model_dir is provided --- examples/bert/build.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/bert/build.py b/examples/bert/build.py index 300740951..5f90c31d5 100644 --- a/examples/bert/build.py +++ b/examples/bert/build.py @@ -221,8 +221,12 @@ def prepare_inputs(): output_name = 'hidden_states' if args.model == 'BertModel' or args.model == 'RobertaModel': - hf_bert = globals()[f'{model_type}Model'](bert_config, - add_pooling_layer=False) + if args.model_dir: + # Shall the same be done for other models? + hf_bert = globals()[f'{model_type}Model'].from_pretrained(args.model_dir, config=bert_config, add_pooling_layer=False) + else: + hf_bert = globals()[f'{model_type}Model'](bert_config, + add_pooling_layer=False) tensorrt_llm_bert = tensorrt_llm.models.BertModel( num_layers=bert_config.num_hidden_layers, num_heads=bert_config.num_attention_heads, From b0f93055dc789fafdf706152091a8c26c1a1b7fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=A2=D0=B8=D0=BC=D1=83=D1=80=20=D0=A5=D0=B0=D0=BD=D0=B8?= =?UTF-8?q?=D0=BF=D0=BE=D0=B2?= Date: Tue, 3 Sep 2024 14:59:55 +0300 Subject: [PATCH 2/2] fix indentation --- examples/bert/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/bert/build.py b/examples/bert/build.py index 5f90c31d5..f9a363880 100644 --- a/examples/bert/build.py +++ b/examples/bert/build.py @@ -226,7 +226,7 @@ def prepare_inputs(): hf_bert = globals()[f'{model_type}Model'].from_pretrained(args.model_dir, config=bert_config, add_pooling_layer=False) else: hf_bert = globals()[f'{model_type}Model'](bert_config, - add_pooling_layer=False) + add_pooling_layer=False) tensorrt_llm_bert = tensorrt_llm.models.BertModel( num_layers=bert_config.num_hidden_layers, num_heads=bert_config.num_attention_heads,