Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/tinker/lib/public_interfaces/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFutu

parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
Expand All @@ -418,7 +418,7 @@ async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> No

parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def get_telemetry(self) -> Telemetry | None:
Expand All @@ -439,7 +439,7 @@ def get_checkpoint_archive_url_from_tinker_path(
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
Expand All @@ -449,7 +449,7 @@ async def get_checkpoint_archive_url_from_tinker_path_async(
"""Async version of get_checkpoint_archive_url_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return await self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _publish_checkpoint_submit(
Expand Down Expand Up @@ -498,15 +498,15 @@ def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFut
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of publish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _unpublish_checkpoint_submit(
Expand Down Expand Up @@ -555,15 +555,15 @@ def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentF
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of unpublish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _set_checkpoint_ttl_submit(
Expand Down Expand Up @@ -615,7 +615,7 @@ def set_checkpoint_ttl_from_tinker_path(
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._set_checkpoint_ttl_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds
).future()

@capture_exceptions(fatal=True)
Expand All @@ -625,7 +625,7 @@ async def set_checkpoint_ttl_from_tinker_path_async(
"""Async version of set_checkpoint_ttl_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._set_checkpoint_ttl_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds
)

def _list_user_checkpoints_submit(
Expand Down
18 changes: 16 additions & 2 deletions src/tinker/types/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,24 @@ class ParsedCheckpointTinkerPath(BaseModel):
"""The training run ID"""

checkpoint_type: CheckpointType
"""The type of checkpoint (training or sampler)"""
""""The type of checkpoint (training or sampler)"""

checkpoint_id: str
"""The checkpoint ID"""
"""The checkpoint ID (includes type prefix for sampler checkpoints)"""

@property
def api_checkpoint_id(self) -> str:
"""Checkpoint ID formatted for API calls.

For training checkpoints: returns just the checkpoint number (e.g., '0001').
For sampler checkpoints: returns the prefixed ID (e.g., 'sampler_weights/0001').
"""
if self.checkpoint_type == "training":
# For training checkpoints, extract just the ID number (e.g., "0001" from "weights/0001")
return self.checkpoint_id.split("/")[-1]
else:
# For sampler checkpoints, use the full checkpoint_id including prefix
return self.checkpoint_id

@classmethod
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
Expand Down
61 changes: 61 additions & 0 deletions tests/test_sampler_weights_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Tests for sampler_weights checkpoint loading functionality."""

import pytest
from tinker.types.checkpoint import ParsedCheckpointTinkerPath


class TestParsedCheckpointTinkerPath:
"""Test parsing of checkpoint tinker paths."""

def test_parse_weights_checkpoint_path(self):
"""Test parsing a weights checkpoint path."""
path = "tinker://run-id123/weights/0001"
parsed = ParsedCheckpointTinkerPath.from_tinker_path(path)

assert parsed.tinker_path == path
assert parsed.training_run_id == "run-id123"
assert parsed.checkpoint_type == "training"
assert parsed.checkpoint_id == "weights/0001"
assert parsed.api_checkpoint_id == "0001"

def test_parse_sampler_weights_checkpoint_path(self):
"""Test parsing a sampler_weights checkpoint path."""
path = "tinker://run-id456/sampler_weights/sampler-001"
parsed = ParsedCheckpointTinkerPath.from_tinker_path(path)

assert parsed.tinker_path == path
assert parsed.training_run_id == "run-id456"
assert parsed.checkpoint_type == "sampler"
assert parsed.checkpoint_id == "sampler_weights/sampler-001"
assert parsed.api_checkpoint_id == "sampler_weights/sampler-001"

def test_api_checkpoint_id_for_training(self):
"""Test api_checkpoint_id returns correct format for training checkpoints."""
path = "tinker://abc123/weights/0100"
parsed = ParsedCheckpointTinkerPath.from_tinker_path(path)

# For training, should return just the checkpoint number
assert parsed.api_checkpoint_id == "0100"

def test_api_checkpoint_id_for_sampler(self):
"""Test api_checkpoint_id returns correct format for sampler checkpoints."""
path = "tinker://xyz789/sampler_weights/final-model"
parsed = ParsedCheckpointTinkerPath.from_tinker_path(path)

# For sampler, should return the full prefixed ID
assert parsed.api_checkpoint_id == "sampler_weights/final-model"

def test_invalid_tinker_path_no_prefix(self):
"""Test that invalid paths without tinker:// prefix raise error."""
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001")

def test_invalid_tinker_path_wrong_parts(self):
"""Test that paths with wrong number of parts raise error."""
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights")

def test_invalid_checkpoint_type(self):
"""Test that invalid checkpoint type raises error."""
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001")