3434from torch .utils .data import DataLoader , Sampler
3535from transformers import (
3636 AutoConfig ,
37+ AutoModelForCausalLM ,
3738 AutoModelForSequenceClassification ,
3839 AutoProcessor ,
3940 AutoTokenizer ,
@@ -239,9 +240,13 @@ def __init__(
239240 f"a `torch.dtype` (e.g., 'float32'), but got { dtype } ."
240241 )
241242 # Disable caching if gradient checkpointing is enabled (not supported)
242- config = AutoConfig .from_pretrained (model_id )
243- architecture = getattr (transformers , config .architectures [0 ])
244- model = architecture .from_pretrained (model_id , ** model_init_kwargs )
243+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = args .trust_remote_code )
244+ if architecture := getattr (transformers , config .architectures [0 ], None ):
245+ model = architecture .from_pretrained (model_id , ** model_init_kwargs )
246+ else :
247+ model = AutoModelForCausalLM .from_pretrained (
248+ model_id , trust_remote_code = args .trust_remote_code , ** model_init_kwargs
249+ )
245250 else :
246251 model_id = model .config ._name_or_path
247252 if args .model_init_kwargs is not None :
@@ -263,7 +268,9 @@ def __init__(
263268
264269 # Processing class
265270 if processing_class is None :
266- processing_class = AutoProcessor .from_pretrained (model .config ._name_or_path )
271+ processing_class = AutoProcessor .from_pretrained (
272+ model .config ._name_or_path , trust_remote_code = args .trust_remote_code
273+ )
267274
268275 # Handle pad token for processors or tokenizers
269276 if isinstance (processing_class , ProcessorMixin ):
@@ -427,9 +434,13 @@ def __init__(
427434 self .ref_model = None
428435 else :
429436 # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
430- config = AutoConfig .from_pretrained (model_id )
431- architecture = getattr (transformers , config .architectures [0 ])
432- self .ref_model = architecture .from_pretrained (model_id , ** model_init_kwargs )
437+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = args .trust_remote_code )
438+ if architecture := getattr (transformers , config .architectures [0 ], None ):
439+ self .ref_model = architecture .from_pretrained (model_id , ** model_init_kwargs )
440+ else :
441+ self .ref_model = AutoModelForCausalLM .from_pretrained (
442+ model_id , trust_remote_code = args .trust_remote_code , ** model_init_kwargs
443+ )
433444
434445 # Disable dropout in the models
435446 if args .disable_dropout :
@@ -537,6 +548,7 @@ def __init__(
537548 max_num_batched_tokens = 4096 ,
538549 model_impl = self .args .vllm_model_impl ,
539550 enable_sleep_mode = self .args .vllm_enable_sleep_mode ,
551+ trust_remote_code = self .args .trust_remote_code ,
540552 )
541553 if self .args .vllm_enable_sleep_mode :
542554 self .llm .sleep (level = 1 )
0 commit comments