Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions docs/source/Head_Visualization_Heatmap_Network.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Head-Level Heatmap and Network Visualizations

This document describes the new attention-head visualizations (heatmaps and network-style plots) and how they compare to the existing arc diagrams and 3D PyMOL overlays.

## Overview

Arc diagrams and PyMOL overlays are useful but show one head at a time. The new tools let you:

- **Compare all heads in one layer at once** via a grid of heatmaps or a grid of small network plots.
- **See aggregated attention** across heads as a single network, with hub residues highlighted.
- **Use the same attention text files** produced by `run_pretrained_openfold.py --demo_attn` (no change to the inference pipeline).

## New Components

| Component | Purpose |
|-----------|---------|
| `visualize_attention_data.py` | Shared attention-map parser and FASTA reader used by every visualization module. |
| `visualize_attention_head_heatmaps.py` | Builds dense per-head matrices from shared parsed attention data and plots a grid of residue–residue heatmaps (one per head). |
| `visualize_attention_networks.py` | Aggregates heads into one weighted graph and/or draws one small network per head; supports circular or linear layout and hub highlighting. |
| `visualize_attention_chord_diagrams.py` | Renders circular chord diagrams for single heads, all-head grids, and aggregated mean attention. |
| Notebook cells (in `viz_attention_demo.ipynb` and `viz_attention_demo_base.ipynb`) | After running 3D and arc visualizations, a new cell runs heatmap and network code for MSA row and (optionally) triangle-start attention. |
| `scripts/run_head_heatmap_network_demo.py` | Standalone script that generates **synthetic** attention data and runs the new visualizations to produce example outputs without full OpenFold inference. |

## How to Run

1. **From the notebook (real data)**
Run the usual inference cell so that `ATTN_MAP_DIR` contains files like `msa_row_attn_layer47.txt`. Then run the new cell titled *"Head-level heatmap and network-style visualizations"*. Outputs go to:
- `IMAGE_OUTPUT_DIR/head_heatmaps/` (combined heatmap panel, optional per-head heatmaps)
- `IMAGE_OUTPUT_DIR/network_plots/` (aggregated network, per-head network grid)

2. **Example outputs without inference**
From the repo root:
```bash
python scripts/run_head_heatmap_network_demo.py
```
This writes synthetic attention into `examples/monomer/sample_attention_viz_outputs/` and runs the same heatmap and network functions. Use it to check that the pipeline runs and to get sample figures.

## Comparing All Heads at Once

- **Heatmap grid**: One subplot per head; each shows the full residue × residue attention matrix (or top-k filled). Shared color scale across heads makes it easy to see which heads focus on similar pairs and which are sparse or different.
- **Per-head network grid**: Same idea as heatmaps but each head is shown as a 2D network (nodes = residues, edges = attention). Good for seeing structural “clusters” and long-range links per head.
- **Aggregated network**: One graph where edge weight is the mean (or sum) of attention over heads. Top-k hubs by total incident weight are highlighted, so you see which residues are attended to most across the layer.

## What These Visualizations Capture That Arc Diagrams Might Miss

- **Global head similarity**: Arc diagrams are one-head-at-a-time. Heatmaps and small-multiples networks show the whole layer in one view, so you can quickly see redundant vs. diverse heads.
- **Node-centric importance**: The aggregated network plus hub highlighting shows which residues are “important” in the sense of total incoming/outgoing attention, which is harder to read from a single-head arc.
- **Dense pattern vs. sparse pattern**: Heatmaps make it obvious when a head is diffuse (many weak links) vs. focused (few strong links), and where on the sequence those links lie (diagonal vs. off-diagonal).
- **Layer-wise comparison**: The same functions can be called for multiple layers (change `layer_idx` and the attention file path); then comparing saved heatmap/network figures across layers shows how attention evolves with depth.

## Evaluation Summary

| Visualization | Best for |
|---------------|----------|
| Arc diagram | Single-head, sequence-linear view of top-k edges; easy to match residues to sequence. |
| 3D PyMOL overlay | Same head in 3D structure context; good for spatial interpretation. |
| **Heatmap grid** | Comparing all heads in one layer; seeing dense vs. sparse and similarity across heads. |
| **Aggregated network** | Which residues are hubs across the whole layer; one picture for “consensus” attention. |
| **Per-head network grid** | Same as heatmap grid but with a network layout; can be easier for seeing clusters and long-range ties. |
| **Chord diagrams** | Circular residue-residue attention view; useful for seeing long-range links without forcing residues into a straight line. |

The new visualizations do not replace arc or 3D views; they complement them by answering “what do all heads in this layer look like together?” and “which residues matter most when we aggregate heads?”

## File and Function Reference

- **Build matrices**: `build_head_matrices(heads, n_residues)` in `visualize_attention_head_heatmaps.py`.
- **Load attention data**: `load_attention_map(connections_file, top_k=...)` in `visualize_attention_data.py`.
- **Heatmap panel**: `plot_head_heatmaps(head_mats, residue_sequence, layer_idx, protein, output_dir, ...)`.
- **Aggregated graph**: `build_aggregated_graph(heads, aggregation="mean", normalize_by_heads=True)` in `visualize_attention_networks.py`.
- **Single network plot**: `plot_residue_network(edges, n_residues, residue_sequence, layer_idx, protein, output_dir, layout="circular", max_edges=200, top_k_hubs=10)`.
- **Per-head networks**: `plot_residue_network_per_head(heads, n_residues, ...)`.
- **Chord diagrams**: `generate_chord_diagrams(attention_dir, residue_sequence, output_dir, protein, ...)` in `visualize_attention_chord_diagrams.py`.

Input `heads` is the dict returned by `load_attention_map(connections_file, top_k=...)` from `visualize_attention_data.py`. The older `load_all_heads()` import path is still available for compatibility.
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Single_Sequence_Inference.md
Multimer_Inference.md
OpenFold_Training_Setup.md
Training_OpenFold.md
Head_Visualization_Heatmap_Network.md
```

```{toctree}
Expand Down
17 changes: 17 additions & 0 deletions examples/monomer/sample_attention_viz_outputs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Sample attention visualization outputs

This directory is populated when you run:

```bash
python scripts/run_head_heatmap_network_demo.py
```

from the repository root. The script uses **synthetic** attention data (same file format as real OpenFold `--demo_attn` output) to generate:

- **head_heatmaps/** — One combined PNG with a grid of residue–residue heatmaps (one per attention head).
- **network_plots/** — Aggregated attention network (all heads combined, hub residues highlighted) and a per-head network grid.
- **chord_diagrams/** — Per-head chord diagrams, a combined per-head grid, and an aggregated mean chord diagram.

To generate visualizations from **real** inference, run the full pipeline in `viz_attention_demo.ipynb` or `viz_attention_demo_base.ipynb`; the same heatmap, network, and chord diagram code can run on the saved attention maps.

See [Head_Visualization_Heatmap_Network.md](../../../docs/source/Head_Visualization_Heatmap_Network.md) for a full description and comparison with arc diagrams and 3D overlays.
138 changes: 138 additions & 0 deletions scripts/run_head_heatmap_network_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Generate example head heatmap and network visualizations using synthetic attention data.

Use this to produce sample outputs without running full OpenFold inference.
Reads a FASTA for sequence length/labels and writes a minimal attention file
in the same format as run_pretrained_openfold.py --demo_attn, then runs
the new visualization utilities.

Usage:
python scripts/run_head_heatmap_network_demo.py

Outputs are written to examples/monomer/sample_attention_viz_outputs/
"""

import os
import sys

# Allow importing from repo root
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, REPO_ROOT)

import numpy as np
from visualize_attention_data import load_attention_map, parse_fasta_sequence
from visualize_attention_chord_diagrams import generate_chord_diagrams
from visualize_attention_head_heatmaps import build_head_matrices, plot_head_heatmaps
from visualize_attention_networks import (
build_aggregated_graph,
plot_residue_network,
plot_residue_network_per_head,
)


def write_synthetic_msa_row_attention(
output_path: str,
n_residues: int,
num_heads: int = 8,
edges_per_head: int = 80,
seed: int = 42,
) -> None:
"""Write a minimal msa_row_attn_layer{L}.txt with synthetic (i, j, weight) lines."""
rng = np.random.default_rng(seed)
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "w") as f:
for h in range(num_heads):
f.write(f"layer 47 head {h}\n")
for _ in range(edges_per_head):
i = rng.integers(0, n_residues)
j = rng.integers(0, n_residues)
if i == j:
j = (j + 1) % n_residues
w = float(rng.uniform(0.05, 0.6))
f.write(f"{i} {j} {w}\n")
print(f"Wrote synthetic MSA row attention: {output_path}")


def main():
fasta_path = os.path.join(
REPO_ROOT, "examples", "monomer", "fasta_dir_6KWC", "6KWC.fasta"
)
if not os.path.exists(fasta_path):
print(f"FASTA not found: {fasta_path}")
return 1
residue_seq = parse_fasta_sequence(fasta_path)
n_residues = len(residue_seq)

out_base = os.path.join(
REPO_ROOT, "examples", "monomer", "sample_attention_viz_outputs"
)
attn_dir = os.path.join(out_base, "attention_files_6KWC_demo")
heatmap_dir = os.path.join(out_base, "head_heatmaps")
network_dir = os.path.join(out_base, "network_plots")
chord_dir = os.path.join(out_base, "chord_diagrams")

layer_idx = 47
top_k = 50
protein = "6KWC"

write_synthetic_msa_row_attention(
os.path.join(attn_dir, f"msa_row_attn_layer{layer_idx}.txt"),
n_residues=n_residues,
num_heads=8,
edges_per_head=100,
)

msa_file = os.path.join(attn_dir, f"msa_row_attn_layer{layer_idx}.txt")
heads = load_attention_map(msa_file, top_k=top_k)
head_mats = build_head_matrices(heads, n_residues=n_residues)
plot_head_heatmaps(
head_mats,
residue_sequence=residue_seq,
layer_idx=layer_idx,
protein=protein,
output_dir=heatmap_dir,
cols=4,
save_combined=True,
save_individual=False,
)
agg_edges = build_aggregated_graph(heads, aggregation="mean", normalize_by_heads=True)
plot_residue_network(
agg_edges,
n_residues=n_residues,
residue_sequence=residue_seq,
layer_idx=layer_idx,
protein=protein,
output_dir=network_dir,
layout="circular",
max_edges=200,
top_k_hubs=10,
)
plot_residue_network_per_head(
heads,
n_residues=n_residues,
residue_sequence=residue_seq,
layer_idx=layer_idx,
protein=protein,
output_dir=network_dir,
layout="circular",
max_edges_per_head=50,
cols=4,
)
generate_chord_diagrams(
attention_dir=attn_dir,
residue_sequence=residue_seq,
output_dir=chord_dir,
protein=protein,
attention_type="msa_row",
top_k=top_k,
layer_idx=layer_idx,
save_individual=True,
save_grid=True,
save_aggregated=True,
)
print(f"Example outputs saved under {out_base}")
return 0


if __name__ == "__main__":
sys.exit(main())
55 changes: 55 additions & 0 deletions tests/test_visualize_attention_chord_diagrams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import tempfile
import unittest

from visualize_attention_chord_diagrams import (
aggregate_chord_edges,
generate_chord_diagrams,
)


class TestVisualizeAttentionChordDiagrams(unittest.TestCase):
def test_aggregate_chord_edges_uses_mean_by_default(self):
heads = {
0: [(0, 1, 0.2), (2, 3, 0.5)],
1: [(0, 1, 0.6)],
}

self.assertEqual(
aggregate_chord_edges(heads),
[(2, 3, 0.5), (0, 1, 0.4)],
)

def test_generate_chord_diagrams_writes_expected_msa_outputs(self):
with tempfile.TemporaryDirectory() as tmpdir:
attn_dir = os.path.join(tmpdir, "attention")
out_dir = os.path.join(tmpdir, "chords")
os.makedirs(attn_dir)
with open(os.path.join(attn_dir, "msa_row_attn_layer47.txt"), "w") as f:
f.write("layer 47 head 0\n")
f.write("0 1 0.8\n")
f.write("1 2 0.2\n")
f.write("layer 47 head 1\n")
f.write("2 3 0.7\n")
f.write("3 4 0.1\n")

paths = generate_chord_diagrams(
attention_dir=attn_dir,
residue_sequence="ABCDE",
output_dir=out_dir,
protein="TEST",
attention_type="msa_row",
top_k=1,
layer_idx=47,
save_individual=True,
save_grid=True,
save_aggregated=True,
)

self.assertEqual(len(paths), 4)
for path in paths:
self.assertTrue(os.path.exists(path))


if __name__ == "__main__":
unittest.main()
50 changes: 50 additions & 0 deletions tests/test_visualize_attention_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import tempfile
import unittest

from visualize_attention_data import (
get_attention_file_path,
load_attention_map,
parse_fasta_sequence,
)


class TestVisualizeAttentionData(unittest.TestCase):
def test_load_attention_map_groups_sorts_and_filters_heads(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "msa_row_attn_layer47.txt")
with open(path, "w") as f:
f.write("layer 47 head 0\n")
f.write("2 3 0.1\n")
f.write("0 1 0.9\n")
f.write("layer 47 head 1\n")
f.write("4 5 0.4\n")
f.write("6 7 0.2\n")

heads = load_attention_map(path, top_k=1)

self.assertEqual(heads, {0: [(0, 1, 0.9)], 1: [(4, 5, 0.4)]})

def test_get_attention_file_path_uses_expected_names(self):
self.assertEqual(
get_attention_file_path("/tmp/attn", "msa_row", 47),
"/tmp/attn/msa_row_attn_layer47.txt",
)
self.assertEqual(
get_attention_file_path("/tmp/attn", "triangle_start", 47, residue_idx=18),
"/tmp/attn/triangle_start_attn_layer47_residue_idx_18.txt",
)

def test_parse_fasta_sequence_joins_sequence_lines(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "protein.fasta")
with open(path, "w") as f:
f.write(">protein\n")
f.write("ACD\n")
f.write("EFG\n")

self.assertEqual(parse_fasta_sequence(path), "ACDEFG")


if __name__ == "__main__":
unittest.main()
Loading