@@ -758,7 +758,14 @@ def __init__(
758758        else :
759759            self .completion_only_loss  =  args .completion_only_loss 
760760
761-         if  data_collator  is  None  and  not  self ._is_vlm :
761+         self ._is_vision_dataset  =  "image"  in  dataset_sample  or  "images"  in  dataset_sample 
762+         if  self ._is_vision_dataset  and  not  self ._is_vlm :
763+             raise  ValueError (
764+                 "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " 
765+                 "model does not seem to be a vision-language model. Please check your model and dataset." 
766+             )
767+ 
768+         if  data_collator  is  None  and  not  self ._is_vision_dataset :
762769            # Get the pad token: if not provided, use the one from the processing class or the eos token 
763770            # if the processing class does not have a pad token. 
764771            pad_token  =  args .pad_token  or  tokenizer .pad_token  or  tokenizer .eos_token 
@@ -777,7 +784,7 @@ def __init__(
777784                return_position_ids = use_flash_attention ,
778785                pad_to_multiple_of = args .pad_to_multiple_of ,
779786            )
780-         elif  data_collator  is  None  and  self ._is_vlm :
787+         elif  data_collator  is  None  and  self ._is_vision_dataset :
781788            data_collator  =  DataCollatorForVisionLanguageModeling (
782789                processor = processing_class ,
783790                max_length = args .max_length ,
@@ -805,7 +812,9 @@ def __init__(
805812        # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where 
806813        # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead. 
807814        skip_prepare_dataset  =  (
808-             args .dataset_kwargs  is  not   None  and  args .dataset_kwargs .get ("skip_prepare_dataset" , False ) or  self ._is_vlm 
815+             args .dataset_kwargs  is  not   None 
816+             and  args .dataset_kwargs .get ("skip_prepare_dataset" , False )
817+             or  self ._is_vision_dataset 
809818        )
810819        if  not  skip_prepare_dataset :
811820            if  self .completion_only_loss  and  formatting_func :
@@ -959,22 +968,36 @@ def add_eos(example, eos_token):
959968                if  isinstance (dataset , Dataset ):  # `IterableDataset.map` does not support `desc` 
960969                    map_kwargs ["desc" ] =  f"Tokenizing { dataset_name }   dataset" 
961970
962-                 def  tokenize (example , processing_class , dataset_text_field , assistant_only_loss ):
971+                 def  tokenize_fn (example , processing_class , dataset_text_field , assistant_only_loss ):
963972                    if  "prompt"  in  example :  # prompt-completion case 
964973                        output  =  {}
965974                        if  is_conversational (example ):
975+                             if  self ._is_vlm :
976+                                 prepare_multimodal_messages (example ["prompt" ], num_images = 0 )
977+                                 prepare_multimodal_messages (example ["completion" ], num_images = 0 )
966978                            prompt_ids  =  processing_class .apply_chat_template (
967979                                example ["prompt" ],
980+                                 tokenize = True ,
968981                                tools = example .get ("tools" ),
969982                                ** example .get ("chat_template_kwargs" , {}),
970983                            )
984+                             # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists 
985+                             # even for single examples, while for LLMs it returns lists of ints. 
986+                             prompt_ids  =  prompt_ids [0 ] if  isinstance (prompt_ids [0 ], list ) else  prompt_ids 
971987                            prompt_completion_processed  =  processing_class .apply_chat_template (
972988                                example ["prompt" ] +  example ["completion" ],
973989                                return_dict = True ,
990+                                 tokenize = True ,
974991                                return_assistant_tokens_mask = assistant_only_loss ,
975992                                tools = example .get ("tools" ),
976993                                ** example .get ("chat_template_kwargs" , {}),
977994                            )
995+                             # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists 
996+                             # even for single examples, while for LLMs it returns lists of ints. 
997+                             prompt_completion_processed  =  {
998+                                 k : v [0 ] if  isinstance (v [0 ], list ) else  v 
999+                                 for  k , v  in  prompt_completion_processed .items ()
1000+                             }
9781001                            prompt_completion_ids  =  prompt_completion_processed ["input_ids" ]
9791002                            if  "assistant_masks"  in  prompt_completion_processed :
9801003                                output ["assistant_masks" ] =  prompt_completion_processed ["assistant_masks" ]
@@ -999,13 +1022,19 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
9991022
10001023                    else :  # language modeling case 
10011024                        if  is_conversational (example ):
1025+                             if  self ._is_vlm :
1026+                                 prepare_multimodal_messages (example ["messages" ], num_images = 0 )
10021027                            processed  =  processing_class .apply_chat_template (
10031028                                example ["messages" ],
10041029                                return_dict = True ,
1030+                                 tokenize = True ,
10051031                                return_assistant_tokens_mask = assistant_only_loss ,
10061032                                tools = example .get ("tools" ),
10071033                                ** example .get ("chat_template_kwargs" , {}),
10081034                            )
1035+                             # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists 
1036+                             # even for single examples, while for LLMs it returns lists of ints. 
1037+                             processed  =  {k : v [0 ] if  isinstance (v [0 ], list ) else  v  for  k , v  in  processed .items ()}
10091038                            if  "assistant_masks"  in  processed  and  1  not  in   processed ["assistant_masks" ]:
10101039                                raise  RuntimeError (
10111040                                    "You're using `assistant_only_loss=True`, but at least one example has no " 
@@ -1020,7 +1049,7 @@ def tokenize(example, processing_class, dataset_text_field, assistant_only_loss)
10201049                    return  output 
10211050
10221051                dataset  =  dataset .map (
1023-                     tokenize ,
1052+                     tokenize_fn ,
10241053                    fn_kwargs = {
10251054                        "processing_class" : processing_class ,
10261055                        "dataset_text_field" : args .dataset_text_field ,
@@ -1064,7 +1093,7 @@ def _set_signature_columns_if_needed(self):
10641093        # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the 
10651094        # dataset. So we need to override the default signature columns to include "completion_mask" as well. 
10661095        if  self ._signature_columns  is  None :
1067-             if  self ._is_vlm :
1096+             if  self ._is_vision_dataset :
10681097                self ._signature_columns  =  ["messages" , "prompt" , "completion" , "images" ]
10691098            else :
10701099                self ._signature_columns  =  ["input_ids" , "labels" , "seq_lengths" , "completion_mask" , "assistant_masks" ]
0 commit comments