diff --git a/src/tinker/lib/public_interfaces/rest_client.py b/src/tinker/lib/public_interfaces/rest_client.py index f8e0dda..97704e5 100644 --- a/src/tinker/lib/public_interfaces/rest_client.py +++ b/src/tinker/lib/public_interfaces/rest_client.py @@ -409,7 +409,7 @@ def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFutu parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._delete_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -418,7 +418,7 @@ async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> No parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._delete_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def get_telemetry(self) -> Telemetry | None: @@ -439,7 +439,7 @@ def get_checkpoint_archive_url_from_tinker_path( """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._get_checkpoint_archive_url_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -449,7 +449,7 @@ async def get_checkpoint_archive_url_from_tinker_path_async( """Async version of get_checkpoint_archive_url_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return await self._get_checkpoint_archive_url_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _publish_checkpoint_submit( @@ -498,7 +498,7 @@ def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFut """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._publish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -506,7 +506,7 @@ async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> N """Async version of publish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._publish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _unpublish_checkpoint_submit( @@ -555,7 +555,7 @@ def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentF """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._unpublish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ).future() @capture_exceptions(fatal=True) @@ -563,7 +563,7 @@ async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> """Async version of unpublish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._unpublish_checkpoint_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id ) def _set_checkpoint_ttl_submit( @@ -615,7 +615,7 @@ def set_checkpoint_ttl_from_tinker_path( """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._set_checkpoint_ttl_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds ).future() @capture_exceptions(fatal=True) @@ -625,7 +625,7 @@ async def set_checkpoint_ttl_from_tinker_path_async( """Async version of set_checkpoint_ttl_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._set_checkpoint_ttl_submit( - parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id, ttl_seconds + parsed_tinker_path.training_run_id, parsed_tinker_path.api_checkpoint_id, ttl_seconds ) def _list_user_checkpoints_submit( diff --git a/src/tinker/types/checkpoint.py b/src/tinker/types/checkpoint.py index b6a5630..0c91970 100644 --- a/src/tinker/types/checkpoint.py +++ b/src/tinker/types/checkpoint.py @@ -39,10 +39,24 @@ class ParsedCheckpointTinkerPath(BaseModel): """The training run ID""" checkpoint_type: CheckpointType - """The type of checkpoint (training or sampler)""" + """"The type of checkpoint (training or sampler)""" checkpoint_id: str - """The checkpoint ID""" + """The checkpoint ID (includes type prefix for sampler checkpoints)""" + + @property + def api_checkpoint_id(self) -> str: + """Checkpoint ID formatted for API calls. + + For training checkpoints: returns just the checkpoint number (e.g., '0001'). + For sampler checkpoints: returns the prefixed ID (e.g., 'sampler_weights/0001'). + """ + if self.checkpoint_type == "training": + # For training checkpoints, extract just the ID number (e.g., "0001" from "weights/0001") + return self.checkpoint_id.split("/")[-1] + else: + # For sampler checkpoints, use the full checkpoint_id including prefix + return self.checkpoint_id @classmethod def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath": diff --git a/tests/test_sampler_weights_loading.py b/tests/test_sampler_weights_loading.py new file mode 100644 index 0000000..16fc173 --- /dev/null +++ b/tests/test_sampler_weights_loading.py @@ -0,0 +1,61 @@ +"""Tests for sampler_weights checkpoint loading functionality.""" + +import pytest +from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + +class TestParsedCheckpointTinkerPath: + """Test parsing of checkpoint tinker paths.""" + + def test_parse_weights_checkpoint_path(self): + """Test parsing a weights checkpoint path.""" + path = "tinker://run-id123/weights/0001" + parsed = ParsedCheckpointTinkerPath.from_tinker_path(path) + + assert parsed.tinker_path == path + assert parsed.training_run_id == "run-id123" + assert parsed.checkpoint_type == "training" + assert parsed.checkpoint_id == "weights/0001" + assert parsed.api_checkpoint_id == "0001" + + def test_parse_sampler_weights_checkpoint_path(self): + """Test parsing a sampler_weights checkpoint path.""" + path = "tinker://run-id456/sampler_weights/sampler-001" + parsed = ParsedCheckpointTinkerPath.from_tinker_path(path) + + assert parsed.tinker_path == path + assert parsed.training_run_id == "run-id456" + assert parsed.checkpoint_type == "sampler" + assert parsed.checkpoint_id == "sampler_weights/sampler-001" + assert parsed.api_checkpoint_id == "sampler_weights/sampler-001" + + def test_api_checkpoint_id_for_training(self): + """Test api_checkpoint_id returns correct format for training checkpoints.""" + path = "tinker://abc123/weights/0100" + parsed = ParsedCheckpointTinkerPath.from_tinker_path(path) + + # For training, should return just the checkpoint number + assert parsed.api_checkpoint_id == "0100" + + def test_api_checkpoint_id_for_sampler(self): + """Test api_checkpoint_id returns correct format for sampler checkpoints.""" + path = "tinker://xyz789/sampler_weights/final-model" + parsed = ParsedCheckpointTinkerPath.from_tinker_path(path) + + # For sampler, should return the full prefixed ID + assert parsed.api_checkpoint_id == "sampler_weights/final-model" + + def test_invalid_tinker_path_no_prefix(self): + """Test that invalid paths without tinker:// prefix raise error.""" + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001") + + def test_invalid_tinker_path_wrong_parts(self): + """Test that paths with wrong number of parts raise error.""" + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights") + + def test_invalid_checkpoint_type(self): + """Test that invalid checkpoint type raises error.""" + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001") \ No newline at end of file