Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/tinker/lib/public_interfaces/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFutu

parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
Expand All @@ -418,7 +418,7 @@ async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> No

parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def get_telemetry(self) -> Telemetry | None:
Expand All @@ -439,7 +439,7 @@ def get_checkpoint_archive_url_from_tinker_path(
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
Expand All @@ -449,7 +449,7 @@ async def get_checkpoint_archive_url_from_tinker_path_async(
"""Async version of get_checkpoint_archive_url_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return await self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _publish_checkpoint_submit(
Expand Down Expand Up @@ -498,15 +498,15 @@ def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFut
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of publish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _unpublish_checkpoint_submit(
Expand Down Expand Up @@ -555,15 +555,15 @@ def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentF
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of unpublish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _set_checkpoint_ttl_submit(
Expand Down Expand Up @@ -615,7 +615,7 @@ def set_checkpoint_ttl_from_tinker_path(
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._set_checkpoint_ttl_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds
).future()

@capture_exceptions(fatal=True)
Expand All @@ -625,7 +625,7 @@ async def set_checkpoint_ttl_from_tinker_path_async(
"""Async version of set_checkpoint_ttl_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._set_checkpoint_ttl_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds
)

def _list_user_checkpoints_submit(
Expand Down
14 changes: 14 additions & 0 deletions src/tinker/types/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ class ParsedCheckpointTinkerPath(BaseModel):
checkpoint_id: str
"""The checkpoint ID"""

@property
def api_checkpoint_id(self) -> str:
"""Return the checkpoint ID formatted for API calls.

For training checkpoints: returns just the checkpoint number (e.g., '0001').
For sampler checkpoints: returns prefixed ID (e.g., 'sampler_weights/0001').
"""
if self.checkpoint_type == "training":
# Training checkpoints use just the number
return self.checkpoint_id.split("/")[-1]
else:
# Sampler checkpoints include the prefix
return self.checkpoint_id

@classmethod
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath"""
Expand Down