diff --git a/src/tinker/cli/commands/checkpoint.py b/src/tinker/cli/commands/checkpoint.py index 3fdf87b..4102de9 100644 --- a/src/tinker/cli/commands/checkpoint.py +++ b/src/tinker/cli/commands/checkpoint.py @@ -335,12 +335,12 @@ def _export_checkpoint_to_hub( create_pr: bool, exist_ok: bool, allow_patterns: list[str] | None, - ignore_patterns: list[str] | None, add_model_card: bool, + overwrite: bool, ) -> str: # Lazy imports to keep CLI startup fast try: - from huggingface_hub import HfApi, hf_hub_download + from huggingface_hub import HfApi except ImportError as exc: raise TinkerCliError( "huggingface_hub is required for this command.", @@ -349,7 +349,6 @@ def _export_checkpoint_to_hub( import json import os - import re import tempfile from pathlib import Path @@ -493,47 +492,77 @@ def _sanitize_repo_name(value: str) -> str: model_card.append("") readme_path.write_text("\n".join(model_card), encoding="utf-8") - api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok) + if api.repo_exists(repo_id=repo_id): + repo_was_created = False + if not exist_ok: + raise TinkerCliError(f"Repository {repo_id} already exists") + else: + api.create_repo(repo_id=repo_id, private=private, exist_ok=False) + repo_was_created = True - def _readme_tinker_path() -> str | None: - try: - readme_file = hf_hub_download( - repo_id=repo_id, - filename="README.md", - revision=revision, - token=None, - ) - except Exception: - return None + def _get_branch_names() -> set[str]: try: - text = Path(readme_file).read_text(encoding="utf-8", errors="ignore") + refs = api.list_repo_refs(repo_id=repo_id) + return {ref.name for ref in refs.branches} except Exception: - return None - match = re.search(r"tinker://[^\s`]+", text) - return match.group(0) if match else None + return set() - existing_tinker_path = _readme_tinker_path() - if existing_tinker_path and existing_tinker_path != tinker_path: + try: + branch_names = _get_branch_names() + target_revision = revision or "main" + uploaded_to_target = False + + if "main" not in branch_names: + if create_pr: + click.echo( + "Warning: --create-pr was requested, but this upload is creating a new repository. " + "Uploading content to 'main' without a PR.", + err=True, + ) + api.upload_folder( + folder_path=os.fspath(extract_dir), + repo_id=repo_id, + path_in_repo="", + revision="main", + commit_message=commit_message or "Upload checkpoint from Tinker", + create_pr=False, + allow_patterns=list(allow_patterns) if allow_patterns else None, + ) + branch_names = _get_branch_names() + + if target_revision == "main": + uploaded_to_target = True + else: + if target_revision not in branch_names: + api.create_branch(repo_id=repo_id, branch=target_revision, exist_ok=True) + uploaded_to_target = True + elif target_revision != "main" and target_revision not in branch_names: + api.create_branch(repo_id=repo_id, branch=target_revision, exist_ok=True) + except Exception as exc: raise TinkerCliError( - "Repo ID appears to contain a different Tinker checkpoint.", - f"Found {existing_tinker_path}, expected {tinker_path}.", - ) + f"Failed to prepare revision {revision or 'main'} in repo {repo_id}", + f"Error: {exc}", + ) from exc if allow_patterns is None: - ignore_patterns = list(ignore_patterns) if ignore_patterns else [] - if "checkpoint_complete" not in ignore_patterns: - ignore_patterns.append("checkpoint_complete") - - api.upload_folder( - folder_path=os.fspath(extract_dir), - repo_id=repo_id, - path_in_repo="", - revision=revision, - commit_message=commit_message, - create_pr=create_pr, - allow_patterns=list(allow_patterns) if allow_patterns else None, - ignore_patterns=list(ignore_patterns) if ignore_patterns else None, - ) + checkpoint_complete.unlink(missing_ok=True) + + if not uploaded_to_target: + if not repo_was_created and not overwrite and target_revision in branch_names: + raise TinkerCliError( + f"Branch '{target_revision}' already exists in repo {repo_id}", + "Use --overwrite to add a new commit to the existing branch, " + "or specify a different --revision", + ) + api.upload_folder( + folder_path=os.fspath(extract_dir), + repo_id=repo_id, + path_in_repo="", + revision=target_revision, + commit_message=commit_message, + create_pr=create_pr, + allow_patterns=list(allow_patterns) if allow_patterns else None, + ) return repo_id @@ -1003,17 +1032,16 @@ def download( multiple=True, help="Only upload files matching this pattern (can be repeated).", ) -@click.option( - "--ignore-pattern", - "ignore_patterns", - multiple=True, - help="Skip files matching this pattern (can be repeated).", -) @click.option( "--no-model-card", is_flag=True, help="Do not create a README.md model card if one is missing.", ) +@click.option( + "--overwrite", + is_flag=True, + help="Allow uploading a new commit to an existing branch (default: False).", +) @click.pass_obj @handle_api_errors def push_hf( @@ -1025,12 +1053,14 @@ def push_hf( commit_message: str | None, create_pr: bool, allow_patterns: tuple[str, ...], - ignore_patterns: tuple[str, ...], no_model_card: bool, + overwrite: bool, ) -> None: """Upload a checkpoint to the Hugging Face Hub as a PEFT adapter. CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/sampler_weights/0001). + If --overwrite is set and the target branch exists, this command uploads a new commit + to that branch (it does not replace branch history). """ # Validate it's a tinker path if not checkpoint_path.startswith("tinker://"): @@ -1050,8 +1080,8 @@ def push_hf( create_pr=create_pr, exist_ok=True, allow_patterns=list(allow_patterns) if allow_patterns else None, - ignore_patterns=list(ignore_patterns) if ignore_patterns else None, add_model_card=not no_model_card, + overwrite=overwrite, ) output_obj = CheckpointHubUploadOutput(