Skip to content

Commit fbc32ea

Browse files
authored
Use /checkpoints instead of events parsing (#312)
* Use /checkpoints instead of events parsing * Fix events listing for async client * Change the parsing logic due to the api changes * Fix string formatting * Parsing updated * Formatting * Remove old implementation
1 parent 72c5066 commit fbc32ea

File tree

2 files changed

+57
-72
lines changed

2 files changed

+57
-72
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.7"
15+
version = "1.5.8"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/resources/finetune.py

Lines changed: 56 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from pathlib import Path
5-
from typing import List, Literal
5+
from typing import Dict, List, Literal
66

77
from rich import print as rprint
88

@@ -30,16 +30,8 @@
3030
TrainingMethodSFT,
3131
TrainingType,
3232
)
33-
from together.types.finetune import (
34-
DownloadCheckpointType,
35-
FinetuneEvent,
36-
FinetuneEventType,
37-
)
38-
from together.utils import (
39-
get_event_step,
40-
log_warn_once,
41-
normalize_key,
42-
)
33+
from together.types.finetune import DownloadCheckpointType
34+
from together.utils import log_warn_once, normalize_key
4335

4436

4537
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
@@ -222,68 +214,38 @@ def create_finetune_request(
222214
return finetune_request
223215

224216

225-
def _process_checkpoints_from_events(
226-
events: List[FinetuneEvent], id: str
217+
def _parse_raw_checkpoints(
218+
checkpoints: List[Dict[str, str]], id: str
227219
) -> List[FinetuneCheckpoint]:
228220
"""
229-
Helper function to process events and create checkpoint list.
221+
Helper function to process raw checkpoints and create checkpoint list.
230222
231223
Args:
232-
events (List[FinetuneEvent]): List of fine-tune events to process
224+
checkpoints (List[Dict[str, str]]): List of raw checkpoints metadata
233225
id (str): Fine-tune job ID
234226
235227
Returns:
236228
List[FinetuneCheckpoint]: List of available checkpoints
237229
"""
238-
checkpoints: List[FinetuneCheckpoint] = []
239-
240-
for event in events:
241-
event_type = event.type
242-
243-
if event_type == FinetuneEventType.CHECKPOINT_SAVE:
244-
step = get_event_step(event)
245-
checkpoint_name = f"{id}:{step}" if step is not None else id
246-
247-
checkpoints.append(
248-
FinetuneCheckpoint(
249-
type=(
250-
f"Intermediate (step {step})"
251-
if step is not None
252-
else "Intermediate"
253-
),
254-
timestamp=event.created_at,
255-
name=checkpoint_name,
256-
)
257-
)
258-
elif event_type == FinetuneEventType.JOB_COMPLETE:
259-
if hasattr(event, "model_path"):
260-
checkpoints.append(
261-
FinetuneCheckpoint(
262-
type=(
263-
"Final Merged"
264-
if hasattr(event, "adapter_path")
265-
else "Final"
266-
),
267-
timestamp=event.created_at,
268-
name=id,
269-
)
270-
)
271230

272-
if hasattr(event, "adapter_path"):
273-
checkpoints.append(
274-
FinetuneCheckpoint(
275-
type=(
276-
"Final Adapter" if hasattr(event, "model_path") else "Final"
277-
),
278-
timestamp=event.created_at,
279-
name=id,
280-
)
281-
)
231+
parsed_checkpoints = []
232+
for checkpoint in checkpoints:
233+
step = checkpoint["step"]
234+
checkpoint_type = checkpoint["checkpoint_type"]
235+
checkpoint_name = (
236+
f"{id}:{step}" if "intermediate" in checkpoint_type.lower() else id
237+
)
282238

283-
# Sort by timestamp (newest first)
284-
checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
239+
parsed_checkpoints.append(
240+
FinetuneCheckpoint(
241+
type=checkpoint_type,
242+
timestamp=checkpoint["created_at"],
243+
name=checkpoint_name,
244+
)
245+
)
285246

286-
return checkpoints
247+
parsed_checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
248+
return parsed_checkpoints
287249

288250

289251
class FineTuning:
@@ -561,8 +523,21 @@ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
561523
Returns:
562524
List[FinetuneCheckpoint]: List of available checkpoints
563525
"""
564-
events = self.list_events(id).data or []
565-
return _process_checkpoints_from_events(events, id)
526+
requestor = api_requestor.APIRequestor(
527+
client=self._client,
528+
)
529+
530+
response, _, _ = requestor.request(
531+
options=TogetherRequest(
532+
method="GET",
533+
url=f"fine-tunes/{id}/checkpoints",
534+
),
535+
stream=False,
536+
)
537+
assert isinstance(response, TogetherResponse)
538+
539+
raw_checkpoints = response.data["data"]
540+
return _parse_raw_checkpoints(raw_checkpoints, id)
566541

567542
def download(
568543
self,
@@ -936,11 +911,9 @@ async def list_events(self, id: str) -> FinetuneListEvents:
936911
),
937912
stream=False,
938913
)
914+
assert isinstance(events_response, TogetherResponse)
939915

940-
# FIXME: API returns "data" field with no object type (should be "list")
941-
events_list = FinetuneListEvents(object="list", **events_response.data)
942-
943-
return events_list
916+
return FinetuneListEvents(**events_response.data)
944917

945918
async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
946919
"""
@@ -950,11 +923,23 @@ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
950923
id (str): Unique identifier of the fine-tune job to list checkpoints for
951924
952925
Returns:
953-
List[FinetuneCheckpoint]: Object containing list of available checkpoints
926+
List[FinetuneCheckpoint]: List of available checkpoints
954927
"""
955-
events_list = await self.list_events(id)
956-
events = events_list.data or []
957-
return _process_checkpoints_from_events(events, id)
928+
requestor = api_requestor.APIRequestor(
929+
client=self._client,
930+
)
931+
932+
response, _, _ = await requestor.arequest(
933+
options=TogetherRequest(
934+
method="GET",
935+
url=f"fine-tunes/{id}/checkpoints",
936+
),
937+
stream=False,
938+
)
939+
assert isinstance(response, TogetherResponse)
940+
941+
raw_checkpoints = response.data["data"]
942+
return _parse_raw_checkpoints(raw_checkpoints, id)
958943

959944
async def download(
960945
self, id: str, *, output: str | None = None, checkpoint_step: int = -1

0 commit comments

Comments
 (0)