Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mason.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def get_env_vars(
"WANDB_API_KEY",
"BEAKER_TOKEN",
"OPENAI_API_KEY",
# Needed for tool use scripts.
"OPEN_INSTRUCT_TOOL_API_KEY",
# litellm expects these env vars
"AZURE_API_KEY",
"AZURE_API_BASE",
Expand Down
17 changes: 11 additions & 6 deletions open_instruct/tool_utils/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ FROM python:3.10-slim
# Set working directory in container
WORKDIR /app

# Install uv
COPY --from=ghcr.io/astral-sh/uv:0.8.8 /uv /bin/uv

# Copy requirements first to leverage Docker cache
COPY requirements.txt requirements.txt
COPY open_instruct/tool_utils/requirements.txt requirements.txt

# Install dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Install dependencies using uv
RUN uv pip install --system --no-cache -r requirements.txt

# Copy the rest of the application
COPY . .
# Copy the tool server files
COPY open_instruct/__init__.py open_instruct/__init__.py
COPY open_instruct/logger_utils.py open_instruct/logger_utils.py
COPY open_instruct/tool_utils/tool_server.py tool_server.py

# Create cache directory for code execution
RUN mkdir -p cache && chmod 777 cache
Expand All @@ -23,4 +28,4 @@ ENV PYTHONUNBUFFERED=1
EXPOSE 8080

# Command to run the application
CMD ["python", "tool_server.py"]
CMD ["uv", "run", "--no-project", "tool_server.py"]
10 changes: 7 additions & 3 deletions open_instruct/tool_utils/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,19 @@ class TestPythonCodeTool(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""Start the tool server for tests."""
# Start the server in a subprocess
import os

env = os.environ.copy()
env.pop("OPEN_INSTRUCT_TOOL_API_KEY", None)

cls.server_process = subprocess.Popen(
["uv", "run", "uvicorn", "tool_server:app", "--host", "0.0.0.0", "--port", "1212"],
cwd="open_instruct/tool_utils",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True, # Create new process group
start_new_session=True,
env=env,
)
# Wait for server to start
time.sleep(3)
cls.api_endpoint = "http://localhost:1212/execute"

Expand Down
65 changes: 59 additions & 6 deletions open_instruct/tool_utils/tool_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,19 @@

This script sets up a FastAPI server that allows users to execute Python code snippets

# API Key Authentication

The server requires an API key for authentication. Set the OPEN_INSTRUCT_TOOL_API_KEY environment variable:

```bash
export OPEN_INSTRUCT_TOOL_API_KEY="your-api-key-here"
```

When running locally:
```bash
cd open_instruct/tool_utils
PREIMPORT_PKGS=pandas,numpy,sympy,time,math,networkx uv run uvicorn tool_server:app --host 0.0.0.0 --port 1212
OPEN_INSTRUCT_TOOL_API_KEY="your-api-key-here" PREIMPORT_PKGS=pandas,numpy,sympy,time,math,networkx uv run uvicorn tool_server:app --host 0.0.0.0 --port 1212
```

```bash
docker build -t tool-server .
Expand All @@ -16,8 +27,8 @@
docker build -t ghcr.io/allenai/open-instruct/python-code-executor -f open_instruct/tool_utils/Dockerfile .
docker push ghcr.io/allenai/open-instruct/python-code-executor

# Run the server
docker run -p 1212:8080 tool-server
# Run the server (pass API key via environment variable)
docker run -p 1212:8080 -e OPEN_INSTRUCT_TOOL_API_KEY="your-api-key-here" tool-server

# gcloud run deploy:
gcloud run deploy open-instruct-tool-server --project ai2-allennlp --region us-central1 --source .
Expand All @@ -39,25 +50,31 @@
1) the timeout works
2) the timeout in the first curl does not block the second curl

All requests now require the X-API-Key header:

```
curl -X POST https://open-instruct-tool-server-10554368204.us-central1.run.app/execute \
-H "Content-Type: application/json" \
-H "X-API-Key: $OPEN_INSTRUCT_TOOL_API_KEY" \
-d '{"code": "import time;time.sleep(4)", "timeout": 3}' \
-w '\nTotal time: %{time_total}s\n'


curl -X POST https://open-instruct-tool-server-10554368204.us-central1.run.app/execute \
-H "Content-Type: application/json" \
-H "X-API-Key: $OPEN_INSTRUCT_TOOL_API_KEY" \
-d '{"code": "print(1)", "timeout": 3}' \
-w '\nTotal time: %{time_total}s\n'

curl -X POST https://open-instruct-tool-server-10554368204.us-central1.run.app/execute \
-H "Content-Type: application/json" \
-H "X-API-Key: $OPEN_INSTRUCT_TOOL_API_KEY" \
-d '{"code": "import sympy", "timeout": 3}' \
-w '\nTotal time: %{time_total}s\n'

curl -X POST https://open-instruct-tool-server-10554368204.us-central1.run.app/execute \
-H "Content-Type: application/json" \
-H "X-API-Key: $OPEN_INSTRUCT_TOOL_API_KEY" \
-d '{"code": "import sympy", "timeout": 3}' \
-w '\nTotal time: %{time_total}s\n'
```
Expand All @@ -80,7 +97,7 @@
from contextlib import redirect_stderr, redirect_stdout
from typing import Optional

from fastapi import FastAPI
from fastapi import Depends, FastAPI, Header, HTTPException
from pydantic import BaseModel

from open_instruct import logger_utils
Expand Down Expand Up @@ -202,11 +219,30 @@ class CodeResponse(BaseModel):
success: bool


###############################################################################
# API Key Authentication
###############################################################################
EXPECTED_API_KEY = os.getenv("OPEN_INSTRUCT_TOOL_API_KEY")


async def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")):
if not EXPECTED_API_KEY:
logger.warning("OPEN_INSTRUCT_TOOL_API_KEY not set - API key validation disabled")
return
if not x_api_key:
logger.warning("Missing API key in request")
raise HTTPException(status_code=401, detail="Missing API key")
if x_api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key attempt")
raise HTTPException(status_code=401, detail="Invalid API key")
return x_api_key


###############################################################################
# Endpoints
###############################################################################
@app.post("/execute", response_model=CodeResponse)
async def execute_code(req: CodeRequest): # noqa: D401
async def execute_code(req: CodeRequest, api_key: str = Depends(verify_api_key)): # noqa: D401
global process_pool # noqa: PLW0603

# Log input (truncate to 200 chars to avoid huge logs)
Expand Down Expand Up @@ -244,4 +280,21 @@ async def execute_code(req: CodeRequest): # noqa: D401

@app.get("/")
async def root(): # noqa: D401
return {"message": "Python Code Executor API — POST /execute {code, timeout}"}
host = os.getenv("HOST", "http://localhost:1212")

examples = f"""Python Code Executor API

Example usage:

curl -X POST {host}/execute \\
-H "Content-Type: application/json" \\
-H "X-API-Key: $OPEN_INSTRUCT_TOOL_API_KEY" \\
-d '{{"code": "print(1 + 1)", "timeout": 3}}'

curl -X POST {host}/execute \\
-H "Content-Type: application/json" \\
-H "X-API-Key: $OPEN_INSTRUCT_TOOL_API_KEY" \\
-d '{{"code": "import sympy; print(sympy.__version__)", "timeout": 3}}'
"""

return {"message": examples}
17 changes: 11 additions & 6 deletions open_instruct/tool_utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class PythonCodeTool(Tool):
"""@vwxyzjn: I recommend using something like a FastAPI for this kind of stuff; 1) you
won't accidentally block the main vLLM process and 2) way easier to parallelize via load balancing."""

def __init__(self, api_endpoint: str, *args, **kwargs):
def __init__(self, api_endpoint: str, api_key: str = None, *args, **kwargs):
self.api_endpoint = api_endpoint
self.api_key = api_key
super().__init__(*args, **kwargs)

def __call__(self, prompt: str) -> ToolOutput:
Expand Down Expand Up @@ -79,17 +80,21 @@ def find_sum_of_a():
timeout_seconds = 3
start_time = time.time()
try:
# Call the FastAPI endpoint to execute the code with client-side timeout
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key

response = requests.post(
self.api_endpoint,
json={"code": code, "timeout": timeout_seconds}, # Server-side timeout (keeping this)
timeout=timeout_seconds, # Client-side timeout
json={"code": code, "timeout": timeout_seconds},
headers=headers,
timeout=timeout_seconds,
)

# Parse the response
response.raise_for_status()

result = response.json()

# Process the API response
output = result["output"]
error = result.get("error") or ""

Expand Down
Loading