Skip to content

Support model deployment #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ [email protected]
# A GitHub token might be required for commiting to the private `agent-reinforcement-training` repository
GITHUB_TOKEN=YOUR_GITHUB_TOKEN

# HuggingFace Token (optional for most models, necessary for training gated models like Llama 3.1)
HF_TOKEN=YOUR_HUGGINGFACE_TOKEN

# Optional, OpenPipe API key
OPENPIPE_API_KEY=YOUR_OPENPIPE_API_KEY
# Optional, Together API key (used for deploying models to Together)
TOGETHER_API_KEY=YOUR_TOGETHER_API_KEY

# Optional, S3 configuration for log and model backups
AWS_ACCESS_KEY_ID=YOUR_AWS_ACCESS_KEY_ID
Expand Down
73 changes: 31 additions & 42 deletions examples/tic_tac_toe/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import os
from dotenv import load_dotenv

from pydantic import BaseModel
from openpipe.client import OpenPipe

from utils import (
Expand All @@ -24,10 +24,12 @@
op_client = OpenPipe(api_key=os.getenv("OPENPIPE_API_KEY"))


class TicTacToeScenario(BaseModel):
step: int


@art.retry(exceptions=(openai.LengthFinishReasonError,))
async def rollout(
model: art.Model, iteration: int, is_validation: bool
) -> art.Trajectory:
async def rollout(model: art.Model, scenario: TicTacToeScenario) -> art.Trajectory:
game = generate_game()

trajectory = art.Trajectory(
Expand Down Expand Up @@ -73,26 +75,6 @@ async def rollout(
failing_trajectory = trajectory
raise e

try:
op_client.report(
requested_at=requested_at,
received_at=int(time.time() * 1000),
req_payload={
"model": model.name,
"messages": messages,
"metadata": {
"notebook-id": "tic-tac-toe",
"iteration": str(iteration),
"validation": str(is_validation),
"move_number": str(move_number),
},
},
resp_payload=chat_completion,
status_code=200,
)
except Exception as e:
print(f"Error reporting to OpenPipe: {e}")

choice = chat_completion.choices[0]
content = choice.message.content
assert isinstance(content, str)
Expand All @@ -104,9 +86,9 @@ async def rollout(
trajectory.reward = -1 + (math.log(move_number + 1) / math.log(100))
break

move_number += 1
if check_winner(game["board"]) is not None:
break
move_number += 1

opponent_move = get_opponent_move(game)
game["board"][opponent_move[0]][opponent_move[1]] = game["opponent_symbol"]
Expand All @@ -125,22 +107,29 @@ async def rollout(

trajectory.metrics["num_moves"] = move_number

try:
op_client.update_log_metadata(
filters=[
{
"field": "completionId",
"equals": last_completion.id,
}
],
metadata={
"reward": str(trajectory.reward),
"reward_assigned": "true",
},
)
except Exception as e:
print(f"Error updating log metadata: {e}")

print(trajectory.reward)
if op_client.api_key:
try:
reported_win = (
trajectory.metrics["win"] if "win" in trajectory.metrics else -1
)
op_client.report(
requested_at=requested_at,
received_at=int(time.time() * 1000),
req_payload={
"model": model.name,
"messages": messages,
"metadata": {
"notebook-id": "tic-tac-toe",
"step": str(scenario.step),
"num_moves": str(move_number),
"win": str(reported_win),
"reward": str(trajectory.reward),
},
},
resp_payload=chat_completion,
status_code=200,
)
except Exception as e:
print(f"Error reporting to OpenPipe: {e}")

return trajectory
12 changes: 7 additions & 5 deletions examples/tic_tac_toe/tic-tac-toe-dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,14 @@
"import openai\n",
"import time\n",
"import math\n",
"from pydantic import BaseModel\n",
"\n",
"class TicTacToeScenario(BaseModel):\n",
" step: int\n",
"\n",
"@art.retry(exceptions=(openai.LengthFinishReasonError,))\n",
"async def rollout(\n",
" model: art.Model, iteration: int, is_validation: bool\n",
" model: art.Model, scenario: TicTacToeScenario\n",
") -> art.Trajectory:\n",
" game = generate_game()\n",
"\n",
Expand Down Expand Up @@ -303,8 +306,7 @@
" \"messages\": messages,\n",
" \"metadata\": {\n",
" \"notebook-id\": \"tic-tac-toe\",\n",
" \"iteration\": str(iteration),\n",
" \"validation\": str(is_validation),\n",
" \"step\": str(scenario.step),\n",
" \"move_number\": str(move_number),\n",
" },\n",
" },\n",
Expand Down Expand Up @@ -372,7 +374,7 @@
" train_groups = await art.gather_trajectory_groups(\n",
" (\n",
" art.TrajectoryGroup(\n",
" rollout(model, i, is_validation=False) for _ in range(48)\n",
" rollout(model, TicTacToeScenario(step=i)) for _ in range(48)\n",
" )\n",
" for _ in range(1)\n",
" ),\n",
Expand All @@ -393,7 +395,7 @@
"async def log_comparison_model(comparison_model: art.Model):\n",
" trajectories = await art.gather_trajectory_groups(\n",
" (\n",
" art.TrajectoryGroup(rollout(comparison_model, 0, is_validation=True) for _ in range(12))\n",
" art.TrajectoryGroup(rollout(comparison_model, TicTacToeScenario(step=0)) for _ in range(12))\n",
" for _ in range(1)\n",
" ),\n",
" pbar_desc=f\"gather {comparison_model.name}\",\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/tic_tac_toe/tic-tac-toe-local.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"\n",
"from art.utils.get_repo_root_path import get_repo_root_path\n",
"from art.local import LocalBackend\n",
"from rollout import rollout\n",
"from rollout import rollout, TicTacToeScenario\n",
"\n",
"load_dotenv()\n",
"\n",
Expand All @@ -47,7 +47,7 @@
" train_groups = await art.gather_trajectory_groups(\n",
" (\n",
" art.TrajectoryGroup(\n",
" rollout(model, i, is_validation=False) for _ in range(100)\n",
" rollout(model, TicTacToeScenario(step=i)) for _ in range(100)\n",
" )\n",
" for _ in range(1)\n",
" ),\n",
Expand Down Expand Up @@ -149,7 +149,7 @@
"async def log_comparison_model(comparison_model: art.Model):\n",
" trajectories = await art.gather_trajectory_groups(\n",
" (\n",
" art.TrajectoryGroup(rollout(comparison_model, 0, is_validation=True) for _ in range(40))\n",
" art.TrajectoryGroup(rollout(comparison_model, TicTacToeScenario(step=0)) for _ in range(40))\n",
" for _ in range(1)\n",
" ),\n",
" pbar_desc=f\"gather {comparison_model.name}\",\n",
Expand Down
54 changes: 47 additions & 7 deletions examples/tic_tac_toe/tic-tac-toe-local.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import random
import asyncio
from dotenv import load_dotenv

import art
from rollout import rollout
from art.utils.deploy_model import previously_deployed_model_name
from rollout import rollout, TicTacToeScenario
from art.local.backend import LocalBackend


Expand All @@ -12,31 +14,69 @@
random.seed(42)

DESTROY_AFTER_RUN = False
STEP = 35


async def main():
# run from the root of the repo
backend = LocalBackend(path="examples/tic_tac_toe/.art")
backend = LocalBackend()

model = art.TrainableModel(
name="001-script",
project="tic-tac-toe-local",
base_model="Qwen/Qwen2.5-3B-Instruct",
name="llama-8b-001",
project="tic-tac-toe",
base_model="meta-llama/Meta-Llama-3.1-8B-Instruct",
)
print("pulling from s3")
await backend._experimental_pull_from_s3(model)

print("registering")
await model.register(backend)

for i in range(await model.get_step(), 100):
print("training")
for i in range(await model.get_step(), STEP):
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(model, i, is_validation=False) for _ in range(200)
rollout(model, TicTacToeScenario(step=i)) for _ in range(48)
)
for _ in range(1)
),
pbar_desc="gather",
)
await model.delete_checkpoints()
await model.train(train_groups, config=art.TrainConfig(learning_rate=1e-4))
await backend._experimental_push_to_s3(model)

deployed_model_name = await previously_deployed_model_name(model, STEP)

if deployed_model_name:
print(f"skipping deployment because model {deployed_model_name} already exists")
else:
deployment_result = await backend._experimental_deploy(
deploy_to="together",
model=model,
step=STEP,
verbose=True,
pull_s3=False,
wait_for_completion=True,
)
if deployment_result.status == "Failed":
raise Exception(f"Deployment failed: {deployment_result.failure_reason}")

deployed_model_name = deployment_result.model_name

lora_model = art.Model(
name=deployed_model_name,
project="tic-tac-toe",
inference_api_key=os.environ["TOGETHER_API_KEY"],
inference_base_url="https://api.together.xyz/v1",
inference_model_name=deployed_model_name,
)

print("Starting a rollout using the deployed model!")
traj = await rollout(lora_model, TicTacToeScenario(step=0))

print(traj)

if DESTROY_AFTER_RUN:
await backend.down()
Expand Down
44 changes: 25 additions & 19 deletions examples/tic_tac_toe/tic-tac-toe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,9 @@
" trajectory.reward = -1 + (math.log(move_number + 1) / math.log(100))\n",
" break\n",
"\n",
" move_number += 1\n",
" if check_winner(game[\"board\"]) is not None:\n",
" break\n",
" move_number += 1\n",
"\n",
" opponent_move = get_opponent_move(game)\n",
" game[\"board\"][opponent_move[0]][opponent_move[1]] = game[\"opponent_symbol\"]\n",
Expand All @@ -434,24 +434,30 @@
"\n",
" trajectory.metrics[\"num_moves\"] = move_number\n",
"\n",
" try:\n",
" if op_client.api_key:\n",
" op_client.update_log_metadata(\n",
" filters=[\n",
" {\n",
" \"field\": \"completionId\",\n",
" \"equals\": last_completion.id,\n",
" }\n",
" ],\n",
" metadata={\n",
" \"reward\": str(trajectory.reward),\n",
" \"reward_assigned\": \"true\",\n",
" if op_client.api_key:\n",
" try:\n",
" reported_win = (\n",
" trajectory.metrics[\"win\"] if \"win\" in trajectory.metrics else -1\n",
" )\n",
" op_client.report(\n",
" requested_at=requested_at,\n",
" received_at=int(time.time() * 1000),\n",
" req_payload={\n",
" \"model\": model.name,\n",
" \"messages\": messages,\n",
" \"metadata\": {\n",
" \"notebook-id\": \"tic-tac-toe\",\n",
" \"step\": str(scenario.step),\n",
" \"num_moves\": str(move_number),\n",
" \"win\": str(reported_win),\n",
" \"reward\": str(trajectory.reward),\n",
" },\n",
" },\n",
" resp_payload=chat_completion,\n",
" status_code=200,\n",
" )\n",
" except Exception as e:\n",
" print(f\"Error updating log metadata: {e}\")\n",
"\n",
" print(trajectory.reward)\n",
" except Exception as e:\n",
" print(f\"Error reporting to OpenPipe: {e}\")\n",
"\n",
" return trajectory\n"
]
Expand All @@ -477,11 +483,11 @@
"metadata": {},
"outputs": [],
"source": [
"for i in range(await model.get_step(), 100):\n",
"for i in range(await model.get_step(), 50):\n",
" train_groups = await art.gather_trajectory_groups(\n",
" (\n",
" art.TrajectoryGroup(\n",
" rollout(model, TicTacToeScenario(step=i)) for _ in range(200)\n",
" rollout(model, TicTacToeScenario(step=i)) for _ in range(48)\n",
" )\n",
" for _ in range(1)\n",
" ),\n",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"awscli>=1.38.1",
"hf-xet>=1.1.0",
"panza",
"semver>=3.0.4",
]

[project.scripts]
Expand All @@ -50,7 +51,6 @@ dev-dependencies = [
"openpipe>=4.49.0",
"skypilot[aws,cudo,do,fluidstack,gcp,lambda,paperspace,runpod]>=0.8.0",
"hatch>=1.14.1",
"semver>=3.0.4",
]
override-dependencies = [
"bitsandbytes; sys_platform == 'linux'",
Expand Down
Loading