Issue #8 Visualize Intermediate Representations#73
Open
SreeDan wants to merge 35 commits into
Open
Conversation
Feature/Issue AI2Science#8 — Web Interface for Visualizing Intermediate Representations
…I2Science#8 Adds viz/ package with plot_heatmap (image plot) and plot_line (line plot) covering Person 3 (visualization) deliverables for issue AI2Science#8: image and line plots over residue indices. Both functions are framework-agnostic (numpy in, matplotlib Figure out) so they can be embedded in the Flask UI or used directly in notebooks. - viz/plots/heatmap.py: 2-D residue-indexed heatmap with colorbar and optional residue highlighting; supports square (NxN) attention/pair channels and rectangular (SxN) MSA channels. - viz/plots/lineplot.py: 1-D residue-indexed line plot with optional residue reference markers. - viz/plots/common.py: shared utilities (save_or_return, normalize, add_residue_axes, add_colorbar, new_figure). - viz/examples/: committed example PNGs from synthetic data. - notebooks/viz_plot_demo.ipynb: end-to-end demo against synthetic data. - viz/README.md: usage snippets and rendered examples. Made-with: Cursor
Adds four plot functions on top of plot_heatmap / plot_line: - plot_heatmap_grid: K residue-indexed heatmaps in a grid (e.g. all heads of one attention layer) with an optional shared colorbar. - plot_lines: K residue-indexed signals overlaid on one axis. - plot_layer_trajectory: one channel's value across L layers, one line per selected residue. - plot_histogram: value distribution of any tensor slice. Adds viz/_fakes.py with shape-matching synthetic tensor helpers so the plots can be exercised before the extraction layer (Priyavi/Pranav) lands; the helpers are explicitly marked as placeholders. Re-exports the new symbols from viz and viz.plots, regenerates the demo notebook (one section per function plus shape-validation), expands viz/README.md to document the full public API, and commits four new example PNGs in viz/examples/. Made-with: Cursor
Added representation tensor utilities for Issue AI2Science#8
Remove headers for consistency
Final checkin
Adds viz/integrations.py exposing four end-to-end helpers that take a raw pair / msa / single tensor, route it through prepare_heatmap_data / prepare_lineplot_data for validation + channel selection + aggregation + normalization, and hand the result to the existing residue-indexed plot functions: - heatmap_from_representation -> plot_heatmap - line_from_representation -> plot_line - lines_from_representation -> plot_lines - pair_channel_grid -> plot_heatmap_grid Re-exports the bridge from viz, extends the demo notebook with a section showing the raw-tensor path against synthetic z/m/s tensors of the expected shapes, commits three new example PNGs, and documents the integration in viz/README.md so the Flask UI can call these helpers directly once real OpenFold tensors arrive. Made-with: Cursor
Jupyter reformatted the JSON cell keys on open (moves execution_count / outputs / id alongside metadata). No content change. Made-with: Cursor
For every supported (kind, channel/aggregate, normalize) combination, asserts that the data the bridge functions paint onto the matplotlib Figure is bit-identical to the array Priyavi's prepare_* returns. Covers: - pair / msa / single heatmaps across minmax / zscore / none normalization - pair channel + mean / max / l2 aggregation - single, pair-diagonal, msa depth-mean line plots - multi-channel overlay (lines_from_representation) - pair_channel_grid panel ordering - error paths (non-square pair, channel out of bounds, channel+aggregate conflict, unknown kind) - minmax/zscore numerical invariants 15 tests, all passing alongside the existing 12 in tests/test_representation_tensor_utils.py. Made-with: Cursor
Adds viz/render_for_ui.py, a CLI that populates the directory
web_interface.py serves (heatmap_<msa_row|triangle_start>_layer{N}.png
for layers 0..47). Three modes:
- real: parses save_attention_topk text files, reconstructs the N x N
attention map per layer (mean over heads) and renders via
viz.plot_heatmap. Use after a model run with --demo_attn.
- demo: synthesizes placeholders from viz._fakes.fake_attention_heads
so the UI has 96 images to display before any model run.
- auto (default): tries real, falls back to demo per layer.
Smoke-tested end-to-end: ran the generator in auto mode (fell back to
demo, wrote 96 PNGs in ~60s), launched web_interface.py on :5001, and
verified the main page, ?layer=23, and a triangle layer-47 fetch all
return HTTP 200 with valid PNGs.
The generated PNGs themselves are gitignored (regenerable artifacts);
only the script and README updates are committed. Anyone cloning the
repo runs `python -m viz.render_for_ui` to populate the UI.
Made-with: Cursor
Merging in my part for group PR
Extends viz/integrations.py with five artifact-aware helpers:
- attention_heatmap_from_artifact: artifact.get_attention_matrix ->
plot_heatmap (covers msa_row_attn and triangle_start_attn).
- representation_tensor_from_artifact: pulls layer_{LL:02d}.{kind} from
artifact.reps with a top-level fallback.
- representation_heatmap_from_artifact / representation_line_from_artifact:
routes the per-layer tensor through Priyavi's prepare_* and then
through viz.heatmap_from_representation / line_from_representation.
- figure_from_artifact: single-call dispatcher selecting attention or
representation path with a `plot=` kwarg.
Adds tests/test_viz_artifact_bridge.py: builds a synthetic artifact in a
temp dir (real-format attention text file + reps.pt) and asserts that
Figure pixel data matches artifact.get_attention_matrix(...) and the
prepare_* output bit-for-bit. Documents the new path in viz/README.md.
Made-with: Cursor
pratham's visualizations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Attempts to close #8
This PR adds a complete visualization layer (
viz/) to the OpenFold fork, along with the extraction infrastructure needed to capture intermediate Evoformer representations and attention weights during a model run. All new code is additive. We didn't change the existing internals of any OpenFold modules except two small hook registration lines inopenfold/__init__.pyandopenfold/utils/import_weights.py.Architecture is in 3 layers:
Layer 1:
openfold/utils/evoformer_instrumentation.pyregisters PyTorch forward hooks on the Evoformer stack via EvoformerRecorder. During a model run it captures MSA and pair representation tensors at configurable layers (clone --> detach --> CPU --> optional dtype cast) and writes them to reps.pt via torch.save. Attention weights are written as sparse top-K text files bysave_attention_topk. This approach adds zero overhead whenenabled=Falseand leaves the model weights untouched.openfold/utils/evoformer_run_artifact.pywraps the on-disk outputs.EvoformerRunArtifactprovidesload_reps()andget_attention_matrix(kind, layer, head, ...)which reconstructs a denseN×Nmatrix from the sparse text files (mean over heads, optional single-residue slice). The artifact is the handoff point that the viz layer reads from.run_evoformer_hook_pretrained_openfold.pyis the end-to-end inference script: loads pretrained weights, attaches hooks, runs the model, and writes all artifacts tooutputs/.Layer 2:
Sits between extraction and rendering. Designed so the same call works for a raw tensor from a model run or a synthetic array
from _fakes.pyLayer 3 - Plot functions
Core plots (
viz/plots/)plot_heatmap(R, C)plot_heatmap_grid(K, R, C)plot_line(N,)plot_lines(K, N)plot_layer_trajectory(L, N)plot_histogramIntegration bridge (
viz/integrations.py)heatmap_from_representation/line_from_representation/lines_from_representation/pair_channel_grid— take a rawpair/msa/singletensor, route it throughprepare_*, and call the appropriate plot function. Axis labels and titles are set automatically perkind.attention_heatmap_from_artifact/representation_heatmap_from_artifact/representation_line_from_artifact/representation_tensor_from_artifact/figure_from_artifact— read directly from anEvoformerRunArtifact.figure_from_artifactis a single-call dispatcher: passattn_kind=for attention orrep_kind=for representations.Flask UI (
viz/web_interface.py,viz/render_for_ui.py)render_for_ui.pyis a CLI that pre-generates 96 PNGs (48 layers × 2 attention types).web_interface.pyserves the PNGs with layer/type dropdowns. Configurable via environment variables (VIZFOLD_IMAGE_DIR,VIZFOLD_PROT,VIZFOLD_TRI_IDX,VIZFOLD_NUM_LAYERS) so it can point at any run without editing source.Synthetic data helpers (
viz/_fakes.py)Shape-matching synthetic tensors (
fake_attention_heads,fake_pair_channel,fake_single_channels,fake_layer_trajectory,fake_distribution) so the entire pipeline can be exercised without a GPU or model weights.Each layer is independently testable and the plot functions accept plain NumPy arrays, so any layer can be bypassed.
Testing:
Example of end-to-end flow:


demo.mp4