From e89fb6b5933982a63daeba9e844ea3d4abe4550a Mon Sep 17 00:00:00 2001 From: RoomWithOutRoof Date: Thu, 16 Apr 2026 02:16:47 +0800 Subject: [PATCH] fix: use api_checkpoint_id for sampler_weights loading - Added api_checkpoint_id property to ParsedCheckpointTinkerPath - For training checkpoints: returns just the checkpoint number (e.g., '0001') - For sampler checkpoints: returns prefixed ID (e.g., 'sampler_weights/0001') - Updated rest_client.py to use api_checkpoint_id instead of checkpoint_id --- .../lib/public_interfaces/rest_client.py | 20 +++++++++---------- src/tinker/types/checkpoint.py | 14 +++++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/tinker/lib/public_interfaces/rest_client.py b/src/tinker/lib/public_interfaces/rest_client.py index f8e0dda..97704e5 100644 --- a/src/tinker/lib/public_interfaces/rest_client.py +++ b/src/tinker/lib/public_interfaces/rest_client.py @@ -409,7 +409,7 @@ def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFutu parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._delete_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -418,7 +418,7 @@ async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> No parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._delete_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def get_telemetry(self) -> Telemetry | None: @@ -439,7 +439,7 @@ def get_checkpoint_archive_url_from_tinker_path( """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._get_checkpoint_archive_url_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -449,7 +449,7 @@ async def get_checkpoint_archive_url_from_tinker_path_async( """Async version of get_checkpoint_archive_url_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return await self._get_checkpoint_archive_url_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _publish_checkpoint_submit( @@ -498,7 +498,7 @@ def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFut """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._publish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -506,7 +506,7 @@ async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> N """Async version of publish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._publish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _unpublish_checkpoint_submit( @@ -555,7 +555,7 @@ def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentF """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._unpublish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -563,7 +563,7 @@ async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> """Async version of unpublish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._unpublish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _set_checkpoint_ttl_submit( @@ -615,7 +615,7 @@ def set_checkpoint_ttl_from_tinker_path( """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._set_checkpoint_ttl_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds ).future() @capture_exceptions(fatal=True) @@ -625,7 +625,7 @@ async def set_checkpoint_ttl_from_tinker_path_async( """Async version of set_checkpoint_ttl_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._set_checkpoint_ttl_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds ) def _list_user_checkpoints_submit( diff --git a/src/tinker/types/checkpoint.py b/src/tinker/types/checkpoint.py index b6a5630..8144f09 100644 --- a/src/tinker/types/checkpoint.py +++ b/src/tinker/types/checkpoint.py @@ -44,6 +44,20 @@ class ParsedCheckpointTinkerPath(BaseModel): checkpoint_id: str """The checkpoint ID""" + @property + def api_checkpoint_id(self) -> str: + """Return the checkpoint ID formatted for API calls. + + For training checkpoints: returns just the checkpoint number (e.g., '0001'). + For sampler checkpoints: returns prefixed ID (e.g., 'sampler_weights/0001'). + """ + if self.checkpoint_type == "training": + # Training checkpoints use just the number + return self.checkpoint_id.split("/")[-1] + else: + # Sampler checkpoints include the prefix + return self.checkpoint_id + @classmethod def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath": """Parse a tinker path to an instance of ParsedCheckpointTinkerPath"""