Skip to content

Add Evoformer extraction hooks for intermediate MSA and pair representations#71

Open
PranavNarala1 wants to merge 18 commits into
AI2Science:mainfrom
PranavNarala1:main
Open

Add Evoformer extraction hooks for intermediate MSA and pair representations#71
PranavNarala1 wants to merge 18 commits into
AI2Science:mainfrom
PranavNarala1:main

Conversation

@PranavNarala1
Copy link
Copy Markdown

Summary

This PR adds Evoformer extraction support to the OpenFold inference workflow by introducing forward-hook-based instrumentation for selected Evoformer layers. During inference, the new extraction path captures intermediate MSA and pair representations, stores them in a structured dictionary, and saves them as a reusable artifact for downstream analysis and visualization.

What this PR adds

This PR adds the extraction layer for intermediate Evoformer representations. In particular, it introduces:

  • forward hooks for selected Evoformer blocks
  • capture of intermediate msa and pair tensors during inference
  • clean hook registration and removal
  • structured dictionary output keyed by layer and tensor type
  • a small inspection utility for validating saved extraction artifacts

Files added / updated

  • run_evoformer_hook_pretrained_openfold.py
    Adds Evoformer instrumentation flags and integrates hook-based extraction into the inference path.

  • openfold/utils/evoformer_instrumentation.py
    Contains the extraction logic for attaching hooks, recording tensors, and saving captured outputs.

  • openfold/utils/evoformer_run_artifact.py
    Provides utilities for working with saved extraction artifacts and downstream visualization workflows.

  • inspect_evoformer_reps.py
    Helper script for validating saved .pt extraction files by printing keys, shapes, and summary statistics.

  • openfold/utils/import_weights.py
    Includes the local compatibility fix needed for the current environment.

  • openfold/__init__.py
    Small import cleanup needed for this setup.

Output format

Captured intermediate outputs are saved in a dictionary with keys of the form:

  • layer_00.msa
  • layer_00.pair
  • layer_12.msa
  • layer_12.pair
  • layer_24.msa
  • layer_24.pair
  • layer_47.msa
  • layer_47.pair

This makes the output easy to consume for downstream tensor processing, visualization, and interface work.

How to test

Run inference with Evoformer extraction enabled:

python3 run_evoformer_hook_pretrained_openfold.py \
    ./examples/monomer/fasta_dir_6KWC \
    /storage/ice1/shared/d-pace_community/alphafold/alphafold_2.3.2_data/pdb_mmcif/mmcif_files \
    --use_precomputed_alignments ./examples/monomer/alignments \
    --output_dir ./outputs/my_outputs_align_6KWC_demo_tri_18 \
    --config_preset model_1_ptm \
    --jax_param_path /storage/ice1/shared/d-pace_community/alphafold/alphafold_2.3.2_data/params/params_model_1_ptm.npz \
    --uniref90_database_path /storage/ice1/shared/d-pace_community/alphafold/alphafold_2.3.2_data/uniref90/uniref90.fasta \
    --mgnify_database_path /storage/ice1/shared/d-pace_community/alphafold/alphafold_2.3.2_data/mgnify/mgy_clusters_2022_05.fa \
    --pdb70_database_path /storage/ice1/shared/d-pace_community/alphafold/alphafold_2.3.2_data/pdb70/pdb70 \
    --uniclust30_database_path /storage/ice1/shared/d-pace_community/alphafold/alphafold_2.3.2_data/uniclust30/uniclust30_2018_08 \
    --bfd_database_path /storage/ice1/shared/d-pace_community/alphafold/alphafold_2.3.2_data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
    --save_outputs \
    --skip_relaxation \
    --model_device cuda:0 \
    --attn_map_dir ./outputs/attention_files_6KWC_demo_tri_18 \
    --num_recycles_save 1 \
    --triangle_residue_idx 18 \
    --demo_attn \
    --instrument_evoformer \
    --instrument_layers 0,12,24,47 \
    --instrument_out_dir ./outputs/instrumentation

After the run, confirm that the extraction artifact exists:

  • ./outputs/instrumentation/6KWC_1_model_1_ptm_evoformer_reps.pt

Then inspect it with:

python inspect_evoformer_reps.py ./outputs/instrumentation/6KWC_1_model_1_ptm_evoformer_reps.pt

Expected behavior:

  • inference completes successfully
  • a .pt file is saved
  • the output contains layer-specific msa and pair tensors
  • keys and shapes are printed correctly by the inspection script

Validation performed

Tested on a real inference run with selected Evoformer layers 0, 12, 24, 47.

The saved artifact contained:

  • layer_00.msa
  • layer_00.pair
  • layer_12.msa
  • layer_12.pair
  • layer_24.msa
  • layer_24.pair
  • layer_47.msa
  • layer_47.pair

Observed tensor shapes:

  • msa: (516, 191, 256)
  • pair: (191, 191, 128)

Additional validation:

  • early vs. late msa layers had matching shapes but were not identical
  • early vs. late pair layers had matching shapes but were not identical

This confirmed that the hooks fired correctly, captured nontrivial intermediate tensors, and produced stable layer-specific outputs.

Limitations

  • Current extraction is based on Evoformer block outputs, not finer-grained submodule attention hooks.
  • The saved tensors can be large, so extraction is currently best used on selected layers rather than all layers at once.
  • This PR focuses on extraction only; downstream tensor processing, visualization, and interface features are handled separately.

Why this matters

Before this change, the workflow supported exported attention summaries, but not general intermediate Evoformer representation capture. This PR adds the extraction backbone needed for downstream visualization and analysis of internal model representations.

@sherrylicodes
Copy link
Copy Markdown

Really cool! Great for downstream visualization work. One suggestion is to include a small metadata object alongside the saved .pt artifact, like selected layers, tensor shapes, residue count, model/config preset, recycle index. Could make it easier for offline readers or visualization tools to validate the artifact before loading large tensors and to map msa or pair outputs into future UI views.

Also, since the tensors can be large, it could be good to document if the saved artifact supports partial loading or if downstream tools should convert it into a chunked format like Zarr.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants