Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
19df29c
adding graphing code from openfold output
Mar 13, 2026
77a2d02
adding in evoformer hook instrumentation and plotting utils
Mar 23, 2026
afa4819
Create run_evoformer_hook_pretrained_openfold.py
PranavNarala1 Apr 13, 2026
9617bff
Add Flask web interface, visualization script, and updated demo noteb…
Apr 25, 2026
aaec509
Merge pull request #1 from SruthiVangavolu7/feature/issue8sruthi
prathampatel21 Apr 27, 2026
f464381
Add viz plot module: residue-indexed heatmap and line plot for Issue #8
prathampatel21 Apr 27, 2026
fa944c6
Extend viz module: heatmap grid, multi-line, layer trajectory, histogram
prathampatel21 Apr 27, 2026
befec41
Add representation tensor processing utilities
priyavisingh Apr 27, 2026
7d48c44
Add tests for representation tensor utilities
priyavisingh Apr 27, 2026
b64a870
Merge pull request #2 from priyavisingh/priyavi-tensor-processing
priyavisingh Apr 27, 2026
4d2b998
Remove headers for consistency
priyavisingh Apr 27, 2026
2d129e0
Merge pull request #3 from priyavisingh/priyavi-tensor-processing
priyavisingh Apr 27, 2026
d3d21db
Merge remote-tracking branch 'origin/main' into prathamheatmap
prathampatel21 Apr 28, 2026
2cec8ab
Add Evoformer extraction hooks and inspection utilities
Apr 28, 2026
4f10378
adding in new cell for testing evoformer hook extraction
Apr 28, 2026
a76a7af
Merge pull request #1 from PranavNarala1/final-checkin
PranavNarala1 Apr 28, 2026
29ca8b3
Integrate viz with Priyavi's representation_tensor_utils
prathampatel21 Apr 28, 2026
40ac71a
Normalize cell key ordering in viz demo notebook
prathampatel21 Apr 28, 2026
a0c3b73
Add integration tests for viz <-> representation_tensor_utils bridge
prathampatel21 Apr 28, 2026
4b569c4
Wire viz heatmaps into the Flask UI via render_for_ui.py
prathampatel21 Apr 28, 2026
a246a85
Add Evoformer extraction hooks and inspection utilities
Apr 28, 2026
bc0f394
adding in new cell for testing evoformer hook extraction
Apr 28, 2026
ac98a6e
Merge branch 'PranavNarala1-main'
Apr 28, 2026
24cad22
Merge branch 'main' of https://github.com/PranavNarala1/attention-viz…
Apr 28, 2026
16b4f67
Resolve notebook conflict against priyavi main
Apr 28, 2026
62ba170
Merge pull request #4 from PranavNarala1/main
PranavNarala1 Apr 28, 2026
63540c1
Merge remote-tracking branch 'origin/main' into prathamheatmap
prathampatel21 Apr 28, 2026
4baeffd
Add EvoformerRunArtifact -> Figure bridge
prathampatel21 Apr 28, 2026
d1a7fce
Merge pull request #5 from priyavisingh/prathamheatmap
prathampatel21 Apr 28, 2026
819df6a
fix(viz): moved web interface into viz folder and added security check
SreeDan Apr 28, 2026
876e574
fix: add edge case handling and quality of life improvements
SreeDan Apr 28, 2026
c921415
refactor(viz)
SreeDan Apr 28, 2026
6e524cb
feat(viz): added requirements
SreeDan Apr 28, 2026
b9c161a
update readme
SreeDan Apr 28, 2026
4b4de7d
refactor(viz): moved evoformer inspect to viz folder
SreeDan Apr 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,28 @@ cutlass/
*.sto
*.a3m
*.hhr

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl

outputs/
.ipynb_checkpoints/
*.so
*.pt
*.pkl
*.bak
98 changes: 98 additions & 0 deletions notebooks/intermediate_rep_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import numpy as np
import matplotlib.pyplot as plt

# ---- Config ----
PROT = "6KWC"
TRI_RESIDUE_IDX = 18
ATTN_MAP_DIR = f"./outputs/attention_files_{PROT}_demo_tri_{TRI_RESIDUE_IDX}"
IMAGE_OUTPUT_DIR = f"./outputs/attention_images_{PROT}_demo_tri_{TRI_RESIDUE_IDX}"
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)

# ---- Helper: parse attention txt file ----
def parse_attention_file(filepath):
scores = []
with open(filepath, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 3:
try:
score = float(parts[2])
i = int(parts[0])
j = int(parts[1])
scores.append((i, j, score))
except:
continue
return scores

# ---- Helper: build matrix from scores ----
def build_matrix(scores, size=200):
matrix = np.zeros((size, size))
for i, j, score in scores:
if i < size and j < size:
matrix[i, j] = score
return matrix

# ---- Plot 1: Heatmap for a single layer ----
def plot_heatmap(layer_idx, attention_type="msa_row"):
if attention_type == "msa_row":
fname = f"msa_row_attn_layer{layer_idx}.txt"
else:
fname = f"triangle_start_attn_layer{layer_idx}_residue_idx_{TRI_RESIDUE_IDX}.txt"

filepath = os.path.join(ATTN_MAP_DIR, fname)
scores = parse_attention_file(filepath)
matrix = build_matrix(scores)

plt.figure(figsize=(8, 6))
plt.imshow(matrix, cmap='viridis', aspect='auto')
plt.colorbar(label='Attention Score')
plt.title(f'{attention_type} Attention - Layer {layer_idx}')
plt.xlabel('Residue j')
plt.ylabel('Residue i')
out_path = os.path.join(IMAGE_OUTPUT_DIR, f'heatmap_{attention_type}_layer{layer_idx}.png')
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved heatmap to {out_path}")

# ---- Plot 2: Line plot of top attention scores across layers ----
def plot_scores_across_layers(attention_type="msa_row", num_layers=48):
top_scores = []
for layer_idx in range(num_layers):
if attention_type == "msa_row":
fname = f"msa_row_attn_layer{layer_idx}.txt"
else:
fname = f"triangle_start_attn_layer{layer_idx}_residue_idx_{TRI_RESIDUE_IDX}.txt"
filepath = os.path.join(ATTN_MAP_DIR, fname)
if os.path.exists(filepath):
scores = parse_attention_file(filepath)
if scores:
top_scores.append(max(s[2] for s in scores))
else:
top_scores.append(0)
else:
top_scores.append(0)

plt.figure(figsize=(12, 5))
plt.plot(range(num_layers), top_scores, marker='o', linewidth=2)
plt.title(f'Top Attention Score Across Layers - {attention_type}')
plt.xlabel('Layer')
plt.ylabel('Top Attention Score')
plt.grid(True)
out_path = os.path.join(IMAGE_OUTPUT_DIR, f'lineplot_{attention_type}_across_layers.png')
plt.savefig(out_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved line plot to {out_path}")

# ---- Run everything ----
if __name__ == "__main__":
# Generate heatmaps for a few layers
for layer in range(48):
plot_heatmap(layer, attention_type="msa_row")
plot_heatmap(layer, attention_type="triangle_start")

# Generate line plots across all 48 layers
plot_scores_across_layers(attention_type="msa_row")
plot_scores_across_layers(attention_type="triangle_start")

print("All done! Check", IMAGE_OUTPUT_DIR)
Loading