Integrate ESMFold inference backend into VizFold (structure export + metadata)#44
Open
jayvenn21 wants to merge 18 commits into
Open
Integrate ESMFold inference backend into VizFold (structure export + metadata)#44jayvenn21 wants to merge 18 commits into
jayvenn21 wants to merge 18 commits into
Conversation
…d build dependency
…tates; use hooks for traces
…pipeline - Rewrite hooks.py: target HF encoder.layer[i].attention.self directly instead of broad name-matching; monkey-patch forward to force output_attentions=True so real attention weights [B,H,N,N] are captured (not hidden states); separate attention/activation hooks with correct layer indices; slice out <cls>/<eos> tokens from attention maps - Fix _coords_to_minimal_pdb: use 3-letter residue codes (valid PDB) - Remove dead code: try_use_outputs() path, shared mutable counter - Extract structure logic into _extract_structure() method - Unify FASTA reading into single read_fasta() returning (seq, id, hash) - Wire --dtype through CLI (float32/float16 model loading) - Log runner.run() result (attention/activation layer counts) - Fix trace_adapter: correct head-slicing axis for 3D vs 4D tensors; fix entropy calculation (per-row, not per-matrix)
JeevanandanRamasamy
left a comment
There was a problem hiding this comment.
This is a solid addition and the architecture is clean. Documentation is very thorough. I highlighted some potential minor issues in my review.
…re, trace relpaths, summary logging, layer_count
Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
* Add VizFold text-file attention export compatible with existing visualization tools * Bug fix: override the positional arg in-place instead of adding to kwargs * Fix: trace_formats missing from meta.json * Robust attention saving & forward signature handling hooks.py: - make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index trace_adapter.py: - reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files - layer-index extraction via regex - compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]
0a0a9d7 to
8351a6d
Compare
e07be3f to
494396b
Compare
Co-authored-by: Mose Kim <kimmose2002@gmail.com>
494396b to
4b86182
Compare
* Extract s_s folding trunk activations and enforce safetensors * Update backend pipeline * Capture s_s and s_z at every recycling iteration via trunk hook * Remove test output artifacts --------- Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
* Add VizFold text-file attention export compatible with existing visualization tools
* Bug fix: override the positional arg in-place instead of adding to kwargs
* Fix: trace_formats missing from meta.json
* Robust attention saving & forward signature handling
hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index
trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]
* Capture and save evoformer trunk intermediates
Add per-block evoformer tracing and output saving for ESMFold.
- hooks.py: introduce register_trunk_hooks and _make_trunk_block_hook to register forward hooks on model.trunk.blocks (EsmFoldTriangularSelfAttentionBlock). Captured per-block sequence_state and pairwise_state are stored in collector.trunk_blocks; clear() updated and warnings added when trunk/blocks are missing.
- inference.py: register the new trunk hooks in ESMFoldRunner, extract and save final folding trunk pair representations (out.s_z), and write per-block evoformer intermediates to trace/trunk/*.pt while recording shapes. Logging messages adjusted.
- trace_adapter.py: update trace layout to include trunk/ files (block_{idx}_seq/pair, s_s, s_z).
* ESMFold: save trunk tensors, CPU attention
Ensure attention tensors are moved to CPU in hooks (detach().cpu()) to avoid GPU tensor serialization. Stop extracting final trunk outputs from model.out and instead collect final s_s/s_z from collector.recycled_s_s/recycled_s_z (avoids redundant copies) and save per-block trunk tensors plus final s_s/s_z into trace/trunk/.
* Squeeze batch dim in hooks; drop recycling archive
Fix tensor shape handling in ESMFoldTraceCollector hooks by squeezing the leading batch dimension before detaching and moving seq and pair states to CPU, preventing stored activations from containing an extra batch axis. Also remove the prior archival of recycled s_s/s_z tensors in the ESMFoldRunner inference flow to avoid redundant/memory-heavy activation copies and logging related to those recycled tensors.
* Add ESMFold backend smoke test and reproducibility documentation * Add tensor shape validation and fix smoke test per review * Fix smoke test per review: tmp_path, sys.executable, tensor shape validation * Fix ESMFold smoke test per review and validate attention tensor shape * Add validation for trunk intermediates, attention exports, and new ESMFold outputs * Address review comments: remove duplicate smoke test file and update recycle output paths
* Add VizFold text-file attention export compatible with existing visualization tools
* Bug fix: override the positional arg in-place instead of adding to kwargs
* Fix: trace_formats missing from meta.json
* Robust attention saving & forward signature handling
hooks.py:
- make the EsmSelfAttention forward patch resilient to signature changes by finding the position of output_attentions by name instead of assuming a fixed positional index
trace_adapter.py:
- reuse OpenFold's save_attention_topk if available, and falls back to a self-contained NumPy implementation (no OpenFold dependency) that writes msa_row_attn text files
- layer-index extraction via regex
- compute produced trace_formats dynamically in build_and_write_meta instead of hardcoding ["pt","txt"]
* Capture and save evoformer trunk intermediates
Add per-block evoformer tracing and output saving for ESMFold.
- hooks.py: introduce register_trunk_hooks and _make_trunk_block_hook to register forward hooks on model.trunk.blocks (EsmFoldTriangularSelfAttentionBlock). Captured per-block sequence_state and pairwise_state are stored in collector.trunk_blocks; clear() updated and warnings added when trunk/blocks are missing.
- inference.py: register the new trunk hooks in ESMFoldRunner, extract and save final folding trunk pair representations (out.s_z), and write per-block evoformer intermediates to trace/trunk/*.pt while recording shapes. Logging messages adjusted.
- trace_adapter.py: update trace layout to include trunk/ files (block_{idx}_seq/pair, s_s, s_z).
* ESMFold: save trunk tensors, CPU attention
Ensure attention tensors are moved to CPU in hooks (detach().cpu()) to avoid GPU tensor serialization. Stop extracting final trunk outputs from model.out and instead collect final s_s/s_z from collector.recycled_s_s/recycled_s_z (avoids redundant copies) and save per-block trunk tensors plus final s_s/s_z into trace/trunk/.
* Squeeze batch dim in hooks; drop recycling archive
Fix tensor shape handling in ESMFoldTraceCollector hooks by squeezing the leading batch dimension before detaching and moving seq and pair states to CPU, preventing stored activations from containing an extra batch axis. Also remove the prior archival of recycled s_s/s_z tensors in the ESMFoldRunner inference flow to avoid redundant/memory-heavy activation copies and logging related to those recycled tensors.
* Enhance ESMFold tracing, structure & sanity checks
Add structure-module tracing output and update docs to pin transformers>=4.36.0 and show example for IPA attention + per-recycle backbone traces. Improve HPC slurm script to fail fast if torch/transformers aren't importable. Strengthen trace hooks: validate attention slice length against an expected_seq_len, perform head filtering in-hook, detach tensors to CPU, and add a helper to extract positions/frames/single from dict or dataclass outputs so structure module traces are robust. Small fixes: reset internal counters after each recycle, remove redundant head-filtering in trace_adapter, and tidy FASTA reading variable naming.
* Implement BackendBase in ESMFoldRunner
Make ESMFoldRunner inherit from BackendBase and implement the backend interface (load_model, run_inference, supports_attention, supports_activations). Refactor internal loading into _load_model and update usage sites to call it; add trace configuration handling and pass expected_seq_len to the trace collector. Also remove redundant imports from tests/test_esmf_smoke.py.
* Bug Fix: Strip <cls>/<eos> from transformer activations
Update activation hook to remove the leading <cls> and trailing <eos> tokens so the activation tensor [B, N, D] aligns with attention maps [B, H, N, N].
* Extract s_s folding trunk activations and enforce safetensors * Update backend pipeline * Capture s_s and s_z at every recycling iteration via trunk hook * Remove test output artifacts * Complete interactive vizfold dashboard with 3Dmol and attention heatmaps * add frontend package dependencies * fix: address final security reviews and add frontend env configuration * add fastapi and uvicorn to backend requirements --------- Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
…nused SVGs (#8) * Extract s_s folding trunk activations and enforce safetensors * Update backend pipeline * Capture s_s and s_z at every recycling iteration via trunk hook * Remove test output artifacts * Complete interactive vizfold dashboard with 3Dmol and attention heatmaps * add frontend package dependencies * fix: address final security reviews and add frontend env configuration * add fastapi and uvicorn to backend requirements * Removed duplicate extraction logic, fixed scaffold CSS, and deleted unused SVGs * fix: re-apply SVG deletions and duplicate block removal --------- Co-authored-by: Rohan Singhal <rsinghal49@atl1-1-03-013-19-0.pace.gatech.edu>
…ctional sync (#9) Split monolithic App.jsx into StructureViewer, TraceExplorer, and TimelineControls components. Preserve 3Dmol camera on style/color changes by separating viewer creation from style application. Make Plotly heatmap responsive (autosize) instead of fixed 500x500. Add bidirectional crosshair sync (heatmap click highlights residue in 3D). Add colorbar labels per view mode and loading/error states per panel. Extract all inline styles to CSS classes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Integrate ESMFold inference backend into VizFold (Fixes #43)
Context
This PR introduces the full ESMFold backend for VizFold — inference, hook-based trace extraction, an interactive visualization dashboard, and HPC deployment tooling. It uses the HuggingFace
EsmForProteinFoldingmodel and exports outputs in the VizFold archive format.This addresses Issue #43 end-to-end: inference, trace extraction across all model components, structure module tracing, a FastAPI bridge, and a React frontend for interactive exploration.
What's Included
HuggingFace-based ESMFold integration
facebook/esmfold_v1viatransformers(pinned>=4.36.0)ESMFoldRunnerimplements theBackendBaseinterface (load_model,run_inference,supports_attention,supports_activations)GPU-enabled inference on ICE (SLURM-tested)
--device cudaon PACE ICE H100 nodesHF_HOMEtorch/transformersnot importableStructure export
structure/predicted.pdb(full atom37 PDB via HF openfold_utils)structure/predicted.pt(coordinate tensor)Trace extraction via forward hooks
model.esm.encoder.layer[i].attention.self, monkey-patching forward to forceoutput_attentions=Trueso real attention probabilities[B, H, N, N]are captured[B, N, D]with<cls>and<eos>special tokens stripped to align with attention mapsexpected_seq_lenwith warnings on mismatch--layersand--headsEvoformer trunk intermediates
[L, C_s]and pair state[L, L, C_z]via hooks ontrunk.blocks[i]s_sands_zcaptured from recycling iterationss_sands_ztensors archived asrecycle_{i}_s_s/recycle_{i}_s_zStructure module tracing
[H, N, N]captured viaCapturingSoftmaxpatching_extract_from_outputhelperInteractive visualization dashboard
VITE_API_URLenvironment variableFastAPI bridge server
meta.json, PDB structure, and trace tensors over HTTPsafe_join()+ explicit..///\\checks--dirCLI argumentMetadata + archive format
meta.jsonconsistent with VizFold archive schematrace/index.jsonmapping layers to files, shapes, dtypestrace/summary.jsonwith per-layer attention entropy and activation normstrace_formatsfield indicating which output formats were produced (pt,txt)CLI support
run_pretrained_esmf.pywith:--fasta,--out,--device,--dtype,--trace_mode,--layers,--heads,--save_fp16,--seed,--deterministic,--structure_tracesTests, docs, and HPC support
tests/test_esmf_smoke.py— CLI help, missing FASTA, schema, import, full smoke-run, and end-to-end validation testsdocs/esmfold.md— setup, usage examples (attention, activations, structure module traces), output layout, and meta.json schemadocs/hpc_ice.md— PACE ICE deployment guidescripts/hpc/ice/run_esmf_ice.slurm— SLURM batch script with environment validationDesign Decision
This implementation uses the HuggingFace ESMFold model rather than directly building OpenFold, which avoids CUDA toolchain mismatches on HPC systems and simplifies reproducibility across environments.
Team Contributions (Integrated)
The following team PRs have been reviewed, merged, and integrated into this branch:
s_sands_zat every recycling iteration via trunk hooks, enforceduse_safetensors=Truefor HPC stability (CVE-2025-32434), built the complete interactive dashboard (3Dmol.js + Plotly heatmaps + FastAPI bridge), added path traversal security, env-based API URL config.txtattention export with--top_ksupport, fixedoutput_attentionspositional arg bug, implementedBackendBaseinterface, added attention slice validation and in-hook head filtering, stripped<cls>/<eos>from activations, added structure module dict/dataclass handlingValidation Evidence
Validated on PACE ICE with GPU and locally on CPU. ESMFold inference is deterministic; repeated runs produce identical archives.
Example run (attention + activations):
HF_HOME=/tmp/$USER/hf_cache \ python run_pretrained_esmf.py \ --fasta examples/monomer/fasta_dir_6KWC/6KWC.fasta \ --out outputs/test_trace \ --device cuda \ --trace_mode attention+activationsExample run (structure module traces):
Output Archive
Architecture Diagram
Demo Video
esmfolddemo.mp4
Screenshots
What's Next
Most planned extensions from the original PR are complete. Remaining tasks: