-
Notifications
You must be signed in to change notification settings - Fork 0
Rough draft of tool support #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: improve_performance
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! left some questions.
trl/trainer/qwen_grpo_trainer.py
Outdated
# Check if the stop string is in the completions | ||
# We need to convert the tensor to a string. | ||
if self.tool_defn.completion_has_tool_call(prompt_completion_str): | ||
tool_response_str = self.tool_defn.call_tool(prompt_completion_str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why doesn't the dataclass have call_tool?
# We need to convert the tensor to a string. | ||
if self.tool_defn.completion_has_tool_call(prompt_completion_str): | ||
tool_response_str = self.tool_defn.call_tool(prompt_completion_str) | ||
tool_response_ids_list = self.processing_class.tokenizer.encode(tool_response_str, add_special_tokens=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming this doesn't add an extra BOS token?
return inputs | ||
|
||
def _generate_completion( | ||
self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think prompt_inputs
might be a BatchFeature
: https://huggingface.co/docs/transformers/en/main_classes/feature_extractor#transformers.BatchFeature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, thanks for making this a function, clearly the right move.
return prompt_completion_ids | ||
|
||
def _generate_single_completion_with_tools( | ||
self, model: PreTrainedModel, prompt_inputs: dict[str, torch.Tensor], max_steps: int = 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same nit here - BatchFeature
(Note that 46*44 is 2024). | ||
""" | ||
conv = SingleConversationWithTools(prompt_inputs, self.tool_defn, self.processing_class) | ||
# Loop until tool isn't called, of we max out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Loop until tool isn't called, of we max out | |
# Loop until tool isn't called, or we max out |
- input_ids: [1, 710] ints. Some stuff at the beginning and the end, the middle full of 151655 | ||
- attention_mask: [1, 710] ints. All 1 | ||
- pixel_values: 2024x1176 floats. The image. | ||
- image_grid_thw: a 1x3 tensor with values: [1, 46, 44]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe add a short comment about what max _steps
is.
My understanding: The generation will stop once a tool is called, then this code processes the tool call. max_steps is the maximum number of tools we're willing to process for a single completion?
No description provided.