Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"dpo_trainer": ["DPOTrainer"],
"gkd_config": ["GKDConfig"],
"gkd_trainer": ["GKDTrainer"],
"qwen_grpo_trainer": ["QwenGRPOTrainer"],
"qwen_grpo_trainer": ["QwenGRPOTrainer", "ToolDefinition"],
"grpo_config": ["GRPOConfig"],
"grpo_trainer": ["GRPOTrainer"],
"iterative_sft_trainer": ["IterativeSFTTrainer"],
Expand Down Expand Up @@ -134,7 +134,7 @@
from .ppo_trainer import PPOTrainer
from .prm_config import PRMConfig
from .prm_trainer import PRMTrainer
from .qwen_grpo_trainer import QwenGRPOTrainer
from .qwen_grpo_trainer import QwenGRPOTrainer, ToolDefinition
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer
from .rloo_config import RLOOConfig
Expand Down Expand Up @@ -164,3 +164,12 @@
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

# Importing from qwen_grpo_trainer to expose them at the package level.
from .qwen_grpo_trainer import QwenGRPOTrainer, ToolDefinition

# Define __all__ to explicitly list the public API of this package.
__all__ = [
"QwenGRPOTrainer",
"ToolDefinition",
]
5 changes: 5 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,8 @@ class GRPOConfig(TrainingArguments):
default=0.04,
metadata={"help": "KL coefficient."},
)

loss_magnifier: float = field(
default=1.0,
metadata={"help": "Multiplies the loss on the way out to avoid underflow."},
)
169 changes: 161 additions & 8 deletions trl/trainer/qwen_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import textwrap
import warnings

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import copy
import math
import os
import textwrap
import warnings

import torch
import torch.utils.data
Expand Down Expand Up @@ -50,6 +53,7 @@
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url, pad

VERBOSE = int(os.environ.get("VERBOSE", "0"))

if is_peft_available():
from peft import PeftConfig, get_peft_model
Expand All @@ -65,6 +69,73 @@
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


@dataclass
class ToolDefinition:
"""Basic metadata that the trainer needs to know about the tools."""
stop_string: str
call_tool: Callable[[torch.Tensor], torch.Tensor]

def completion_has_tool_call(self, completion_str: str) -> bool:
"""Check if the completion has a tool call."""
return self.stop_string in completion_str


class SingleConversationWithTools:
"""Keeps track of the prompt, and knows how to put together the partial responses and the tool call responses."""

def __init__(self, prompt_inputs: dict[str, torch.Tensor], tool_defn: ToolDefinition, processing_class: PreTrainedTokenizerBase):
self.prompt_inputs = prompt_inputs
self.tool_defn = tool_defn
self.response = []
self.processing_class = processing_class

def process_response(self, prompt_completion_ids: torch.Tensor) -> bool:
"""Adds the response to the conversation, including calling the tool if necessary.
Returns True if there was a tool call, and the conversation should continue.
Returns False if there was no tool call, and the conversation is complete.
"""
self.response.append(prompt_completion_ids)
prompt_completion_str = self.processing_class.tokenizer.decode(prompt_completion_ids[0], skip_special_tokens=True)
# Check if the stop string is in the completions
# We need to convert the tensor to a string.
if self.tool_defn.completion_has_tool_call(prompt_completion_str):
try:
tool_response_str = self.tool_defn.call_tool(prompt_completion_str)
except Exception as e:
tool_response_str = f"Tool failed: {e}\n"
tool_response_ids_list = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this doesn't add an extra BOS token?

tool_response_ids = torch.tensor(tool_response_ids_list, device=prompt_completion_ids.device) # [L]
tool_response_ids = tool_response_ids[None, :] # [1, L]
self.response.append(tool_response_ids)
self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, prompt_completion_ids)
self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, tool_response_ids)
# Note: we're gonna have to figure out images.
return True
else:
# No tool call, so we're done.
return False

def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]:
"""Add the response to the prompt inputs."""
if VERBOSE > 0:
addition_str = self.processing_class.decode(response[0])
print(f"Adding response: {addition_str}")
prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1)
ones = torch.ones_like(response, device=response.device)
prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], ones], dim=1)
return prompt_inputs


def get_just_completion_ids(self) -> torch.Tensor:
"""Returns the response (not including the prompt) as a tensor."""
# String together all the response tensors on their long dimension.
return torch.cat(self.response, dim=1)

def get_prompt_completion_ids(self) -> torch.Tensor:
"""Returns the prompt and completion as a tensor. The full completion includes the prompt and the response."""
out = torch.cat([self.prompt_inputs["input_ids"], self.get_just_completion_ids()], dim=1)
return out

class QwenGRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
Expand Down Expand Up @@ -147,6 +218,10 @@ class QwenGRPOTrainer(Trainer):
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
tool_defn ([`~trl.ToolDefinition`], *optional*, defaults to `None`):
Tool definition used to define the tool call.
loss_magnifier (float, *optional*, defaults to 1.0):
Multiplies the loss on the way out to avoid underflow.
"""

_tag_names = ["trl", "grpo"]
Expand All @@ -164,6 +239,8 @@ def __init__(
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
tool_defn: Optional[ToolDefinition] = None,
loss_magnifier: float = 1.0,
):
# Args
if args is None:
Expand Down Expand Up @@ -220,6 +297,8 @@ def __init__(
)
self.reward_funcs = reward_funcs

self.tool_defn = tool_defn

# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
Expand Down Expand Up @@ -250,6 +329,7 @@ def data_collator(features): # No data collation is needed in GRPO
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm
self.loss_magnifier = args.loss_magnifier

self.beta = args.beta

Expand Down Expand Up @@ -330,12 +410,19 @@ def data_collator(features): # No data collation is needed in GRPO
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
# No vLLM, so we use the regular generation config

stop_strings = None
if self.tool_defn:
stop_strings = [self.tool_defn.stop_string]

self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=args.temperature,
num_return_sequences=self.num_generations,
pad_token_id=processing_class.tokenizer.pad_token_id,
stop_strings=stop_strings,
)

# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
Expand Down Expand Up @@ -369,6 +456,59 @@ def _set_signature_columns_if_needed(self):
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
return inputs

def _generate_completion(
self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, thanks for making this a function, clearly the right move.

) -> torch.Tensor:
"""Generate completion(s) using the model."""
temp_generation_config = copy.deepcopy(self.generation_config)
temp_generation_config.num_return_sequences = 1
prompt_completion_ids = model.generate(
**prompt_inputs,
generation_config=temp_generation_config,
tokenizer=self.processing_class.tokenizer,
)
return prompt_completion_ids

def _generate_single_completion_with_tools(
self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor], max_steps: int = 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit here - BatchFeature

) -> torch.Tensor:
"""Iterates between generation and tool calling.

Note this is currently only called from the non-vLLM path

prompt_inputs is a dict with the following keys:
- input_ids: [1, 710] ints. Some stuff at the beginning and the end, the middle full of 151655
- attention_mask: [1, 710] ints. All 1
- pixel_values: 2024x1176 floats. The image.
- image_grid_thw: a 1x3 tensor with values: [1, 46, 44].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add a short comment about what max _steps is.

My understanding: The generation will stop once a tool is called, then this code processes the tool call. max_steps is the maximum number of tools we're willing to process for a single completion?

(Note that 46*44 is 2024).
"""
conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class)
# Loop until tool isn't called, of we max out
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Loop until tool isn't called, of we max out
# Loop until tool isn't called, or we max out

for step in range(max_steps):
if VERBOSE > 1:
print(f"\n\n\nGenerating completion with tool call. Step {step}. Shapes of inputs:")
for key, val in prompt_inputs.items():
print(f"{key}: {val.shape}")
print(f"Text of the prompt: {self.processing_class.decode(prompt_inputs['input_ids'][0])}")
prompt_completion_ids = self._generate_completion(model, prompt_inputs)
# prompt_completion_ids is a tensor of shape (1, L) Because we only generated one completion.
# Note that L includes both the prompt and the response.
# We only want to process the response, so we'll strip the prompt.
input_length = len(prompt_inputs["input_ids"][0])
ids_to_process = prompt_completion_ids[:, input_length:]
tool_was_used = conv.process_response(ids_to_process)
if not tool_was_used:
break

full_completion = conv.get_prompt_completion_ids()
if VERBOSE > 0:
print(f"\nDONE!")
print(f"Final completion (with prompt):\n{self.processing_class.decode(full_completion[0,:])}")
print(f"^^^ Final Response!\n\n\n\n\n\n")
return full_completion


def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
Expand Down Expand Up @@ -419,17 +559,17 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
# Regular generation path (not using vLLM)
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
# Generate N times, each generate one with the temp_generation_config
num_generations = self.generation_config.num_return_sequences
temp_generation_config = copy.deepcopy(self.generation_config)
temp_generation_config.num_return_sequences = 1

all_completions = []

for i in range(num_generations):
completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config)
if self.tool_defn:
completion = self._generate_single_completion_with_tools(unwrapped_model, prompt_inputs)
else:
completion = self._generate_completion(unwrapped_model, prompt_inputs)
all_completions.append(completion)

# Stack all completions and pad if needed
Expand All @@ -449,6 +589,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

# Stack all padded completions
prompt_completion_ids = torch.cat(padded_completions, dim=0)
if VERBOSE > 0:
print(f"Done generating {num_generations} completions.")

prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
Expand Down Expand Up @@ -542,6 +684,12 @@ def get_per_token_logps(model, input_ids):
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# We could break this down like this:
#loss = (per_token_loss * completion_mask).sum(dim=1) # Now we have a [B] tensor of losses for each example
#loss = loss / completion_mask.sum(dim=1) # normalize by number of unmasked tokens. Still [B]
#loss = loss.mean() # average across the batch.
# Rescale to avoid underflow - we see losses underflow to 0 when they're around 1e-7, which is common
loss = loss * self.loss_magnifier

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
Expand Down Expand Up @@ -644,3 +792,8 @@ def create_model_card(
)

model_card.save(os.path.join(self.args.output_dir, "README.md"))



def simple_stats(x) -> str:
return f"Min: {x.min()}, Mean: {x.mean()}, Max: {x.max()}"