diff --git a/src/tinker/lib/public_interfaces/rest_client.py b/src/tinker/lib/public_interfaces/rest_client.py index f8e0dda..97704e5 100644 --- a/src/tinker/lib/public_interfaces/rest_client.py +++ b/src/tinker/lib/public_interfaces/rest_client.py @@ -409,7 +409,7 @@ def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFutu parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._delete_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -418,7 +418,7 @@ async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> No parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._delete_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def get_telemetry(self) -> Telemetry | None: @@ -439,7 +439,7 @@ def get_checkpoint_archive_url_from_tinker_path( """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._get_checkpoint_archive_url_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -449,7 +449,7 @@ async def get_checkpoint_archive_url_from_tinker_path_async( """Async version of get_checkpoint_archive_url_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return await self._get_checkpoint_archive_url_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _publish_checkpoint_submit( @@ -498,7 +498,7 @@ def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFut """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._publish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -506,7 +506,7 @@ async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> N """Async version of publish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._publish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _unpublish_checkpoint_submit( @@ -555,7 +555,7 @@ def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentF """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._unpublish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -563,7 +563,7 @@ async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> """Async version of unpublish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._unpublish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _set_checkpoint_ttl_submit( @@ -615,7 +615,7 @@ def set_checkpoint_ttl_from_tinker_path( """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._set_checkpoint_ttl_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds ).future() @capture_exceptions(fatal=True) @@ -625,7 +625,7 @@ async def set_checkpoint_ttl_from_tinker_path_async( """Async version of set_checkpoint_ttl_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._set_checkpoint_ttl_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds ) def _list_user_checkpoints_submit( diff --git a/src/tinker/types/checkpoint.py b/src/tinker/types/checkpoint.py index b6a5630..b717dd5 100644 --- a/src/tinker/types/checkpoint.py +++ b/src/tinker/types/checkpoint.py @@ -42,7 +42,14 @@ class ParsedCheckpointTinkerPath(BaseModel): """The type of checkpoint (training or sampler)""" checkpoint_id: str - """The checkpoint ID""" + """The checkpoint ID (without the type prefix)""" + + @property + def api_checkpoint_id(self) -> str: + """Returns the checkpoint ID as used in API calls (includes type prefix for sampler).""" + if self.checkpoint_type == "sampler": + return f"sampler_weights/{self.checkpoint_id}" + return self.checkpoint_id @classmethod def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath": @@ -55,9 +62,11 @@ def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath": if parts[1] not in ["weights", "sampler_weights"]: raise ValueError(f"Invalid tinker path: {tinker_path}") checkpoint_type = "training" if parts[1] == "weights" else "sampler" + # checkpoint_id should be just the ID (e.g., "0001"), not including the type prefix + checkpoint_id = parts[2] return cls( tinker_path=tinker_path, training_run_id=parts[0], checkpoint_type=checkpoint_type, - checkpoint_id="/".join(parts[1:]), + checkpoint_id=checkpoint_id, ) diff --git a/tests/test_sampler_weights_loading.py b/tests/test_sampler_weights_loading.py new file mode 100644 index 0000000..34518fd --- /dev/null +++ b/tests/test_sampler_weights_loading.py @@ -0,0 +1,59 @@ +"""Tests for sampler_weights checkpoint loading (Issue #25).""" + +import pytest + +from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + +class TestParsedCheckpointTinkerPath: + """Tests for ParsedCheckpointTinkerPath to ensure sampler_weights loading works correctly.""" + + def test_parse_weights_checkpoint(self) -> None: + """Test parsing a weights checkpoint path.""" + parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/weights/0001") + assert parsed.training_run_id == "run-123" + assert parsed.checkpoint_type == "training" + assert parsed.checkpoint_id == "0001" + assert parsed.api_checkpoint_id == "0001" + + def test_parse_sampler_weights_checkpoint(self) -> None: + """Test parsing a sampler_weights checkpoint path (Issue #25).""" + parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/sampler_weights/0001") + assert parsed.training_run_id == "run-123" + assert parsed.checkpoint_type == "sampler" + assert parsed.checkpoint_id == "0001" + # This is the key fix: api_checkpoint_id should include sampler_weights prefix + assert parsed.api_checkpoint_id == "sampler_weights/0001" + + def test_api_checkpoint_id_for_weights(self) -> None: + """Test that weights checkpoints don't add prefix to checkpoint_id.""" + parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://model-789/weights/final") + assert parsed.api_checkpoint_id == "final" + + def test_api_checkpoint_id_for_sampler_weights(self) -> None: + """Test that sampler_weights checkpoints include sampler_weights prefix in api_checkpoint_id.""" + parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://model-789/sampler_weights/ckpt-001") + assert parsed.api_checkpoint_id == "sampler_weights/ckpt-001" + + def test_invalid_tinker_path_no_prefix(self) -> None: + """Test that paths without tinker:// prefix are rejected.""" + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("run-123/weights/0001") + + def test_invalid_tinker_path_wrong_type(self) -> None: + """Test that invalid checkpoint types are rejected.""" + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/invalid/0001") + + def test_invalid_tinker_path_wrong_parts(self) -> None: + """Test that paths with wrong number of parts are rejected.""" + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-123/weights") + + def test_sampler_weights_with_custom_name(self) -> None: + """Test sampler_weights with custom checkpoint name.""" + parsed = ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-456/sampler_weights/my-sampler-v2") + assert parsed.training_run_id == "run-456" + assert parsed.checkpoint_type == "sampler" + assert parsed.checkpoint_id == "my-sampler-v2" + assert parsed.api_checkpoint_id == "sampler_weights/my-sampler-v2" \ No newline at end of file