diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py new file mode 100644 index 00000000000..9405f6d0b2a --- /dev/null +++ b/examples/scripts/openenv/catch.py @@ -0,0 +1,255 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import Dataset +from envs.openspiel_env import OpenSpielEnv +from envs.openspiel_env.models import OpenSpielAction + +from trl import GRPOConfig, GRPOTrainer, apply_chat_template + + +""" +Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function +is based on the catch game where the agent tries to catch falling balls. + +Setup: + +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +uv pip install open_spiel +``` + +Usage (2 GPUs required): + +# Spin up vLLM server + +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/catch.py +``` +""" + +GEN_URL = "http://0.0.0.0:8000/generate/" +ENV_URL = "http://0.0.0.0:8001" + +BASE_PROMPT = """You are an AI agent playing the game **Catch**. + +### Game Description +- The game is played on a **5×5 grid**. +- There is one **falling ball** and one **paddle** that you control at the bottom. +- The objective is to **move the paddle left or right to catch the ball** as it falls. +- The episode ends when the ball reaches the bottom row: + - You get **+1 reward** if you catch it. + - You get **–1 reward** if you miss it. + +### Observation Format You will receive: +- `observation`: a list of **50 numbers (floats)**. + - The first **25 numbers** (indices `0–24`) represent the **ball layer**, flattened from a 5×5 grid. Each cell is + `1.0` if the ball is there, `0.0` otherwise. + - The next **25 numbers** (indices `25–49`) represent the **paddle layer**, also flattened from a 5×5 grid. Each cell + is `1.0` if the paddle occupies that column in the bottom row, `0.0` otherwise. +- `legal_actions`: a list of integers representing which actions are currently allowed. + +### Actions Each action is a discrete integer: +- `0` → Move paddle **left** +- `1` → **Stay** (no movement) +- `2` → Move paddle **right** + +### Output Format Respond **only with one integer** representing your chosen action: `0`, `1`, or `2`. + +### Current Observation +""" + +# Start the OpenSpiel server in background +print("⚡ Starting FastAPI server for OpenSpiel Catch Environment...") + +# Determine the correct path +work_dir = str(Path.cwd().parent.absolute()) + +server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, +) + +print("⏳ Waiting for server to start...") +time.sleep(5) + +# Check if server is running +try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("\n✅ OpenSpiel Catch Environment server is running!") +except Exception as e: + print(f"\n❌ Server failed to start: {e}") + print("\n📋 Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +# Create HTTP client for OpenSpiel Catch Environment +client = OpenSpielEnv(base_url=f"{ENV_URL}") + + +def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: + """ + Custom rollout function that generates completions via vLLM server and computes environment rewards. + + The catch game expects action IDs (integers). We'll parse the model's text output to extract action choices. + + Args: + prompts: List of prompt strings to generate from + images: Optional images for vision models (not used in this example) + args: GRPOConfig containing all sampling parameters + processing_class: Tokenizer/processor for decoding completions + + Returns: + Dict containing prompt_ids, completion_ids, logprobs, and env_reward + """ + import re + + # Run full episodes for each generation to get episode rewards + env_rewards = [] + all_prompt_ids = [] + all_completion_ids = [] + all_logprobs = [] + + for base_prompt in prompts: + for _ in range(args.num_generations): + # Run episode: Reset environment and loop until done + env_result = client.reset() + obs = env_result.observation + total_reward = 0.0 + + episode_prompt_ids = [] + episode_completion_ids = [] + episode_logprobs = [] + + # TODO: parallelise! + while not obs.done: + # FIXME: handle the addition of observation to prompt more cleanly, ideally without a train_dataset + episode_msg = {"prompt": [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}]} + episode_prompt = apply_chat_template(episode_msg, processing_class) + + # Generate action from model + gen_payload = { + "prompts": [episode_prompt["prompt"]], + "n": 1, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": -1 if args.top_k is None else args.top_k, + "min_p": 0.0 if args.min_p is None else args.min_p, + "max_tokens": args.max_completion_length, + "repetition_penalty": args.repetition_penalty, + } + gen_response = requests.post(GEN_URL, json=gen_payload) + gen_response.raise_for_status() + gen_result = gen_response.json() + + # Collect prompt_ids, completion_ids, and logprobs from this step + episode_prompt_ids.extend(gen_result["prompt_ids"][0]) + episode_completion_ids.extend(gen_result["completion_ids"][0]) + episode_logprobs.extend(gen_result["logprobs"][0]) + + completion_text = processing_class.batch_decode( + gen_result["completion_ids"], skip_special_tokens=True + )[0] + + # Parse action from completion + action_id = 0 # default + numbers = re.findall(r"\b([0-2])\b", completion_text) + if numbers: + action_id = int(numbers[0]) + elif obs.legal_actions: + action_id = obs.legal_actions[0] + + # Take action in environment + env_result = client.step(OpenSpielAction(action_id=action_id, game_name="catch")) + reward = env_result.reward if env_result.reward is not None else 0.0 + total_reward += reward + obs = env_result.observation + + # Store episode results + env_rewards.append(total_reward) + all_prompt_ids.append(episode_prompt_ids) + all_completion_ids.append(episode_completion_ids) + all_logprobs.append(episode_logprobs) + + return { + "prompt_ids": all_prompt_ids, + "completion_ids": all_completion_ids, + "logprobs": all_logprobs, + "env_reward": env_rewards, + } + + +dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * 1000}) + + +def reward_from_env(completions, **kwargs): + """Reward function that uses the environment reward from the catch game.""" + # Extract environment rewards from kwargs (propagated via extra_fields) + env_rewards = kwargs.get("env_reward", []) + if env_rewards: + return [float(reward) for reward in env_rewards] + else: + # Fallback if env_reward is not available + return [0.0] * len(completions) + + +training_args = GRPOConfig( + output_dir="scratch/Qwen2.5-0.5B-GRPO-Catch", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to=["trackio", "wandb"], + num_train_epochs=1, + num_generations=8, + max_completion_length=4, + per_device_train_batch_size=8, + gradient_accumulation_steps=4, +) +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, +) +trainer.train() + +# Give time for background threads to finish +time.sleep(5) + +print("🛑 Terminating environment server...") +server_process.terminate() diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py new file mode 100644 index 00000000000..e43ab960276 --- /dev/null +++ b/examples/scripts/openenv/echo.py @@ -0,0 +1,184 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import load_dataset, Dataset +from envs.echo_env import EchoEnv +from envs.echo_env.models import EchoAction + +from trl import GRPOConfig, GRPOTrainer + + +""" +Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages +longer completions. + +Setup: + +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +``` + +Usage (2 GPUs required): + +# Spin up server + +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py +``` +""" + +GEN_URL = "http://0.0.0.0:8000/generate/" +ENV_URL = "http://0.0.0.0:8001" + +print("⚡ Starting FastAPI server for Echo Environment...") + + +# Workaround if you can't run the env with Docker +work_dir = str(Path.cwd().parent.absolute()) +server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, +) + +print("⏳ Waiting for server to start...") +time.sleep(5) + +try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("\n✅ Echo Environment server is running!") +except Exception as e: + print(f"\n❌ Server failed to start: {e}") + print("\n📋 Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +# Create HTTP client for Echo Environment +client = EchoEnv(base_url=f"{ENV_URL}") +client.action_class = EchoAction + +prompt="""You are connected to an *Echo Environment* that mirrors your input. +To communicate, send a JSON-formatted message wrapped in `` tags. + +Example: +`{"message": "Hi there!"}` + +Try sending your first message to begin the interaction. The next user message will be the echoed back to you.""" + + +dataset = Dataset.from_list([{"prompt": [{"role": "user", "content": prompt}]}] * 1000) + +# def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: +# """ +# Custom rollout function that generates completions via vLLM server and computes environment rewards. + +# Args: +# prompts: List of prompt strings to generate from +# images: Optional images for vision models (not used in this example) +# args: GRPOConfig containing all sampling parameters +# processing_class: Tokenizer/processor for decoding completions + +# Returns: +# Dict containing prompt_ids, completion_ids, logprobs, and env_reward +# """ +# # Make request to TRL's custom /generate/ endpoint +# payload = { +# "prompts": prompts, +# "n": args.num_generations, +# "temperature": args.temperature, +# "top_p": args.top_p, +# "top_k": -1 if args.top_k is None else args.top_k, +# "min_p": 0.0 if args.min_p is None else args.min_p, +# "max_tokens": args.max_completion_length, +# "repetition_penalty": args.repetition_penalty, +# } +# response = requests.post(GEN_URL, json=payload) + +# if response.status_code != 200: +# print(f"Error response: {response.text}") + +# response.raise_for_status() +# result = response.json() + +# completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) + +# # Flush env +# env_result = client.reset() + +# env_rewards = [] +# for msg in completions_text: +# env_result = client.step(EchoAction(message=msg)) +# env_rewards.append(env_result.reward) + +# result["env_reward"] = env_rewards + +# return result + + +# dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") + + +def reward_zero(completions, **kwargs): + return [0.0] * len(completions) + + +training_args = GRPOConfig( + output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout", + # vllm_mode="server", + # use_vllm=True, + logging_steps=1, + # report_to=["trackio", "wandb"], + # num_train_epochs=1, + # num_generations=8, + # max_completion_length=2048, + # per_device_train_batch_size=8, + # gradient_accumulation_steps=4, +) +trainer = GRPOTrainer( + model="Qwen/Qwen3-0.6B", + # reward_funcs=reward_zero, + args=training_args, + train_dataset=dataset, + env=client, + # rollout_func=rollout_func, +) +trainer.train() + +# # Give time for background threads to finish +# time.sleep(5) + +# print("🛑 Terminating Echo Environment server...") +# server_process.terminate() diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3150f947f3a..1a9d5254fc9 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -99,6 +99,32 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +# What we call a rollout function is a callable that takes prompts (list), images (optional), args (GRPOConfig), +# and processing_class as parameters and returns a dict of generation results. Those results must include "prompt_ids", +# "completion_ids", and "logprobs" fields. Any extra fields (per-completion) are forwarded to the reward functions. +RolloutFunc = Callable[[list[str], Any, Any, Any], dict[str, Any]] + +import re +import json +import re +import json +from typing import Any, Optional + +def extract_action(text: str) -> Optional[dict[str, Any]]: + """ + Extract and return the first valid JSON object found inside ... tags. + Returns None if no valid action block is found. + """ + pattern = r"\s*(\{.*?\})\s*" + match = re.search(pattern, text, re.DOTALL) + if not match: + return None + + try: + action = json.loads(match.group(1)) + return action + except json.JSONDecodeError: + return None class GRPOTrainer(BaseTrainer): """ @@ -200,6 +226,10 @@ def reward_func(completions, **kwargs): model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + rollout_func (`RolloutFunc`, *optional*, defaults to `None`): + Function to use for generating completions. It must take prompts, images (optional), args, and + processing_class as parameters and return a dict with "prompt_ids", "completion_ids", and "logprobs" + fields. Any other fields that are forwarded to the reward functions. """ _tag_names = ["trl", "grpo"] @@ -221,7 +251,7 @@ def reward_func(completions, **kwargs): def __init__( self, model: Union[str, PreTrainedModel], - reward_funcs: Union[RewardFunc, list[RewardFunc]], + reward_funcs: Union[RewardFunc, list[RewardFunc]]=None, args: Optional[GRPOConfig] = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, @@ -230,7 +260,10 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + rollout_func: Optional[RolloutFunc] = None, + env=None, ): + self.env = env # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -296,6 +329,8 @@ def __init__( self.eos_token_id = tokenizer.eos_token_id # Reward functions + if reward_funcs is None: + reward_funcs = [None] # for the env if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] self.reward_func_names = [] @@ -306,6 +341,8 @@ def __init__( ) if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1]) + elif reward_funcs[i] is None: + self.reward_func_names.append("env_reward") else: self.reward_func_names.append(reward_funcs[i].__name__) self.reward_funcs = reward_funcs @@ -345,6 +382,9 @@ def __init__( self.reward_processing_classes = reward_processing_classes + # Rollout function + self.rollout_func = rollout_func + # Training arguments self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper @@ -1114,18 +1154,29 @@ def _generate_single_turn(self, prompts: list): "generation_kwargs": self.args.generation_kwargs, } with profiling_context(self, "vLLM.generate"): - if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + if self.rollout_func is not None: + output = self.rollout_func( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + args=self.args, + processing_class=self.processing_class, + ) else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + # Extract required fields and collect any extra fields for reward functions + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) else: payload = None # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] + all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] @@ -1138,6 +1189,14 @@ def _generate_single_turn(self, prompts: list): completion_ids = all_completion_ids[process_slice] logprobs = all_logprobs[process_slice] + # Slice extra fields dict-of-lists per process (extra fields are per-completion, like completion_ids) + extra_fields = {} + for key, values in all_extra_fields.items(): + if isinstance(values, list): + extra_fields[key] = values[process_slice] + else: + extra_fields[key] = values + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": if self.guided_decoding_regex: @@ -1198,6 +1257,8 @@ def _generate_single_turn(self, prompts: list): completion_ids = all_completion_ids logprobs = all_logprobs + extra_fields = {} # No extra fields for colocate mode + if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1232,6 +1293,7 @@ def _generate_single_turn(self, prompts: list): completion_ids = [output.generated_tokens for output in all_outputs.values()] prompt_ids = generate_inputs["inputs"] logprobs = None # not used in this case + extra_fields = {} # No extra fields for paged mode else: # Regular generation path @@ -1276,14 +1338,15 @@ def _generate_single_turn(self, prompts: list): prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] logprobs = None # not used in this case + extra_fields = {} # No extra fields for non-rollout_func paths - return prompt_ids, completion_ids, logprobs + return prompt_ids, completion_ids, logprobs, extra_fields def _generate(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts) + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1315,7 +1378,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs + return prompt_ids, completion_ids, total_completion_tokens, logprobs, extra_fields def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1341,10 +1404,48 @@ def _generate_and_score_completions( if images is not None: prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] - prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = self._generate( prompts ) + completion_contents = self.processing_class.batch_decode(completion_ids_list, skip_special_tokens=True) + + actions = [extract_action(content) for content in completion_contents] + idxs_with_action = [i for i, a in enumerate(actions) if a] # find indices that actually have a tool call + actions = [actions[i] for i in idxs_with_action] + + while idxs_with_action: + prompts_for_generation = [prompts[i] for i in idxs_with_action] + for idx, action, prompt_for_generation in zip(idxs_with_action, actions, prompts_for_generation): + prompt_for_generation.append({"role": "assistant", "content": completion_contents[idx]}) + a = self.env.action_class(**action) + output = self.env.step(a) + observation_message = {"role": "user", "content": str(output.observation)} + prompt_for_generation.append(observation_message) + + prompt_completion_action_ids, post_tool_ids, _, _ = self._generate_single_turn(prompts_for_generation) + + # Truncate post-tool completion so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length + for i in range(len(post_tool_ids)): + excess_length = ( + len(prompt_completion_action_ids[i]) + + len(post_tool_ids[i]) + - (self.max_prompt_length + self.max_completion_length) + ) + if excess_length > 0: + post_tool_ids[i] = post_tool_ids[i][:-excess_length] + + for idx, pct, post_tool in zip(idxs_with_action, prompt_completion_action_ids, post_tool_ids): + completion_ids_list[idx] = pct[len(prompt_ids_list[idx]) :] + post_tool + + cc = self.processing_class.batch_decode(post_tool_ids, skip_special_tokens=True) + actions = [extract_action(content) for content in cc] + completion_contents = [None] * len(completion_contents) + for i, content in zip(idxs_with_action, cc): + completion_contents[i] = content + idxs_with_action = [idx for idx, tc in zip(idxs_with_action, actions) if tc] + actions = [tc for tc in actions if tc] + # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] @@ -1461,6 +1562,15 @@ def _generate_and_score_completions( else: completions = completions_text + # Merge extra_fields from rollout_func into inputs for reward functions + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + inp[key] = values + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset.