2
2
3
3
import re
4
4
from pathlib import Path
5
- from typing import List , Literal
5
+ from typing import Dict , List , Literal
6
6
7
7
from rich import print as rprint
8
8
30
30
TrainingMethodSFT ,
31
31
TrainingType ,
32
32
)
33
- from together .types .finetune import (
34
- DownloadCheckpointType ,
35
- FinetuneEvent ,
36
- FinetuneEventType ,
37
- )
38
- from together .utils import (
39
- get_event_step ,
40
- log_warn_once ,
41
- normalize_key ,
42
- )
33
+ from together .types .finetune import DownloadCheckpointType
34
+ from together .utils import log_warn_once , normalize_key
43
35
44
36
45
37
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
@@ -222,68 +214,38 @@ def create_finetune_request(
222
214
return finetune_request
223
215
224
216
225
- def _process_checkpoints_from_events (
226
- events : List [FinetuneEvent ], id : str
217
+ def _parse_raw_checkpoints (
218
+ checkpoints : List [Dict [ str , str ] ], id : str
227
219
) -> List [FinetuneCheckpoint ]:
228
220
"""
229
- Helper function to process events and create checkpoint list.
221
+ Helper function to process raw checkpoints and create checkpoint list.
230
222
231
223
Args:
232
- events (List[FinetuneEvent] ): List of fine-tune events to process
224
+ checkpoints (List[Dict[str, str]] ): List of raw checkpoints metadata
233
225
id (str): Fine-tune job ID
234
226
235
227
Returns:
236
228
List[FinetuneCheckpoint]: List of available checkpoints
237
229
"""
238
- checkpoints : List [FinetuneCheckpoint ] = []
239
-
240
- for event in events :
241
- event_type = event .type
242
-
243
- if event_type == FinetuneEventType .CHECKPOINT_SAVE :
244
- step = get_event_step (event )
245
- checkpoint_name = f"{ id } :{ step } " if step is not None else id
246
-
247
- checkpoints .append (
248
- FinetuneCheckpoint (
249
- type = (
250
- f"Intermediate (step { step } )"
251
- if step is not None
252
- else "Intermediate"
253
- ),
254
- timestamp = event .created_at ,
255
- name = checkpoint_name ,
256
- )
257
- )
258
- elif event_type == FinetuneEventType .JOB_COMPLETE :
259
- if hasattr (event , "model_path" ):
260
- checkpoints .append (
261
- FinetuneCheckpoint (
262
- type = (
263
- "Final Merged"
264
- if hasattr (event , "adapter_path" )
265
- else "Final"
266
- ),
267
- timestamp = event .created_at ,
268
- name = id ,
269
- )
270
- )
271
230
272
- if hasattr (event , "adapter_path" ):
273
- checkpoints .append (
274
- FinetuneCheckpoint (
275
- type = (
276
- "Final Adapter" if hasattr (event , "model_path" ) else "Final"
277
- ),
278
- timestamp = event .created_at ,
279
- name = id ,
280
- )
281
- )
231
+ parsed_checkpoints = []
232
+ for checkpoint in checkpoints :
233
+ step = checkpoint ["step" ]
234
+ checkpoint_type = checkpoint ["checkpoint_type" ]
235
+ checkpoint_name = (
236
+ f"{ id } :{ step } " if "intermediate" in checkpoint_type .lower () else id
237
+ )
282
238
283
- # Sort by timestamp (newest first)
284
- checkpoints .sort (key = lambda x : x .timestamp , reverse = True )
239
+ parsed_checkpoints .append (
240
+ FinetuneCheckpoint (
241
+ type = checkpoint_type ,
242
+ timestamp = checkpoint ["created_at" ],
243
+ name = checkpoint_name ,
244
+ )
245
+ )
285
246
286
- return checkpoints
247
+ parsed_checkpoints .sort (key = lambda x : x .timestamp , reverse = True )
248
+ return parsed_checkpoints
287
249
288
250
289
251
class FineTuning :
@@ -561,8 +523,21 @@ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
561
523
Returns:
562
524
List[FinetuneCheckpoint]: List of available checkpoints
563
525
"""
564
- events = self .list_events (id ).data or []
565
- return _process_checkpoints_from_events (events , id )
526
+ requestor = api_requestor .APIRequestor (
527
+ client = self ._client ,
528
+ )
529
+
530
+ response , _ , _ = requestor .request (
531
+ options = TogetherRequest (
532
+ method = "GET" ,
533
+ url = f"fine-tunes/{ id } /checkpoints" ,
534
+ ),
535
+ stream = False ,
536
+ )
537
+ assert isinstance (response , TogetherResponse )
538
+
539
+ raw_checkpoints = response .data ["data" ]
540
+ return _parse_raw_checkpoints (raw_checkpoints , id )
566
541
567
542
def download (
568
543
self ,
@@ -936,11 +911,9 @@ async def list_events(self, id: str) -> FinetuneListEvents:
936
911
),
937
912
stream = False ,
938
913
)
914
+ assert isinstance (events_response , TogetherResponse )
939
915
940
- # FIXME: API returns "data" field with no object type (should be "list")
941
- events_list = FinetuneListEvents (object = "list" , ** events_response .data )
942
-
943
- return events_list
916
+ return FinetuneListEvents (** events_response .data )
944
917
945
918
async def list_checkpoints (self , id : str ) -> List [FinetuneCheckpoint ]:
946
919
"""
@@ -950,11 +923,23 @@ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
950
923
id (str): Unique identifier of the fine-tune job to list checkpoints for
951
924
952
925
Returns:
953
- List[FinetuneCheckpoint]: Object containing list of available checkpoints
926
+ List[FinetuneCheckpoint]: List of available checkpoints
954
927
"""
955
- events_list = await self .list_events (id )
956
- events = events_list .data or []
957
- return _process_checkpoints_from_events (events , id )
928
+ requestor = api_requestor .APIRequestor (
929
+ client = self ._client ,
930
+ )
931
+
932
+ response , _ , _ = await requestor .arequest (
933
+ options = TogetherRequest (
934
+ method = "GET" ,
935
+ url = f"fine-tunes/{ id } /checkpoints" ,
936
+ ),
937
+ stream = False ,
938
+ )
939
+ assert isinstance (response , TogetherResponse )
940
+
941
+ raw_checkpoints = response .data ["data" ]
942
+ return _parse_raw_checkpoints (raw_checkpoints , id )
958
943
959
944
async def download (
960
945
self , id : str , * , output : str | None = None , checkpoint_step : int = - 1
0 commit comments