This repository is based on code from the original MIB repository, with modifications (adding new alignment maps) and an added experimental pipeline for analyzing different alignment maps.
This guide explains how to reproduce the RevNet, Rotation+RevNet, and Rotation+Linear experiment results, as well as how to generate the corresponding visualizations and analyses.
sh experiment_RevNet.shTo modify the intervention size, edit line 305 of the configuration file:
"n_features": 1,Change it to:
"n_features": 2,or
"n_features": 16,After running the experiment, open and execute the corresponding Jupyter notebook to generate the visualisations
ModelAnalysis_RevNet.ipynb
sh experiment_RotationRevNet.shAfterwards, run the related Jupyter notebook to produce the plots and figures:
ModelAnalysis_RotationRevNet.ipynb
sh experiment_RotationLinear.shRun:
sh experiment_RotationLinear_Result_Difference_visualization_true.sh-
Determine the range of the intervention (x) values from step 1.
-
Modify the first two numbers in:
experiment_RotationLinear_Result_Difference_visualization_X.sh
-
Run the script to generate RDA results for random interventions within that range.
-
From step 1, identify the region where true interventions fall in rotation space.
-
Choose ranges for the intervention (x) and second latent (y) dimensions.
-
Update the first two (x) and second two (y) numbers to the chosen range:
experiment_RotationLinear_Result_Difference_visualization_XY_Rotation.sh
-
Run the script to analyse behaviour in rotation space.
sh experiment_RotationLinear_Logit_Difference_visualization_true.sh-
Based on step 1, identify the intervention range from the dataset.
-
Update the first two numbers (for x-range) in:
experiment_RotationLinear_Logit_Difference_visualization_X.sh
-
Run the script to perform the logit difference analysis with random interventions.
To make experiments with an algorithm which computes and uses the Carry-One value from a number system other than 10 (e.g. Binary), you can change in
tasks/two_digit_addition_task/arithmetic.py
The 16th line. E.g., for Binary change
MNS=10
to
MNS=2