Skip to content

Issue #8 Visualize Intermediate Representations#73

Open
SreeDan wants to merge 35 commits into
AI2Science:mainfrom
priyavisingh:main
Open

Issue #8 Visualize Intermediate Representations#73
SreeDan wants to merge 35 commits into
AI2Science:mainfrom
priyavisingh:main

Conversation

@SreeDan
Copy link
Copy Markdown

@SreeDan SreeDan commented Apr 29, 2026

Attempts to close #8

This PR adds a complete visualization layer (viz/) to the OpenFold fork, along with the extraction infrastructure needed to capture intermediate Evoformer representations and attention weights during a model run. All new code is additive. We didn't change the existing internals of any OpenFold modules except two small hook registration lines in openfold/__init__.py and openfold/utils/import_weights.py.

Architecture is in 3 layers:

model run
   │  EvoformerRecorder hooks (openfold/utils/evoformer_instrumentation.py)
   │  save_attention_topk text files + reps.pt
   ▼
representation_tensor_utils.py
   │  validate · select_channel · aggregate · normalize
   │  prepare_heatmap_data · prepare_lineplot_data
   ▼
viz/integrations.py
   │  *_from_representation  ·  *_from_artifact
   ▼
viz/plots/  →  matplotlib Figure  →  notebook / Flask UI

Layer 1:

openfold/utils/evoformer_instrumentation.py registers PyTorch forward hooks on the Evoformer stack via EvoformerRecorder. During a model run it captures MSA and pair representation tensors at configurable layers (clone --> detach --> CPU --> optional dtype cast) and writes them to reps.pt via torch.save. Attention weights are written as sparse top-K text files by save_attention_topk. This approach adds zero overhead when enabled=False and leaves the model weights untouched.

openfold/utils/evoformer_run_artifact.py wraps the on-disk outputs. EvoformerRunArtifact provides load_reps() and get_attention_matrix(kind, layer, head, ...) which reconstructs a dense N×N matrix from the sparse text files (mean over heads, optional single-residue slice). The artifact is the handoff point that the viz layer reads from.

run_evoformer_hook_pretrained_openfold.py is the end-to-end inference script: loads pretrained weights, attaches hooks, runs the model, and writes all artifacts to outputs/.

Layer 2:

Sits between extraction and rendering. Designed so the same call works for a raw tensor from a model run or a synthetic array from _fakes.py

Layer 3 - Plot functions

Core plots (viz/plots/)

Function Input shape Purpose
plot_heatmap (R, C) Residue-by-residue map with colorbar and optional red highlight lines
plot_heatmap_grid (K, R, C) K panels in a grid, optional shared colorbar
plot_line (N,) Single residue-indexed signal
plot_lines (K, N) K overlaid signals with legend
plot_layer_trajectory (L, N) One channel across L layers, one line per residue
plot_histogram any shape Value distribution (flattened, NaNs dropped)

Integration bridge (viz/integrations.py)

  • heatmap_from_representation / line_from_representation / lines_from_representation / pair_channel_grid — take a raw pair/msa/single tensor, route it through prepare_*, and call the appropriate plot function. Axis labels and titles are set automatically per kind.
  • attention_heatmap_from_artifact / representation_heatmap_from_artifact / representation_line_from_artifact / representation_tensor_from_artifact / figure_from_artifact — read directly from an EvoformerRunArtifact. figure_from_artifact is a single-call dispatcher: pass attn_kind= for attention or rep_kind= for representations.

Flask UI (viz/web_interface.py, viz/render_for_ui.py)

render_for_ui.py is a CLI that pre-generates 96 PNGs (48 layers × 2 attention types).

web_interface.py serves the PNGs with layer/type dropdowns. Configurable via environment variables (VIZFOLD_IMAGE_DIR, VIZFOLD_PROT, VIZFOLD_TRI_IDX, VIZFOLD_NUM_LAYERS) so it can point at any run without editing source.

Synthetic data helpers (viz/_fakes.py)

Shape-matching synthetic tensors (fake_attention_heads, fake_pair_channel, fake_single_channels, fake_layer_trajectory, fake_distribution) so the entire pipeline can be exercised without a GPU or model weights.

Each layer is independently testable and the plot functions accept plain NumPy arrays, so any layer can be bypassed.

Testing:

# Install viz dependencies
pip install -r viz/requirements.txt

# --- Unit / integration tests ---
python -m unittest tests.test_representation_tensor_utils -v
python -m unittest tests.test_viz_integration -v
python -m unittest tests.test_viz_artifact_bridge -v

# --- Flask UI (synthetic mode, no model run needed) ---
python -m viz.render_for_ui --mode demo --n-res 96
# Expected: "[demo] wrote 96 PNGs into outputs/attention_images_6KWC_demo_tri_18"

python viz/web_interface.py
# Open http://localhost:5001
# Verify: layer dropdown 0-47, both attention types load images, protein name in heading

# --- Security check: path traversal must be blocked ---
curl -s -o /dev/null -w "%{http_code}" "http://localhost:5001/image?path=/etc/passwd"
# Expected: 403

# --- Bad query params must not 500 ---
curl -s -o /dev/null -w "%{http_code}" "http://localhost:5001/?attn_type=EVIL&layer=abc"
# Expected: 200 (falls back to msa_row, layer 0)

# --- Env-var config ---
VIZFOLD_NUM_LAYERS=10 python viz/web_interface.py
# Layer dropdown should show only 0-9

# --- Real model run (if a GPU + weights are available) ---
python run_evoformer_hook_pretrained_openfold.py \
    --config_preset model_1 \
    --output_dir outputs/run_6KWC \
    --demo_attn
python -m viz.render_for_ui --mode real \
    --attn-dir outputs/attention_files_6KWC_demo_tri_18
python viz/web_interface.py
# Verify real attention heatmaps load in the UI

Example of end-to-end flow:
IMG_5765
IMG_6975

demo.mp4

Pranav Narala and others added 30 commits March 13, 2026 13:24
Feature/Issue AI2Science#8 — Web Interface for Visualizing Intermediate Representations
…I2Science#8

Adds viz/ package with plot_heatmap (image plot) and plot_line (line plot)
covering Person 3 (visualization) deliverables for issue AI2Science#8: image and line
plots over residue indices. Both functions are framework-agnostic (numpy
in, matplotlib Figure out) so they can be embedded in the Flask UI or used
directly in notebooks.

- viz/plots/heatmap.py: 2-D residue-indexed heatmap with colorbar and
  optional residue highlighting; supports square (NxN) attention/pair
  channels and rectangular (SxN) MSA channels.
- viz/plots/lineplot.py: 1-D residue-indexed line plot with optional
  residue reference markers.
- viz/plots/common.py: shared utilities (save_or_return, normalize,
  add_residue_axes, add_colorbar, new_figure).
- viz/examples/: committed example PNGs from synthetic data.
- notebooks/viz_plot_demo.ipynb: end-to-end demo against synthetic data.
- viz/README.md: usage snippets and rendered examples.

Made-with: Cursor
Adds four plot functions on top of plot_heatmap / plot_line:
- plot_heatmap_grid: K residue-indexed heatmaps in a grid (e.g. all heads
  of one attention layer) with an optional shared colorbar.
- plot_lines: K residue-indexed signals overlaid on one axis.
- plot_layer_trajectory: one channel's value across L layers, one line per
  selected residue.
- plot_histogram: value distribution of any tensor slice.

Adds viz/_fakes.py with shape-matching synthetic tensor helpers so the
plots can be exercised before the extraction layer (Priyavi/Pranav) lands;
the helpers are explicitly marked as placeholders.

Re-exports the new symbols from viz and viz.plots, regenerates the demo
notebook (one section per function plus shape-validation), expands
viz/README.md to document the full public API, and commits four new
example PNGs in viz/examples/.

Made-with: Cursor
Added representation tensor utilities for Issue AI2Science#8
Adds viz/integrations.py exposing four end-to-end helpers that take a raw
pair / msa / single tensor, route it through prepare_heatmap_data /
prepare_lineplot_data for validation + channel selection + aggregation +
normalization, and hand the result to the existing residue-indexed plot
functions:

- heatmap_from_representation -> plot_heatmap
- line_from_representation    -> plot_line
- lines_from_representation   -> plot_lines
- pair_channel_grid           -> plot_heatmap_grid

Re-exports the bridge from viz, extends the demo notebook with a section
showing the raw-tensor path against synthetic z/m/s tensors of the
expected shapes, commits three new example PNGs, and documents the
integration in viz/README.md so the Flask UI can call these helpers
directly once real OpenFold tensors arrive.

Made-with: Cursor
Jupyter reformatted the JSON cell keys on open (moves execution_count /
outputs / id alongside metadata). No content change.

Made-with: Cursor
For every supported (kind, channel/aggregate, normalize) combination,
asserts that the data the bridge functions paint onto the matplotlib
Figure is bit-identical to the array Priyavi's prepare_* returns.

Covers:
- pair / msa / single heatmaps across minmax / zscore / none normalization
- pair channel + mean / max / l2 aggregation
- single, pair-diagonal, msa depth-mean line plots
- multi-channel overlay (lines_from_representation)
- pair_channel_grid panel ordering
- error paths (non-square pair, channel out of bounds, channel+aggregate
  conflict, unknown kind)
- minmax/zscore numerical invariants

15 tests, all passing alongside the existing 12 in
tests/test_representation_tensor_utils.py.

Made-with: Cursor
Adds viz/render_for_ui.py, a CLI that populates the directory
web_interface.py serves (heatmap_<msa_row|triangle_start>_layer{N}.png
for layers 0..47). Three modes:

- real: parses save_attention_topk text files, reconstructs the N x N
  attention map per layer (mean over heads) and renders via
  viz.plot_heatmap. Use after a model run with --demo_attn.
- demo: synthesizes placeholders from viz._fakes.fake_attention_heads
  so the UI has 96 images to display before any model run.
- auto (default): tries real, falls back to demo per layer.

Smoke-tested end-to-end: ran the generator in auto mode (fell back to
demo, wrote 96 PNGs in ~60s), launched web_interface.py on :5001, and
verified the main page, ?layer=23, and a triangle layer-47 fetch all
return HTTP 200 with valid PNGs.

The generated PNGs themselves are gitignored (regenerable artifacts);
only the script and README updates are committed. Anyone cloning the
repo runs `python -m viz.render_for_ui` to populate the UI.

Made-with: Cursor
Merging in my part for group PR
Extends viz/integrations.py with five artifact-aware helpers:

- attention_heatmap_from_artifact: artifact.get_attention_matrix ->
  plot_heatmap (covers msa_row_attn and triangle_start_attn).
- representation_tensor_from_artifact: pulls layer_{LL:02d}.{kind} from
  artifact.reps with a top-level fallback.
- representation_heatmap_from_artifact / representation_line_from_artifact:
  routes the per-layer tensor through Priyavi's prepare_* and then
  through viz.heatmap_from_representation / line_from_representation.
- figure_from_artifact: single-call dispatcher selecting attention or
  representation path with a `plot=` kwarg.

Adds tests/test_viz_artifact_bridge.py: builds a synthetic artifact in a
temp dir (real-format attention text file + reps.pt) and asserts that
Figure pixel data matches artifact.get_attention_matrix(...) and the
prepare_* output bit-for-bit. Documents the new path in viz/README.md.

Made-with: Cursor
@SreeDan SreeDan changed the title Issue #8 Attention Head Vizualization Issue #8 Visualize Intermediate Representations Apr 29, 2026
@SreeDan SreeDan marked this pull request as ready for review April 29, 2026 02:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Visualize representations from intermediate layers for OpenFold

4 participants