Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,11 @@ class GRPOConfig(TrainingArguments):
default=False,
metadata={"help": "Whether to log the completions during training."},
)
limit_image_per_prompt: int = field(
default=1,
metadata={"help": "Limit the number of images per prompt for vllm generation."},
)
limit_video_per_prompt: int = field(
default=0,
metadata={"help": "Limit the number of videos per prompt for vllm generation."},
)
94 changes: 87 additions & 7 deletions trl/trainer/qwen_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def __init__(
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
shuffle_dataset: bool = True,
image_pad_id: int = 151655,
inputs_to_log: list[str] = [],
):
# Args
if args is None:
Expand Down Expand Up @@ -439,7 +441,7 @@ def data_collator(features): # No data collation is needed in GRPO
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
# Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it.
limit_mm_per_prompt={"image": 1, "video": 0},
limit_mm_per_prompt={"image": self.args.limit_image_per_prompt, "video": self.args.limit_video_per_prompt},
)
self.sampling_params = SamplingParams(
temperature=args.temperature,
Expand Down Expand Up @@ -478,6 +480,9 @@ def data_collator(features): # No data collation is needed in GRPO
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)

self.image_pad_id = image_pad_id
self.inputs_to_log = inputs_to_log

def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
Expand Down Expand Up @@ -621,15 +626,22 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
if self.env is None:
raise ValueError("No environment provided. Only supporting envs now.")
else:
completion_ids = self.env.generate(
generated_output = self.env.generate(
conversations=all_conversations,
vlm_inputs=all_env_inputs,
vlm=self.vlm,
sampling_params=self.sampling_params,
)

completion_ids = generated_output['ids']
completion_messages = generated_output.get('messages', None)
completion_mask = generated_output.get('mask', None)


else:
completion_ids = [None] * len(all_env_inputs)
completion_messages = [None] * len(all_env_inputs)
completion_mask = [None] * len(all_env_inputs)

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
Expand All @@ -640,14 +652,61 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
)
completion_ids = completion_ids[process_slice]

eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device)

# Pad completion_ids to uniform length, mask from last output token (EOS)
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id)
sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

# broadcast and slice completion messages too.
completion_messages = broadcast_object_list(completion_messages, from_process=0)
completion_messages = completion_messages[process_slice]

# Handle completion mask: broadcast from main process to all processes if available
if completion_mask is not None:
# Broadcast the completion_mask from the main process to all processes
completion_mask = broadcast_object_list(completion_mask, from_process=0)

# Each process takes its corresponding slice based on process index
completion_mask = completion_mask[process_slice]

# Convert mask elements to tensors and move to correct device
completion_mask = [torch.tensor(mask, device=device) for mask in completion_mask]
# Pad masks to uniform length
completion_mask = pad(completion_mask, padding_value=0)
else:
print("No completion mask provided. Computing mask based on EOS positions.")
# Fallback: compute mask based on EOS positions if not provided
eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device)
sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)

# Handle the potential new images generated from the environment (tool) in completion_messages
new_images = []
for i, completion_message in enumerate(completion_messages):
if completion_message is not None:
for message in completion_message:
for content in message["content"]:
if content.get("type", None) == "image":
new_images.append(content["image"])

if len(new_images) > 0:
# use the processor to get pixel_values and image_grid_thw for the new images
new_images_info = self.processing_class(
text='',
images=new_images,
return_tensors='pt',
padding=True,
)
new_pixel_values = new_images_info["pixel_values"]
new_image_grid_thw = new_images_info["image_grid_thw"]

# Concatenate the new pixel_values and image_grid_thw with the existing ones
# make sure pixel_values and new_pixel_values are on the same device. same for image_grid_thw and new_image_grid_thw
new_pixel_values = new_pixel_values.to(device)
new_image_grid_thw = new_image_grid_thw.to(device)
pixel_values = torch.cat([pixel_values, new_pixel_values], dim=0)
image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0)
else:
raise ValueError("Attempted to generate with HF. Only supporting vllm now.")

Expand Down Expand Up @@ -682,7 +741,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(conversations, completions_text):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
bootstrap = prompt[-1]["content"] if prompt[-1]["role"] == "assistant" else ""
if isinstance(bootstrap, list):
if len(bootstrap) > 1:
raise ValueError("Only one bootstrap is supported for now.")
Expand All @@ -697,6 +756,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
raise NotImplementedError("Models as reward functions are not supported yet.")
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(conversations, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
Expand All @@ -717,6 +777,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
reward_kwargs["prompts_text"] = prompts_text
reward_kwargs["completions_messages"] = completion_messages
output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

Expand Down Expand Up @@ -779,11 +840,30 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
import pandas as pd

# For logging
inputs_data_to_log = {
key: gather_object(
[i[key] for i in inputs if key in i]
) for key in self.inputs_to_log
}
# if the value is torch.Tensor, convert it to a list
for key, value in inputs_data_to_log.items():
if isinstance(value, torch.Tensor):
inputs_data_to_log[key] = value.tolist()

# gather completion_ids and get num_image_pad_ids
# completion_ids shape: (B*G, C) B is batch size, G is number of generations, C is completion length
gathered_completion_ids = gather_object(completion_ids)
# after gathering, there will be B*G items and each item is a tensor of shape their own(C,)
# handle each item one by one
num_image_pad_ids = [(ids == self.image_pad_id).sum().item() for ids in gathered_completion_ids]
table = {
"step": [str(self.state.global_step)] * len(rewards),
"prompt": gather_object(prompts_text),
"completion": gather_object(completions_text),
"reward": rewards.tolist(),
"reward_per_func": rewards_per_func.tolist(),
"num_image_pad_ids": num_image_pad_ids,
**inputs_data_to_log,
}
df = pd.DataFrame(table)

Expand Down