Skip to content

Integrate ESMFold inference backend into VizFold (structure export + metadata)#44

Open
jayvenn21 wants to merge 18 commits into
AI2Science:mainfrom
jayvenn21:feature/esmfold-backend
Open

Integrate ESMFold inference backend into VizFold (structure export + metadata)#44
jayvenn21 wants to merge 18 commits into
AI2Science:mainfrom
jayvenn21:feature/esmfold-backend

Conversation

@jayvenn21
Copy link
Copy Markdown

@jayvenn21 jayvenn21 commented Feb 21, 2026

Integrate ESMFold inference backend into VizFold (Fixes #43)

Context

This PR introduces the full ESMFold backend for VizFold — inference, hook-based trace extraction, an interactive visualization dashboard, and HPC deployment tooling. It uses the HuggingFace EsmForProteinFolding model and exports outputs in the VizFold archive format.

This addresses Issue #43 end-to-end: inference, trace extraction across all model components, structure module tracing, a FastAPI bridge, and a React frontend for interactive exploration.


What's Included

HuggingFace-based ESMFold integration

  • Uses facebook/esmfold_v1 via transformers (pinned >=4.36.0)
  • Avoids OpenFold CUDA build dependency for improved portability
  • ESMFoldRunner implements the BackendBase interface (load_model, run_inference, supports_attention, supports_activations)

GPU-enabled inference on ICE (SLURM-tested)

  • Successfully validated with --device cuda on PACE ICE H100 nodes
  • Handles large checkpoint download via configurable HF_HOME
  • Fail-fast sanity check in SLURM script if torch/transformers not importable

Structure export

  • structure/predicted.pdb (full atom37 PDB via HF openfold_utils)
  • structure/predicted.pt (coordinate tensor)
  • Fallback minimal CA-only PDB with valid 3-letter residue codes

Trace extraction via forward hooks

  • ESM-2 encoder attention: hooks on model.esm.encoder.layer[i].attention.self, monkey-patching forward to force output_attentions=True so real attention probabilities [B, H, N, N] are captured
  • ESM-2 activations: per-layer hidden states [B, N, D] with <cls> and <eos> special tokens stripped to align with attention maps
  • Attention slice validation against expected_seq_len with warnings on mismatch
  • Head filtering performed in-hook to reduce GPU memory usage
  • Layer and head subsetting via --layers and --heads

Evoformer trunk intermediates

  • Per-block sequence state [L, C_s] and pair state [L, L, C_z] via hooks on trunk.blocks[i]
  • Final trunk representations s_s and s_z captured from recycling iterations
  • Per-recycle s_s and s_z tensors archived as recycle_{i}_s_s / recycle_{i}_s_z

Structure module tracing

  • IPA (Invariant Point Attention) attention matrices [H, N, N] captured via CapturingSoftmax patching
  • Per-recycle backbone positions and single representations
  • Handles both dict and dataclass outputs from HuggingFace via _extract_from_output helper

Interactive visualization dashboard

  • React + Vite frontend with 3Dmol.js WebGL structure viewer and Plotly.js heatmaps
  • Trunk Evolution mode (s_z averaged across 128 hidden channels) and ESM-2 Attention mode (averaged across heads)
  • Bidirectional crosshair sync: click a residue in 3D → crosshair on heatmap; click a heatmap cell → highlight residue in 3D
  • Camera-preserving style switching (Cartoon/Sticks/Spheres) without viewer recreation
  • Responsive heatmap layout that fills available panel space
  • Loading and error states per panel (backend-down detection, missing trace data)
  • Configurable API URL via VITE_API_URL environment variable

FastAPI bridge server

  • Serves meta.json, PDB structure, and trace tensors over HTTP
  • Server-side averaging for s_z (across hidden channels) and attention (across heads) to minimize payload size
  • Path traversal protection via safe_join() + explicit ..///\\ checks
  • Configurable output directory via --dir CLI argument

Metadata + archive format

  • meta.json consistent with VizFold archive schema
  • trace/index.json mapping layers to files, shapes, dtypes
  • trace/summary.json with per-layer attention entropy and activation norms
  • trace_formats field indicating which output formats were produced (pt, txt)
  • Shapes recorded for attention, activations, trunk, and structure module outputs

CLI support

  • run_pretrained_esmf.py with: --fasta, --out, --device, --dtype, --trace_mode, --layers, --heads, --save_fp16, --seed, --deterministic, --structure_traces

Tests, docs, and HPC support

  • tests/test_esmf_smoke.py — CLI help, missing FASTA, schema, import, full smoke-run, and end-to-end validation tests
  • docs/esmfold.md — setup, usage examples (attention, activations, structure module traces), output layout, and meta.json schema
  • docs/hpc_ice.md — PACE ICE deployment guide
  • scripts/hpc/ice/run_esmf_ice.slurm — SLURM batch script with environment validation

Design Decision

This implementation uses the HuggingFace ESMFold model rather than directly building OpenFold, which avoids CUDA toolchain mismatches on HPC systems and simplifies reproducibility across environments.


Team Contributions (Integrated)

The following team PRs have been reviewed, merged, and integrated into this branch:

  • @rohan5986 (PR #3, PR #6, PR #8): Captured s_s and s_z at every recycling iteration via trunk hooks, enforced use_safetensors=True for HPC stability (CVE-2025-32434), built the complete interactive dashboard (3Dmol.js + Plotly heatmaps + FastAPI bridge), added path traversal security, env-based API URL config
  • @JeevanandanRamasamy (PR #1, PR #4, PR #7): Extracted Evoformer trunk per-block intermediates, added VizFold-compatible .txt attention export with --top_k support, fixed output_attentions positional arg bug, implemented BackendBase interface, added attention slice validation and in-hook head filtering, stripped <cls>/<eos> from activations, added structure module dict/dataclass handling
  • @Mose-Kim02 (PR #5): Validation test suite for ESMFold trace extraction pipeline, reproducibility documentation
  • @jayvenn21 (PR #9): Refactored frontend into StructureViewer/TraceExplorer/TimelineControls components, added camera persistence, responsive heatmap, bidirectional crosshair sync, colorbar labels, loading/error states, CSS cleanup

Validation Evidence

Validated on PACE ICE with GPU and locally on CPU. ESMFold inference is deterministic; repeated runs produce identical archives.

Example run (attention + activations):

HF_HOME=/tmp/$USER/hf_cache \
python run_pretrained_esmf.py \
  --fasta examples/monomer/fasta_dir_6KWC/6KWC.fasta \
  --out outputs/test_trace \
  --device cuda \
  --trace_mode attention+activations

Example run (structure module traces):

python run_pretrained_esmf.py \
  --fasta examples/monomer/fasta_dir_6KWC/6KWC.fasta \
  --out outputs/esmf_6KWC \
  --trace_mode attention+activations \
  --structure_traces \
  --save_fp16

Output Archive

outputs/demo_traces/
├── meta.json
├── logs.txt
├── structure/
│   ├── predicted.pdb
│   └── predicted.pt
├── trace/
│   ├── index.json
│   ├── summary.json
│   ├── attention/              (36 layers)
│   ├── activations/            (36 layers + recycle_*_s_s/s_z)
│   ├── trunk/                  (per-block seq/pair + final s_s/s_z)
│   │   ├── block_000_seq.pt
│   │   ├── block_000_pair.pt
│   │   ├── s_s.pt
│   │   └── s_z.pt
│   └── structure_module/       (with --structure_traces)
│       ├── ipa_attention/
│       │   └── recycle_00_block_00.pt
│       └── backbone/
│           ├── recycle_00_positions.pt
│           └── recycle_00_states.pt
├── attention_files/            (.txt format for VizFold viz tools)
│   └── msa_row_attn_layer0.txt
└── frontend/                   (React dashboard)
    ├── src/
    │   ├── App.jsx
    │   └── components/
    │       ├── StructureViewer.jsx
    │       ├── TraceExplorer.jsx
    │       └── TimelineControls.jsx
    └── server.py               (FastAPI bridge)

Architecture Diagram

ESMFold Diagram-updated

Demo Video

esmfolddemo.mp4

Screenshots

Screenshot 2026-03-24 at 7 18 28 PM Screenshot 2026-03-24 at 7 18 42 PM

What's Next

Most planned extensions from the original PR are complete. Remaining tasks:

  • Visualization validation — confirm existing VizFold notebooks and PyMOL scripts work with ESMFold traces
  • Cross-model comparison — tooling to compare OpenFold and ESMFold archives side-by-side
  • Boltz integration — merge Boltz-2 backend work into the main VizFold repo (coordinating with Boltz team)

…pipeline

- Rewrite hooks.py: target HF encoder.layer[i].attention.self directly
  instead of broad name-matching; monkey-patch forward to force
  output_attentions=True so real attention weights [B,H,N,N] are captured
  (not hidden states); separate attention/activation hooks with correct
  layer indices; slice out <cls>/<eos> tokens from attention maps
- Fix _coords_to_minimal_pdb: use 3-letter residue codes (valid PDB)
- Remove dead code: try_use_outputs() path, shared mutable counter
- Extract structure logic into _extract_structure() method
- Unify FASTA reading into single read_fasta() returning (seq, id, hash)
- Wire --dtype through CLI (float32/float16 model loading)
- Log runner.run() result (attention/activation layer counts)
- Fix trace_adapter: correct head-slicing axis for 3D vs 4D tensors;
  fix entropy calculation (per-row, not per-matrix)
Copy link
Copy Markdown

@JeevanandanRamasamy JeevanandanRamasamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a solid addition and the architecture is clean. Documentation is very thorough. I highlighted some potential minor issues in my review.

Comment thread docs/.DS_Store Outdated
Comment thread requirements-esmfold.txt Outdated
Comment thread environment.yml Outdated
Comment thread environment-mac.yml
Comment thread vizfold/backends/esmfold/inference.py Outdated
Comment thread vizfold/backends/esmfold/inference.py Outdated
Comment thread vizfold/backends/esmfold/trace_adapter.py
Comment thread vizfold/backends/esmfold/inference.py
jayvenn21 and others added 3 commits March 22, 2026 21:47
…re, trace relpaths, summary logging, layer_count
Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
* Add VizFold text-file attention export compatible with existing visualization tools

* Bug fix: override the positional arg in-place instead of adding to kwargs

* Fix: trace_formats missing from meta.json

* Robust attention saving & forward signature handling

hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index

trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]
@jayvenn21 jayvenn21 force-pushed the feature/esmfold-backend branch from 0a0a9d7 to 8351a6d Compare March 24, 2026 01:40
@jayvenn21 jayvenn21 force-pushed the feature/esmfold-backend branch from e07be3f to 494396b Compare March 24, 2026 01:44
Co-authored-by: Mose Kim <kimmose2002@gmail.com>
@jayvenn21 jayvenn21 force-pushed the feature/esmfold-backend branch from 494396b to 4b86182 Compare March 24, 2026 01:45
@jayvenn21 jayvenn21 changed the title Integrate ESMFold inference backend into VizFold (structure export + metadata) – Fixes #43 Integrate ESMFold inference backend into VizFold (structure export + metadata) Mar 24, 2026
jayvenn21 and others added 9 commits March 25, 2026 22:06
* Extract s_s folding trunk activations and enforce safetensors

* Update backend pipeline

* Capture s_s and s_z at every recycling iteration via trunk hook

* Remove test output artifacts

---------

Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
* Add VizFold text-file attention export compatible with existing visualization tools

* Bug fix: override the positional arg in-place instead of adding to kwargs

* Fix: trace_formats missing from meta.json

* Robust attention saving & forward signature handling

hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index

trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]

* Capture and save evoformer trunk intermediates

Add per-block evoformer tracing and output saving for ESMFold.

- hooks.py: introduce register_trunk_hooks and _make_trunk_block_hook to register forward hooks on model.trunk.blocks (EsmFoldTriangularSelfAttentionBlock). Captured per-block sequence_state and pairwise_state are stored in collector.trunk_blocks; clear() updated and warnings added when trunk/blocks are missing.
- inference.py: register the new trunk hooks in ESMFoldRunner, extract and save final folding trunk pair representations (out.s_z), and write per-block evoformer intermediates to trace/trunk/*.pt while recording shapes. Logging messages adjusted.
- trace_adapter.py: update trace layout to include trunk/ files (block_{idx}_seq/pair, s_s, s_z).

* ESMFold: save trunk tensors, CPU attention

Ensure attention tensors are moved to CPU in hooks (detach().cpu()) to avoid GPU tensor serialization. Stop extracting final trunk outputs from model.out and instead collect final s_s/s_z from collector.recycled_s_s/recycled_s_z (avoids redundant copies) and save per-block trunk tensors plus final s_s/s_z into trace/trunk/.

* Squeeze batch dim in hooks; drop recycling archive

Fix tensor shape handling in ESMFoldTraceCollector hooks by squeezing the leading batch dimension before detaching and moving seq and pair states to CPU, preventing stored activations from containing an extra batch axis. Also remove the prior archival of recycled s_s/s_z tensors in the ESMFoldRunner inference flow to avoid redundant/memory-heavy activation copies and logging related to those recycled tensors.
* Add ESMFold backend smoke test and reproducibility documentation

* Add tensor shape validation and fix smoke test per review

* Fix smoke test per review: tmp_path, sys.executable, tensor shape validation

* Fix ESMFold smoke test per review and validate attention tensor shape

* Add validation for trunk intermediates, attention exports, and new ESMFold outputs

* Address review comments: remove duplicate smoke test file and update recycle output paths
* Add VizFold text-file attention export compatible with existing visualization tools

* Bug fix: override the positional arg in-place instead of adding to kwargs

* Fix: trace_formats missing from meta.json

* Robust attention saving & forward signature handling

hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index

trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]

* Capture and save evoformer trunk intermediates

Add per-block evoformer tracing and output saving for ESMFold.

- hooks.py: introduce register_trunk_hooks and _make_trunk_block_hook to register forward hooks on model.trunk.blocks (EsmFoldTriangularSelfAttentionBlock). Captured per-block sequence_state and pairwise_state are stored in collector.trunk_blocks; clear() updated and warnings added when trunk/blocks are missing.
- inference.py: register the new trunk hooks in ESMFoldRunner, extract and save final folding trunk pair representations (out.s_z), and write per-block evoformer intermediates to trace/trunk/*.pt while recording shapes. Logging messages adjusted.
- trace_adapter.py: update trace layout to include trunk/ files (block_{idx}_seq/pair, s_s, s_z).

* ESMFold: save trunk tensors, CPU attention

Ensure attention tensors are moved to CPU in hooks (detach().cpu()) to avoid GPU tensor serialization. Stop extracting final trunk outputs from model.out and instead collect final s_s/s_z from collector.recycled_s_s/recycled_s_z (avoids redundant copies) and save per-block trunk tensors plus final s_s/s_z into trace/trunk/.

* Squeeze batch dim in hooks; drop recycling archive

Fix tensor shape handling in ESMFoldTraceCollector hooks by squeezing the leading batch dimension before detaching and moving seq and pair states to CPU, preventing stored activations from containing an extra batch axis. Also remove the prior archival of recycled s_s/s_z tensors in the ESMFoldRunner inference flow to avoid redundant/memory-heavy activation copies and logging related to those recycled tensors.

* Enhance ESMFold tracing, structure & sanity checks

Add structure-module tracing output and update docs to pin transformers>=4.36.0 and show example for IPA attention + per-recycle backbone traces. Improve HPC slurm script to fail fast if torch/transformers aren't importable. Strengthen trace hooks: validate attention slice length against an expected_seq_len, perform head filtering in-hook, detach tensors to CPU, and add a helper to extract positions/frames/single from dict or dataclass outputs so structure module traces are robust. Small fixes: reset internal counters after each recycle, remove redundant head-filtering in trace_adapter, and tidy FASTA reading variable naming.

* Implement BackendBase in ESMFoldRunner

Make ESMFoldRunner inherit from BackendBase and implement the backend interface (load_model, run_inference, supports_attention, supports_activations). Refactor internal loading into _load_model and update usage sites to call it; add trace configuration handling and pass expected_seq_len to the trace collector. Also remove redundant imports from tests/test_esmf_smoke.py.

* Bug Fix: Strip <cls>/<eos> from transformer activations

Update activation hook to remove the leading <cls> and trailing <eos> tokens so the activation tensor [B, N, D] aligns with attention maps [B, H, N, N].
* Extract s_s folding trunk activations and enforce safetensors

* Update backend pipeline

* Capture s_s and s_z at every recycling iteration via trunk hook

* Remove test output artifacts

* Complete interactive vizfold dashboard with 3Dmol and attention heatmaps

* add frontend package dependencies

* fix: address final security reviews and add frontend env configuration

* add fastapi and uvicorn to backend requirements

---------

Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
…nused SVGs (#8)

* Extract s_s folding trunk activations and enforce safetensors

* Update backend pipeline

* Capture s_s and s_z at every recycling iteration via trunk hook

* Remove test output artifacts

* Complete interactive vizfold dashboard with 3Dmol and attention heatmaps

* add frontend package dependencies

* fix: address final security reviews and add frontend env configuration

* add fastapi and uvicorn to backend requirements

* Removed duplicate extraction logic, fixed scaffold CSS, and deleted unused SVGs

* fix: re-apply SVG deletions and duplicate block removal

---------

Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
…ctional sync (#9)

Split monolithic App.jsx into StructureViewer, TraceExplorer, and
TimelineControls components. Preserve 3Dmol camera on style/color
changes by separating viewer creation from style application. Make
Plotly heatmap responsive (autosize) instead of fixed 500x500. Add
bidirectional crosshair sync (heatmap click highlights residue in 3D).
Add colorbar labels per view mode and loading/error states per panel.
Extract all inline styles to CSS classes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Extend VizFold Inference and Visualization to ESMFold

4 participants