From bb02f23476031f8652d26352381b0a141be954e4 Mon Sep 17 00:00:00 2001 From: Sai Sagili Date: Sun, 8 Mar 2026 20:11:53 -0400 Subject: [PATCH 1/3] Added heatmap and network visualizations 3/8/2026 --- .../Head_Visualization_Heatmap_Network.md | 69 ++ docs/source/index.md | 1 + .../sample_attention_viz_outputs/README.md | 16 + scripts/run_head_heatmap_network_demo.py | 124 ++++ visualize_attention_head_heatmaps.py | 193 ++++++ visualize_attention_networks.py | 298 +++++++++ viz_attention_demo.ipynb | 604 ++++++++++-------- viz_attention_demo_base.ipynb | 492 ++++++++------ 8 files changed, 1343 insertions(+), 454 deletions(-) create mode 100644 docs/source/Head_Visualization_Heatmap_Network.md create mode 100644 examples/monomer/sample_attention_viz_outputs/README.md create mode 100644 scripts/run_head_heatmap_network_demo.py create mode 100644 visualize_attention_head_heatmaps.py create mode 100644 visualize_attention_networks.py diff --git a/docs/source/Head_Visualization_Heatmap_Network.md b/docs/source/Head_Visualization_Heatmap_Network.md new file mode 100644 index 00000000..de03587b --- /dev/null +++ b/docs/source/Head_Visualization_Heatmap_Network.md @@ -0,0 +1,69 @@ +# Head-Level Heatmap and Network Visualizations + +This document describes the new attention-head visualizations (heatmaps and network-style plots) and how they compare to the existing arc diagrams and 3D PyMOL overlays. + +## Overview + +Arc diagrams and PyMOL overlays are useful but show one head at a time. The new tools let you: + +- **Compare all heads in one layer at once** via a grid of heatmaps or a grid of small network plots. +- **See aggregated attention** across heads as a single network, with hub residues highlighted. +- **Use the same attention text files** produced by `run_pretrained_openfold.py --demo_attn` (no change to the inference pipeline). + +## New Components + +| Component | Purpose | +|-----------|---------| +| `visualize_attention_head_heatmaps.py` | Builds dense per-head matrices from `load_all_heads()` and plots a grid of residue–residue heatmaps (one per head). | +| `visualize_attention_networks.py` | Aggregates heads into one weighted graph and/or draws one small network per head; supports circular or linear layout and hub highlighting. | +| Notebook cells (in `viz_attention_demo.ipynb` and `viz_attention_demo_base.ipynb`) | After running 3D and arc visualizations, a new cell runs heatmap and network code for MSA row and (optionally) triangle-start attention. | +| `scripts/run_head_heatmap_network_demo.py` | Standalone script that generates **synthetic** attention data and runs the new visualizations to produce example outputs without full OpenFold inference. | + +## How to Run + +1. **From the notebook (real data)** + Run the usual inference cell so that `ATTN_MAP_DIR` contains files like `msa_row_attn_layer47.txt`. Then run the new cell titled *"Head-level heatmap and network-style visualizations"*. Outputs go to: + - `IMAGE_OUTPUT_DIR/head_heatmaps/` (combined heatmap panel, optional per-head heatmaps) + - `IMAGE_OUTPUT_DIR/network_plots/` (aggregated network, per-head network grid) + +2. **Example outputs without inference** + From the repo root: + ```bash + python scripts/run_head_heatmap_network_demo.py + ``` + This writes synthetic attention into `examples/monomer/sample_attention_viz_outputs/` and runs the same heatmap and network functions. Use it to check that the pipeline runs and to get sample figures. + +## Comparing All Heads at Once + +- **Heatmap grid**: One subplot per head; each shows the full residue × residue attention matrix (or top-k filled). Shared color scale across heads makes it easy to see which heads focus on similar pairs and which are sparse or different. +- **Per-head network grid**: Same idea as heatmaps but each head is shown as a 2D network (nodes = residues, edges = attention). Good for seeing structural “clusters” and long-range links per head. +- **Aggregated network**: One graph where edge weight is the mean (or sum) of attention over heads. Top-k hubs by total incident weight are highlighted, so you see which residues are attended to most across the layer. + +## What These Visualizations Capture That Arc Diagrams Might Miss + +- **Global head similarity**: Arc diagrams are one-head-at-a-time. Heatmaps and small-multiples networks show the whole layer in one view, so you can quickly see redundant vs. diverse heads. +- **Node-centric importance**: The aggregated network plus hub highlighting shows which residues are “important” in the sense of total incoming/outgoing attention, which is harder to read from a single-head arc. +- **Dense pattern vs. sparse pattern**: Heatmaps make it obvious when a head is diffuse (many weak links) vs. focused (few strong links), and where on the sequence those links lie (diagonal vs. off-diagonal). +- **Layer-wise comparison**: The same functions can be called for multiple layers (change `layer_idx` and the attention file path); then comparing saved heatmap/network figures across layers shows how attention evolves with depth. + +## Evaluation Summary + +| Visualization | Best for | +|---------------|----------| +| Arc diagram | Single-head, sequence-linear view of top-k edges; easy to match residues to sequence. | +| 3D PyMOL overlay | Same head in 3D structure context; good for spatial interpretation. | +| **Heatmap grid** | Comparing all heads in one layer; seeing dense vs. sparse and similarity across heads. | +| **Aggregated network** | Which residues are hubs across the whole layer; one picture for “consensus” attention. | +| **Per-head network grid** | Same as heatmap grid but with a network layout; can be easier for seeing clusters and long-range ties. | + +The new visualizations do not replace arc or 3D views; they complement them by answering “what do all heads in this layer look like together?” and “which residues matter most when we aggregate heads?” + +## File and Function Reference + +- **Build matrices**: `build_head_matrices(heads, n_residues)` in `visualize_attention_head_heatmaps.py`. +- **Heatmap panel**: `plot_head_heatmaps(head_mats, residue_sequence, layer_idx, protein, output_dir, ...)`. +- **Aggregated graph**: `build_aggregated_graph(heads, aggregation="mean", normalize_by_heads=True)` in `visualize_attention_networks.py`. +- **Single network plot**: `plot_residue_network(edges, n_residues, residue_sequence, layer_idx, protein, output_dir, layout="circular", max_edges=200, top_k_hubs=10)`. +- **Per-head networks**: `plot_residue_network_per_head(heads, n_residues, ...)`. + +Input `heads` is the dict returned by `load_all_heads(connections_file, top_k=...)` from `visualize_attention_arc_diagram_demo_utils` (or the same function in `visualize_attention_3d_demo_utils`). diff --git a/docs/source/index.md b/docs/source/index.md index 5da44919..fec23059 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -106,6 +106,7 @@ Single_Sequence_Inference.md Multimer_Inference.md OpenFold_Training_Setup.md Training_OpenFold.md +Head_Visualization_Heatmap_Network.md ``` ```{toctree} diff --git a/examples/monomer/sample_attention_viz_outputs/README.md b/examples/monomer/sample_attention_viz_outputs/README.md new file mode 100644 index 00000000..3396b865 --- /dev/null +++ b/examples/monomer/sample_attention_viz_outputs/README.md @@ -0,0 +1,16 @@ +# Sample attention visualization outputs + +This directory is populated when you run: + +```bash +python scripts/run_head_heatmap_network_demo.py +``` + +from the repository root. The script uses **synthetic** attention data (same file format as real OpenFold `--demo_attn` output) to generate: + +- **head_heatmaps/** — One combined PNG with a grid of residue–residue heatmaps (one per attention head). +- **network_plots/** — Aggregated attention network (all heads combined, hub residues highlighted) and a per-head network grid. + +To generate visualizations from **real** inference, run the full pipeline in `viz_attention_demo.ipynb` or `viz_attention_demo_base.ipynb`; the same heatmap and network code runs in the notebook and writes to `IMAGE_OUTPUT_DIR/head_heatmaps/` and `IMAGE_OUTPUT_DIR/network_plots/`. + +See [Head_Visualization_Heatmap_Network.md](../../../docs/source/Head_Visualization_Heatmap_Network.md) for a full description and comparison with arc diagrams and 3D overlays. diff --git a/scripts/run_head_heatmap_network_demo.py b/scripts/run_head_heatmap_network_demo.py new file mode 100644 index 00000000..61c71d8f --- /dev/null +++ b/scripts/run_head_heatmap_network_demo.py @@ -0,0 +1,124 @@ +""" +Generate example head heatmap and network visualizations using synthetic attention data. + +Use this to produce sample outputs without running full OpenFold inference. +Reads a FASTA for sequence length/labels and writes a minimal attention file +in the same format as run_pretrained_openfold.py --demo_attn, then runs +the new visualization utilities. + +Usage: + python scripts/run_head_heatmap_network_demo.py + +Outputs are written to examples/monomer/sample_attention_viz_outputs/ +""" + +import os +import sys + +# Allow importing from repo root +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, REPO_ROOT) + +import numpy as np +from visualize_attention_arc_diagram_demo_utils import load_all_heads, parse_fasta_sequence +from visualize_attention_head_heatmaps import build_head_matrices, plot_head_heatmaps +from visualize_attention_networks import ( + build_aggregated_graph, + plot_residue_network, + plot_residue_network_per_head, +) + + +def write_synthetic_msa_row_attention( + output_path: str, + n_residues: int, + num_heads: int = 8, + edges_per_head: int = 80, + seed: int = 42, +) -> None: + """Write a minimal msa_row_attn_layer{L}.txt with synthetic (i, j, weight) lines.""" + rng = np.random.default_rng(seed) + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + for h in range(num_heads): + f.write(f"layer 47 head {h}\n") + for _ in range(edges_per_head): + i = rng.integers(0, n_residues) + j = rng.integers(0, n_residues) + if i == j: + j = (j + 1) % n_residues + w = float(rng.uniform(0.05, 0.6)) + f.write(f"{i} {j} {w}\n") + print(f"Wrote synthetic MSA row attention: {output_path}") + + +def main(): + fasta_path = os.path.join( + REPO_ROOT, "examples", "monomer", "fasta_dir_6KWC", "6KWC.fasta" + ) + if not os.path.exists(fasta_path): + print(f"FASTA not found: {fasta_path}") + return 1 + residue_seq = parse_fasta_sequence(fasta_path) + n_residues = len(residue_seq) + + out_base = os.path.join( + REPO_ROOT, "examples", "monomer", "sample_attention_viz_outputs" + ) + attn_dir = os.path.join(out_base, "attention_files_6KWC_demo") + heatmap_dir = os.path.join(out_base, "head_heatmaps") + network_dir = os.path.join(out_base, "network_plots") + + layer_idx = 47 + top_k = 50 + protein = "6KWC" + + write_synthetic_msa_row_attention( + os.path.join(attn_dir, f"msa_row_attn_layer{layer_idx}.txt"), + n_residues=n_residues, + num_heads=8, + edges_per_head=100, + ) + + msa_file = os.path.join(attn_dir, f"msa_row_attn_layer{layer_idx}.txt") + heads = load_all_heads(msa_file, top_k=top_k) + head_mats = build_head_matrices(heads, n_residues=n_residues) + plot_head_heatmaps( + head_mats, + residue_sequence=residue_seq, + layer_idx=layer_idx, + protein=protein, + output_dir=heatmap_dir, + cols=4, + save_combined=True, + save_individual=False, + ) + agg_edges = build_aggregated_graph(heads, aggregation="mean", normalize_by_heads=True) + plot_residue_network( + agg_edges, + n_residues=n_residues, + residue_sequence=residue_seq, + layer_idx=layer_idx, + protein=protein, + output_dir=network_dir, + layout="circular", + max_edges=200, + top_k_hubs=10, + ) + plot_residue_network_per_head( + heads, + n_residues=n_residues, + residue_sequence=residue_seq, + layer_idx=layer_idx, + protein=protein, + output_dir=network_dir, + layout="circular", + max_edges_per_head=50, + cols=4, + ) + print(f"Example outputs saved under {out_base}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/visualize_attention_head_heatmaps.py b/visualize_attention_head_heatmaps.py new file mode 100644 index 00000000..2467a3e8 --- /dev/null +++ b/visualize_attention_head_heatmaps.py @@ -0,0 +1,193 @@ +""" +Head-level attention heatmap visualization. + +Builds dense per-head attention matrices from the same text files used by +arc diagrams and 3D overlays, and plots them as a grid of heatmaps for +comparing all heads in one layer at once. +""" + +import os +from typing import Dict, List, Optional + +import numpy as np +import matplotlib.pyplot as plt + + +def build_head_matrices( + heads: Dict[int, List[tuple]], + n_residues: Optional[int] = None, + symmetrize: bool = False, +) -> Dict[int, np.ndarray]: + """ + Convert sparse per-head connection lists into dense residue x residue matrices. + + Args: + heads: Dict mapping head_idx -> list of (res_i, res_j, weight). + From load_all_heads() in visualize_attention_arc_diagram_demo_utils + or visualize_attention_3d_demo_utils. + n_residues: Number of residues (sequence length). If None, inferred from + max index appearing in any head. + symmetrize: If True, set A[i,j] = A[j,i] = max(weight(i,j), weight(j,i)). + + Returns: + Dict mapping head_idx -> 2D array of shape (n_residues, n_residues). + """ + if n_residues is None: + n_residues = 0 + for conns in heads.values(): + for res1, res2, _ in conns: + n_residues = max(n_residues, res1 + 1, res2 + 1) + if n_residues == 0: + return {} + + out = {} + for head_idx, conns in heads.items(): + A = np.zeros((n_residues, n_residues), dtype=np.float64) + for res_i, res_j, weight in conns: + i, j = int(res_i), int(res_j) + if 0 <= i < n_residues and 0 <= j < n_residues: + if symmetrize: + A[i, j] = max(A[i, j], weight) + A[j, i] = max(A[j, i], weight) + else: + A[i, j] = weight + out[head_idx] = A + return out + + +def plot_head_heatmaps( + head_mats: Dict[int, np.ndarray], + residue_sequence: Optional[str], + layer_idx: int, + protein: str, + output_dir: str, + cols: int = 4, + cmap: str = "viridis", + mask_zeros: bool = True, + save_combined: bool = True, + save_individual: bool = False, + figsize_per_subplot: float = 3.0, + show_plot: bool = False, +) -> List[str]: + """ + Plot a grid of heatmaps, one per head, for comparing all heads in a layer. + + Args: + head_mats: From build_head_matrices(). + residue_sequence: Optional sequence string for axis labels (e.g. from parse_fasta_sequence). + layer_idx: Layer index (for titles and filenames). + protein: Protein name (for filenames). + output_dir: Directory to save PNGs. + cols: Number of columns in the grid. + cmap: Matplotlib colormap name. + mask_zeros: If True, plot zeros as transparent / masked so pattern is clearer. + save_combined: Save one combined multi-head panel. + save_individual: If True, also save one PNG per head. + figsize_per_subplot: Size per subplot in inches. + show_plot: If True, call plt.show(). + + Returns: + List of saved file paths. + """ + os.makedirs(output_dir, exist_ok=True) + saved_paths = [] + + head_indices = sorted(head_mats.keys()) + if not head_indices: + print("[Warning] No head matrices to plot.") + return saved_paths + + n_heads = len(head_indices) + n_res = next(iter(head_mats.values())).shape[0] + + # Optional: show residue labels only for small sequences + show_residue_labels = residue_sequence is not None and n_res <= 80 + if show_residue_labels and len(residue_sequence) != n_res: + show_residue_labels = False + + rows = (n_heads + cols - 1) // cols + fig, axes = plt.subplots( + rows, + cols, + figsize=(cols * figsize_per_subplot, rows * figsize_per_subplot), + squeeze=False, + ) + axes_flat = axes.flatten() + + vmin = min(np.min(A) for A in head_mats.values() if A.size > 0) + vmax = max(np.max(A) for A in head_mats.values() if A.size > 0) + if vmax <= vmin: + vmax = vmin + 1e-6 + + for idx, head_idx in enumerate(head_indices): + ax = axes_flat[idx] + A = head_mats[head_idx] + plot_mat = np.ma.masked_where(A == 0, A) if mask_zeros else A + im = ax.imshow( + plot_mat, + aspect="auto", + cmap=cmap, + vmin=vmin, + vmax=vmax, + interpolation="nearest", + ) + ax.set_title(f"Head {head_idx}", fontsize=10) + if show_residue_labels: + ax.set_xticks(np.arange(n_res)) + ax.set_xticklabels(list(residue_sequence), fontsize=5, rotation=90) + ax.set_yticks(np.arange(n_res)) + ax.set_yticklabels(list(residue_sequence), fontsize=5) + else: + ax.set_xlabel("Residue j") + ax.set_ylabel("Residue i") + ax.tick_params(axis="both", labelsize=6) + + for idx in range(n_heads, len(axes_flat)): + axes_flat[idx].set_visible(False) + + fig.suptitle( + f"{protein} — Layer {layer_idx} — All heads (residue–residue attention)", + fontsize=12, + weight="bold", + y=1.02, + ) + plt.tight_layout() + + combined_path = os.path.join( + output_dir, f"head_heatmaps_layer_{layer_idx}_{protein}.png" + ) + if save_combined: + plt.savefig(combined_path, dpi=150, bbox_inches="tight") + saved_paths.append(combined_path) + print(f"[Saved] Combined heatmaps: {combined_path}") + + if show_plot: + plt.show() + else: + plt.close() + + if save_individual: + for head_idx in head_indices: + fig1, ax1 = plt.subplots(figsize=(6, 5)) + A = head_mats[head_idx] + plot_mat = np.ma.masked_where(A == 0, A) if mask_zeros else A + ax1.imshow( + plot_mat, + aspect="auto", + cmap=cmap, + vmin=vmin, + vmax=vmax, + interpolation="nearest", + ) + ax1.set_title(f"{protein} — Layer {layer_idx} — Head {head_idx}") + ax1.set_xlabel("Residue j") + ax1.set_ylabel("Residue i") + path = os.path.join( + output_dir, f"head_heatmap_layer_{layer_idx}_head_{head_idx}_{protein}.png" + ) + plt.savefig(path, dpi=150, bbox_inches="tight") + saved_paths.append(path) + plt.close() + print(f"[Saved] {n_heads} individual heatmap(s) to {output_dir}") + + return saved_paths diff --git a/visualize_attention_networks.py b/visualize_attention_networks.py new file mode 100644 index 00000000..73c5820b --- /dev/null +++ b/visualize_attention_networks.py @@ -0,0 +1,298 @@ +""" +Network-style attention visualization. + +Aggregates per-head attention edges into a single weighted graph and plots +2D residue networks (circular or linear layout) to highlight hub residues +and compare attention patterns across heads. +""" + +import os +from typing import Dict, List, Optional, Tuple + +import numpy as np +import matplotlib.pyplot as plt + + +def build_aggregated_graph( + heads: Dict[int, List[tuple]], + aggregation: str = "mean", + normalize_by_heads: bool = True, +) -> List[Tuple[int, int, float]]: + """ + Aggregate all heads' edges into a single weighted graph. + + Args: + heads: Dict mapping head_idx -> list of (res_i, res_j, weight). + aggregation: "mean" or "sum" over heads for each (i, j). + normalize_by_heads: If True and aggregation is "sum", divide by number + of heads that have that edge (so we get mean). + + Returns: + List of (res_i, res_j, aggregated_weight), sorted by weight descending. + """ + from collections import defaultdict + + edge_to_weights = defaultdict(list) + for head_idx, conns in heads.items(): + for res_i, res_j, w in conns: + key = (int(res_i), int(res_j)) + edge_to_weights[key].append(w) + + aggregated = [] + for (res_i, res_j), weights in edge_to_weights.items(): + if aggregation == "mean": + agg_w = np.mean(weights) + elif aggregation == "sum": + agg_w = np.sum(weights) + if normalize_by_heads: + agg_w /= len(weights) + else: + agg_w = np.mean(weights) + aggregated.append((res_i, res_j, float(agg_w))) + + aggregated.sort(key=lambda x: x[2], reverse=True) + return aggregated + + +def _layout_circular(n_residues: int) -> np.ndarray: + """Place n_residues nodes on a circle (radius 1).""" + angles = np.linspace(0, 2 * np.pi, n_residues, endpoint=False) + return np.column_stack([np.cos(angles), np.sin(angles)]) + + +def _layout_linear(n_residues: int) -> np.ndarray: + """Place n_residues nodes in a line (x = 0..1, y = 0).""" + x = np.linspace(0, 1, n_residues) + return np.column_stack([x, np.zeros(n_residues)]) + + +def plot_residue_network( + edges: List[Tuple[int, int, float]], + n_residues: int, + residue_sequence: Optional[str], + layer_idx: int, + protein: str, + output_dir: str, + layout: str = "circular", + threshold: Optional[float] = None, + max_edges: Optional[int] = 200, + top_k_hubs: int = 10, + figsize: Tuple[float, float] = (10, 10), + show_plot: bool = False, +) -> str: + """ + Draw a 2D network: residues as nodes, attention as weighted edges. + Optionally highlight top-k hub residues by total incident weight. + + Args: + edges: List of (res_i, res_j, weight) from build_aggregated_graph(). + n_residues: Number of residues (nodes). + residue_sequence: Optional sequence for node labels (if short). + layer_idx: Layer index for title. + protein: Protein name for title and filename. + output_dir: Where to save the PNG. + layout: "circular" or "linear". + threshold: If set, only draw edges with weight >= threshold. + max_edges: Cap number of edges drawn (take top by weight). + top_k_hubs: Number of hub nodes to highlight (by total incident weight). + figsize: Figure size in inches. + show_plot: If True, call plt.show(). + + Returns: + Path to saved PNG. + """ + os.makedirs(output_dir, exist_ok=True) + + if layout == "circular": + pos = _layout_circular(n_residues) + else: + pos = _layout_linear(n_residues) + + # Filter and limit edges + if threshold is not None: + edges = [(i, j, w) for i, j, w in edges if w >= threshold] + if max_edges is not None: + edges = edges[:max_edges] + + if not edges: + print("[Warning] No edges to draw for network.") + fig, ax = plt.subplots(figsize=figsize) + ax.set_title(f"{protein} — Layer {layer_idx} — Aggregated attention (no edges)") + path = os.path.join(output_dir, f"network_layer_{layer_idx}_{protein}.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + return path + + weights = [w for _, _, w in edges] + w_min, w_max = min(weights), max(weights) + if w_max <= w_min: + w_max = w_min + 1e-6 + + # Hub strength: sum of incident edge weights + hub_strength = np.zeros(n_residues) + for res_i, res_j, w in edges: + hub_strength[res_i] += w + hub_strength[res_j] += w + hub_rank = np.argsort(hub_strength)[::-1] + hub_set = set(hub_rank[:top_k_hubs]) + + fig, ax = plt.subplots(figsize=figsize) + + # Draw edges + for res_i, res_j, weight in edges: + xi, yi = pos[res_i] + xj, yj = pos[res_j] + norm_w = (weight - w_min) / (w_max - w_min) + lw = 0.3 + 2.0 * norm_w + alpha = 0.4 + 0.5 * norm_w + ax.plot([xi, xj], [yi, yj], color="steelblue", linewidth=lw, alpha=alpha, zorder=1) + + # Draw nodes + node_size = 20 + ax.scatter( + pos[:, 0], + pos[:, 1], + s=node_size, + c="lightgray", + edgecolors="gray", + linewidths=0.5, + zorder=2, + ) + if hub_set: + hub_pos = pos[list(hub_set)] + ax.scatter( + hub_pos[:, 0], + hub_pos[:, 1], + s=80, + c="coral", + edgecolors="darkred", + linewidths=1.5, + zorder=3, + label=f"Top-{top_k_hubs} hubs", + ) + ax.legend(loc="upper right", fontsize=8) + + show_labels = residue_sequence is not None and n_residues <= 60 + if show_labels and len(residue_sequence) == n_residues: + for i in range(n_residues): + ax.annotate( + residue_sequence[i], + (pos[i, 0], pos[i, 1]), + fontsize=4, + ha="center", + va="center", + ) + else: + ax.set_xlabel("Layout x") + ax.set_ylabel("Layout y") + + ax.set_title( + f"{protein} — Layer {layer_idx} — Aggregated attention network ({len(edges)} edges)" + ) + ax.set_aspect("equal") + ax.axis("off") + plt.tight_layout() + + path = os.path.join(output_dir, f"network_layer_{layer_idx}_{protein}.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + if show_plot: + plt.show() + else: + plt.close() + print(f"[Saved] Network: {path}") + return path + + +def plot_residue_network_per_head( + heads: Dict[int, List[tuple]], + n_residues: int, + residue_sequence: Optional[str], + layer_idx: int, + protein: str, + output_dir: str, + layout: str = "circular", + max_edges_per_head: int = 50, + cols: int = 4, + figsize_per_subplot: float = 3.0, + show_plot: bool = False, +) -> List[str]: + """ + Draw one small network per head in a grid (small multiples). + + Args: + heads: From load_all_heads(). + n_residues, residue_sequence, layer_idx, protein, output_dir: As in plot_residue_network. + layout: "circular" or "linear". + max_edges_per_head: Max edges to draw per head. + cols: Grid columns. + figsize_per_subplot: Inches per subplot. + show_plot: If True, plt.show(). + + Returns: + List of saved file paths. + """ + os.makedirs(output_dir, exist_ok=True) + head_indices = sorted(heads.keys()) + if not head_indices: + return [] + + if layout == "circular": + pos = _layout_circular(n_residues) + else: + pos = _layout_linear(n_residues) + + n_heads = len(head_indices) + rows = (n_heads + cols - 1) // cols + fig, axes = plt.subplots( + rows, + cols, + figsize=(cols * figsize_per_subplot, rows * figsize_per_subplot), + squeeze=False, + ) + axes_flat = axes.flatten() + + for idx, head_idx in enumerate(head_indices): + ax = axes_flat[idx] + conns = heads[head_idx][:max_edges_per_head] + if not conns: + ax.set_title(f"Head {head_idx}") + ax.axis("off") + continue + w_min = min(w for _, _, w in conns) + w_max = max(w for _, _, w in conns) + if w_max <= w_min: + w_max = w_min + 1e-6 + for res_i, res_j, weight in conns: + xi, yi = pos[res_i] + xj, yj = pos[res_j] + norm_w = (weight - w_min) / (w_max - w_min) + lw = 0.2 + 1.2 * norm_w + alpha = 0.3 + 0.5 * norm_w + ax.plot([xi, xj], [yi, yj], color="steelblue", linewidth=lw, alpha=alpha) + ax.scatter( + pos[:, 0], pos[:, 1], s=8, c="lightgray", edgecolors="gray", linewidths=0.3 + ) + ax.set_title(f"Head {head_idx}", fontsize=9) + ax.set_aspect("equal") + ax.axis("off") + + for idx in range(n_heads, len(axes_flat)): + axes_flat[idx].set_visible(False) + + fig.suptitle( + f"{protein} — Layer {layer_idx} — Per-head networks", + fontsize=12, + weight="bold", + y=1.02, + ) + plt.tight_layout() + path = os.path.join( + output_dir, f"network_per_head_layer_{layer_idx}_{protein}.png" + ) + plt.savefig(path, dpi=150, bbox_inches="tight") + if show_plot: + plt.show() + else: + plt.close() + print(f"[Saved] Per-head networks: {path}") + return [path] diff --git a/viz_attention_demo.ipynb b/viz_attention_demo.ipynb index ef0cc701..ea99ddf7 100644 --- a/viz_attention_demo.ipynb +++ b/viz_attention_demo.ipynb @@ -1,257 +1,351 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "04e484dd", - "metadata": {}, - "outputs": [], - "source": [ - "# Install Airavata SDK with Jupyter support\n", - "%pip install -q \"airavata-python-sdk[notebook]\"\n", - "\n", - "# Import Airavata Jupyter magics for managing remote HPC jobs\n", - "import airavata_jupyter_magic\n", - "\n", - "# Authenticate the session (prompts for credentials or uses tokens)\n", - "%authenticate\n", - "\n", - "# Request an HPC runtime with GPU support using a YAML job config\n", - "%request_runtime hpc_gpu --file=cybershuttle.yml --walltime=30 --use=AnvilGPU:gpu-debug\n", - "\n", - "# Wait for the runtime to be ready\n", - "%wait_for_runtime hpc_gpu --live\n", - "\n", - "# Switch to the requested HPC runtime\n", - "%switch_runtime hpc_gpu" - ] + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Install Airavata SDK with Jupyter support\n", + "%pip install -q \"airavata-python-sdk[notebook]\"\n", + "\n", + "# Import Airavata Jupyter magics for managing remote HPC jobs\n", + "import airavata_jupyter_magic\n", + "\n", + "# Authenticate the session (prompts for credentials or uses tokens)\n", + "%authenticate\n", + "\n", + "# Request an HPC runtime with GPU support using a YAML job config\n", + "%request_runtime hpc_gpu --file=cybershuttle.yml --walltime=30 --use=AnvilGPU:gpu-debug\n", + "\n", + "# Wait for the runtime to be ready\n", + "%wait_for_runtime hpc_gpu --live\n", + "\n", + "# Switch to the requested HPC runtime\n", + "%switch_runtime hpc_gpu" + ], + "execution_count": null, + "outputs": [], + "id": "04e484dd" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "%%bash\n", + "\n", + "# clone the attention visualization demo repo into the workspace\n", + "git clone \\\n", + " https://github.com/vizfold/attention-viz-demo \\\n", + " attention-viz-demo\n", + "\n", + "# cd to attention-viz-demo directory\n", + "cd attention-viz-demo\n", + "\n", + "# Download required files for OpenFold from external sources\n", + "wget -N --no-check-certificate \\\n", + " -P openfold/resources \\\n", + " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", + "\n", + "# Download pretrained OpenFold model weights\n", + "./scripts/download_openfold_params.sh openfold/resources/params\n", + "\n", + "# Create output directory to save results\n", + "mkdir -p outputs\n", + "\n", + "# Set up openfold tools\n", + "pip install -e ." + ], + "execution_count": null, + "outputs": [], + "id": "26bb31dc" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Required imports\n", + "import os\n", + "import subprocess\n", + "\n", + "# Define target protein and the residue to center triangle attention on\n", + "PROT = \"6KWC\"\n", + "TRI_RESIDUE_IDX = 18\n", + "\n", + "# Define all relevant directories\n", + "BASE_DATA_DIR = \"/depot/itap/datasets/alphafold/db\" # path to AlphaFold2 data\n", + "ATTN_MAP_DIR = f\"./outputs/attention_files_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory for saving text files with top-k attention scores\n", + "ALIGNMENT_DIR = \"./examples/monomer/alignments\" # directory containing pre-computed alignment files (and MSAs)\n", + "OUTPUT_DIR = f\"./outputs/my_outputs_align_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory to save outputs\n", + "IMAGE_OUTPUT_DIR = f\"./outputs/attention_images_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\"\n", + "FASTA_DIR = f\"/u/thayes/vizfold/examples/monomer/fasta_dir_{PROT}\"\n", + "\n", + "# Note: If this is a new protein, the ALIGNMENT_DIR does not need to be specified here or in the next cell\n", + "# In this case, the code will compute MSAs and alignments, which can take several hours\n" + ], + "execution_count": null, + "outputs": [], + "id": "1400375b" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Run OpenFold inference and save top attention scores to text files \n", + "inference_cmd = f\"\"\"\n", + "python3 run_pretrained_openfold.py \\\n", + " {FASTA_DIR} \\\n", + " {BASE_DATA_DIR}/pdb_mmcif/mmcif_files \\\n", + " --use_precomputed_alignments {ALIGNMENT_DIR} \\\n", + " --output_dir {OUTPUT_DIR} \\\n", + " --config_preset model_1_ptm \\\n", + " --uniref90_database_path {BASE_DATA_DIR}/uniref90/uniref90.fasta \\\n", + " --mgnify_database_path {BASE_DATA_DIR}/mgnify/mgy_clusters_2022_05.fa \\\n", + " --pdb70_database_path {BASE_DATA_DIR}/pdb70/pdb70 \\\n", + " --uniclust30_database_path {BASE_DATA_DIR}/uniclust30/uniclust30_2018_08 \\\n", + " --bfd_database_path {BASE_DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \\\n", + " --save_outputs \\\n", + " --model_device \"cuda:0\" \\\n", + " --attn_map_dir {ATTN_MAP_DIR} \\\n", + " --num_recycles_save 1 \\\n", + " --triangle_residue_idx {TRI_RESIDUE_IDX} \\\n", + " --demo_attn\n", + "\"\"\"\n", + "\n", + "subprocess.run(inference_cmd, shell=True, check=True)\n" + ], + "execution_count": null, + "outputs": [], + "id": "3fd757f9" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Render predicted 3D structure and save as PNG image\n", + "from visualize_attention_general_utils import render_pdb_to_image\n", + "\n", + "PDB_FILE = os.path.join(OUTPUT_DIR, f\"predictions/{PROT}_1_model_1_ptm_relaxed.pdb\")\n", + "FNAME = f\"predicted_structure_{PROT}_tri_{TRI_RESIDUE_IDX}.png\"\n", + "\n", + "render_pdb_to_image(PDB_FILE, IMAGE_OUTPUT_DIR, FNAME)\n" + ], + "execution_count": null, + "outputs": [], + "id": "66a306cc" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Import visualization utilities\n", + "from visualize_attention_3d_demo_utils import plot_pymol_attention_heads\n", + "from visualize_attention_arc_diagram_demo_utils import generate_arc_diagrams, parse_fasta_sequence\n", + "\n", + "# Setup visualization output directories\n", + "output_dir_msa = os.path.join(IMAGE_OUTPUT_DIR, 'msa_row_attention_plots') # directory for saving msa attention 3D visuals\n", + "output_dir_tri = os.path.join(IMAGE_OUTPUT_DIR, 'tri_start_attention_plots') # directory for saving triangle attention 3D visuals\n", + "FASTA_PATH = f\"/u/thayes/vizfold/examples/monomer/fasta_dir_{PROT}/{PROT}.fasta\"\n", + "LAYER_IDX = 47 # selected layer for attention evaluation\n", + "TOP_K = 50 # show top-k attention links (limit to 500)\n", + "\n", + "# Generate 3D attention plots for MSA row attention\n", + "plot_pymol_attention_heads(\n", + " pdb_file=PDB_FILE,\n", + " attention_dir=ATTN_MAP_DIR,\n", + " output_dir=output_dir_msa,\n", + " protein=PROT,\n", + " attention_type=\"msa_row\",\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n", + "\n", + "# Generate 3D attention plots for triangle start attention\n", + "plot_pymol_attention_heads(\n", + " pdb_file=PDB_FILE,\n", + " attention_dir=ATTN_MAP_DIR,\n", + " output_dir=output_dir_tri,\n", + " protein=PROT,\n", + " attention_type=\"triangle_start\",\n", + " residue_indices=[TRI_RESIDUE_IDX],\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n", + "\n", + "# Parse FASTA for arc diagrams\n", + "residue_seq = parse_fasta_sequence(FASTA_PATH)\n", + "\n", + "# Generate arc diagrams for MSA row attention\n", + "generate_arc_diagrams(\n", + " attention_dir=ATTN_MAP_DIR,\n", + " residue_sequence=residue_seq,\n", + " output_dir=output_dir_msa,\n", + " protein=PROT,\n", + " attention_type=\"msa_row\",\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n", + "\n", + "# Generate arc diagrams for triangle start attention\n", + "generate_arc_diagrams(\n", + " attention_dir=ATTN_MAP_DIR,\n", + " residue_sequence=residue_seq,\n", + " output_dir=output_dir_tri,\n", + " protein=PROT,\n", + " attention_type=\"triangle_start\",\n", + " residue_indices=[TRI_RESIDUE_IDX],\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n" + ], + "execution_count": null, + "outputs": [], + "id": "941bd83f" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Import function for combining attention plots\n", + "from visualize_attention_general_utils import generate_combined_attention_panels\n", + "\n", + "# Combine MSA row plots\n", + "generate_combined_attention_panels(\n", + " attention_type=\"msa_row\",\n", + " protein=PROT,\n", + " layer_idx=LAYER_IDX,\n", + " output_dir_3d=output_dir_msa,\n", + " output_dir_arc=output_dir_msa,\n", + " combined_output_dir=IMAGE_OUTPUT_DIR,\n", + ")\n", + "\n", + "# Combine triangle start plots\n", + "generate_combined_attention_panels(\n", + " attention_type=\"triangle_start\",\n", + " protein=PROT,\n", + " layer_idx=LAYER_IDX,\n", + " output_dir_3d=output_dir_tri,\n", + " output_dir_arc=output_dir_tri,\n", + " combined_output_dir=IMAGE_OUTPUT_DIR,\n", + " residue_indices=[TRI_RESIDUE_IDX]\n", + ")\n" + ], + "execution_count": null, + "outputs": [], + "id": "756ecf74" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Head-level heatmap and network-style visualizations (compare all heads in one layer)\n", + "from visualize_attention_arc_diagram_demo_utils import load_all_heads\n", + "from visualize_attention_head_heatmaps import build_head_matrices, plot_head_heatmaps\n", + "from visualize_attention_networks import (\n", + " build_aggregated_graph,\n", + " plot_residue_network,\n", + " plot_residue_network_per_head,\n", + ")\n", + "\n", + "output_dir_heatmaps = os.path.join(IMAGE_OUTPUT_DIR, \"head_heatmaps\")\n", + "output_dir_networks = os.path.join(IMAGE_OUTPUT_DIR, \"network_plots\")\n", + "n_residues = len(residue_seq)\n", + "\n", + "# --- MSA Row: heatmaps and aggregated network ---\n", + "msa_attn_file = os.path.join(ATTN_MAP_DIR, f\"msa_row_attn_layer{LAYER_IDX}.txt\")\n", + "if os.path.exists(msa_attn_file):\n", + " msa_heads = load_all_heads(msa_attn_file, top_k=TOP_K)\n", + " head_mats = build_head_matrices(msa_heads, n_residues=n_residues)\n", + " plot_head_heatmaps(\n", + " head_mats,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=PROT,\n", + " output_dir=output_dir_heatmaps,\n", + " cols=4,\n", + " save_combined=True,\n", + " save_individual=False,\n", + " )\n", + " agg_edges = build_aggregated_graph(msa_heads, aggregation=\"mean\", normalize_by_heads=True)\n", + " plot_residue_network(\n", + " agg_edges,\n", + " n_residues=n_residues,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=PROT,\n", + " output_dir=output_dir_networks,\n", + " layout=\"circular\",\n", + " max_edges=200,\n", + " top_k_hubs=10,\n", + " )\n", + " plot_residue_network_per_head(\n", + " msa_heads,\n", + " n_residues=n_residues,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=PROT,\n", + " output_dir=output_dir_networks,\n", + " layout=\"circular\",\n", + " max_edges_per_head=50,\n", + " cols=4,\n", + " )\n", + "else:\n", + " print(f\"[Skip] MSA attention file not found: {msa_attn_file}\")\n", + "\n", + "# --- Triangle Start (optional): heatmaps and network for one residue ---\n", + "tri_attn_file = os.path.join(\n", + " ATTN_MAP_DIR, f\"triangle_start_attn_layer{LAYER_IDX}_residue_idx_{TRI_RESIDUE_IDX}.txt\"\n", + ")\n", + "if os.path.exists(tri_attn_file):\n", + " tri_heads = load_all_heads(tri_attn_file, top_k=TOP_K)\n", + " tri_head_mats = build_head_matrices(tri_heads, n_residues=n_residues)\n", + " plot_head_heatmaps(\n", + " tri_head_mats,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=f\"{PROT}_tri_res{TRI_RESIDUE_IDX}\",\n", + " output_dir=output_dir_heatmaps,\n", + " cols=4,\n", + " save_combined=True,\n", + " save_individual=False,\n", + " )\n", + " tri_agg = build_aggregated_graph(tri_heads, aggregation=\"mean\", normalize_by_heads=True)\n", + " plot_residue_network(\n", + " tri_agg,\n", + " n_residues=n_residues,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=f\"{PROT}_tri_res{TRI_RESIDUE_IDX}\",\n", + " output_dir=output_dir_networks,\n", + " layout=\"circular\",\n", + " max_edges=200,\n", + " top_k_hubs=10,\n", + " )\n", + "else:\n", + " print(f\"[Skip] Triangle attention file not found: {tri_attn_file}\")" + ], + "id": "1cc0df28", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } }, - { - "cell_type": "code", - "execution_count": null, - "id": "26bb31dc", - "metadata": {}, - "outputs": [], - "source": [ - "%%bash\n", - "\n", - "# clone the attention visualization demo repo into the workspace\n", - "git clone \\\n", - " https://github.com/vizfold/attention-viz-demo \\\n", - " attention-viz-demo\n", - "\n", - "# cd to attention-viz-demo directory\n", - "cd attention-viz-demo\n", - "\n", - "# Download required files for OpenFold from external sources\n", - "wget -N --no-check-certificate \\\n", - " -P openfold/resources \\\n", - " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", - "\n", - "# Download pretrained OpenFold model weights\n", - "./scripts/download_openfold_params.sh openfold/resources/params\n", - "\n", - "# Create output directory to save results\n", - "mkdir -p outputs\n", - "\n", - "# Set up openfold tools\n", - "pip install -e ." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1400375b", - "metadata": {}, - "outputs": [], - "source": [ - "# Required imports\n", - "import os\n", - "import subprocess\n", - "\n", - "# Define target protein and the residue to center triangle attention on\n", - "PROT = \"6KWC\"\n", - "TRI_RESIDUE_IDX = 18\n", - "\n", - "# Define all relevant directories\n", - "BASE_DATA_DIR = \"/depot/itap/datasets/alphafold/db\" # path to AlphaFold2 data\n", - "ATTN_MAP_DIR = f\"./outputs/attention_files_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory for saving text files with top-k attention scores\n", - "ALIGNMENT_DIR = \"./examples/monomer/alignments\" # directory containing pre-computed alignment files (and MSAs)\n", - "OUTPUT_DIR = f\"./outputs/my_outputs_align_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory to save outputs\n", - "IMAGE_OUTPUT_DIR = f\"./outputs/attention_images_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\"\n", - "FASTA_DIR = f\"/u/thayes/vizfold/examples/monomer/fasta_dir_{PROT}\"\n", - "\n", - "# Note: If this is a new protein, the ALIGNMENT_DIR does not need to be specified here or in the next cell\n", - "# In this case, the code will compute MSAs and alignments, which can take several hours\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3fd757f9", - "metadata": {}, - "outputs": [], - "source": [ - "# Run OpenFold inference and save top attention scores to text files \n", - "inference_cmd = f\"\"\"\n", - "python3 run_pretrained_openfold.py \\\n", - " {FASTA_DIR} \\\n", - " {BASE_DATA_DIR}/pdb_mmcif/mmcif_files \\\n", - " --use_precomputed_alignments {ALIGNMENT_DIR} \\\n", - " --output_dir {OUTPUT_DIR} \\\n", - " --config_preset model_1_ptm \\\n", - " --uniref90_database_path {BASE_DATA_DIR}/uniref90/uniref90.fasta \\\n", - " --mgnify_database_path {BASE_DATA_DIR}/mgnify/mgy_clusters_2022_05.fa \\\n", - " --pdb70_database_path {BASE_DATA_DIR}/pdb70/pdb70 \\\n", - " --uniclust30_database_path {BASE_DATA_DIR}/uniclust30/uniclust30_2018_08 \\\n", - " --bfd_database_path {BASE_DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \\\n", - " --save_outputs \\\n", - " --model_device \"cuda:0\" \\\n", - " --attn_map_dir {ATTN_MAP_DIR} \\\n", - " --num_recycles_save 1 \\\n", - " --triangle_residue_idx {TRI_RESIDUE_IDX} \\\n", - " --demo_attn\n", - "\"\"\"\n", - "\n", - "subprocess.run(inference_cmd, shell=True, check=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "66a306cc", - "metadata": {}, - "outputs": [], - "source": [ - "# Render predicted 3D structure and save as PNG image\n", - "from visualize_attention_general_utils import render_pdb_to_image\n", - "\n", - "PDB_FILE = os.path.join(OUTPUT_DIR, f\"predictions/{PROT}_1_model_1_ptm_relaxed.pdb\")\n", - "FNAME = f\"predicted_structure_{PROT}_tri_{TRI_RESIDUE_IDX}.png\"\n", - "\n", - "render_pdb_to_image(PDB_FILE, IMAGE_OUTPUT_DIR, FNAME)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "941bd83f", - "metadata": {}, - "outputs": [], - "source": [ - "# Import visualization utilities\n", - "from visualize_attention_3d_demo_utils import plot_pymol_attention_heads\n", - "from visualize_attention_arc_diagram_demo_utils import generate_arc_diagrams, parse_fasta_sequence\n", - "\n", - "# Setup visualization output directories\n", - "output_dir_msa = os.path.join(IMAGE_OUTPUT_DIR, 'msa_row_attention_plots') # directory for saving msa attention 3D visuals\n", - "output_dir_tri = os.path.join(IMAGE_OUTPUT_DIR, 'tri_start_attention_plots') # directory for saving triangle attention 3D visuals\n", - "FASTA_PATH = f\"/u/thayes/vizfold/examples/monomer/fasta_dir_{PROT}/{PROT}.fasta\"\n", - "LAYER_IDX = 47 # selected layer for attention evaluation\n", - "TOP_K = 50 # show top-k attention links (limit to 500)\n", - "\n", - "# Generate 3D attention plots for MSA row attention\n", - "plot_pymol_attention_heads(\n", - " pdb_file=PDB_FILE,\n", - " attention_dir=ATTN_MAP_DIR,\n", - " output_dir=output_dir_msa,\n", - " protein=PROT,\n", - " attention_type=\"msa_row\",\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n", - "\n", - "# Generate 3D attention plots for triangle start attention\n", - "plot_pymol_attention_heads(\n", - " pdb_file=PDB_FILE,\n", - " attention_dir=ATTN_MAP_DIR,\n", - " output_dir=output_dir_tri,\n", - " protein=PROT,\n", - " attention_type=\"triangle_start\",\n", - " residue_indices=[TRI_RESIDUE_IDX],\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n", - "\n", - "# Parse FASTA for arc diagrams\n", - "residue_seq = parse_fasta_sequence(FASTA_PATH)\n", - "\n", - "# Generate arc diagrams for MSA row attention\n", - "generate_arc_diagrams(\n", - " attention_dir=ATTN_MAP_DIR,\n", - " residue_sequence=residue_seq,\n", - " output_dir=output_dir_msa,\n", - " protein=PROT,\n", - " attention_type=\"msa_row\",\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n", - "\n", - "# Generate arc diagrams for triangle start attention\n", - "generate_arc_diagrams(\n", - " attention_dir=ATTN_MAP_DIR,\n", - " residue_sequence=residue_seq,\n", - " output_dir=output_dir_tri,\n", - " protein=PROT,\n", - " attention_type=\"triangle_start\",\n", - " residue_indices=[TRI_RESIDUE_IDX],\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "756ecf74", - "metadata": {}, - "outputs": [], - "source": [ - "# Import function for combining attention plots\n", - "from visualize_attention_general_utils import generate_combined_attention_panels\n", - "\n", - "# Combine MSA row plots\n", - "generate_combined_attention_panels(\n", - " attention_type=\"msa_row\",\n", - " protein=PROT,\n", - " layer_idx=LAYER_IDX,\n", - " output_dir_3d=output_dir_msa,\n", - " output_dir_arc=output_dir_msa,\n", - " combined_output_dir=IMAGE_OUTPUT_DIR,\n", - ")\n", - "\n", - "# Combine triangle start plots\n", - "generate_combined_attention_panels(\n", - " attention_type=\"triangle_start\",\n", - " protein=PROT,\n", - " layer_idx=LAYER_IDX,\n", - " output_dir_3d=output_dir_tri,\n", - " output_dir_arc=output_dir_tri,\n", - " combined_output_dir=IMAGE_OUTPUT_DIR,\n", - " residue_indices=[TRI_RESIDUE_IDX]\n", - ")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.16" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/viz_attention_demo_base.ipynb b/viz_attention_demo_base.ipynb index d10f0b1b..5d8a1d42 100644 --- a/viz_attention_demo_base.ipynb +++ b/viz_attention_demo_base.ipynb @@ -1,201 +1,295 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "1400375b", - "metadata": {}, - "outputs": [], - "source": [ - "# Required imports\n", - "import os\n", - "import subprocess\n", - "\n", - "# Define target protein and the residue to center triangle attention on\n", - "PROT = \"6KWC\"\n", - "TRI_RESIDUE_IDX = 18\n", - "\n", - "# Define all relevant directories\n", - "BASE_DATA_DIR = \"/ime/hdd/rhaas/SUP-5301/database\" # path to AlphaFold database\n", - "\n", - "# Local paths for saving results (these probably can remain unchanged)\n", - "ATTN_MAP_DIR = f\"./outputs/attention_files_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory for saving text files with top-k attention scores\n", - "ALIGNMENT_DIR = \"./examples/monomer/alignments\" # directory containing pre-computed alignment files (and MSAs)\n", - "OUTPUT_DIR = f\"./outputs/my_outputs_align_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory to save outputs\n", - "IMAGE_OUTPUT_DIR = f\"./outputs/attention_images_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\"\n", - "FASTA_DIR = f\"./examples/monomer/fasta_dir_{PROT}\"\n", - "\n", - "# Note: If this is a new protein, the ALIGNMENT_DIR does not need to be specified here or in the next cell\n", - "# In this case, the code will compute MSAs and alignments, which can take several hours\n" - ] + "cells": [ + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Required imports\n", + "import os\n", + "import subprocess\n", + "\n", + "# Define target protein and the residue to center triangle attention on\n", + "PROT = \"6KWC\"\n", + "TRI_RESIDUE_IDX = 18\n", + "\n", + "# Define all relevant directories\n", + "BASE_DATA_DIR = \"/ime/hdd/rhaas/SUP-5301/database\" # path to AlphaFold database\n", + "\n", + "# Local paths for saving results (these probably can remain unchanged)\n", + "ATTN_MAP_DIR = f\"./outputs/attention_files_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory for saving text files with top-k attention scores\n", + "ALIGNMENT_DIR = \"./examples/monomer/alignments\" # directory containing pre-computed alignment files (and MSAs)\n", + "OUTPUT_DIR = f\"./outputs/my_outputs_align_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\" # directory to save outputs\n", + "IMAGE_OUTPUT_DIR = f\"./outputs/attention_images_{PROT}_demo_tri_{TRI_RESIDUE_IDX}\"\n", + "FASTA_DIR = f\"./examples/monomer/fasta_dir_{PROT}\"\n", + "\n", + "# Note: If this is a new protein, the ALIGNMENT_DIR does not need to be specified here or in the next cell\n", + "# In this case, the code will compute MSAs and alignments, which can take several hours\n" + ], + "execution_count": null, + "outputs": [], + "id": "1400375b" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Run OpenFold inference and save top attention scores to text files \n", + "inference_cmd = f\"\"\"\n", + "python3 run_pretrained_openfold.py \\\n", + " {FASTA_DIR} \\\n", + " {BASE_DATA_DIR}/pdb_mmcif/mmcif_files \\\n", + " --use_precomputed_alignments {ALIGNMENT_DIR} \\\n", + " --output_dir {OUTPUT_DIR} \\\n", + " --config_preset model_1_ptm \\\n", + " --uniref90_database_path {BASE_DATA_DIR}/uniref90/uniref90.fasta \\\n", + " --mgnify_database_path {BASE_DATA_DIR}/mgnify/mgy_clusters_2022_05.fa \\\n", + " --pdb70_database_path {BASE_DATA_DIR}/pdb70/pdb70 \\\n", + " --uniclust30_database_path {BASE_DATA_DIR}/uniclust30/uniclust30_2018_08 \\\n", + " --bfd_database_path {BASE_DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \\\n", + " --save_outputs \\\n", + " --model_device \"cuda:0\" \\\n", + " --attn_map_dir {ATTN_MAP_DIR} \\\n", + " --num_recycles_save 1 \\\n", + " --triangle_residue_idx {TRI_RESIDUE_IDX} \\\n", + " --demo_attn\n", + "\"\"\"\n", + "\n", + "subprocess.run(inference_cmd, shell=True, check=True)\n" + ], + "execution_count": null, + "outputs": [], + "id": "3fd757f9" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Render predicted 3D structure and save as PNG image\n", + "from visualize_attention_general_utils import render_pdb_to_image\n", + "\n", + "PDB_FILE = os.path.join(OUTPUT_DIR, f\"predictions/{PROT}_1_model_1_ptm_relaxed.pdb\")\n", + "FNAME = f\"predicted_structure_{PROT}_tri_{TRI_RESIDUE_IDX}.png\"\n", + "\n", + "render_pdb_to_image(PDB_FILE, IMAGE_OUTPUT_DIR, FNAME)\n" + ], + "execution_count": null, + "outputs": [], + "id": "66a306cc" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Import visualization utilities\n", + "from visualize_attention_3d_demo_utils import plot_pymol_attention_heads\n", + "from visualize_attention_arc_diagram_demo_utils import generate_arc_diagrams, parse_fasta_sequence\n", + "\n", + "# Setup visualization output directories\n", + "output_dir_msa = os.path.join(IMAGE_OUTPUT_DIR, 'msa_row_attention_plots') # directory for saving msa attention 3D visuals\n", + "output_dir_tri = os.path.join(IMAGE_OUTPUT_DIR, 'tri_start_attention_plots') # directory for saving triangle attention 3D visuals\n", + "FASTA_PATH = f\"/u/thayes/vizfold/examples/monomer/fasta_dir_{PROT}/{PROT}.fasta\"\n", + "LAYER_IDX = 47 # selected layer for attention evaluation\n", + "TOP_K = 50 # show top-k attention links (limit to 500)\n", + "\n", + "# Generate 3D attention plots for MSA row attention\n", + "plot_pymol_attention_heads(\n", + " pdb_file=PDB_FILE,\n", + " attention_dir=ATTN_MAP_DIR,\n", + " output_dir=output_dir_msa,\n", + " protein=PROT,\n", + " attention_type=\"msa_row\",\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n", + "\n", + "# Generate 3D attention plots for triangle start attention\n", + "plot_pymol_attention_heads(\n", + " pdb_file=PDB_FILE,\n", + " attention_dir=ATTN_MAP_DIR,\n", + " output_dir=output_dir_tri,\n", + " protein=PROT,\n", + " attention_type=\"triangle_start\",\n", + " residue_indices=[TRI_RESIDUE_IDX],\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n", + "\n", + "# Parse FASTA for arc diagrams\n", + "residue_seq = parse_fasta_sequence(FASTA_PATH)\n", + "\n", + "# Generate arc diagrams for MSA row attention\n", + "generate_arc_diagrams(\n", + " attention_dir=ATTN_MAP_DIR,\n", + " residue_sequence=residue_seq,\n", + " output_dir=output_dir_msa,\n", + " protein=PROT,\n", + " attention_type=\"msa_row\",\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n", + "\n", + "# Generate arc diagrams for triangle start attention\n", + "generate_arc_diagrams(\n", + " attention_dir=ATTN_MAP_DIR,\n", + " residue_sequence=residue_seq,\n", + " output_dir=output_dir_tri,\n", + " protein=PROT,\n", + " attention_type=\"triangle_start\",\n", + " residue_indices=[TRI_RESIDUE_IDX],\n", + " top_k=TOP_K,\n", + " layer_idx=LAYER_IDX\n", + ")\n" + ], + "execution_count": null, + "outputs": [], + "id": "941bd83f" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Import function for combining attention plots\n", + "from visualize_attention_general_utils import generate_combined_attention_panels\n", + "\n", + "# Combine MSA row plots\n", + "generate_combined_attention_panels(\n", + " attention_type=\"msa_row\",\n", + " protein=PROT,\n", + " layer_idx=LAYER_IDX,\n", + " output_dir_3d=output_dir_msa,\n", + " output_dir_arc=output_dir_msa,\n", + " combined_output_dir=IMAGE_OUTPUT_DIR,\n", + ")\n", + "\n", + "# Combine triangle start plots\n", + "generate_combined_attention_panels(\n", + " attention_type=\"triangle_start\",\n", + " protein=PROT,\n", + " layer_idx=LAYER_IDX,\n", + " output_dir_3d=output_dir_tri,\n", + " output_dir_arc=output_dir_tri,\n", + " combined_output_dir=IMAGE_OUTPUT_DIR,\n", + " residue_indices=[TRI_RESIDUE_IDX]\n", + ")\n" + ], + "execution_count": null, + "outputs": [], + "id": "756ecf74" + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Head-level heatmap and network-style visualizations (compare all heads in one layer)\n", + "from visualize_attention_arc_diagram_demo_utils import load_all_heads\n", + "from visualize_attention_head_heatmaps import build_head_matrices, plot_head_heatmaps\n", + "from visualize_attention_networks import (\n", + " build_aggregated_graph,\n", + " plot_residue_network,\n", + " plot_residue_network_per_head,\n", + ")\n", + "\n", + "output_dir_heatmaps = os.path.join(IMAGE_OUTPUT_DIR, \"head_heatmaps\")\n", + "output_dir_networks = os.path.join(IMAGE_OUTPUT_DIR, \"network_plots\")\n", + "n_residues = len(residue_seq)\n", + "\n", + "# --- MSA Row: heatmaps and aggregated network ---\n", + "msa_attn_file = os.path.join(ATTN_MAP_DIR, f\"msa_row_attn_layer{LAYER_IDX}.txt\")\n", + "if os.path.exists(msa_attn_file):\n", + " msa_heads = load_all_heads(msa_attn_file, top_k=TOP_K)\n", + " head_mats = build_head_matrices(msa_heads, n_residues=n_residues)\n", + " plot_head_heatmaps(\n", + " head_mats,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=PROT,\n", + " output_dir=output_dir_heatmaps,\n", + " cols=4,\n", + " save_combined=True,\n", + " save_individual=False,\n", + " )\n", + " agg_edges = build_aggregated_graph(msa_heads, aggregation=\"mean\", normalize_by_heads=True)\n", + " plot_residue_network(\n", + " agg_edges,\n", + " n_residues=n_residues,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=PROT,\n", + " output_dir=output_dir_networks,\n", + " layout=\"circular\",\n", + " max_edges=200,\n", + " top_k_hubs=10,\n", + " )\n", + " plot_residue_network_per_head(\n", + " msa_heads,\n", + " n_residues=n_residues,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=PROT,\n", + " output_dir=output_dir_networks,\n", + " layout=\"circular\",\n", + " max_edges_per_head=50,\n", + " cols=4,\n", + " )\n", + "else:\n", + " print(f\"[Skip] MSA attention file not found: {msa_attn_file}\")\n", + "\n", + "# --- Triangle Start (optional): heatmaps and network for one residue ---\n", + "tri_attn_file = os.path.join(\n", + " ATTN_MAP_DIR, f\"triangle_start_attn_layer{LAYER_IDX}_residue_idx_{TRI_RESIDUE_IDX}.txt\"\n", + ")\n", + "if os.path.exists(tri_attn_file):\n", + " tri_heads = load_all_heads(tri_attn_file, top_k=TOP_K)\n", + " tri_head_mats = build_head_matrices(tri_heads, n_residues=n_residues)\n", + " plot_head_heatmaps(\n", + " tri_head_mats,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=f\"{PROT}_tri_res{TRI_RESIDUE_IDX}\",\n", + " output_dir=output_dir_heatmaps,\n", + " cols=4,\n", + " save_combined=True,\n", + " save_individual=False,\n", + " )\n", + " tri_agg = build_aggregated_graph(tri_heads, aggregation=\"mean\", normalize_by_heads=True)\n", + " plot_residue_network(\n", + " tri_agg,\n", + " n_residues=n_residues,\n", + " residue_sequence=residue_seq,\n", + " layer_idx=LAYER_IDX,\n", + " protein=f\"{PROT}_tri_res{TRI_RESIDUE_IDX}\",\n", + " output_dir=output_dir_networks,\n", + " layout=\"circular\",\n", + " max_edges=200,\n", + " top_k_hubs=10,\n", + " )\n", + "else:\n", + " print(f\"[Skip] Triangle attention file not found: {tri_attn_file}\")" + ], + "id": "917682c7", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "openfold_env3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.17" + } }, - { - "cell_type": "code", - "execution_count": null, - "id": "3fd757f9", - "metadata": {}, - "outputs": [], - "source": [ - "# Run OpenFold inference and save top attention scores to text files \n", - "inference_cmd = f\"\"\"\n", - "python3 run_pretrained_openfold.py \\\n", - " {FASTA_DIR} \\\n", - " {BASE_DATA_DIR}/pdb_mmcif/mmcif_files \\\n", - " --use_precomputed_alignments {ALIGNMENT_DIR} \\\n", - " --output_dir {OUTPUT_DIR} \\\n", - " --config_preset model_1_ptm \\\n", - " --uniref90_database_path {BASE_DATA_DIR}/uniref90/uniref90.fasta \\\n", - " --mgnify_database_path {BASE_DATA_DIR}/mgnify/mgy_clusters_2022_05.fa \\\n", - " --pdb70_database_path {BASE_DATA_DIR}/pdb70/pdb70 \\\n", - " --uniclust30_database_path {BASE_DATA_DIR}/uniclust30/uniclust30_2018_08 \\\n", - " --bfd_database_path {BASE_DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \\\n", - " --save_outputs \\\n", - " --model_device \"cuda:0\" \\\n", - " --attn_map_dir {ATTN_MAP_DIR} \\\n", - " --num_recycles_save 1 \\\n", - " --triangle_residue_idx {TRI_RESIDUE_IDX} \\\n", - " --demo_attn\n", - "\"\"\"\n", - "\n", - "subprocess.run(inference_cmd, shell=True, check=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "66a306cc", - "metadata": {}, - "outputs": [], - "source": [ - "# Render predicted 3D structure and save as PNG image\n", - "from visualize_attention_general_utils import render_pdb_to_image\n", - "\n", - "PDB_FILE = os.path.join(OUTPUT_DIR, f\"predictions/{PROT}_1_model_1_ptm_relaxed.pdb\")\n", - "FNAME = f\"predicted_structure_{PROT}_tri_{TRI_RESIDUE_IDX}.png\"\n", - "\n", - "render_pdb_to_image(PDB_FILE, IMAGE_OUTPUT_DIR, FNAME)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "941bd83f", - "metadata": {}, - "outputs": [], - "source": [ - "# Import visualization utilities\n", - "from visualize_attention_3d_demo_utils import plot_pymol_attention_heads\n", - "from visualize_attention_arc_diagram_demo_utils import generate_arc_diagrams, parse_fasta_sequence\n", - "\n", - "# Setup visualization output directories\n", - "output_dir_msa = os.path.join(IMAGE_OUTPUT_DIR, 'msa_row_attention_plots') # directory for saving msa attention 3D visuals\n", - "output_dir_tri = os.path.join(IMAGE_OUTPUT_DIR, 'tri_start_attention_plots') # directory for saving triangle attention 3D visuals\n", - "FASTA_PATH = f\"/u/thayes/vizfold/examples/monomer/fasta_dir_{PROT}/{PROT}.fasta\"\n", - "LAYER_IDX = 47 # selected layer for attention evaluation\n", - "TOP_K = 50 # show top-k attention links (limit to 500)\n", - "\n", - "# Generate 3D attention plots for MSA row attention\n", - "plot_pymol_attention_heads(\n", - " pdb_file=PDB_FILE,\n", - " attention_dir=ATTN_MAP_DIR,\n", - " output_dir=output_dir_msa,\n", - " protein=PROT,\n", - " attention_type=\"msa_row\",\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n", - "\n", - "# Generate 3D attention plots for triangle start attention\n", - "plot_pymol_attention_heads(\n", - " pdb_file=PDB_FILE,\n", - " attention_dir=ATTN_MAP_DIR,\n", - " output_dir=output_dir_tri,\n", - " protein=PROT,\n", - " attention_type=\"triangle_start\",\n", - " residue_indices=[TRI_RESIDUE_IDX],\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n", - "\n", - "# Parse FASTA for arc diagrams\n", - "residue_seq = parse_fasta_sequence(FASTA_PATH)\n", - "\n", - "# Generate arc diagrams for MSA row attention\n", - "generate_arc_diagrams(\n", - " attention_dir=ATTN_MAP_DIR,\n", - " residue_sequence=residue_seq,\n", - " output_dir=output_dir_msa,\n", - " protein=PROT,\n", - " attention_type=\"msa_row\",\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n", - "\n", - "# Generate arc diagrams for triangle start attention\n", - "generate_arc_diagrams(\n", - " attention_dir=ATTN_MAP_DIR,\n", - " residue_sequence=residue_seq,\n", - " output_dir=output_dir_tri,\n", - " protein=PROT,\n", - " attention_type=\"triangle_start\",\n", - " residue_indices=[TRI_RESIDUE_IDX],\n", - " top_k=TOP_K,\n", - " layer_idx=LAYER_IDX\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "756ecf74", - "metadata": {}, - "outputs": [], - "source": [ - "# Import function for combining attention plots\n", - "from visualize_attention_general_utils import generate_combined_attention_panels\n", - "\n", - "# Combine MSA row plots\n", - "generate_combined_attention_panels(\n", - " attention_type=\"msa_row\",\n", - " protein=PROT,\n", - " layer_idx=LAYER_IDX,\n", - " output_dir_3d=output_dir_msa,\n", - " output_dir_arc=output_dir_msa,\n", - " combined_output_dir=IMAGE_OUTPUT_DIR,\n", - ")\n", - "\n", - "# Combine triangle start plots\n", - "generate_combined_attention_panels(\n", - " attention_type=\"triangle_start\",\n", - " protein=PROT,\n", - " layer_idx=LAYER_IDX,\n", - " output_dir_3d=output_dir_tri,\n", - " output_dir_arc=output_dir_tri,\n", - " combined_output_dir=IMAGE_OUTPUT_DIR,\n", - " residue_indices=[TRI_RESIDUE_IDX]\n", - ")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "openfold_env3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.17" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From 3ca68e462cb04aedfa63f54eac9a0c1e137e4166 Mon Sep 17 00:00:00 2001 From: ssagili3 Date: Sun, 26 Apr 2026 20:11:38 -0400 Subject: [PATCH 2/3] Refactor attention visualization data loading --- .../Head_Visualization_Heatmap_Network.md | 6 +- scripts/run_head_heatmap_network_demo.py | 4 +- tests/test_visualize_attention_data.py | 50 +++++++++ visualize_attention_3d_demo_utils.py | 44 ++------ visualize_attention_arc_diagram_demo_utils.py | 53 +++------- visualize_attention_data.py | 100 ++++++++++++++++++ visualize_attention_head_heatmaps.py | 3 +- visualize_attention_networks.py | 2 +- 8 files changed, 180 insertions(+), 82 deletions(-) create mode 100644 tests/test_visualize_attention_data.py create mode 100644 visualize_attention_data.py diff --git a/docs/source/Head_Visualization_Heatmap_Network.md b/docs/source/Head_Visualization_Heatmap_Network.md index de03587b..8be4660a 100644 --- a/docs/source/Head_Visualization_Heatmap_Network.md +++ b/docs/source/Head_Visualization_Heatmap_Network.md @@ -14,7 +14,8 @@ Arc diagrams and PyMOL overlays are useful but show one head at a time. The new | Component | Purpose | |-----------|---------| -| `visualize_attention_head_heatmaps.py` | Builds dense per-head matrices from `load_all_heads()` and plots a grid of residue–residue heatmaps (one per head). | +| `visualize_attention_data.py` | Shared attention-map parser and FASTA reader used by every visualization module. | +| `visualize_attention_head_heatmaps.py` | Builds dense per-head matrices from shared parsed attention data and plots a grid of residue–residue heatmaps (one per head). | | `visualize_attention_networks.py` | Aggregates heads into one weighted graph and/or draws one small network per head; supports circular or linear layout and hub highlighting. | | Notebook cells (in `viz_attention_demo.ipynb` and `viz_attention_demo_base.ipynb`) | After running 3D and arc visualizations, a new cell runs heatmap and network code for MSA row and (optionally) triangle-start attention. | | `scripts/run_head_heatmap_network_demo.py` | Standalone script that generates **synthetic** attention data and runs the new visualizations to produce example outputs without full OpenFold inference. | @@ -61,9 +62,10 @@ The new visualizations do not replace arc or 3D views; they complement them by a ## File and Function Reference - **Build matrices**: `build_head_matrices(heads, n_residues)` in `visualize_attention_head_heatmaps.py`. +- **Load attention data**: `load_attention_map(connections_file, top_k=...)` in `visualize_attention_data.py`. - **Heatmap panel**: `plot_head_heatmaps(head_mats, residue_sequence, layer_idx, protein, output_dir, ...)`. - **Aggregated graph**: `build_aggregated_graph(heads, aggregation="mean", normalize_by_heads=True)` in `visualize_attention_networks.py`. - **Single network plot**: `plot_residue_network(edges, n_residues, residue_sequence, layer_idx, protein, output_dir, layout="circular", max_edges=200, top_k_hubs=10)`. - **Per-head networks**: `plot_residue_network_per_head(heads, n_residues, ...)`. -Input `heads` is the dict returned by `load_all_heads(connections_file, top_k=...)` from `visualize_attention_arc_diagram_demo_utils` (or the same function in `visualize_attention_3d_demo_utils`). +Input `heads` is the dict returned by `load_attention_map(connections_file, top_k=...)` from `visualize_attention_data.py`. The older `load_all_heads()` import path is still available for compatibility. diff --git a/scripts/run_head_heatmap_network_demo.py b/scripts/run_head_heatmap_network_demo.py index 61c71d8f..898d5f6a 100644 --- a/scripts/run_head_heatmap_network_demo.py +++ b/scripts/run_head_heatmap_network_demo.py @@ -20,7 +20,7 @@ sys.path.insert(0, REPO_ROOT) import numpy as np -from visualize_attention_arc_diagram_demo_utils import load_all_heads, parse_fasta_sequence +from visualize_attention_data import load_attention_map, parse_fasta_sequence from visualize_attention_head_heatmaps import build_head_matrices, plot_head_heatmaps from visualize_attention_networks import ( build_aggregated_graph, @@ -81,7 +81,7 @@ def main(): ) msa_file = os.path.join(attn_dir, f"msa_row_attn_layer{layer_idx}.txt") - heads = load_all_heads(msa_file, top_k=top_k) + heads = load_attention_map(msa_file, top_k=top_k) head_mats = build_head_matrices(heads, n_residues=n_residues) plot_head_heatmaps( head_mats, diff --git a/tests/test_visualize_attention_data.py b/tests/test_visualize_attention_data.py new file mode 100644 index 00000000..f5367d5b --- /dev/null +++ b/tests/test_visualize_attention_data.py @@ -0,0 +1,50 @@ +import os +import tempfile +import unittest + +from visualize_attention_data import ( + get_attention_file_path, + load_attention_map, + parse_fasta_sequence, +) + + +class TestVisualizeAttentionData(unittest.TestCase): + def test_load_attention_map_groups_sorts_and_filters_heads(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "msa_row_attn_layer47.txt") + with open(path, "w") as f: + f.write("layer 47 head 0\n") + f.write("2 3 0.1\n") + f.write("0 1 0.9\n") + f.write("layer 47 head 1\n") + f.write("4 5 0.4\n") + f.write("6 7 0.2\n") + + heads = load_attention_map(path, top_k=1) + + self.assertEqual(heads, {0: [(0, 1, 0.9)], 1: [(4, 5, 0.4)]}) + + def test_get_attention_file_path_uses_expected_names(self): + self.assertEqual( + get_attention_file_path("/tmp/attn", "msa_row", 47), + "/tmp/attn/msa_row_attn_layer47.txt", + ) + self.assertEqual( + get_attention_file_path("/tmp/attn", "triangle_start", 47, residue_idx=18), + "/tmp/attn/triangle_start_attn_layer47_residue_idx_18.txt", + ) + + def test_parse_fasta_sequence_joins_sequence_lines(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "protein.fasta") + with open(path, "w") as f: + f.write(">protein\n") + f.write("ACD\n") + f.write("EFG\n") + + self.assertEqual(parse_fasta_sequence(path), "ACDEFG") + + +if __name__ == "__main__": + unittest.main() diff --git a/visualize_attention_3d_demo_utils.py b/visualize_attention_3d_demo_utils.py index 378979c8..e84b9f45 100644 --- a/visualize_attention_3d_demo_utils.py +++ b/visualize_attention_3d_demo_utils.py @@ -10,39 +10,11 @@ import matplotlib.image as mpimg import os +from visualize_attention_data import get_attention_file_path, load_attention_map -# ========== Attention File I/O ========== -def load_all_heads(connections_file, top_k=None): - """ - Loads all heads' connections from a combined text file. - Returns a dict mapping head_index -> list of (res1, res2, weight). - """ - heads = {} - current_head = None - with open(connections_file, 'r') as f: - for line in f: - line = line.strip() - if not line: - continue - if line.lower().startswith('layer'): - # New head section - parts = line.replace(',', '').split() - head_idx = int(parts[-1]) - current_head = head_idx - heads[current_head] = [] - else: - # Residue-residue-weight line - res1, res2, weight = map(float, line.split()) - heads[current_head].append((int(res1), int(res2), weight)) - - # Sort each head's connections - for head_idx, conns in heads.items(): - conns.sort(key=lambda x: x[2], reverse=True) - if top_k is not None: - heads[head_idx] = conns[:top_k] - - return heads +# ========== Attention File I/O ========== +load_all_heads = load_attention_map def load_connections(connections_file, top_k=None): @@ -297,8 +269,8 @@ def plot_pymol_attention_heads( os.makedirs(output_dir, exist_ok=True) if attention_type == "msa_row": - msa_file = os.path.join(attention_dir, f"msa_row_attn_layer{layer_idx}.txt") - msa_heads = load_all_heads(msa_file, top_k=top_k) + msa_file = get_attention_file_path(attention_dir, attention_type, layer_idx) + msa_heads = load_attention_map(msa_file, top_k=top_k) image_paths = [] for head_idx, connections in msa_heads.items(): @@ -316,12 +288,14 @@ def plot_pymol_attention_heads( assert residue_indices is not None, "Must supply residue_indices for triangle attention" for res_idx in residue_indices: - tri_file = os.path.join(attention_dir, f"triangle_start_attn_layer{layer_idx}_residue_idx_{res_idx}.txt") + tri_file = get_attention_file_path( + attention_dir, attention_type, layer_idx, residue_idx=res_idx + ) if not os.path.exists(tri_file): print(f"[Warning] Missing attention file for residue {res_idx}") continue - tri_heads = load_all_heads(tri_file, top_k=top_k) + tri_heads = load_attention_map(tri_file, top_k=top_k) res_pngs = [] for head_idx, connections in tri_heads.items(): output_png = os.path.join(output_dir, f"tri_start_residue_{res_idx}_head_{head_idx}_layer_{layer_idx}_{protein}.png") diff --git a/visualize_attention_arc_diagram_demo_utils.py b/visualize_attention_arc_diagram_demo_utils.py index 5033f88b..f1ee612f 100644 --- a/visualize_attention_arc_diagram_demo_utils.py +++ b/visualize_attention_arc_diagram_demo_utils.py @@ -2,43 +2,14 @@ import numpy as np import matplotlib.pyplot as plt +from visualize_attention_data import ( + get_attention_file_path, + load_attention_map, + parse_fasta_sequence, +) -# ========== Input Parsing ========== -def load_all_heads(connections_file, top_k=None): - heads = {} - current_head = None - with open(connections_file, 'r') as f: - for line in f: - line = line.strip() - if not line: - continue - if line.lower().startswith('layer'): - parts = line.replace(',', '').split() - head_idx = int(parts[-1]) - current_head = head_idx - heads[current_head] = [] - else: - res1, res2, weight = map(float, line.split()) - heads[current_head].append((int(res1), int(res2), weight)) - - for head_idx, conns in heads.items(): - conns.sort(key=lambda x: x[2], reverse=True) - if top_k is not None: - heads[head_idx] = conns[:top_k] - - return heads - - -def parse_fasta_sequence(fasta_path): - """ - Parse a single-entry FASTA file and return the sequence string. - """ - with open(fasta_path, 'r') as f: - lines = f.readlines() - - seq_lines = [line.strip() for line in lines if not line.startswith('>')] - sequence = ''.join(seq_lines) - return sequence +# Backward-compatible import path for notebooks/scripts that already use this file. +load_all_heads = load_attention_map # ========== Arc Plotting ========== @@ -116,8 +87,8 @@ def generate_arc_diagrams( os.makedirs(output_dir, exist_ok=True) if attention_type == "msa_row": - file_path = os.path.join(attention_dir, f"msa_row_attn_layer{layer_idx}.txt") - heads = load_all_heads(file_path, top_k=top_k) + file_path = get_attention_file_path(attention_dir, attention_type, layer_idx) + heads = load_attention_map(file_path, top_k=top_k) pngs = [] for head_idx, connections in heads.items(): @@ -131,12 +102,14 @@ def generate_arc_diagrams( assert residue_indices is not None, "residue_indices required for triangle_start attention" for res_idx in residue_indices: - file_path = os.path.join(attention_dir, f"triangle_start_attn_layer{layer_idx}_residue_idx_{res_idx}.txt") + file_path = get_attention_file_path( + attention_dir, attention_type, layer_idx, residue_idx=res_idx + ) if not os.path.exists(file_path): print(f"[Warning] Missing file for residue {res_idx}") continue - heads = load_all_heads(file_path, top_k=top_k) + heads = load_attention_map(file_path, top_k=top_k) pngs = [] for head_idx, connections in heads.items(): diff --git a/visualize_attention_data.py b/visualize_attention_data.py new file mode 100644 index 00000000..c0e0c82e --- /dev/null +++ b/visualize_attention_data.py @@ -0,0 +1,100 @@ +""" +Shared attention-map loading helpers for visualization modules. + +All visualization types should consume the parsed ``heads`` structure from this +module instead of re-parsing attention text files independently. +""" + +import os +from typing import Dict, List, Optional, Tuple + + +AttentionEdge = Tuple[int, int, float] +AttentionHeads = Dict[int, List[AttentionEdge]] + + +def filter_top_k_edges(heads: AttentionHeads, top_k: Optional[int] = None) -> AttentionHeads: + """Sort each head by descending weight and optionally keep only top-k edges.""" + filtered = {} + for head_idx, conns in heads.items(): + sorted_conns = sorted(conns, key=lambda x: x[2], reverse=True) + filtered[head_idx] = sorted_conns[:top_k] if top_k is not None else sorted_conns + return filtered + + +def load_attention_map(connections_file: str, top_k: Optional[int] = None) -> AttentionHeads: + """ + Load a combined attention text file into ``head_idx -> [(res_i, res_j, weight)]``. + + Expected format: + layer 47 head 0 + 12 39 0.41 + 8 14 0.32 + layer 47 head 1 + ... + """ + heads = {} + current_head = None + + with open(connections_file, "r") as f: + for line_number, line in enumerate(f, start=1): + line = line.strip() + if not line: + continue + + if line.lower().startswith("layer"): + parts = line.replace(",", "").split() + try: + current_head = int(parts[-1]) + except (IndexError, ValueError) as exc: + raise ValueError( + f"Could not parse head index on line {line_number}: {line}" + ) from exc + heads[current_head] = [] + continue + + if current_head is None: + raise ValueError( + f"Found attention edge before any head header on line {line_number}: {line}" + ) + + try: + res1, res2, weight = map(float, line.split()) + except ValueError as exc: + raise ValueError( + f"Could not parse attention edge on line {line_number}: {line}" + ) from exc + heads[current_head].append((int(res1), int(res2), weight)) + + return filter_top_k_edges(heads, top_k=top_k) + + +def get_attention_file_path( + attention_dir: str, + attention_type: str, + layer_idx: int, + residue_idx: Optional[int] = None, +) -> str: + """Return the canonical attention text-file path for a visualization request.""" + if attention_type == "msa_row": + return os.path.join(attention_dir, f"msa_row_attn_layer{layer_idx}.txt") + + if attention_type == "triangle_start": + if residue_idx is None: + raise ValueError("residue_idx is required for triangle_start attention") + return os.path.join( + attention_dir, + f"triangle_start_attn_layer{layer_idx}_residue_idx_{residue_idx}.txt", + ) + + raise ValueError(f"Unsupported attention_type: {attention_type}") + + +def parse_fasta_sequence(fasta_path: str) -> str: + """Parse a single-entry FASTA file and return the sequence string.""" + with open(fasta_path, "r") as f: + return "".join(line.strip() for line in f if not line.startswith(">")) + + +# Backward-compatible name used by existing notebooks and scripts. +load_all_heads = load_attention_map diff --git a/visualize_attention_head_heatmaps.py b/visualize_attention_head_heatmaps.py index 2467a3e8..e709ea47 100644 --- a/visualize_attention_head_heatmaps.py +++ b/visualize_attention_head_heatmaps.py @@ -23,8 +23,7 @@ def build_head_matrices( Args: heads: Dict mapping head_idx -> list of (res_i, res_j, weight). - From load_all_heads() in visualize_attention_arc_diagram_demo_utils - or visualize_attention_3d_demo_utils. + From load_attention_map() in visualize_attention_data. n_residues: Number of residues (sequence length). If None, inferred from max index appearing in any head. symmetrize: If True, set A[i,j] = A[j,i] = max(weight(i,j), weight(j,i)). diff --git a/visualize_attention_networks.py b/visualize_attention_networks.py index 73c5820b..2f74380e 100644 --- a/visualize_attention_networks.py +++ b/visualize_attention_networks.py @@ -220,7 +220,7 @@ def plot_residue_network_per_head( Draw one small network per head in a grid (small multiples). Args: - heads: From load_all_heads(). + heads: From load_attention_map() in visualize_attention_data. n_residues, residue_sequence, layer_idx, protein, output_dir: As in plot_residue_network. layout: "circular" or "linear". max_edges_per_head: Max edges to draw per head. From 5fc22ade4b02e06c2fe7458bc19a2b824a81feae Mon Sep 17 00:00:00 2001 From: ssagili3 Date: Tue, 28 Apr 2026 16:09:26 -0400 Subject: [PATCH 3/3] Add chord diagram attention visualization What changed: - Add visualize_attention_chord_diagrams.py for circular residue-residue attention diagrams. - Support per-head chord PNGs, an all-head small-multiples grid, and an aggregated mean chord diagram. - Build chord diagrams from the shared load_attention_map data structure added in the prior refactor. - Wire the sample heatmap/network demo script to also generate chord diagram outputs. - Document the chord_diagrams output folder and add tests for chord aggregation and PNG generation. --- .../Head_Visualization_Heatmap_Network.md | 3 + .../sample_attention_viz_outputs/README.md | 3 +- scripts/run_head_heatmap_network_demo.py | 14 + ...test_visualize_attention_chord_diagrams.py | 55 +++ visualize_attention_chord_diagrams.py | 386 ++++++++++++++++++ 5 files changed, 460 insertions(+), 1 deletion(-) create mode 100644 tests/test_visualize_attention_chord_diagrams.py create mode 100644 visualize_attention_chord_diagrams.py diff --git a/docs/source/Head_Visualization_Heatmap_Network.md b/docs/source/Head_Visualization_Heatmap_Network.md index 8be4660a..41c663fd 100644 --- a/docs/source/Head_Visualization_Heatmap_Network.md +++ b/docs/source/Head_Visualization_Heatmap_Network.md @@ -17,6 +17,7 @@ Arc diagrams and PyMOL overlays are useful but show one head at a time. The new | `visualize_attention_data.py` | Shared attention-map parser and FASTA reader used by every visualization module. | | `visualize_attention_head_heatmaps.py` | Builds dense per-head matrices from shared parsed attention data and plots a grid of residue–residue heatmaps (one per head). | | `visualize_attention_networks.py` | Aggregates heads into one weighted graph and/or draws one small network per head; supports circular or linear layout and hub highlighting. | +| `visualize_attention_chord_diagrams.py` | Renders circular chord diagrams for single heads, all-head grids, and aggregated mean attention. | | Notebook cells (in `viz_attention_demo.ipynb` and `viz_attention_demo_base.ipynb`) | After running 3D and arc visualizations, a new cell runs heatmap and network code for MSA row and (optionally) triangle-start attention. | | `scripts/run_head_heatmap_network_demo.py` | Standalone script that generates **synthetic** attention data and runs the new visualizations to produce example outputs without full OpenFold inference. | @@ -56,6 +57,7 @@ Arc diagrams and PyMOL overlays are useful but show one head at a time. The new | **Heatmap grid** | Comparing all heads in one layer; seeing dense vs. sparse and similarity across heads. | | **Aggregated network** | Which residues are hubs across the whole layer; one picture for “consensus” attention. | | **Per-head network grid** | Same as heatmap grid but with a network layout; can be easier for seeing clusters and long-range ties. | +| **Chord diagrams** | Circular residue-residue attention view; useful for seeing long-range links without forcing residues into a straight line. | The new visualizations do not replace arc or 3D views; they complement them by answering “what do all heads in this layer look like together?” and “which residues matter most when we aggregate heads?” @@ -67,5 +69,6 @@ The new visualizations do not replace arc or 3D views; they complement them by a - **Aggregated graph**: `build_aggregated_graph(heads, aggregation="mean", normalize_by_heads=True)` in `visualize_attention_networks.py`. - **Single network plot**: `plot_residue_network(edges, n_residues, residue_sequence, layer_idx, protein, output_dir, layout="circular", max_edges=200, top_k_hubs=10)`. - **Per-head networks**: `plot_residue_network_per_head(heads, n_residues, ...)`. +- **Chord diagrams**: `generate_chord_diagrams(attention_dir, residue_sequence, output_dir, protein, ...)` in `visualize_attention_chord_diagrams.py`. Input `heads` is the dict returned by `load_attention_map(connections_file, top_k=...)` from `visualize_attention_data.py`. The older `load_all_heads()` import path is still available for compatibility. diff --git a/examples/monomer/sample_attention_viz_outputs/README.md b/examples/monomer/sample_attention_viz_outputs/README.md index 3396b865..6bdc0098 100644 --- a/examples/monomer/sample_attention_viz_outputs/README.md +++ b/examples/monomer/sample_attention_viz_outputs/README.md @@ -10,7 +10,8 @@ from the repository root. The script uses **synthetic** attention data (same fil - **head_heatmaps/** — One combined PNG with a grid of residue–residue heatmaps (one per attention head). - **network_plots/** — Aggregated attention network (all heads combined, hub residues highlighted) and a per-head network grid. +- **chord_diagrams/** — Per-head chord diagrams, a combined per-head grid, and an aggregated mean chord diagram. -To generate visualizations from **real** inference, run the full pipeline in `viz_attention_demo.ipynb` or `viz_attention_demo_base.ipynb`; the same heatmap and network code runs in the notebook and writes to `IMAGE_OUTPUT_DIR/head_heatmaps/` and `IMAGE_OUTPUT_DIR/network_plots/`. +To generate visualizations from **real** inference, run the full pipeline in `viz_attention_demo.ipynb` or `viz_attention_demo_base.ipynb`; the same heatmap, network, and chord diagram code can run on the saved attention maps. See [Head_Visualization_Heatmap_Network.md](../../../docs/source/Head_Visualization_Heatmap_Network.md) for a full description and comparison with arc diagrams and 3D overlays. diff --git a/scripts/run_head_heatmap_network_demo.py b/scripts/run_head_heatmap_network_demo.py index 898d5f6a..e6f05362 100644 --- a/scripts/run_head_heatmap_network_demo.py +++ b/scripts/run_head_heatmap_network_demo.py @@ -21,6 +21,7 @@ import numpy as np from visualize_attention_data import load_attention_map, parse_fasta_sequence +from visualize_attention_chord_diagrams import generate_chord_diagrams from visualize_attention_head_heatmaps import build_head_matrices, plot_head_heatmaps from visualize_attention_networks import ( build_aggregated_graph, @@ -68,6 +69,7 @@ def main(): attn_dir = os.path.join(out_base, "attention_files_6KWC_demo") heatmap_dir = os.path.join(out_base, "head_heatmaps") network_dir = os.path.join(out_base, "network_plots") + chord_dir = os.path.join(out_base, "chord_diagrams") layer_idx = 47 top_k = 50 @@ -116,6 +118,18 @@ def main(): max_edges_per_head=50, cols=4, ) + generate_chord_diagrams( + attention_dir=attn_dir, + residue_sequence=residue_seq, + output_dir=chord_dir, + protein=protein, + attention_type="msa_row", + top_k=top_k, + layer_idx=layer_idx, + save_individual=True, + save_grid=True, + save_aggregated=True, + ) print(f"Example outputs saved under {out_base}") return 0 diff --git a/tests/test_visualize_attention_chord_diagrams.py b/tests/test_visualize_attention_chord_diagrams.py new file mode 100644 index 00000000..c6fd0e7c --- /dev/null +++ b/tests/test_visualize_attention_chord_diagrams.py @@ -0,0 +1,55 @@ +import os +import tempfile +import unittest + +from visualize_attention_chord_diagrams import ( + aggregate_chord_edges, + generate_chord_diagrams, +) + + +class TestVisualizeAttentionChordDiagrams(unittest.TestCase): + def test_aggregate_chord_edges_uses_mean_by_default(self): + heads = { + 0: [(0, 1, 0.2), (2, 3, 0.5)], + 1: [(0, 1, 0.6)], + } + + self.assertEqual( + aggregate_chord_edges(heads), + [(2, 3, 0.5), (0, 1, 0.4)], + ) + + def test_generate_chord_diagrams_writes_expected_msa_outputs(self): + with tempfile.TemporaryDirectory() as tmpdir: + attn_dir = os.path.join(tmpdir, "attention") + out_dir = os.path.join(tmpdir, "chords") + os.makedirs(attn_dir) + with open(os.path.join(attn_dir, "msa_row_attn_layer47.txt"), "w") as f: + f.write("layer 47 head 0\n") + f.write("0 1 0.8\n") + f.write("1 2 0.2\n") + f.write("layer 47 head 1\n") + f.write("2 3 0.7\n") + f.write("3 4 0.1\n") + + paths = generate_chord_diagrams( + attention_dir=attn_dir, + residue_sequence="ABCDE", + output_dir=out_dir, + protein="TEST", + attention_type="msa_row", + top_k=1, + layer_idx=47, + save_individual=True, + save_grid=True, + save_aggregated=True, + ) + + self.assertEqual(len(paths), 4) + for path in paths: + self.assertTrue(os.path.exists(path)) + + +if __name__ == "__main__": + unittest.main() diff --git a/visualize_attention_chord_diagrams.py b/visualize_attention_chord_diagrams.py new file mode 100644 index 00000000..90e2a0ff --- /dev/null +++ b/visualize_attention_chord_diagrams.py @@ -0,0 +1,386 @@ +""" +Chord-style attention visualizations. + +This module renders circular residue-residue attention diagrams from the shared +``heads`` structure produced by ``visualize_attention_data.load_attention_map``. +""" + +import os +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +import matplotlib.pyplot as plt +from matplotlib.path import Path +from matplotlib.patches import PathPatch +import numpy as np + +from visualize_attention_data import get_attention_file_path, load_attention_map + + +AttentionEdge = Tuple[int, int, float] + + +def aggregate_chord_edges( + heads: Dict[int, List[AttentionEdge]], + aggregation: str = "mean", +) -> List[AttentionEdge]: + """ + Aggregate per-head attention edges for an all-head chord diagram. + + Args: + heads: Dict mapping head_idx -> list of (res_i, res_j, weight). + aggregation: "mean" or "sum" over weights for each residue pair. + + Returns: + List of aggregated (res_i, res_j, weight), sorted by descending weight. + """ + edge_to_weights = defaultdict(list) + for conns in heads.values(): + for res_i, res_j, weight in conns: + edge_to_weights[(int(res_i), int(res_j))].append(float(weight)) + + aggregated = [] + for (res_i, res_j), weights in edge_to_weights.items(): + if aggregation == "sum": + value = np.sum(weights) + else: + value = np.mean(weights) + aggregated.append((res_i, res_j, float(value))) + + aggregated.sort(key=lambda x: x[2], reverse=True) + return aggregated + + +def _residue_positions(n_residues: int) -> np.ndarray: + angles = np.linspace(0, 2 * np.pi, n_residues, endpoint=False) + (np.pi / 2) + return np.column_stack([np.cos(angles), np.sin(angles)]) + + +def _draw_chord(ax, p1, p2, weight, w_min, w_max, color): + if w_max <= w_min: + norm_weight = 0.5 + else: + norm_weight = (weight - w_min) / (w_max - w_min) + + path = Path( + [p1, (0.0, 0.0), p2], + [Path.MOVETO, Path.CURVE3, Path.CURVE3], + ) + patch = PathPatch( + path, + facecolor="none", + edgecolor=color, + linewidth=0.4 + 3.0 * norm_weight, + alpha=0.25 + 0.55 * norm_weight, + zorder=1, + ) + ax.add_patch(patch) + + +def plot_chord_diagram( + connections: List[AttentionEdge], + n_residues: int, + residue_sequence: Optional[str], + layer_idx: int, + protein: str, + output_path: str, + head_idx: Optional[int] = None, + title: Optional[str] = None, + highlight_residue_index: Optional[int] = None, + max_edges: Optional[int] = 80, + figsize: Tuple[float, float] = (9, 9), + show_plot: bool = False, +) -> str: + """ + Draw one circular chord diagram for a set of attention connections. + + Args: + connections: List of (res_i, res_j, weight) attention edges. + n_residues: Number of residues in the sequence. + residue_sequence: Optional sequence labels for small proteins. + layer_idx: Attention layer index. + protein: Protein name for plot title. + output_path: PNG output path. + head_idx: Optional head index for titles. + title: Optional explicit title. + highlight_residue_index: Optional residue index to highlight. + max_edges: Cap edges drawn by descending weight. + figsize: Matplotlib figure size. + show_plot: If True, call plt.show(). + + Returns: + Path to saved PNG. + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + if max_edges is not None: + connections = sorted(connections, key=lambda x: x[2], reverse=True)[:max_edges] + + pos = _residue_positions(n_residues) + fig, ax = plt.subplots(figsize=figsize) + + circle = plt.Circle((0, 0), 1.0, fill=False, color="gray", linewidth=1.0, alpha=0.7) + ax.add_patch(circle) + + if connections: + weights = [weight for _, _, weight in connections] + w_min, w_max = min(weights), max(weights) + for res_i, res_j, weight in connections: + if 0 <= res_i < n_residues and 0 <= res_j < n_residues: + _draw_chord( + ax, + pos[res_i], + pos[res_j], + weight, + w_min, + w_max, + color="royalblue", + ) + + node_colors = ["lightgray"] * n_residues + if highlight_residue_index is not None and 0 <= highlight_residue_index < n_residues: + node_colors[highlight_residue_index] = "coral" + + ax.scatter( + pos[:, 0], + pos[:, 1], + s=24, + c=node_colors, + edgecolors="dimgray", + linewidths=0.5, + zorder=2, + ) + + show_labels = residue_sequence is not None and len(residue_sequence) == n_residues and n_residues <= 90 + if show_labels: + for idx, (x, y) in enumerate(pos): + label_x, label_y = 1.08 * x, 1.08 * y + ax.text( + label_x, + label_y, + residue_sequence[idx], + ha="center", + va="center", + fontsize=5, + color="darkred" if idx == highlight_residue_index else "black", + ) + + if title is None: + head_label = f"Head {head_idx}" if head_idx is not None else "Aggregated Heads" + title = f"{protein} - Layer {layer_idx} - {head_label} chord diagram" + + ax.set_title(title, fontsize=12, weight="bold") + ax.set_aspect("equal") + ax.axis("off") + ax.set_xlim(-1.25, 1.25) + ax.set_ylim(-1.25, 1.25) + plt.tight_layout() + plt.savefig(output_path, dpi=200, bbox_inches="tight") + + if show_plot: + plt.show() + else: + plt.close() + + print(f"[Saved] Chord diagram: {output_path}") + return output_path + + +def plot_chord_diagrams_per_head( + heads: Dict[int, List[AttentionEdge]], + n_residues: int, + residue_sequence: Optional[str], + layer_idx: int, + protein: str, + output_dir: str, + max_edges_per_head: int = 80, + show_plot: bool = False, +) -> List[str]: + """Save one chord diagram per attention head.""" + os.makedirs(output_dir, exist_ok=True) + saved_paths = [] + + for head_idx, connections in sorted(heads.items()): + output_path = os.path.join( + output_dir, + f"chord_head_{head_idx}_layer_{layer_idx}_{protein}.png", + ) + saved_paths.append( + plot_chord_diagram( + connections, + n_residues=n_residues, + residue_sequence=residue_sequence, + layer_idx=layer_idx, + protein=protein, + output_path=output_path, + head_idx=head_idx, + max_edges=max_edges_per_head, + show_plot=show_plot, + ) + ) + + return saved_paths + + +def plot_chord_diagram_grid( + heads: Dict[int, List[AttentionEdge]], + n_residues: int, + layer_idx: int, + protein: str, + output_dir: str, + max_edges_per_head: int = 50, + cols: int = 4, + figsize_per_subplot: float = 3.0, + show_plot: bool = False, +) -> List[str]: + """Save a small-multiples grid of chord diagrams, one subplot per head.""" + os.makedirs(output_dir, exist_ok=True) + head_indices = sorted(heads.keys()) + if not head_indices: + return [] + + rows = (len(head_indices) + cols - 1) // cols + fig, axes = plt.subplots( + rows, + cols, + figsize=(cols * figsize_per_subplot, rows * figsize_per_subplot), + squeeze=False, + ) + axes_flat = axes.flatten() + pos = _residue_positions(n_residues) + + for idx, head_idx in enumerate(head_indices): + ax = axes_flat[idx] + conns = sorted(heads[head_idx], key=lambda x: x[2], reverse=True)[:max_edges_per_head] + ax.add_patch(plt.Circle((0, 0), 1.0, fill=False, color="gray", linewidth=0.7, alpha=0.6)) + if conns: + weights = [weight for _, _, weight in conns] + w_min, w_max = min(weights), max(weights) + for res_i, res_j, weight in conns: + if 0 <= res_i < n_residues and 0 <= res_j < n_residues: + _draw_chord(ax, pos[res_i], pos[res_j], weight, w_min, w_max, "royalblue") + ax.scatter(pos[:, 0], pos[:, 1], s=6, c="lightgray", edgecolors="gray", linewidths=0.2, zorder=2) + ax.set_title(f"Head {head_idx}", fontsize=9) + ax.set_aspect("equal") + ax.axis("off") + ax.set_xlim(-1.15, 1.15) + ax.set_ylim(-1.15, 1.15) + + for idx in range(len(head_indices), len(axes_flat)): + axes_flat[idx].set_visible(False) + + fig.suptitle(f"{protein} - Layer {layer_idx} - Chord diagrams per head", fontsize=12, weight="bold", y=1.02) + plt.tight_layout() + output_path = os.path.join(output_dir, f"chord_heads_layer_{layer_idx}_{protein}_grid.png") + plt.savefig(output_path, dpi=200, bbox_inches="tight") + + if show_plot: + plt.show() + else: + plt.close() + + print(f"[Saved] Chord grid: {output_path}") + return [output_path] + + +def generate_chord_diagrams( + attention_dir: str, + residue_sequence: str, + output_dir: str, + protein: str, + attention_type: str = "msa_row", + residue_indices: Optional[List[int]] = None, + top_k: int = 50, + layer_idx: int = 47, + save_individual: bool = True, + save_grid: bool = True, + save_aggregated: bool = True, +) -> List[str]: + """ + Generate chord diagrams from saved attention text files. + + For MSA row attention, this can save per-head diagrams, a per-head grid, and + an aggregated mean diagram. For triangle-start attention, diagrams are + generated separately for each requested residue index. + """ + os.makedirs(output_dir, exist_ok=True) + n_residues = len(residue_sequence) + saved_paths = [] + + if attention_type == "msa_row": + file_path = get_attention_file_path(attention_dir, attention_type, layer_idx) + heads = load_attention_map(file_path, top_k=top_k) + + if save_individual: + saved_paths.extend( + plot_chord_diagrams_per_head( + heads, + n_residues, + residue_sequence, + layer_idx, + protein, + output_dir, + max_edges_per_head=top_k, + ) + ) + if save_grid: + saved_paths.extend( + plot_chord_diagram_grid( + heads, + n_residues, + layer_idx, + protein, + output_dir, + max_edges_per_head=top_k, + ) + ) + if save_aggregated: + aggregated = aggregate_chord_edges(heads, aggregation="mean") + output_path = os.path.join(output_dir, f"chord_aggregated_layer_{layer_idx}_{protein}.png") + saved_paths.append( + plot_chord_diagram( + aggregated, + n_residues, + residue_sequence, + layer_idx, + protein, + output_path, + title=f"{protein} - Layer {layer_idx} - Mean attention chord diagram", + max_edges=top_k, + ) + ) + + elif attention_type == "triangle_start": + if residue_indices is None: + raise ValueError("residue_indices is required for triangle_start attention") + + for res_idx in residue_indices: + file_path = get_attention_file_path( + attention_dir, attention_type, layer_idx, residue_idx=res_idx + ) + if not os.path.exists(file_path): + print(f"[Warning] Missing file for residue {res_idx}") + continue + heads = load_attention_map(file_path, top_k=top_k) + for head_idx, connections in sorted(heads.items()): + output_path = os.path.join( + output_dir, + f"chord_tri_start_res_{res_idx}_head_{head_idx}_layer_{layer_idx}_{protein}.png", + ) + saved_paths.append( + plot_chord_diagram( + connections, + n_residues, + residue_sequence, + layer_idx, + protein, + output_path, + head_idx=head_idx, + highlight_residue_index=res_idx, + max_edges=top_k, + ) + ) + else: + raise ValueError(f"Unsupported attention_type: {attention_type}") + + return saved_paths