Skip to content

Commit 271af98

Browse files
authored
Improve fine-tuning checkpoint support (#259)
1 parent 583de09 commit 271af98

File tree

10 files changed

+271
-27
lines changed

10 files changed

+271
-27
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ print(response.choices[0].message.content)
6767
response = client.chat.completions.create(
6868
model="meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
6969
messages=[{
70-
"role": "user",
70+
"role": "user",
7171
"content": [
7272
{
7373
"type": "text",
@@ -91,7 +91,7 @@ response = client.chat.completions.create(
9191
"role": "user",
9292
"content": [
9393
{
94-
"type": "text",
94+
"type": "text",
9595
"text": "Compare these two images."
9696
},
9797
{

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.4.2"
15+
version = "1.4.3"
1616
authors = [
1717
"Together AI <[email protected]>"
1818
]

src/together/cli/api/finetune.py

+51-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import json
4-
from datetime import datetime
4+
from datetime import datetime, timezone
55
from textwrap import wrap
66
from typing import Any, Literal
7+
import re
78

89
import click
910
from click.core import ParameterSource # type: ignore[attr-defined]
@@ -17,8 +18,13 @@
1718
log_warn,
1819
log_warn_once,
1920
parse_timestamp,
21+
format_timestamp,
22+
)
23+
from together.types.finetune import (
24+
DownloadCheckpointType,
25+
FinetuneTrainingLimits,
26+
FinetuneEventType,
2027
)
21-
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
2228

2329

2430
_CONFIRMATION_MESSAGE = (
@@ -126,6 +132,14 @@ def fine_tuning(ctx: click.Context) -> None:
126132
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
127133
"`auto` will automatically determine whether to mask the inputs based on the data format.",
128134
)
135+
@click.option(
136+
"--from-checkpoint",
137+
type=str,
138+
default=None,
139+
help="The checkpoint identifier to continue training from a previous fine-tuning job. "
140+
"The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. "
141+
"The step value is optional, without it the final checkpoint will be used.",
142+
)
129143
def create(
130144
ctx: click.Context,
131145
training_file: str,
@@ -152,6 +166,7 @@ def create(
152166
wandb_name: str,
153167
confirm: bool,
154168
train_on_inputs: bool | Literal["auto"],
169+
from_checkpoint: str,
155170
) -> None:
156171
"""Start fine-tuning"""
157172
client: Together = ctx.obj
@@ -180,6 +195,7 @@ def create(
180195
wandb_project_name=wandb_project_name,
181196
wandb_name=wandb_name,
182197
train_on_inputs=train_on_inputs,
198+
from_checkpoint=from_checkpoint,
183199
)
184200

185201
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
@@ -261,7 +277,9 @@ def list(ctx: click.Context) -> None:
261277

262278
response.data = response.data or []
263279

264-
response.data.sort(key=lambda x: parse_timestamp(x.created_at or ""))
280+
# Use a default datetime for None values to make sure the key function always returns a comparable value
281+
epoch_start = datetime.fromtimestamp(0, tz=timezone.utc)
282+
response.data.sort(key=lambda x: parse_timestamp(x.created_at or "") or epoch_start)
265283

266284
display_list = []
267285
for i in response.data:
@@ -344,6 +362,34 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
344362
click.echo(table)
345363

346364

365+
@fine_tuning.command()
366+
@click.pass_context
367+
@click.argument("fine_tune_id", type=str, required=True)
368+
def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None:
369+
"""List available checkpoints for a fine-tuning job"""
370+
client: Together = ctx.obj
371+
372+
checkpoints = client.fine_tuning.list_checkpoints(fine_tune_id)
373+
374+
display_list = []
375+
for checkpoint in checkpoints:
376+
display_list.append(
377+
{
378+
"Type": checkpoint.type,
379+
"Timestamp": format_timestamp(checkpoint.timestamp),
380+
"Name": checkpoint.name,
381+
}
382+
)
383+
384+
if display_list:
385+
click.echo(f"Job {fine_tune_id} contains the following checkpoints:")
386+
table = tabulate(display_list, headers="keys", tablefmt="grid")
387+
click.echo(table)
388+
click.echo("\nTo download a checkpoint, use `together fine-tuning download`")
389+
else:
390+
click.echo(f"No checkpoints found for job {fine_tune_id}")
391+
392+
347393
@fine_tuning.command()
348394
@click.pass_context
349395
@click.argument("fine_tune_id", type=str, required=True)
@@ -358,7 +404,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
358404
"--checkpoint-step",
359405
type=int,
360406
required=False,
361-
default=-1,
407+
default=None,
362408
help="Download fine-tuning checkpoint. Defaults to latest.",
363409
)
364410
@click.option(
@@ -372,7 +418,7 @@ def download(
372418
ctx: click.Context,
373419
fine_tune_id: str,
374420
output_dir: str,
375-
checkpoint_step: int,
421+
checkpoint_step: int | None,
376422
checkpoint_type: DownloadCheckpointType,
377423
) -> None:
378424
"""Download fine-tuning checkpoint"""

src/together/legacy/finetune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def download(
161161
cls,
162162
fine_tune_id: str,
163163
output: str | None = None,
164-
step: int = -1,
164+
step: int | None = None,
165165
) -> Dict[str, Any]:
166166
"""Legacy finetuning download function."""
167167

0 commit comments

Comments
 (0)