diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 7e7c8128a4b..9d3ceeb6cf2 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -338,9 +338,7 @@ def data_collator(features): # No data collation is needed in GRPO self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm - - if self.use_vllm: - raise ValueError("vLLM not supported yet.") + print(f"use_vllm: {self.use_vllm}") self.beta = args.beta @@ -393,7 +391,6 @@ def data_collator(features): # No data collation is needed in GRPO set_seed(args.seed, device_specific=True) if self.use_vllm: - raise ValueError("vLLM not supported yet.") if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " @@ -404,9 +401,12 @@ def data_collator(features): # No data collation is needed in GRPO vllm_device = self.args.vllm_device if vllm_device == "auto": if torch.cuda.device_count() == 1: + print("Only one GPU available, sharing it between vLLM and training.") vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it else: vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx + print(f"Using GPU {vllm_device} for vLLM.") + # Check that the requested device is available if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count(): raise ValueError( @@ -432,7 +432,7 @@ def data_collator(features): # No data collation is needed in GRPO return_value=None, ) with world_size_patch, profiling_patch: - self.llm = LLM( + self.vlm = LLM( model=model.name_or_path, device=vllm_device, gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, @@ -442,6 +442,8 @@ def data_collator(features): # No data collation is needed in GRPO # This is particularly useful here because we generate completions from the same prompts. enable_prefix_caching=True, max_model_len=self.args.vllm_max_model_len, + # Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it. + limit_mm_per_prompt={"image": 1, "video": 0}, ) self.sampling_params = SamplingParams( temperature=args.temperature, @@ -543,7 +545,6 @@ def _get_per_token_logps( return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens def _move_model_to_vllm(self): - raise ValueError("vLLM not supported yet.") with unwrap_model_for_generation( self.model, self.accelerator, @@ -568,13 +569,13 @@ def _move_model_to_vllm(self): else: state_dict = unwrapped_model.state_dict() if self.accelerator.is_main_process: - llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights(state_dict.items()) + vlm_model = self.vlm.llm_engine.model_executor.driver_worker.model_runner.model + vlm_model.load_weights(state_dict.items()) def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device - prompt_inputs, prompts_text, prompts = self.tokenize_and_inject_images( + prompt_inputs, vllm_inputs, prompts_text, prompts = self.tokenize_and_inject_images( inputs=inputs, processing_class=self.processing_class ) @@ -588,28 +589,35 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s ) if self.max_prompt_length is not None: + if self.use_vllm: + raise ValueError( + "max_prompt_length is not supported when using vLLM. Please set it to None if vLLM is used. This is because we don't control tokenization when using vLLM." + ) + prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :] # Generate completions using either vLLM or regular generation - if self.args.use_vllm: - raise ValueError("vLLM not supported yet.") + if self.use_vllm: # First, have main process load weights if needed if self.state.global_step != self._last_loaded_step: self._move_model_to_vllm() self._last_loaded_step = self.state.global_step - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + # Generate completions using vLLM: gather all prompt inputs and use them in a single call in the main process + all_vllm_inputs = gather_object(vllm_inputs) all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: - outputs = self.llm.generate( - all_prompts_text, + outputs = self.vlm.generate( + all_vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False, ) completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] else: completion_ids = [None] * len(all_prompts_text) + # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. completion_ids = broadcast_object_list(completion_ids, from_process=0) @@ -621,7 +629,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id) prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: # Regular generation path