Skip to content

feat: safetensors checkpoint support for Torch backend#227

Merged
nikopueringer merged 7 commits intonikopueringer:mainfrom
alexandremendoncaalvaro:feat/safetensors-checkpoint
Apr 17, 2026
Merged

feat: safetensors checkpoint support for Torch backend#227
nikopueringer merged 7 commits intonikopueringer:mainfrom
alexandremendoncaalvaro:feat/safetensors-checkpoint

Conversation

@alexandremendoncaalvaro
Copy link
Copy Markdown
Contributor

Why

CorridorKey.pth loads via pickle, which executes arbitrary code — a real
concern for a model shipped to a wide user base. .safetensors is the safe
standard (no pickle, memory-mapped, zero-copy). The MLX side already uses it
(MLX_EXT = ".safetensors"); this PR brings the Torch side in line while
keeping every existing .pth install working untouched.

Design: prefer safetensors, fall back to .pth

No flag, no env var. The file that's there wins:

  1. Discovery in CorridorKeyModule/checkpoints/: .safetensors
    .pth (legacy) → auto-download.
  2. Auto-download: request .safetensors from HF first; only on
    EntryNotFoundError (HF-specific "file absent from repo", not a network
    error) fall back to downloading .pth. Network / disk / auth failures
    propagate as today.
  3. _load_model dispatches by extension to safetensors.torch.load_file
    or torch.load. The _orig_mod. strip and pos_embed reinterpolation
    logic is format-blind and unchanged.

Back-compat contract

  • Cached CorridorKey.pth → still loads, no action required.
  • New install → auto-download fetches .safetensors.
  • Both present → .safetensors wins with an INFO log.

Every fallback branch is tagged # DEPRECATED: remove after .pth sunset so
the follow-up cleanup PR is a trivial grep.

Merge ordering

This PR is safe to land before the .safetensors upload to
nikopueringer/CorridorKey_v1.0: until the upload lands, first-run users
fall through to .pth via the EntryNotFoundError path. After the upload,
fresh installs pick up .safetensors automatically. Existing users are
unaffected either way.


Validation

Automated

  • uv run ruff check → clean.
  • uv run pytest -m "not gpu"326 passed. 9 pre-existing failures
    (OPENCV_IO_ENABLE_OPENEXR disabled in opencv-python on Windows, see
    opencv#21326) — confirmed
    identical on upstream/main, unrelated to this PR.
  • Touched test files in isolation: 106 passed.

New test coverage

  • Discovery preference (both present, only-one present, empty dir).
  • HF download fallback on EntryNotFoundError downloads .pth.
  • HF download non-fallback on any other exception (network / auth /
    disk) surfaces a RuntimeError — a regression guard against a future
    except widening that would silently downgrade users to the less-safe
    format.
  • _load_model extension dispatch: .safetensors routes to
    safetensors.torch.load_file, .pth routes to torch.load. Each test
    mocks the other loader with an AssertionError side-effect to prove the
    wrong path isn't taken.
  • Converter script (scripts/convert_pth_to_safetensors.py): wrapper
    unwrap, _orig_mod. prefix strip, non-tensor metadata skip, contiguity
    enforcement, round-trip detection of missing keys and shape mismatches,
    end-to-end .pth.safetensors preserving tensor values bit-identically,
    CLI error paths.
  • PBT auto-download properties updated — .safetensors is no longer
    treated as "junk" that should not trigger download.

Manual end-to-end (Windows, NVIDIA CUDA)

Converted the real 400 MB HF .pth via the new script (367 tensors,
round-trip verified) and ran three scenarios against a real 2048×2048 clip
through clip_manager.py --action run_inference --backend torch:

Scenario Checkpoint loaded Result
Both files present .safetensors Inference succeeds
Only .safetensors .safetensors Outputs byte-identical (SHA-256) to previous run
Only .pth (legacy) .pth Outputs byte-identical (SHA-256) to the safetensors run

Separately, a deterministic-seed smoke script compared engine outputs
directly: max|diff| = 0.000e+00 on both alpha and fg, confirming
bit-for-bit equivalence at the tensor level.

Follow-up (out of scope)

Sunset .pth: delete from the HF repo and remove every branch tagged
# DEPRECATED: remove after .pth sunset. Trivial once the deprecation
window closes.

alexandremendoncaalvaro and others added 7 commits April 14, 2026 21:42
Adds the safetensors library as a core dependency to support loading
.safetensors model checkpoints. Used by the Torch backend's new loader
path and by the scripts/convert_pth_to_safetensors.py maintainer tool.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Teaches the Torch backend to prefer .safetensors over the legacy .pth:

- _discover_checkpoint() scans for .safetensors first, then .pth.
  If both exist, .safetensors wins silently and is logged at INFO.
- _ensure_torch_checkpoint() tries CorridorKey.safetensors from HF
  first and falls back to CorridorKey.pth on EntryNotFoundError, so
  the code can ship before the HF upload lands without breaking
  first-run users.
- Existing .pth-only installs continue to work unchanged — no forced
  re-download, no migration required.

Fallback branches are tagged `# DEPRECATED: remove after .pth sunset`
for a clean grep when the .pth path is eventually removed.

Tests cover: safetensors preferred over .pth, .pth-only still works,
safetensors-only works, and HF EntryNotFoundError triggers .pth
fallback. The PBT strategy for "non-Torch extensions" drops
.safetensors since it is now a valid Torch checkpoint.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Branches on the checkpoint extension when building the state dict:
.safetensors goes through safetensors.torch.load_file, .pth stays on
torch.load(weights_only=True). The downstream state-dict processing
(_orig_mod. prefix strip, pos_embed bicubic reinterpolation,
load_state_dict(strict=False)) is format-blind and unchanged.

Tests verify the dispatch: a safetensors path doesn't invoke
torch.load, and a .pth path doesn't invoke safetensors.load_file.

The .pth branch is tagged for removal after the .pth sunset.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
A maintainer utility for (re)publishing the official CorridorKey
safetensors file and for contributors converting custom fine-tunes.
Does not change tensor values — only re-serializes. Matches the
engine's state-dict unwrapping (strips "state_dict" wrapper and
_orig_mod. prefix left by torch.compile), enforces contiguous
tensors (required by safetensors), and verifies the round-trip by
reloading and diffing key sets and shapes against the source.

Usage:
    uv run python scripts/convert_pth_to_safetensors.py \
        --input  CorridorKeyModule/checkpoints/CorridorKey.pth \
        --output CorridorKeyModule/checkpoints/CorridorKey.safetensors

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Both installers now pull CorridorKey.safetensors from HuggingFace
as the primary target and transparently fall back to CorridorKey.pth
if the safetensors file is not yet present in the repo. This makes
the change safe to merge before the HF upload completes — new users
get whatever is currently published, and the engine's discovery
logic handles either file.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Reflects the new checkpoint preference order across user-facing
docs, the module-level README, the LLM handover guide, and the
contributor guide. Each mention notes that legacy .pth continues
to work unchanged. README also points maintainers at the new
scripts/convert_pth_to_safetensors.py utility.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Two gaps identified during coverage audit:

1. `_ensure_torch_checkpoint` distinguishes EntryNotFoundError (file absent
   from the HF repo → fall back to legacy .pth) from any other exception
   (network/auth/disk → propagate as actionable RuntimeError). The fallback
   branch was tested; the "do NOT fall back on a generic network error"
   invariant was not. Without this guard, a future widening of the `except`
   clause could silently downgrade users to the less-safe .pth format
   whenever the network hiccups.

2. `scripts/convert_pth_to_safetensors.py` had zero tests. It is a
   maintainer-facing tool run once per released model whose output is
   published to HuggingFace for every downstream user; silent bugs
   (dropped tensors, missed _orig_mod. prefix, altered values) would
   corrupt the published file. New tests cover:

   - `_extract_state_dict`: wrapper unwrap, prefix strip, non-tensor skip,
     contiguity enforcement
   - `_verify_round_trip`: happy path, missing key, shape mismatch
   - End-to-end: a wrapped checkpoint with _orig_mod. prefixes and
     optimizer metadata round-trips to a clean .safetensors with
     bit-identical tensor values
   - CLI error paths: missing `--input`, wrong `--output` suffix
@nikopueringer nikopueringer merged commit 3f398d3 into nikopueringer:main Apr 17, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants