Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,28 @@ cutlass/
*.sto
*.a3m
*.hhr

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl
*.bak
64 changes: 64 additions & 0 deletions inspect_evoformer_reps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import argparse
from pathlib import Path

import torch


def summarize_tensor(name: str, value: torch.Tensor):
value_f = value.float()
print(
f"{name}: "
f"shape={tuple(value.shape)}, "
f"dtype={value.dtype}, "
f"mean={float(value_f.mean()):.5f}, "
f"std={float(value_f.std()):.5f}, "
f"min={float(value_f.min()):.5f}, "
f"max={float(value_f.max()):.5f}"
)


def compare_tensors(payload, key_a: str, key_b: str):
if key_a not in payload or key_b not in payload:
print(f"Skipping comparison {key_a} vs {key_b} (missing key)")
return

a = payload[key_a]
b = payload[key_b]

print(f"Comparing {key_a} vs {key_b}")
print(f"same_shape={a.shape == b.shape}")
if a.shape == b.shape:
print(f"allclose={torch.allclose(a, b)}")
print()


def main(path: str):
path = Path(path)

if not path.exists():
raise FileNotFoundError(f"File not found: {path}")

payload = torch.load(path, map_location="cpu")

print(f"Loaded: {path}")
print(f"Number of keys: {len(payload)}")
print()

for key in sorted(payload.keys()):
value = payload[key]
if isinstance(value, torch.Tensor):
summarize_tensor(key, value)
else:
print(f"{key}: type={type(value)}")

print()
print("Layer comparison check:")
compare_tensors(payload, "layer_00.msa", "layer_47.msa")
compare_tensors(payload, "layer_00.pair", "layer_47.pair")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("path", type=str, help="Path to saved evoformer reps .pt file")
args = parser.parse_args()
main(args.path)
98 changes: 98 additions & 0 deletions notebooks/intermediate_rep_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import numpy as np
import matplotlib.pyplot as plt

# ---- Config ----
PROT = "6KWC"
TRI_RESIDUE_IDX = 18
ATTN_MAP_DIR = f"./outputs/attention_files_{PROT}_demo_tri_{TRI_RESIDUE_IDX}"
IMAGE_OUTPUT_DIR = f"./outputs/attention_images_{PROT}_demo_tri_{TRI_RESIDUE_IDX}"
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)

# ---- Helper: parse attention txt file ----
def parse_attention_file(filepath):
scores = []
with open(filepath, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 3:
try:
score = float(parts[2])
i = int(parts[0])
j = int(parts[1])
scores.append((i, j, score))
except:
continue
return scores

# ---- Helper: build matrix from scores ----
def build_matrix(scores, size=200):
matrix = np.zeros((size, size))
for i, j, score in scores:
if i < size and j < size:
matrix[i, j] = score
return matrix

# ---- Plot 1: Heatmap for a single layer ----
def plot_heatmap(layer_idx, attention_type="msa_row"):
if attention_type == "msa_row":
fname = f"msa_row_attn_layer{layer_idx}.txt"
else:
fname = f"triangle_start_attn_layer{layer_idx}_residue_idx_{TRI_RESIDUE_IDX}.txt"

filepath = os.path.join(ATTN_MAP_DIR, fname)
scores = parse_attention_file(filepath)
matrix = build_matrix(scores)

plt.figure(figsize=(8, 6))
plt.imshow(matrix, cmap='viridis', aspect='auto')
plt.colorbar(label='Attention Score')
plt.title(f'{attention_type} Attention - Layer {layer_idx}')
plt.xlabel('Residue j')
plt.ylabel('Residue i')
out_path = os.path.join(IMAGE_OUTPUT_DIR, f'heatmap_{attention_type}_layer{layer_idx}.png')
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved heatmap to {out_path}")

# ---- Plot 2: Line plot of top attention scores across layers ----
def plot_scores_across_layers(attention_type="msa_row", num_layers=48):
top_scores = []
for layer_idx in range(num_layers):
if attention_type == "msa_row":
fname = f"msa_row_attn_layer{layer_idx}.txt"
else:
fname = f"triangle_start_attn_layer{layer_idx}_residue_idx_{TRI_RESIDUE_IDX}.txt"
filepath = os.path.join(ATTN_MAP_DIR, fname)
if os.path.exists(filepath):
scores = parse_attention_file(filepath)
if scores:
top_scores.append(max(s[2] for s in scores))
else:
top_scores.append(0)
else:
top_scores.append(0)

plt.figure(figsize=(12, 5))
plt.plot(range(num_layers), top_scores, marker='o', linewidth=2)
plt.title(f'Top Attention Score Across Layers - {attention_type}')
plt.xlabel('Layer')
plt.ylabel('Top Attention Score')
plt.grid(True)
out_path = os.path.join(IMAGE_OUTPUT_DIR, f'lineplot_{attention_type}_across_layers.png')
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved line plot to {out_path}")

# ---- Run everything ----
if __name__ == "__main__":
# Generate heatmaps for a few layers
for layer in range(48):
plot_heatmap(layer, attention_type="msa_row")
plot_heatmap(layer, attention_type="triangle_start")

# Generate line plots across all 48 layers
plot_scores_across_layers(attention_type="msa_row")
plot_scores_across_layers(attention_type="triangle_start")

print("All done! Check", IMAGE_OUTPUT_DIR)
4 changes: 1 addition & 3 deletions openfold/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@
from . import utils
from . import data
from . import np
from . import resources

__all__ = ["model", "utils", "np", "data", "resources"]
__all__ = ["model", "utils", "np", "data"]
188 changes: 188 additions & 0 deletions openfold/utils/evoformer_instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

import torch


@dataclass
class InstrumentationConfig:
layers: Optional[List[int]] = None
capture_msa: bool = True
capture_pair: bool = True
clone: bool = True
to_cpu: bool = True
dtype: Optional[torch.dtype] = torch.float32


class EvoformerRecorder:
def __init__(self, config: InstrumentationConfig):
self.config = config
self.records: Dict[str, Any] = {}
self.enabled = True

def clear(self):
self.records.clear()

def _process_tensor(self, x: torch.Tensor) -> torch.Tensor:
if self.config.clone:
x = x.clone()
x = x.detach()
if self.config.to_cpu:
x = x.cpu()
if self.config.dtype is not None:
x = x.to(self.config.dtype)
return x

def record(self, key: str, value: Any):
if not self.enabled:
return

if isinstance(value, torch.Tensor):
self.records[key] = self._process_tensor(value)
elif isinstance(value, (tuple, list)):
out = []
for item in value:
if isinstance(item, torch.Tensor):
out.append(self._process_tensor(item))
else:
out.append(item)
self.records[key] = out
elif isinstance(value, dict):
out = {}
for k, v in value.items():
if isinstance(v, torch.Tensor):
out[k] = self._process_tensor(v)
else:
out[k] = v
self.records[key] = out
else:
self.records[key] = value

def save(self, path: str | Path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.records, path)

def summary(self) -> Dict[str, Any]:
s = {}
for k, v in self.records.items():
if isinstance(v, torch.Tensor):
s[k] = {
"shape": tuple(v.shape),
"dtype": str(v.dtype),
"mean": float(v.float().mean()),
"std": float(v.float().std()),
}
elif isinstance(v, list):
s[k] = {"type": "list", "len": len(v)}
elif isinstance(v, dict):
inner = {}
for kk, vv in v.items():
if isinstance(vv, torch.Tensor):
inner[kk] = {"shape": tuple(vv.shape), "dtype": str(vv.dtype)}
else:
inner[kk] = str(type(vv))
s[k] = inner
else:
s[k] = str(type(v))
return s


def find_evoformer_blocks(model):
candidates = [
"evoformer.blocks",
"evoformer.trunk.blocks",
"model.evoformer.blocks",
]
for name in candidates:
obj = model
ok = True
for part in name.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok:
return obj

for name, module in model.named_modules():
if name.endswith("evoformer.blocks"):
return module

raise RuntimeError("Could not find Evoformer blocks in model")


def _selected(layer_idx: int, layers: Optional[List[int]]) -> bool:
return layers is None or layer_idx in layers


def attach_evoformer_block_output_hooks(model, recorder: EvoformerRecorder):
handles = []
blocks = find_evoformer_blocks(model)

for layer_idx, block in enumerate(blocks):
if not _selected(layer_idx, recorder.config.layers):
continue

def block_hook(module, inputs, output, layer_idx=layer_idx):
if not recorder.enabled:
return

if isinstance(output, (tuple, list)):
if recorder.config.capture_msa and len(output) > 0 and isinstance(output[0], torch.Tensor):
recorder.record(f"layer_{layer_idx:02d}.msa", output[0])
if recorder.config.capture_pair and len(output) > 1 and isinstance(output[1], torch.Tensor):
recorder.record(f"layer_{layer_idx:02d}.pair", output[1])
elif isinstance(output, dict):
if recorder.config.capture_msa:
for k, v in output.items():
if "msa" in str(k).lower() and isinstance(v, torch.Tensor):
recorder.record(f"layer_{layer_idx:02d}.msa", v)
break
if recorder.config.capture_pair:
for k, v in output.items():
if "pair" in str(k).lower() and isinstance(v, torch.Tensor):
recorder.record(f"layer_{layer_idx:02d}.pair", v)
break
elif isinstance(output, torch.Tensor):
recorder.record(f"layer_{layer_idx:02d}.output", output)

handles.append(block.register_forward_hook(block_hook))

return handles


def remove_hooks(handles):
for h in handles:
h.remove()


def attach_and_run_on_batch(model, batch: Dict[str, Any], layers=None, device: str = "cpu", out_path: str | Path | None = None):
model = model.to(device)
cfg = InstrumentationConfig(layers=layers, capture_msa=True, capture_pair=True)
recorder = EvoformerRecorder(cfg)
handles = attach_evoformer_block_output_hooks(model, recorder)

with torch.no_grad():
_ = model(batch)

remove_hooks(handles)

if out_path is not None:
recorder.save(out_path)

return recorder


def make_dummy_batch(n_res=64, n_seq=16, c_msa=49, device="cpu"):
return {
"target_feat": torch.randn(1, n_res, 22, device=device),
"residue_index": torch.arange(n_res, device=device).unsqueeze(0),
"msa_feat": torch.randn(1, n_seq, n_res, c_msa, device=device),
"seq_mask": torch.ones(1, n_res, device=device),
"msa_mask": torch.ones(1, n_seq, n_res, device=device),
"aatype": torch.zeros(1, n_res, dtype=torch.long, device=device),
}
Loading