Skip to content

Commit d72b736

Browse files
authored
Support model deployment (#113)
* Add skeleton of _experimental_deploy * Get presigned url * Upload LoRAs to Together * Add deploy_model utils * Document HF tokens * Better OP reporting * Deploy models to Together * Document TOGETHER_API_KEY * Remove _experimental_deploy from Model * Document HF_TOKEN * Shorten ttt training loop * Introduce `LoRADeploymentJobStatus`, wait for deployment * Properly handle deployment failures * Move checking for previous deployment to `_experimental_deploy` * Add support for more Together statuses * Update tic-tac-toe.py * Possibly fix `_experimental_deploy` via cli * Parse backend in tic tac toe * Deal with more Together errors * Actually return deployment job * deploy_model TYPE_CHECKING
1 parent aec08f1 commit d72b736

15 files changed

+3928
-3500
lines changed

.env.example

+5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@ [email protected]
88
# A GitHub token might be required for commiting to the private `agent-reinforcement-training` repository
99
GITHUB_TOKEN=YOUR_GITHUB_TOKEN
1010

11+
# HuggingFace Token (optional for most models, necessary for training gated models like Llama 3.1)
12+
HF_TOKEN=YOUR_HUGGINGFACE_TOKEN
13+
1114
# Optional, OpenPipe API key
1215
OPENPIPE_API_KEY=YOUR_OPENPIPE_API_KEY
16+
# Optional, Together API key (used for deploying models to Together)
17+
TOGETHER_API_KEY=YOUR_TOGETHER_API_KEY
1318

1419
# Optional, S3 configuration for log and model backups
1520
AWS_ACCESS_KEY_ID=YOUR_AWS_ACCESS_KEY_ID

examples/tic_tac_toe/rollout.py

+31-42
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
import os
66
from dotenv import load_dotenv
7-
7+
from pydantic import BaseModel
88
from openpipe.client import OpenPipe
99

1010
from utils import (
@@ -24,10 +24,12 @@
2424
op_client = OpenPipe(api_key=os.getenv("OPENPIPE_API_KEY"))
2525

2626

27+
class TicTacToeScenario(BaseModel):
28+
step: int
29+
30+
2731
@art.retry(exceptions=(openai.LengthFinishReasonError,))
28-
async def rollout(
29-
model: art.Model, iteration: int, is_validation: bool
30-
) -> art.Trajectory:
32+
async def rollout(model: art.Model, scenario: TicTacToeScenario) -> art.Trajectory:
3133
game = generate_game()
3234

3335
trajectory = art.Trajectory(
@@ -73,26 +75,6 @@ async def rollout(
7375
failing_trajectory = trajectory
7476
raise e
7577

76-
try:
77-
op_client.report(
78-
requested_at=requested_at,
79-
received_at=int(time.time() * 1000),
80-
req_payload={
81-
"model": model.name,
82-
"messages": messages,
83-
"metadata": {
84-
"notebook-id": "tic-tac-toe",
85-
"iteration": str(iteration),
86-
"validation": str(is_validation),
87-
"move_number": str(move_number),
88-
},
89-
},
90-
resp_payload=chat_completion,
91-
status_code=200,
92-
)
93-
except Exception as e:
94-
print(f"Error reporting to OpenPipe: {e}")
95-
9678
choice = chat_completion.choices[0]
9779
content = choice.message.content
9880
assert isinstance(content, str)
@@ -104,9 +86,9 @@ async def rollout(
10486
trajectory.reward = -1 + (math.log(move_number + 1) / math.log(100))
10587
break
10688

89+
move_number += 1
10790
if check_winner(game["board"]) is not None:
10891
break
109-
move_number += 1
11092

11193
opponent_move = get_opponent_move(game)
11294
game["board"][opponent_move[0]][opponent_move[1]] = game["opponent_symbol"]
@@ -125,22 +107,29 @@ async def rollout(
125107

126108
trajectory.metrics["num_moves"] = move_number
127109

128-
try:
129-
op_client.update_log_metadata(
130-
filters=[
131-
{
132-
"field": "completionId",
133-
"equals": last_completion.id,
134-
}
135-
],
136-
metadata={
137-
"reward": str(trajectory.reward),
138-
"reward_assigned": "true",
139-
},
140-
)
141-
except Exception as e:
142-
print(f"Error updating log metadata: {e}")
143-
144-
print(trajectory.reward)
110+
if op_client.api_key:
111+
try:
112+
reported_win = (
113+
trajectory.metrics["win"] if "win" in trajectory.metrics else -1
114+
)
115+
op_client.report(
116+
requested_at=requested_at,
117+
received_at=int(time.time() * 1000),
118+
req_payload={
119+
"model": model.name,
120+
"messages": messages,
121+
"metadata": {
122+
"notebook-id": "tic-tac-toe",
123+
"step": str(scenario.step),
124+
"num_moves": str(move_number),
125+
"win": str(reported_win),
126+
"reward": str(trajectory.reward),
127+
},
128+
},
129+
resp_payload=chat_completion,
130+
status_code=200,
131+
)
132+
except Exception as e:
133+
print(f"Error reporting to OpenPipe: {e}")
145134

146135
return trajectory

examples/tic_tac_toe/tic-tac-toe-dev.ipynb

+7-5
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,14 @@
243243
"import openai\n",
244244
"import time\n",
245245
"import math\n",
246+
"from pydantic import BaseModel\n",
246247
"\n",
248+
"class TicTacToeScenario(BaseModel):\n",
249+
" step: int\n",
247250
"\n",
248251
"@art.retry(exceptions=(openai.LengthFinishReasonError,))\n",
249252
"async def rollout(\n",
250-
" model: art.Model, iteration: int, is_validation: bool\n",
253+
" model: art.Model, scenario: TicTacToeScenario\n",
251254
") -> art.Trajectory:\n",
252255
" game = generate_game()\n",
253256
"\n",
@@ -303,8 +306,7 @@
303306
" \"messages\": messages,\n",
304307
" \"metadata\": {\n",
305308
" \"notebook-id\": \"tic-tac-toe\",\n",
306-
" \"iteration\": str(iteration),\n",
307-
" \"validation\": str(is_validation),\n",
309+
" \"step\": str(scenario.step),\n",
308310
" \"move_number\": str(move_number),\n",
309311
" },\n",
310312
" },\n",
@@ -372,7 +374,7 @@
372374
" train_groups = await art.gather_trajectory_groups(\n",
373375
" (\n",
374376
" art.TrajectoryGroup(\n",
375-
" rollout(model, i, is_validation=False) for _ in range(48)\n",
377+
" rollout(model, TicTacToeScenario(step=i)) for _ in range(48)\n",
376378
" )\n",
377379
" for _ in range(1)\n",
378380
" ),\n",
@@ -393,7 +395,7 @@
393395
"async def log_comparison_model(comparison_model: art.Model):\n",
394396
" trajectories = await art.gather_trajectory_groups(\n",
395397
" (\n",
396-
" art.TrajectoryGroup(rollout(comparison_model, 0, is_validation=True) for _ in range(12))\n",
398+
" art.TrajectoryGroup(rollout(comparison_model, TicTacToeScenario(step=0)) for _ in range(12))\n",
397399
" for _ in range(1)\n",
398400
" ),\n",
399401
" pbar_desc=f\"gather {comparison_model.name}\",\n",

examples/tic_tac_toe/tic-tac-toe-local.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"\n",
2121
"from art.utils.get_repo_root_path import get_repo_root_path\n",
2222
"from art.local import LocalBackend\n",
23-
"from rollout import rollout\n",
23+
"from rollout import rollout, TicTacToeScenario\n",
2424
"\n",
2525
"load_dotenv()\n",
2626
"\n",
@@ -47,7 +47,7 @@
4747
" train_groups = await art.gather_trajectory_groups(\n",
4848
" (\n",
4949
" art.TrajectoryGroup(\n",
50-
" rollout(model, i, is_validation=False) for _ in range(100)\n",
50+
" rollout(model, TicTacToeScenario(step=i)) for _ in range(100)\n",
5151
" )\n",
5252
" for _ in range(1)\n",
5353
" ),\n",
@@ -149,7 +149,7 @@
149149
"async def log_comparison_model(comparison_model: art.Model):\n",
150150
" trajectories = await art.gather_trajectory_groups(\n",
151151
" (\n",
152-
" art.TrajectoryGroup(rollout(comparison_model, 0, is_validation=True) for _ in range(40))\n",
152+
" art.TrajectoryGroup(rollout(comparison_model, TicTacToeScenario(step=0)) for _ in range(40))\n",
153153
" for _ in range(1)\n",
154154
" ),\n",
155155
" pbar_desc=f\"gather {comparison_model.name}\",\n",

examples/tic_tac_toe/tic-tac-toe-local.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import os
12
import random
23
import asyncio
34
from dotenv import load_dotenv
45

56
import art
6-
from rollout import rollout
7+
from rollout import rollout, TicTacToeScenario
78
from art.local.backend import LocalBackend
89

910

@@ -12,31 +13,64 @@
1213
random.seed(42)
1314

1415
DESTROY_AFTER_RUN = False
16+
STEP = 36
1517

1618

1719
async def main():
1820
# run from the root of the repo
19-
backend = LocalBackend(path="examples/tic_tac_toe/.art")
21+
backend = LocalBackend()
2022

2123
model = art.TrainableModel(
22-
name="001-script",
23-
project="tic-tac-toe-local",
24-
base_model="Qwen/Qwen2.5-3B-Instruct",
24+
name="llama-8b-001",
25+
project="tic-tac-toe",
26+
base_model="meta-llama/Meta-Llama-3.1-8B-Instruct",
2527
)
28+
print("pulling from s3")
29+
await backend._experimental_pull_from_s3(model)
30+
31+
print("registering")
2632
await model.register(backend)
2733

28-
for i in range(await model.get_step(), 100):
34+
print("training")
35+
for i in range(await model.get_step(), STEP):
2936
train_groups = await art.gather_trajectory_groups(
3037
(
3138
art.TrajectoryGroup(
32-
rollout(model, i, is_validation=False) for _ in range(200)
39+
rollout(model, TicTacToeScenario(step=i)) for _ in range(48)
3340
)
3441
for _ in range(1)
3542
),
3643
pbar_desc="gather",
3744
)
3845
await model.delete_checkpoints()
3946
await model.train(train_groups, config=art.TrainConfig(learning_rate=1e-4))
47+
await backend._experimental_push_to_s3(model)
48+
49+
deployment_result = await backend._experimental_deploy(
50+
deploy_to="together",
51+
model=model,
52+
step=STEP,
53+
verbose=True,
54+
pull_s3=False,
55+
wait_for_completion=True,
56+
)
57+
if deployment_result.status == "Failed":
58+
raise Exception(f"Deployment failed: {deployment_result.failure_reason}")
59+
60+
deployed_model_name = deployment_result.model_name
61+
62+
lora_model = art.Model(
63+
name=deployed_model_name,
64+
project="tic-tac-toe",
65+
inference_api_key=os.environ["TOGETHER_API_KEY"],
66+
inference_base_url="https://api.together.xyz/v1",
67+
inference_model_name=deployed_model_name,
68+
)
69+
70+
print("Starting a rollout using the deployed model!")
71+
traj = await rollout(lora_model, TicTacToeScenario(step=0))
72+
73+
print(traj)
4074

4175
if DESTROY_AFTER_RUN:
4276
await backend.down()

examples/tic_tac_toe/tic-tac-toe.ipynb

+25-19
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@
413413
" trajectory.reward = -1 + (math.log(move_number + 1) / math.log(100))\n",
414414
" break\n",
415415
"\n",
416+
" move_number += 1\n",
416417
" if check_winner(game[\"board\"]) is not None:\n",
417418
" break\n",
418-
" move_number += 1\n",
419419
"\n",
420420
" opponent_move = get_opponent_move(game)\n",
421421
" game[\"board\"][opponent_move[0]][opponent_move[1]] = game[\"opponent_symbol\"]\n",
@@ -434,24 +434,30 @@
434434
"\n",
435435
" trajectory.metrics[\"num_moves\"] = move_number\n",
436436
"\n",
437-
" try:\n",
438-
" if op_client.api_key:\n",
439-
" op_client.update_log_metadata(\n",
440-
" filters=[\n",
441-
" {\n",
442-
" \"field\": \"completionId\",\n",
443-
" \"equals\": last_completion.id,\n",
444-
" }\n",
445-
" ],\n",
446-
" metadata={\n",
447-
" \"reward\": str(trajectory.reward),\n",
448-
" \"reward_assigned\": \"true\",\n",
437+
" if op_client.api_key:\n",
438+
" try:\n",
439+
" reported_win = (\n",
440+
" trajectory.metrics[\"win\"] if \"win\" in trajectory.metrics else -1\n",
441+
" )\n",
442+
" op_client.report(\n",
443+
" requested_at=requested_at,\n",
444+
" received_at=int(time.time() * 1000),\n",
445+
" req_payload={\n",
446+
" \"model\": model.name,\n",
447+
" \"messages\": messages,\n",
448+
" \"metadata\": {\n",
449+
" \"notebook-id\": \"tic-tac-toe\",\n",
450+
" \"step\": str(scenario.step),\n",
451+
" \"num_moves\": str(move_number),\n",
452+
" \"win\": str(reported_win),\n",
453+
" \"reward\": str(trajectory.reward),\n",
454+
" },\n",
449455
" },\n",
456+
" resp_payload=chat_completion,\n",
457+
" status_code=200,\n",
450458
" )\n",
451-
" except Exception as e:\n",
452-
" print(f\"Error updating log metadata: {e}\")\n",
453-
"\n",
454-
" print(trajectory.reward)\n",
459+
" except Exception as e:\n",
460+
" print(f\"Error reporting to OpenPipe: {e}\")\n",
455461
"\n",
456462
" return trajectory\n"
457463
]
@@ -477,11 +483,11 @@
477483
"metadata": {},
478484
"outputs": [],
479485
"source": [
480-
"for i in range(await model.get_step(), 100):\n",
486+
"for i in range(await model.get_step(), 50):\n",
481487
" train_groups = await art.gather_trajectory_groups(\n",
482488
" (\n",
483489
" art.TrajectoryGroup(\n",
484-
" rollout(model, TicTacToeScenario(step=i)) for _ in range(200)\n",
490+
" rollout(model, TicTacToeScenario(step=i)) for _ in range(48)\n",
485491
" )\n",
486492
" for _ in range(1)\n",
487493
" ),\n",

0 commit comments

Comments
 (0)