diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b9e73a78ff4..f98dfe987c7 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -126,6 +126,9 @@ class DPOConfig(TrainingArguments): tools (`Optional[list[Union[dict, Callable]]]`, *optional*): List of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. + dataset_kwargs (`dict[str, Any]`, *optional*): + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. > Parameters that control the training @@ -301,6 +304,13 @@ class DPOConfig(TrainingArguments): default=None, metadata={"help": "Number of processes to use for processing the dataset."}, ) + dataset_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`." + }, + ) pad_token: Optional[str] = field( default=None, metadata={ diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index bfcc4b4c53e..8cdd0f4b1da 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -470,15 +470,19 @@ def __init__( self.dataset_num_proc = args.dataset_num_proc # Dataset preparation - train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") - if eval_dataset is not None: - if isinstance(eval_dataset, dict): - eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, key) - for key, dataset in eval_dataset.items() - } - else: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + skip_prepare_dataset = args.dataset_kwargs is not None and args.dataset_kwargs.get( + "skip_prepare_dataset", False + ) + if not skip_prepare_dataset: + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") super().__init__( model=model, @@ -991,6 +995,8 @@ def concatenated_inputs( ) if "image_sizes" in batch: output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) + if "image_grid_thw" in batch: + output["image_grid_thw"] = torch.cat([batch["image_grid_thw"], batch["image_grid_thw"]], dim=0) # Concatenate the chosen and rejected completions max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) @@ -1249,6 +1255,9 @@ def _compute_loss_liger( model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] if "image_sizes" in concatenated_batch: model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + # For Qwen-VL models + if "image_grid_thw" in concatenated_batch: + model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_thw"] prompt_attention_mask = concatenated_batch["prompt_attention_mask"] completion_attention_mask = concatenated_batch["completion_attention_mask"] @@ -1496,6 +1505,9 @@ def concatenated_forward( model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] if "image_sizes" in concatenated_batch: model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + # For Qwen-VL models + if "image_grid_thw" in concatenated_batch: + model_kwargs["image_grid_thw"] = concatenated_batch["image_grid_thw"] prompt_input_ids = concatenated_batch["prompt_input_ids"] prompt_attention_mask = concatenated_batch["prompt_attention_mask"]