The Non-Linear Representation Dilemma: Is Causal Abstraction Enough for Mechanistic Interpretability?
Note that the distributed alignment search (DAS) implemented in this repository is a reimplementation of the DAS proposed in "Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations" (Geiger et al., 2024b).
The required libraries can be imported with:
pip install -r requirements.txt
To reproduce our experiments, you must first generate the results for the Hierarchical Equality and Distributive Law Setting Experiments and Indirect Object Identification Experiments, then run the plotting functions (Plot Generation).
The scripts (in notebook folder) used for this settings are:
- Standard_DAS.py: Script used to create the results applying the standard DAS approach as proposed in paper "Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations" (Geiger et al., 2024b)
- Standard_DAS_DAS_fitted.py: Similar to "Standard_DAS.ipynb" but MLP model is not only trained to predict the task output but is further biased to ether the And-Or-And or And-Or Algorithm in the Distributive Law Setting.
- Hidden_Size_Progression.py: Script used to create the results applying DAS using different RevNets with different sized hidden layer size.
- Training_Progression.py: Script used to create the results of DAS along the training of the MLP
- Standard_DAS_Neurons.py: Script used to create the results when instead of applying DAS, using greedy search over the neurons for interventions.
- Standard_DAS_Injectivity.py: Script used to create the results for the injectivity experiments mentioned in the Appendix.
Note that to evaluate the different algorithms (with corresponding Tasks) and bijective function ϕ, the script script_commands.sh can be used. Please note that it is not recommended to run this script as is, as it would need far to long. It is recommended to split the script in multiple parts and run it in parallel on multiple computers. Note that we provide also a zip file including all results generated by our experiments (script_commands.sh) in these two settings as "Results.zip".
The main training code for DAS on the Pythia suite is implemented in scripts/das_llm.py. To reproduce our Pythia experiments, we provide several bash scripts in the run folder. The training will automatically log to wandb.ai under the CausalAbstractionLLM project. You can change the entity, project or runname by adding arguments. Run python scripts/das_llm.py --help for details.
For training progression experiments on Pythia-410M, run:
bash run/pythia410m.shFor the different pythia sizes experiments, run:
bash run/pythia_size.shFor the different seeds on pythia410m, run:
bash run/pythia_seeds.shThese scripts will populate the results folder with the outputs. Once the results are generated, you can proceed to run the plotting scripts to visualize the data.
The plotting scripts automatically generate all plots from the paper and export them to the notebooks/plots folder.
# Hierarchical Equality and Distributive Law Setting Experiments
python notebooks/plot_mlp_hidden_size.py # Hidden Size plots for all MLP experiments
python notebooks/plot_mlp_training_progression.py # Training progression plot for all MLP experiments
# IOI Experiments
python notebooks/plot_pythia410m.py # Training progression plot for pythia 410m
python notebooks/plot_pythiasizes.py # Plot for multiple pythia sizes
python notebooks/plot_pythia410m_seeds.py # Plot for multiple 410m seedsThe das folder contains the main code for our experiments:
- DAS.py: Includes the code for the general DAS class facilitating the DAS experiments
- DAS_MLP.py: Includes specific code of the DAS class applied on a MLP.
- DAS_LLM.py: Includes specific code of the DAS class applied on a LLM.
- Dataset_Generation.py: Code used for generating the used datasets.
- Helper_Functions.py: General helper code
- Classification_Model.py: Includes the code for the MLP model used for the tasks Hierarchical Equality and Distributed Law Setting Experiments.
- LLM_Model.py: Code required to get the LLM model used for the Indirect Object Identification Task
- Rotation_Model.py: Rotation and its inverse
- RevNet.py: RevNet and its inverse
- plotting.py: Plotting utilities.