Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/tinker/lib/public_interfaces/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFutu

parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
Expand All @@ -418,7 +418,7 @@ async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> No

parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def get_telemetry(self) -> Telemetry | None:
Expand All @@ -439,7 +439,7 @@ def get_checkpoint_archive_url_from_tinker_path(
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
Expand All @@ -449,7 +449,7 @@ async def get_checkpoint_archive_url_from_tinker_path_async(
"""Async version of get_checkpoint_archive_url_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return await self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _publish_checkpoint_submit(
Expand Down Expand Up @@ -498,15 +498,15 @@ def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFut
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of publish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _unpublish_checkpoint_submit(
Expand Down Expand Up @@ -555,15 +555,15 @@ def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentF
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
).future()

@capture_exceptions(fatal=True)
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of unpublish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id
)

def _set_checkpoint_ttl_submit(
Expand Down Expand Up @@ -615,7 +615,7 @@ def set_checkpoint_ttl_from_tinker_path(
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._set_checkpoint_ttl_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds
).future()

@capture_exceptions(fatal=True)
Expand All @@ -625,7 +625,7 @@ async def set_checkpoint_ttl_from_tinker_path_async(
"""Async version of set_checkpoint_ttl_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._set_checkpoint_ttl_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds
parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds
)

def _list_user_checkpoints_submit(
Expand Down
13 changes: 11 additions & 2 deletions src/tinker/types/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ class ParsedCheckpointTinkerPath(BaseModel):
"""The type of checkpoint (training or sampler)"""

checkpoint_id: str
"""The checkpoint ID"""
"""The checkpoint ID (without the type prefix)"""

@property
def api_checkpoint_id(self) -> str:
"""Returns the checkpoint ID as used in API calls (includes type prefix for sampler)."""
if self.checkpoint_type == "sampler":
return f"sampler_weights/{self.checkpoint_id}"
return self.checkpoint_id

@classmethod
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
Expand All @@ -55,9 +62,11 @@ def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
if parts[1] not in ["weights", "sampler_weights"]:
raise ValueError(f"Invalid tinker path: {tinker_path}")
checkpoint_type = "training" if parts[1] == "weights" else "sampler"
# checkpoint_id should be just the ID (e.g., "0001"), not including the type prefix
checkpoint_id = parts[2]
return cls(
tinker_path=tinker_path,
training_run_id=parts[0],
checkpoint_type=checkpoint_type,
checkpoint_id="/".join(parts[1:]),
checkpoint_id=checkpoint_id,
)
59 changes: 59 additions & 0 deletions tests/test_sampler_weights_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests for sampler_weights checkpoint loading (Issue #25)."""

import pytest

from tinker.types.checkpoint import ParsedCheckpointTinkerPath


class TestParsedCheckpointTinkerPath:
"""Tests for ParsedCheckpointTinkerPath to ensure sampler_weights loading works correctly."""

def test_parse_weights_checkpoint(self) -> None:
"""Test parsing a weights checkpoint path."""
parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/weights/0001")
assert parsed.training_run_id == "run-123"
assert parsed.checkpoint_type == "training"
assert parsed.checkpoint_id == "0001"
assert parsed.api_checkpoint_id == "0001"

def test_parse_sampler_weights_checkpoint(self) -> None:
"""Test parsing a sampler_weights checkpoint path (Issue #25)."""
parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/sampler_weights/0001")
assert parsed.training_run_id == "run-123"
assert parsed.checkpoint_type == "sampler"
assert parsed.checkpoint_id == "0001"
# This is the key fix: api_checkpoint_id should include sampler_weights prefix
assert parsed.api_checkpoint_id == "sampler_weights/0001"

def test_api_checkpoint_id_for_weights(self) -> None:
"""Test that weights checkpoints don't add prefix to checkpoint_id."""
parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://model-789/weights/final")
assert parsed.api_checkpoint_id == "final"

def test_api_checkpoint_id_for_sampler_weights(self) -> None:
"""Test that sampler_weights checkpoints include sampler_weights prefix in api_checkpoint_id."""
parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://model-789/sampler_weights/ckpt-001")
assert parsed.api_checkpoint_id == "sampler_weights/ckpt-001"

def test_invalid_tinker_path_no_prefix(self) -> None:
"""Test that paths without tinker:// prefix are rejected."""
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("run-123/weights/0001")

def test_invalid_tinker_path_wrong_type(self) -> None:
"""Test that invalid checkpoint types are rejected."""
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/invalid/0001")

def test_invalid_tinker_path_wrong_parts(self) -> None:
"""Test that paths with wrong number of parts are rejected."""
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/weights")

def test_sampler_weights_with_custom_name(self) -> None:
"""Test sampler_weights with custom checkpoint name."""
parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-456/sampler_weights/my-sampler-v2")
assert parsed.training_run_id == "run-456"
assert parsed.checkpoint_type == "sampler"
assert parsed.checkpoint_id == "my-sampler-v2"
assert parsed.api_checkpoint_id == "sampler_weights/my-sampler-v2"