feat: safetensors checkpoint support for Torch backend#227
Merged
nikopueringer merged 7 commits intonikopueringer:mainfrom Apr 17, 2026
Merged
Conversation
Adds the safetensors library as a core dependency to support loading .safetensors model checkpoints. Used by the Torch backend's new loader path and by the scripts/convert_pth_to_safetensors.py maintainer tool. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Teaches the Torch backend to prefer .safetensors over the legacy .pth: - _discover_checkpoint() scans for .safetensors first, then .pth. If both exist, .safetensors wins silently and is logged at INFO. - _ensure_torch_checkpoint() tries CorridorKey.safetensors from HF first and falls back to CorridorKey.pth on EntryNotFoundError, so the code can ship before the HF upload lands without breaking first-run users. - Existing .pth-only installs continue to work unchanged — no forced re-download, no migration required. Fallback branches are tagged `# DEPRECATED: remove after .pth sunset` for a clean grep when the .pth path is eventually removed. Tests cover: safetensors preferred over .pth, .pth-only still works, safetensors-only works, and HF EntryNotFoundError triggers .pth fallback. The PBT strategy for "non-Torch extensions" drops .safetensors since it is now a valid Torch checkpoint. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Branches on the checkpoint extension when building the state dict: .safetensors goes through safetensors.torch.load_file, .pth stays on torch.load(weights_only=True). The downstream state-dict processing (_orig_mod. prefix strip, pos_embed bicubic reinterpolation, load_state_dict(strict=False)) is format-blind and unchanged. Tests verify the dispatch: a safetensors path doesn't invoke torch.load, and a .pth path doesn't invoke safetensors.load_file. The .pth branch is tagged for removal after the .pth sunset. Co-Authored-By: Claude Opus 4.6 <[email protected]>
A maintainer utility for (re)publishing the official CorridorKey
safetensors file and for contributors converting custom fine-tunes.
Does not change tensor values — only re-serializes. Matches the
engine's state-dict unwrapping (strips "state_dict" wrapper and
_orig_mod. prefix left by torch.compile), enforces contiguous
tensors (required by safetensors), and verifies the round-trip by
reloading and diffing key sets and shapes against the source.
Usage:
uv run python scripts/convert_pth_to_safetensors.py \
--input CorridorKeyModule/checkpoints/CorridorKey.pth \
--output CorridorKeyModule/checkpoints/CorridorKey.safetensors
Co-Authored-By: Claude Opus 4.6 <[email protected]>
Both installers now pull CorridorKey.safetensors from HuggingFace as the primary target and transparently fall back to CorridorKey.pth if the safetensors file is not yet present in the repo. This makes the change safe to merge before the HF upload completes — new users get whatever is currently published, and the engine's discovery logic handles either file. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Reflects the new checkpoint preference order across user-facing docs, the module-level README, the LLM handover guide, and the contributor guide. Each mention notes that legacy .pth continues to work unchanged. README also points maintainers at the new scripts/convert_pth_to_safetensors.py utility. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Two gaps identified during coverage audit:
1. `_ensure_torch_checkpoint` distinguishes EntryNotFoundError (file absent
from the HF repo → fall back to legacy .pth) from any other exception
(network/auth/disk → propagate as actionable RuntimeError). The fallback
branch was tested; the "do NOT fall back on a generic network error"
invariant was not. Without this guard, a future widening of the `except`
clause could silently downgrade users to the less-safe .pth format
whenever the network hiccups.
2. `scripts/convert_pth_to_safetensors.py` had zero tests. It is a
maintainer-facing tool run once per released model whose output is
published to HuggingFace for every downstream user; silent bugs
(dropped tensors, missed _orig_mod. prefix, altered values) would
corrupt the published file. New tests cover:
- `_extract_state_dict`: wrapper unwrap, prefix strip, non-tensor skip,
contiguity enforcement
- `_verify_round_trip`: happy path, missing key, shape mismatch
- End-to-end: a wrapped checkpoint with _orig_mod. prefixes and
optimizer metadata round-trips to a clean .safetensors with
bit-identical tensor values
- CLI error paths: missing `--input`, wrong `--output` suffix
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Why
CorridorKey.pthloads viapickle, which executes arbitrary code — a realconcern for a model shipped to a wide user base.
.safetensorsis the safestandard (no pickle, memory-mapped, zero-copy). The MLX side already uses it
(
MLX_EXT = ".safetensors"); this PR brings the Torch side in line whilekeeping every existing
.pthinstall working untouched.Design: prefer safetensors, fall back to .pth
No flag, no env var. The file that's there wins:
CorridorKeyModule/checkpoints/:.safetensors→.pth(legacy) → auto-download..safetensorsfrom HF first; only onEntryNotFoundError(HF-specific "file absent from repo", not a networkerror) fall back to downloading
.pth. Network / disk / auth failurespropagate as today.
_load_modeldispatches by extension tosafetensors.torch.load_fileor
torch.load. The_orig_mod.strip andpos_embedreinterpolationlogic is format-blind and unchanged.
Back-compat contract
CorridorKey.pth→ still loads, no action required..safetensors..safetensorswins with an INFO log.Every fallback branch is tagged
# DEPRECATED: remove after .pth sunsetsothe follow-up cleanup PR is a trivial grep.
Merge ordering
This PR is safe to land before the
.safetensorsupload tonikopueringer/CorridorKey_v1.0: until the upload lands, first-run usersfall through to
.pthvia theEntryNotFoundErrorpath. After the upload,fresh installs pick up
.safetensorsautomatically. Existing users areunaffected either way.
Validation
Automated
uv run ruff check→ clean.uv run pytest -m "not gpu"→ 326 passed. 9 pre-existing failures(
OPENCV_IO_ENABLE_OPENEXRdisabled inopencv-pythonon Windows, seeopencv#21326) — confirmed
identical on
upstream/main, unrelated to this PR.New test coverage
EntryNotFoundErrordownloads.pth.disk) surfaces a
RuntimeError— a regression guard against a futureexceptwidening that would silently downgrade users to the less-safeformat.
_load_modelextension dispatch:.safetensorsroutes tosafetensors.torch.load_file,.pthroutes totorch.load. Each testmocks the other loader with an
AssertionErrorside-effect to prove thewrong path isn't taken.
scripts/convert_pth_to_safetensors.py): wrapperunwrap,
_orig_mod.prefix strip, non-tensor metadata skip, contiguityenforcement, round-trip detection of missing keys and shape mismatches,
end-to-end
.pth→.safetensorspreserving tensor values bit-identically,CLI error paths.
.safetensorsis no longertreated as "junk" that should not trigger download.
Manual end-to-end (Windows, NVIDIA CUDA)
Converted the real 400 MB HF
.pthvia the new script (367 tensors,round-trip verified) and ran three scenarios against a real 2048×2048 clip
through
clip_manager.py --action run_inference --backend torch:.safetensors✅.safetensors.safetensors✅.pth(legacy).pth✅Separately, a deterministic-seed smoke script compared engine outputs
directly:
max|diff| = 0.000e+00on bothalphaandfg, confirmingbit-for-bit equivalence at the tensor level.
Follow-up (out of scope)
Sunset
.pth: delete from the HF repo and remove every branch tagged# DEPRECATED: remove after .pth sunset. Trivial once the deprecationwindow closes.