Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions src/tinker/cli/commands/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ def _export_checkpoint_to_hub(
create_pr: bool,
exist_ok: bool,
allow_patterns: list[str] | None,
ignore_patterns: list[str] | None,
add_model_card: bool,
overwrite: bool,
) -> str:
# Lazy imports to keep CLI startup fast
try:
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub import HfApi
except ImportError as exc:
raise TinkerCliError(
"huggingface_hub is required for this command.",
Expand All @@ -349,7 +349,6 @@ def _export_checkpoint_to_hub(

import json
import os
import re
import tempfile
from pathlib import Path

Expand Down Expand Up @@ -493,47 +492,77 @@ def _sanitize_repo_name(value: str) -> str:
model_card.append("")
readme_path.write_text("\n".join(model_card), encoding="utf-8")

api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
if api.repo_exists(repo_id=repo_id):
repo_was_created = False
if not exist_ok:
raise TinkerCliError(f"Repository {repo_id} already exists")
else:
api.create_repo(repo_id=repo_id, private=private, exist_ok=False)
repo_was_created = True

def _readme_tinker_path() -> str | None:
try:
readme_file = hf_hub_download(
repo_id=repo_id,
filename="README.md",
revision=revision,
token=None,
)
except Exception:
return None
def _get_branch_names() -> set[str]:
try:
text = Path(readme_file).read_text(encoding="utf-8", errors="ignore")
refs = api.list_repo_refs(repo_id=repo_id)
return {ref.name for ref in refs.branches}
except Exception:
return None
match = re.search(r"tinker://[^\s`]+", text)
return match.group(0) if match else None
return set()

existing_tinker_path = _readme_tinker_path()
if existing_tinker_path and existing_tinker_path != tinker_path:
try:
branch_names = _get_branch_names()
target_revision = revision or "main"
uploaded_to_target = False

if "main" not in branch_names:
if create_pr:
click.echo(
"Warning: --create-pr was requested, but this upload is creating a new repository. "
"Uploading content to 'main' without a PR.",
err=True,
)
api.upload_folder(
folder_path=os.fspath(extract_dir),
repo_id=repo_id,
path_in_repo="",
revision="main",
commit_message=commit_message or "Upload checkpoint from Tinker",
create_pr=False,
allow_patterns=list(allow_patterns) if allow_patterns else None,
)
branch_names = _get_branch_names()

if target_revision == "main":
uploaded_to_target = True
else:
if target_revision not in branch_names:
api.create_branch(repo_id=repo_id, branch=target_revision, exist_ok=True)
uploaded_to_target = True
elif target_revision != "main" and target_revision not in branch_names:
api.create_branch(repo_id=repo_id, branch=target_revision, exist_ok=True)
except Exception as exc:
raise TinkerCliError(
"Repo ID appears to contain a different Tinker checkpoint.",
f"Found {existing_tinker_path}, expected {tinker_path}.",
)
f"Failed to prepare revision {revision or 'main'} in repo {repo_id}",
f"Error: {exc}",
) from exc

if allow_patterns is None:
ignore_patterns = list(ignore_patterns) if ignore_patterns else []
if "checkpoint_complete" not in ignore_patterns:
ignore_patterns.append("checkpoint_complete")

api.upload_folder(
folder_path=os.fspath(extract_dir),
repo_id=repo_id,
path_in_repo="",
revision=revision,
commit_message=commit_message,
create_pr=create_pr,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
)
checkpoint_complete.unlink(missing_ok=True)

if not uploaded_to_target:
if not repo_was_created and not overwrite and target_revision in branch_names:
raise TinkerCliError(
f"Branch '{target_revision}' already exists in repo {repo_id}",
"Use --overwrite to add a new commit to the existing branch, "
"or specify a different --revision",
)
api.upload_folder(
folder_path=os.fspath(extract_dir),
repo_id=repo_id,
path_in_repo="",
revision=target_revision,
commit_message=commit_message,
create_pr=create_pr,
allow_patterns=list(allow_patterns) if allow_patterns else None,
)

return repo_id

Expand Down Expand Up @@ -1003,17 +1032,16 @@ def download(
multiple=True,
help="Only upload files matching this pattern (can be repeated).",
)
@click.option(
"--ignore-pattern",
"ignore_patterns",
multiple=True,
help="Skip files matching this pattern (can be repeated).",
)
@click.option(
"--no-model-card",
is_flag=True,
help="Do not create a README.md model card if one is missing.",
)
@click.option(
"--overwrite",
is_flag=True,
help="Allow uploading a new commit to an existing branch (default: False).",
)
@click.pass_obj
@handle_api_errors
def push_hf(
Expand All @@ -1025,12 +1053,14 @@ def push_hf(
commit_message: str | None,
create_pr: bool,
allow_patterns: tuple[str, ...],
ignore_patterns: tuple[str, ...],
no_model_card: bool,
overwrite: bool,
) -> None:
"""Upload a checkpoint to the Hugging Face Hub as a PEFT adapter.

CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/sampler_weights/0001).
If --overwrite is set and the target branch exists, this command uploads a new commit
to that branch (it does not replace branch history).
"""
# Validate it's a tinker path
if not checkpoint_path.startswith("tinker://"):
Expand All @@ -1050,8 +1080,8 @@ def push_hf(
create_pr=create_pr,
exist_ok=True,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
add_model_card=not no_model_card,
overwrite=overwrite,
)

output_obj = CheckpointHubUploadOutput(
Expand Down