diff --git a/docs/source/openenv.md b/docs/source/openenv.md index edc0f500850..146bd5db7c6 100644 --- a/docs/source/openenv.md +++ b/docs/source/openenv.md @@ -11,7 +11,7 @@ In this guide, we’ll focus on **how to integrate OpenEnv with TRL**, but feel To use OpenEnv with TRL, install the framework: ```bash -pip install openenv-core +pip install git+https://github.com/meta-pytorch/OpenEnv.git ``` ## Using `rollout_func` with OpenEnv environments @@ -65,6 +65,33 @@ By using OpenEnv in this loop, you can: * Plug in custom simulators, web APIs, or evaluators as environments. * Pass structured reward signals back into RL training seamlessly. +## Running the Environments + +You can run OpenEnv environments in three different ways: + +1. **Local Docker container** *(recommended)* + + To start a Docker container: + * Open the environment on the Hugging Face Hub. + * Click the **⋮ (three dots)** menu. + * Select **“Run locally.”** + * Copy and execute the provided command in your terminal. + + Example: + ```bash + docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest + ``` + ![open_env_launch_docker](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/open_env_launch_docker.png) +2. **Local Python process**: Launch the environment directly using Uvicorn. + You can start the server manually as a local process. For more details about the available environments, refer to the [OpenEnv repository](https://github.com/meta-pytorch/OpenEnv/tree/main/src/envs). + ```bash + python -m uvicorn envs.echo_env.server.app:app --host 0.0.0.0 --port 8001 + ``` +3. **Hugging Face Spaces**: Connect to a hosted environment running on the Hugging Face Hub. + To find the connection URL, open the Space page, click the **⋮ (three dots)** menu, and select **“Embed this Space.”** + You can then use that URL to connect directly from your client. + Keep in mind that public Spaces may have rate limits or temporarily go offline if inactive. + ## A simple example The [echo.py](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards: @@ -75,6 +102,15 @@ from trl import GRPOConfig, GRPOTrainer # Create HTTP client for Echo Environment client = EchoEnv.from_docker_image("echo-env:latest") +""" +Alternatively, you can start the environment manually with Docker and connect to it: + +# Step 1: Start the Echo environment +docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest + +# Step 2: Connect the client to the running container +client = EchoEnv(base_url="http://0.0.0.0:8001") +""" def rollout_func(prompts, args, processing_class): # 1. Generate completions via vLLM inference server (running on port 8000) @@ -151,6 +187,21 @@ CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py ``` +Alternatively, you can manually start the Echo environment in a Docker container before running the training: + +```bash +# Launch the Echo environment +docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest +``` + +Then, initialize the client using: + +`client = EchoEnv(base_url="http://0.0.0.0:8001")` + +instead of: + +`client = EchoEnv.from_docker_image("echo-env:latest")`. + Below is the reward curve from training: @@ -352,7 +403,7 @@ trainer = GRPOTrainer( trainer.train() ``` -### Running the Example +### Running the Advanced Example The example requires two GPUs: @@ -364,6 +415,17 @@ CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --p CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py ``` +Again, you can manually start the TextArena environment in a Docker container before running the training. +In this case, initialize the client with +`client = TextArenaEnv(base_url="http://0.0.0.0:8001")` +instead of +`client = TextArenaEnv.from_docker_image("registry.hf.space/burtenshaw-textarena:latest")`: + +```bash +# Launch the TextArena environment +docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest +``` + ### Results The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters. diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index 9a976542a93..f7263d9d3d1 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -12,22 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa: T201 -import os -import re -import subprocess -import sys -import time -from pathlib import Path - -import requests -from datasets import Dataset -from envs.openspiel_env import OpenSpielEnv -from envs.openspiel_env.models import OpenSpielAction - -from trl import GRPOConfig, GRPOTrainer, RichProgressCallback, apply_chat_template - - """ Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function is based on the catch game where the agent tries to catch falling balls. @@ -36,11 +20,15 @@ ```sh uv pip install git+https://github.com/meta-pytorch/OpenEnv.git -uv pip install open_spiel rich trackio ``` Usage (2 GPUs required): +# Start the docker container for the Catch environment (recommended). Alternatively, you can run it locally or directly from a HF Space. +```sh +docker run -d -p 8001:8001 registry.hf.space/openenv-openspiel-env:latest +``` + # Spin up vLLM server ```sh @@ -54,8 +42,96 @@ ``` """ -GEN_URL = "http://0.0.0.0:8000/generate/" -ENV_URL = "http://0.0.0.0:8001" +# ruff: noqa: T201 +import argparse +import os +import re +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import Dataset +from envs.openspiel_env import OpenSpielEnv +from envs.openspiel_env.models import OpenSpielAction + +from trl import GRPOConfig, GRPOTrainer, RichProgressCallback, apply_chat_template + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run GRPO training with OpenSpiel Catch environment and vLLM.") + + # --- Environment settings --- + parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.") + parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.") + parser.add_argument( + "--env-mode", + choices=["local", "docker", "space"], + default="docker", + help="Where to run the environment: 'local', 'docker', or 'space'.", + ) + # --- Generation and model config --- + parser.add_argument( + "--gen-url", + type=str, + default="http://0.0.0.0:8000/generate/", + help="vLLM generation endpoint URL.", + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-0.5B-Instruct", + help="Model name or path.", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=1000, + help="Number of prompts to use for training dataset.", + ) + + return parser.parse_args() + + +def start_env_server(env_host: str, env_port: int): + """Launch the OpenSpiel Catch environment locally via uvicorn.""" + env_url = f"http://{env_host}:{env_port}" + print(f"⚡ Starting FastAPI server for OpenSpiel Catch Environment on {env_url}...") + + work_dir = str(Path.cwd().parent.absolute()) + process = subprocess.Popen( + [ + sys.executable, + "-m", + "uvicorn", + "envs.openspiel_env.server.app:app", + "--host", + env_host, + "--port", + str(env_port), + ], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, + ) + + print("⏳ Waiting for server to start...") + time.sleep(5) + + try: + requests.get(f"{env_url}/health", timeout=2) + print("\n✅ OpenSpiel Catch Environment server is running!") + except Exception as e: + print(f"\n❌ Server failed to start: {e}") + if process.stderr: + print(process.stderr.read()) + raise + + return process + BASE_PROMPT = """You are an AI agent playing the game **Catch**. @@ -68,135 +144,64 @@ - You get **–1 reward** if you miss it. ### Observation Format +Each observation is a flattened 10x5 grid (list of 50 floats). +- 1.0 → occupied (ball or paddle) +- 0.0 → empty cell -- `observation`: a list of **50 numbers (floats)** representing the entire grid, flattened row by row. - - Each cell contains `1.0` if it is occupied (either by the ball or the paddle), or `0.0` if it is empty. - - The positions of the two `1.0` values indicate where the **ball** and **paddle** currently are. -- `legal_actions`: a list of integers representing which actions are currently allowed. - -### Actions Each action is a discrete integer: -- `0` → Move paddle **left** -- `1` → **Stay** (no movement) -- `2` → Move paddle **right** +### Actions: +- `0` → Move left +- `1` → Stay +- `2` → Move right -### Output Format Respond **only with one integer** representing your chosen action: `0`, `1`, or `2`. +Respond **only** with one integer: `0`, `1`, or `2`. ### Current Observation """ -# Start the OpenSpiel server in background -print("⚡ Starting FastAPI server for OpenSpiel Catch Environment...") - -# Determine the correct path -work_dir = str(Path.cwd().parent.absolute()) - -server_process = subprocess.Popen( - [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], - env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - cwd=work_dir, -) - -print("⏳ Waiting for server to start...") -time.sleep(5) - -# Check if server is running -try: - response = requests.get(f"{ENV_URL}/health", timeout=2) - print("\n✅ OpenSpiel Catch Environment server is running!") -except Exception as e: - print(f"\n❌ Server failed to start: {e}") - print("\n📋 Checking error output...") - server_process.poll() - if server_process.stderr: - stderr = server_process.stderr.read() - if stderr: - print(stderr) - raise - - -# Create HTTP client for OpenSpiel Catch Environment -client = OpenSpielEnv(base_url=f"{ENV_URL}") - - -def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]: - """ - Custom rollout function that generates completions via vLLM server and computes environment rewards. - - The catch game expects action IDs (integers). We'll parse the model's text output to extract action choices. - - Args: - prompts: List of prompts to generate from - args: GRPOConfig containing all sampling parameters - processing_class: Tokenizer/processor for decoding completions - - Returns: - Dict containing prompt_ids, completion_ids, logprobs, and env_reward - """ - # Run full episodes for each generation to get episode rewards + +def rollout_func( + prompts: list[str], args: GRPOConfig, processing_class, client: OpenSpielEnv, gen_url: str +) -> dict[str, list]: + """Generate completions via vLLM and compute environment rewards.""" env_rewards = [] - all_prompt_ids = [] - all_completion_ids = [] - all_logprobs = [] + all_prompt_ids, all_completion_ids, all_logprobs = [], [], [] for base_prompt in prompts: for _ in range(args.num_generations): - # Run episode: Reset environment and loop until done env_result = client.reset() obs = env_result.observation total_reward = 0.0 - episode_prompt_ids = [] - episode_completion_ids = [] - episode_logprobs = [] + episode_prompt_ids, episode_completion_ids, episode_logprobs = [], [], [] - # TODO: parallelise! while not obs.done: - # FIXME: handle the addition of observation to prompt more cleanly, ideally without a train_dataset episode_msg = {"prompt": [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}]} episode_prompt = apply_chat_template(episode_msg, processing_class) - # Generate action from model - gen_payload = { + payload = { "prompts": [episode_prompt["prompt"]], "n": 1, "temperature": args.temperature, "top_p": args.top_p, - "top_k": -1 if args.top_k is None else args.top_k, - "min_p": 0.0 if args.min_p is None else args.min_p, "max_tokens": args.max_completion_length, - "repetition_penalty": args.repetition_penalty, } - gen_response = requests.post(GEN_URL, json=gen_payload) - gen_response.raise_for_status() - gen_result = gen_response.json() + response = requests.post(gen_url, json=payload) + response.raise_for_status() + result = response.json() - # Collect prompt_ids, completion_ids, and logprobs from this step - episode_prompt_ids.extend(gen_result["prompt_ids"][0]) - episode_completion_ids.extend(gen_result["completion_ids"][0]) - episode_logprobs.extend(gen_result["logprobs"][0]) + episode_prompt_ids.extend(result["prompt_ids"][0]) + episode_completion_ids.extend(result["completion_ids"][0]) + episode_logprobs.extend(result["logprobs"][0]) - completion_text = processing_class.batch_decode( - gen_result["completion_ids"], skip_special_tokens=True - )[0] + completion_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True)[0] - # Parse action from completion - action_id = 0 # default numbers = re.findall(r"\b([0-2])\b", completion_text) - if numbers: - action_id = int(numbers[0]) - elif obs.legal_actions: - action_id = obs.legal_actions[0] + action_id = int(numbers[0]) if numbers else obs.legal_actions[0] - # Take action in environment env_result = client.step(OpenSpielAction(action_id=action_id, game_name="catch")) - reward = env_result.reward if env_result.reward is not None else 0.0 - total_reward += reward + total_reward += env_result.reward or 0.0 obs = env_result.observation - # Store episode results env_rewards.append(total_reward) all_prompt_ids.append(episode_prompt_ids) all_completion_ids.append(episode_completion_ids) @@ -210,42 +215,60 @@ def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict } -dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * 1000}) - - def reward_from_env(completions, **kwargs): - """Reward function that uses the environment reward from the catch game.""" - # Extract environment rewards from kwargs (propagated via extra_fields) - env_rewards = kwargs.get("env_reward", []) - if env_rewards: - return [float(reward) for reward in env_rewards] + rewards = kwargs.get("env_reward", []) + return [float(r) for r in rewards] if rewards else [0.0] * len(completions) + + +def main(): + args = parse_args() + + # Select environment mode + if args.env_mode == "local": + env_url = f"http://{args.env_host}:{args.env_port}" + server_process = start_env_server(args.env_host, args.env_port) + elif args.env_mode == "docker": + env_url = f"http://{args.env_host}:{args.env_port}" + server_process = None + print(f"🌍 Using existing Docker environment at {env_url}") + elif args.env_mode == "space": + env_url = args.env_host + server_process = None + print(f"🚀 Using Hugging Face Space environment at {env_url}") else: - # Fallback if env_reward is not available - return [0.0] * len(completions) - - -training_args = GRPOConfig( - output_dir="Qwen2.5-0.5B-GRPO-Catch", - vllm_mode="server", - use_vllm=True, - logging_steps=1, - report_to="trackio", - num_train_epochs=1, - max_completion_length=4, - gradient_accumulation_steps=4, -) -trainer = GRPOTrainer( - model="Qwen/Qwen2.5-0.5B-Instruct", - reward_funcs=reward_from_env, - args=training_args, - train_dataset=dataset, - rollout_func=rollout_func, - callbacks=[RichProgressCallback()], -) -trainer.train() - -# Give time for background threads to finish -time.sleep(5) - -print("🛑 Terminating environment server...") -server_process.terminate() + raise ValueError(f"Unknown env mode: {args.env_mode}") + + gen_url = args.gen_url + client = OpenSpielEnv(base_url=env_url) + dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * args.dataset_size}) + + training_args = GRPOConfig( + output_dir=f"{args.model.split('/')[-1]}-GRPO-Catch", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to="trackio", + num_train_epochs=1, + max_completion_length=4, + gradient_accumulation_steps=4, + ) + + trainer = GRPOTrainer( + model=args.model, + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=lambda p, a, pc: rollout_func(p, a, pc, client, gen_url), + callbacks=[RichProgressCallback()], + ) + + trainer.train() + time.sleep(5) + + if server_process: + print("🛑 Terminating environment server...") + server_process.terminate() + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py index f2d05aa015b..b5cdb724bd7 100644 --- a/examples/scripts/openenv/echo.py +++ b/examples/scripts/openenv/echo.py @@ -12,21 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# ruff: noqa: T201 -import os -import subprocess -import sys -import time -from pathlib import Path - -import requests -from datasets import load_dataset -from envs.echo_env import EchoEnv -from envs.echo_env.models import EchoAction - -from trl import GRPOConfig, GRPOTrainer, RichProgressCallback - - """ Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages longer completions. @@ -39,6 +24,11 @@ Usage (2 GPUs required): +# Start the docker container for the Echo environment (recommended). Alternatively, you can run it locally or directly from a HF Space. +```sh +docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest +``` + # Spin up server ```sh @@ -52,55 +42,89 @@ ``` """ -GEN_URL = "http://0.0.0.0:8000/generate/" -ENV_URL = "http://0.0.0.0:8001" - -print("⚡ Starting FastAPI server for Echo Environment...") -# Workaround if you can't run the env with Docker -work_dir = str(Path.cwd().parent.absolute()) -server_process = subprocess.Popen( - [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], - env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - cwd=work_dir, -) - -print("⏳ Waiting for server to start...") -time.sleep(5) - -try: - response = requests.get(f"{ENV_URL}/health", timeout=2) - print("\n✅ Echo Environment server is running!") -except Exception as e: - print(f"\n❌ Server failed to start: {e}") - print("\n📋 Checking error output...") - server_process.poll() - if server_process.stderr: - stderr = server_process.stderr.read() - if stderr: - print(stderr) - raise - - -# Create HTTP client for Echo Environment -client = EchoEnv(base_url=f"{ENV_URL}") - - -def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]: - """ - Custom rollout function that generates completions via vLLM server and computes environment rewards. - - Args: - prompts: List of prompts to generate from - args: GRPOConfig containing all sampling parameters - processing_class: Tokenizer/processor for decoding completions - - Returns: - Dict containing prompt_ids, completion_ids, logprobs, and env_reward - """ - # 1. Generate completions via vLLM inference server (running on port 8000) +# ruff: noqa: T201 +import argparse +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import load_dataset +from envs.echo_env import EchoEnv +from envs.echo_env.models import EchoAction + +from trl import GRPOConfig, GRPOTrainer, RichProgressCallback + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run GRPO training with Echo environment and vLLM.") + + parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the Echo environment.") + parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.") + parser.add_argument( + "--env-mode", + choices=["local", "docker", "space"], + default="docker", + help="Where to run the Echo environment: 'local' to launch it, 'docker' if already running, or 'space' to use a remote Space URL.", + ) + parser.add_argument( + "--gen-url", + type=str, + default="http://0.0.0.0:8000/generate/", + help="Base URL for the vLLM generation endpoint.", + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-0.5B-Instruct", + help="Model to use for training.", + ) + parser.add_argument( + "--dataset", + type=str, + default="trl-lib/ultrafeedback-prompt", + help="Dataset to use for training.", + ) + + return parser.parse_args() + + +def start_env_server(env_host: str, env_port: int): + """Launch the Echo environment server locally.""" + env_url = f"http://{env_host}:{env_port}" + print(f"⚡ Starting FastAPI server for Echo Environment on {env_url}...") + + work_dir = str(Path.cwd().parent.absolute()) + process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", env_host, "--port", str(env_port)], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, + ) + + print("⏳ Waiting for server to start...") + time.sleep(5) + + try: + requests.get(f"{env_url}/health", timeout=2) + print("\n✅ Echo Environment server is running!") + except Exception as e: + print(f"\n❌ Server failed to start: {e}") + if process.stderr: + print(process.stderr.read()) + raise + + return process + + +def rollout_func( + prompts: list[str], args: GRPOConfig, processing_class, client: EchoEnv, gen_url: str +) -> dict[str, list]: + """Generate completions via vLLM and compute environment rewards.""" payload = { "prompts": prompts, "n": args.num_generations, @@ -111,64 +135,80 @@ def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict "max_tokens": args.max_completion_length, "repetition_penalty": args.repetition_penalty, } - response = requests.post(GEN_URL, json=payload) + response = requests.post(gen_url, json=payload) if response.status_code != 200: print(f"Error response: {response.text}") - response.raise_for_status() - result = response.json() + result = response.json() completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) - # 2. Step through the environment to get rewards env_result = client.reset() env_rewards = [] for msg in completions_text: env_result = client.step(EchoAction(message=msg)) env_rewards.append(env_result.reward) - # 3. Add environment rewards as extra field result["env_reward"] = env_rewards - return result def reward_from_env(completions, **kwargs): - """Reward function that uses the environment reward.""" - # Extract environment rewards from kwargs (propagated via extra_fields) + """Extract environment rewards for training.""" env_rewards = kwargs.get("env_reward", []) - if env_rewards: - return [float(reward) for reward in env_rewards] + return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions) + + +def main(): + args = parse_args() + + # Select environment mode + if args.env_mode == "local": + env_url = f"http://{args.env_host}:{args.env_port}" + server_process = start_env_server(args.env_host, args.env_port) + elif args.env_mode == "docker": + env_url = f"http://{args.env_host}:{args.env_port}" + server_process = None + print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}") + elif args.env_mode == "space": + env_url = args.env_host + server_process = None + print(f"🚀 Using Hugging Face Space environment at: {env_url}") else: - # Fallback if env_reward is not available - return [0.0] * len(completions) - - -dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") - -training_args = GRPOConfig( - output_dir="Qwen2.5-0.5B-GRPO-Rollout", - vllm_mode="server", - use_vllm=True, - logging_steps=1, - report_to="trackio", - num_train_epochs=1, - max_completion_length=2048, - gradient_accumulation_steps=4, -) -trainer = GRPOTrainer( - model="Qwen/Qwen2.5-0.5B-Instruct", - reward_funcs=reward_from_env, - args=training_args, - train_dataset=dataset, - rollout_func=rollout_func, - callbacks=[RichProgressCallback()], -) -trainer.train() - -# Give time for background threads to finish -time.sleep(5) - -print("🛑 Terminating Echo Environment server...") -server_process.terminate() + raise ValueError(f"Unknown environment mode: {args.env_mode}") + + gen_url = args.gen_url + client = EchoEnv(base_url=env_url) + dataset = load_dataset(args.dataset, split="train[:1000]") + + training_args = GRPOConfig( + output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to="trackio", + num_train_epochs=1, + max_completion_length=2048, + gradient_accumulation_steps=4, + ) + + trainer = GRPOTrainer( + model=args.model, + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=lambda p, a, pc: rollout_func(p, a, pc, client, gen_url), + callbacks=[RichProgressCallback()], + ) + + trainer.train() + time.sleep(5) + + if server_process: + print("🛑 Terminating Echo Environment server...") + server_process.terminate() + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index bf49795cd22..2683bcdacfa 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -13,18 +13,33 @@ # limitations under the License. """ -GRPO training for Wordle using TRL's `GRPOTrainer` and the TextArena OpenEnv environment. +Simple script to run GRPO training with OpenEnv's Wordle environment and a vLLM server. -Usage: - # First, start the TextArena Wordle server (Docker or local): - TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 \ - python -m src.envs.textarena_env.server.app +Setup: - # Start the vLLM server with your model - CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000 +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +``` - # Then run this training script: - CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py +Usage (2 GPUs required): + +# Start the docker container for the Wordle environment (recommended). Alternatively, you can run it locally or directly from a HF Space. +```sh +docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest +# or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app +``` + +# Spin up vLLM server + +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000 +``` + +# Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py +``` """ from __future__ import annotations @@ -70,8 +85,8 @@ def parse_args() -> argparse.Namespace: help="Model identifier passed to GRPOTrainer for fine-tuning.", ) parser.add_argument( - "--textarena-url", - default="https://burtenshaw-textarena.hf.space", + "--env-url", + default="https://0.0.0.0:8001", # default="https://burtenshaw-textarena.hf.space" help="Base URL for the TextArena Wordle environment.", ) parser.add_argument( @@ -505,7 +520,7 @@ def main() -> None: tokenizer = AutoTokenizer.from_pretrained(cli_args.tokenizer_id) tokenizer.pad_token = tokenizer.eos_token - env = TextArenaEnv(base_url=cli_args.textarena_url) + env = TextArenaEnv(base_url=cli_args.env_url) system_prompt = resolve_system_prompt(cli_args.system_prompt_path)