-
Notifications
You must be signed in to change notification settings - Fork 0
Rough draft of tool support #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: improve_performance
Are you sure you want to change the base?
Changes from all commits
5f15aaa
29a8ef1
c1c8a92
e6e4761
7d72188
69291c4
92fab09
f73e913
c3c10fe
8186689
c6da400
d2ee50b
5ac1d04
efc25df
d88f6c2
ae74b4a
68965d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,13 +12,16 @@ | |||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
import os | ||||||
import textwrap | ||||||
import warnings | ||||||
|
||||||
from collections import defaultdict | ||||||
from dataclasses import dataclass | ||||||
from typing import Any, Callable, Optional, Union | ||||||
from unittest.mock import patch | ||||||
import copy | ||||||
import math | ||||||
import os | ||||||
import textwrap | ||||||
import warnings | ||||||
|
||||||
import torch | ||||||
import torch.utils.data | ||||||
|
@@ -50,6 +53,7 @@ | |||||
from .grpo_config import GRPOConfig | ||||||
from .utils import generate_model_card, get_comet_experiment_url, pad | ||||||
|
||||||
VERBOSE = int(os.environ.get("VERBOSE", "0")) | ||||||
|
||||||
if is_peft_available(): | ||||||
from peft import PeftConfig, get_peft_model | ||||||
|
@@ -65,6 +69,73 @@ | |||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class ToolDefinition: | ||||||
"""Basic metadata that the trainer needs to know about the tools.""" | ||||||
stop_string: str | ||||||
call_tool: Callable[[torch.Tensor], torch.Tensor] | ||||||
|
||||||
def completion_has_tool_call(self, completion_str: str) -> bool: | ||||||
"""Check if the completion has a tool call.""" | ||||||
return self.stop_string in completion_str | ||||||
|
||||||
|
||||||
class SingleConversationWithTools: | ||||||
"""Keeps track of the prompt, and knows how to put together the partial responses and the tool call responses.""" | ||||||
|
||||||
def __init__(self, prompt_inputs: dict[str, torch.Tensor], tool_defn: ToolDefinition, processing_class: PreTrainedTokenizerBase): | ||||||
self.prompt_inputs = prompt_inputs | ||||||
self.tool_defn = tool_defn | ||||||
self.response = [] | ||||||
self.processing_class = processing_class | ||||||
|
||||||
def process_response(self, prompt_completion_ids: torch.Tensor) -> bool: | ||||||
"""Adds the response to the conversation, including calling the tool if necessary. | ||||||
Returns True if there was a tool call, and the conversation should continue. | ||||||
Returns False if there was no tool call, and the conversation is complete. | ||||||
""" | ||||||
self.response.append(prompt_completion_ids) | ||||||
prompt_completion_str = self.processing_class.tokenizer.decode(prompt_completion_ids[0], skip_special_tokens=True) | ||||||
# Check if the stop string is in the completions | ||||||
# We need to convert the tensor to a string. | ||||||
if self.tool_defn.completion_has_tool_call(prompt_completion_str): | ||||||
try: | ||||||
tool_response_str = self.tool_defn.call_tool(prompt_completion_str) | ||||||
except Exception as e: | ||||||
tool_response_str = f"Tool failed: {e}\n" | ||||||
tool_response_ids_list = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) | ||||||
tool_response_ids = torch.tensor(tool_response_ids_list, device=prompt_completion_ids.device) # [L] | ||||||
tool_response_ids = tool_response_ids[None, :] # [1, L] | ||||||
self.response.append(tool_response_ids) | ||||||
self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, prompt_completion_ids) | ||||||
self.prompt_inputs = self._add_response_to_prompt_inputs(self.prompt_inputs, tool_response_ids) | ||||||
# Note: we're gonna have to figure out images. | ||||||
return True | ||||||
else: | ||||||
# No tool call, so we're done. | ||||||
return False | ||||||
|
||||||
def _add_response_to_prompt_inputs(self, prompt_inputs: dict[str, torch.Tensor], response: torch.Tensor) -> dict[str, torch.Tensor]: | ||||||
"""Add the response to the prompt inputs.""" | ||||||
if VERBOSE > 0: | ||||||
addition_str = self.processing_class.decode(response[0]) | ||||||
print(f"Adding response: {addition_str}") | ||||||
prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"], response], dim=1) | ||||||
ones = torch.ones_like(response, device=response.device) | ||||||
prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], ones], dim=1) | ||||||
return prompt_inputs | ||||||
|
||||||
|
||||||
def get_just_completion_ids(self) -> torch.Tensor: | ||||||
"""Returns the response (not including the prompt) as a tensor.""" | ||||||
# String together all the response tensors on their long dimension. | ||||||
return torch.cat(self.response, dim=1) | ||||||
|
||||||
def get_prompt_completion_ids(self) -> torch.Tensor: | ||||||
"""Returns the prompt and completion as a tensor. The full completion includes the prompt and the response.""" | ||||||
out = torch.cat([self.prompt_inputs["input_ids"], self.get_just_completion_ids()], dim=1) | ||||||
return out | ||||||
|
||||||
class QwenGRPOTrainer(Trainer): | ||||||
""" | ||||||
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the | ||||||
|
@@ -147,6 +218,10 @@ class QwenGRPOTrainer(Trainer): | |||||
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. | ||||||
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): | ||||||
PEFT configuration used to wrap the model. If `None`, the model is not wrapped. | ||||||
tool_defn ([`~trl.ToolDefinition`], *optional*, defaults to `None`): | ||||||
Tool definition used to define the tool call. | ||||||
loss_magnifier (float, *optional*, defaults to 1.0): | ||||||
Multiplies the loss on the way out to avoid underflow. | ||||||
""" | ||||||
|
||||||
_tag_names = ["trl", "grpo"] | ||||||
|
@@ -164,6 +239,8 @@ def __init__( | |||||
callbacks: Optional[list[TrainerCallback]] = None, | ||||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), | ||||||
peft_config: Optional["PeftConfig"] = None, | ||||||
tool_defn: Optional[ToolDefinition] = None, | ||||||
loss_magnifier: float = 1.0, | ||||||
): | ||||||
# Args | ||||||
if args is None: | ||||||
|
@@ -220,6 +297,8 @@ def __init__( | |||||
) | ||||||
self.reward_funcs = reward_funcs | ||||||
|
||||||
self.tool_defn = tool_defn | ||||||
|
||||||
# Reward processing class | ||||||
if reward_processing_classes is None: | ||||||
reward_processing_classes = [None] * len(reward_funcs) | ||||||
|
@@ -250,6 +329,7 @@ def data_collator(features): # No data collation is needed in GRPO | |||||
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper | ||||||
self.num_generations = args.num_generations # = G in the GRPO paper | ||||||
self.use_vllm = args.use_vllm | ||||||
self.loss_magnifier = args.loss_magnifier | ||||||
|
||||||
self.beta = args.beta | ||||||
|
||||||
|
@@ -330,12 +410,19 @@ def data_collator(features): # No data collation is needed in GRPO | |||||
# synchronize all processes after vLLM has been fully initialized. | ||||||
self.accelerator.wait_for_everyone() | ||||||
else: | ||||||
# No vLLM, so we use the regular generation config | ||||||
|
||||||
stop_strings = None | ||||||
if self.tool_defn: | ||||||
stop_strings = [self.tool_defn.stop_string] | ||||||
|
||||||
self.generation_config = GenerationConfig( | ||||||
max_new_tokens=self.max_completion_length, | ||||||
do_sample=True, | ||||||
temperature=args.temperature, | ||||||
num_return_sequences=self.num_generations, | ||||||
pad_token_id=processing_class.tokenizer.pad_token_id, | ||||||
stop_strings=stop_strings, | ||||||
) | ||||||
|
||||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the | ||||||
|
@@ -369,6 +456,59 @@ def _set_signature_columns_if_needed(self): | |||||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: | ||||||
return inputs | ||||||
|
||||||
def _generate_completion( | ||||||
self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, thanks for making this a function, clearly the right move. |
||||||
) -> torch.Tensor: | ||||||
"""Generate completion(s) using the model.""" | ||||||
temp_generation_config = copy.deepcopy(self.generation_config) | ||||||
temp_generation_config.num_return_sequences = 1 | ||||||
prompt_completion_ids = model.generate( | ||||||
**prompt_inputs, | ||||||
generation_config=temp_generation_config, | ||||||
tokenizer=self.processing_class.tokenizer, | ||||||
) | ||||||
return prompt_completion_ids | ||||||
|
||||||
def _generate_single_completion_with_tools( | ||||||
self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor], max_steps: int = 10 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same nit here - |
||||||
) -> torch.Tensor: | ||||||
"""Iterates between generation and tool calling. | ||||||
|
||||||
Note this is currently only called from the non-vLLM path | ||||||
|
||||||
prompt_inputs is a dict with the following keys: | ||||||
- input_ids: [1, 710] ints. Some stuff at the beginning and the end, the middle full of 151655 | ||||||
- attention_mask: [1, 710] ints. All 1 | ||||||
- pixel_values: 2024x1176 floats. The image. | ||||||
- image_grid_thw: a 1x3 tensor with values: [1, 46, 44]. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe add a short comment about what My understanding: The generation will stop once a tool is called, then this code processes the tool call. max_steps is the maximum number of tools we're willing to process for a single completion? |
||||||
(Note that 46*44 is 2024). | ||||||
""" | ||||||
conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class) | ||||||
# Loop until tool isn't called, of we max out | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
for step in range(max_steps): | ||||||
if VERBOSE > 1: | ||||||
print(f"\n\n\nGenerating completion with tool call. Step {step}. Shapes of inputs:") | ||||||
for key, val in prompt_inputs.items(): | ||||||
print(f"{key}: {val.shape}") | ||||||
print(f"Text of the prompt: {self.processing_class.decode(prompt_inputs['input_ids'][0])}") | ||||||
prompt_completion_ids = self._generate_completion(model, prompt_inputs) | ||||||
# prompt_completion_ids is a tensor of shape (1, L) Because we only generated one completion. | ||||||
# Note that L includes both the prompt and the response. | ||||||
# We only want to process the response, so we'll strip the prompt. | ||||||
input_length = len(prompt_inputs["input_ids"][0]) | ||||||
ids_to_process = prompt_completion_ids[:, input_length:] | ||||||
tool_was_used = conv.process_response(ids_to_process) | ||||||
if not tool_was_used: | ||||||
break | ||||||
|
||||||
full_completion = conv.get_prompt_completion_ids() | ||||||
if VERBOSE > 0: | ||||||
print(f"\nDONE!") | ||||||
print(f"Final completion (with prompt):\n{self.processing_class.decode(full_completion[0,:])}") | ||||||
print(f"^^^ Final Response!\n\n\n\n\n\n") | ||||||
return full_completion | ||||||
|
||||||
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||||||
if return_outputs: | ||||||
raise ValueError("The GRPOTrainer does not support returning outputs") | ||||||
|
@@ -419,17 +559,17 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |||||
prompt_inputs_repeated = torch.repeat_interleave(prompt_inputs["input_ids"], self.num_generations, dim=0) | ||||||
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1) | ||||||
else: | ||||||
# Regular generation path | ||||||
# Regular generation path (not using vLLM) | ||||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: | ||||||
# Generate N times, each generate one with the temp_generation_config | ||||||
num_generations = self.generation_config.num_return_sequences | ||||||
temp_generation_config = copy.deepcopy(self.generation_config) | ||||||
temp_generation_config.num_return_sequences = 1 | ||||||
|
||||||
all_completions = [] | ||||||
|
||||||
for i in range(num_generations): | ||||||
completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config) | ||||||
if self.tool_defn: | ||||||
completion = self._generate_single_completion_with_tools(unwrapped_model, prompt_inputs) | ||||||
else: | ||||||
completion = self._generate_completion(unwrapped_model, prompt_inputs) | ||||||
all_completions.append(completion) | ||||||
|
||||||
# Stack all completions and pad if needed | ||||||
|
@@ -449,6 +589,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |||||
|
||||||
# Stack all padded completions | ||||||
prompt_completion_ids = torch.cat(padded_completions, dim=0) | ||||||
if VERBOSE > 0: | ||||||
print(f"Done generating {num_generations} completions.") | ||||||
|
||||||
prompt_length = prompt_inputs["input_ids"].size(1) | ||||||
completion_ids = prompt_completion_ids[:, prompt_length:] | ||||||
|
@@ -542,6 +684,12 @@ def get_per_token_logps(model, input_ids): | |||||
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) | ||||||
per_token_loss = -(per_token_loss - self.beta * per_token_kl) | ||||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() | ||||||
# We could break this down like this: | ||||||
#loss = (per_token_loss * completion_mask).sum(dim=1) # Now we have a [B] tensor of losses for each example | ||||||
#loss = loss / completion_mask.sum(dim=1) # normalize by number of unmasked tokens. Still [B] | ||||||
#loss = loss.mean() # average across the batch. | ||||||
# Rescale to avoid underflow - we see losses underflow to 0 when they're around 1e-7, which is common | ||||||
loss = loss * self.loss_magnifier | ||||||
|
||||||
# Log the metrics | ||||||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() | ||||||
|
@@ -644,3 +792,8 @@ def create_model_card( | |||||
) | ||||||
|
||||||
model_card.save(os.path.join(self.args.output_dir, "README.md")) | ||||||
|
||||||
|
||||||
|
||||||
def simple_stats(x) -> str: | ||||||
return f"Min: {x.min()}, Mean: {x.mean()}, Max: {x.max()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming this doesn't add an extra BOS token?