Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions experiments/aggregate_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def clear_memory():
# Run DAS (Direct Attribution with Subspace) method
if "DAS" in methods:
if verbose:
print("Running DAS method...")
print("Running DAS method with alignment map"+config["alignment_map"]+"...")

config["method_name"] = "DAS"
experiment = PatchResidualStream(pipeline, task, list(range(start, end)), token_positions, checker, config=config)
method_model_dir = os.path.join(model_dir, f"DAS_{pipeline.model.__class__.__name__}_{"-".join(target_variables)}")
method_model_dir = os.path.join(model_dir, f"DAS_{config['alignment_map']}_{pipeline.model.__class__.__name__}_{"-".join(target_variables)}")
experiment.train_interventions(train_data, target_variables, method="DAS", verbose=verbose, model_dir=method_model_dir)
raw_results = experiment.perform_interventions(test_data, verbose=verbose, target_variables_list=[target_variables], save_dir=results_dir)
heatmaps(experiment, raw_results, config)
Expand Down Expand Up @@ -353,4 +353,4 @@ def clear_memory():

# Release memory before next experiment
del experiment, raw_results
clear_memory()
clear_memory()
27 changes: 19 additions & 8 deletions experiments/intervention_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,24 @@ def train_interventions(self, datasets, target_variables, method="DAS", model_di
for model_unit in model_units:
if method == "DAS":
# For DAS, use trainable subspace featurizer
model_unit.set_featurizer(
SubspaceFeaturizer(
shape=(model_unit.shape[0], self.config["n_features"]),
trainable=True,
id="DAS"
if self.config["alignment_map"]=="Rotation":
model_unit.set_featurizer(
SubspaceFeaturizer(
shape=(model_unit.shape[0], self.config["n_features"]),
trainable=True,
id="DAS_Rotation"
)
)
)
model_unit.set_feature_indices(None) # Use all features
model_unit.set_feature_indices(None) # Use all features
if self.config["alignment_map"]=="RevNet":
model_unit.set_featurizer(
SubspaceFeaturizerRevNet(
shape=(self.config["number_blocks"], model_unit.shape[0], self.config["hidden_size"]),
trainable=True,
id="DAS_RevNet"
)
)
model_unit.set_feature_indices(list(range(self.config["n_features"]))) # Use n_features first features

# Train the intervention
_train_intervention(self.pipeline, model_units_list, counterfactual_dataset,
Expand All @@ -445,4 +455,5 @@ def train_interventions(self, datasets, target_variables, method="DAS", model_di
if model_dir is not None:
self.save_featurizers([model_unit for model_units in model_units_list for model_unit in model_units], model_dir)

return self
return self

190 changes: 190 additions & 0 deletions neural/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn
import pyvene as pv


Expand Down Expand Up @@ -149,6 +150,13 @@ def save_modules(self, path: str) -> Tuple[str, str]:
additional_config["requires_grad"] = (
self.featurizer.rotate.weight.requires_grad
)

# Extra config needed for Subspace featurizers with RevNet
elif featurizer_class == "SubspaceFeaturizerModule_RevNet":
additional_config["architectureshape"] = self.featurizer.shape
additional_config["state_dict"] = self.featurizer.RevNet.state_dict()
additional_config["requires_grad"] = self.featurizer.trainable


model_info = {
"featurizer_class": featurizer_class,
Expand Down Expand Up @@ -205,6 +213,19 @@ def load_modules(cls, path: str) -> "Featurizer":
assert (
featurizer.rotate.weight.shape == rot.shape
), "Rotation-matrix shape mismatch after deserialisation."

elif featurizer_class == "SubspaceFeaturizerModule_RevNet":
trainable = model_info["additional_config"]["requires_grad"]

# Re-build the RevNet.
shape=model_info["additional_config"]["architectureshape"]
RN = RevNet(shape[0], shape[1], shape[2])
RN.load_state_dict(model_info["additional_config"]["state_dict"])
RN.requires_grad_(trainable)

featurizer = SubspaceFeaturizerModule_RevNet(RN,shape,trainable)
inverse = SubspaceInverseFeaturizerModule_RevNet(RN,shape,traimplementedinable)

elif featurizer_class == "IdentityFeaturizerModule":
featurizer = IdentityFeaturizerModule()
inverse = IdentityInverseFeaturizerModule()
Expand Down Expand Up @@ -388,6 +409,32 @@ def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
def forward(self, f, error):
r = self.rotate.weight.T
return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)


class SubspaceFeaturizerModule_RevNet(torch.nn.Module):
"""Linear projector onto an orthogonal *rotation* sub-space."""

def __init__(self, RevNet,sh,trainable):
super().__init__()
self.RevNet = RevNet
self.shape=sh
self.trainable=trainable

def forward(self, x: torch.Tensor):
return self.RevNet(x.to(torch.float32)),torch.zeros_like(x)


class SubspaceInverseFeaturizerModule_RevNet(torch.nn.Module):
"""Inverse of :class:`SubspaceFeaturizerModule`."""

def __init__(self, RevNet,sh,trainable):
super().__init__()
self.RevNet = RevNet
self.shape=sh
self.trainable=trainable

def forward(self, f, error):
return self.RevNet.inverse(f).to(error.dtype) + error


class SubspaceFeaturizer(Featurizer):
Expand Down Expand Up @@ -421,6 +468,31 @@ def __init__(
n_features=rotate.weight.shape[1],
id=id,
)


class SubspaceFeaturizerRevNet(Featurizer):
"""RevNet sub-space featurizer."""

def __init__(
self,
*,
shape: Tuple[int, int, int] | None = None,
trainable: bool = True,
id: str = "subspace",
):

assert (
shape is not None
), "Provide `shape`."

RN = RevNet(shape[0], shape[1], shape[2])
RN.requires_grad_(trainable)
super().__init__(
SubspaceFeaturizerModule_RevNet(RN,shape,trainable),
SubspaceInverseFeaturizerModule_RevNet(RN,shape,trainable),
n_features=shape[1],
id=id,
)


class SAEFeaturizerModule(torch.nn.Module):
Expand Down Expand Up @@ -477,3 +549,121 @@ def _subspace_is_all_none(subspaces) -> bool:
return subspaces is None or all(
inner is None or all(elem is None for elem in inner) for inner in subspaces
)








# --------------------------------------------------------------------------- #
# RevNet Implementation #
# --------------------------------------------------------------------------- #
"""
Implementation of the RevNet:
The Reversible Residual Network: Backpropagation Without Storing Activations
From Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse
Url: https://arxiv.org/abs/1707.04585

Code was partially implemented with the help of ChatGPT (https://chatgpt.com/)
"""


class MLP(nn.Module):
"""
A Multi-Layer Perceptron (MLP) with configurable depth and width.

This implementation allows for dynamic specification of the network architecture
through the hidden_sizes parameter, which determines both the width and depth
of the network.

Args:
input_size (int): Dimensionality of the input features
output_size (int): Dimensionality of the output features
hidden_sizes (list of int): List specifying the size of each hidden layer
activation (nn.Module, optional): Activation function to use between layers.
Defaults to nn.ReLU()
dropout_rate (float, optional): Dropout probability. Defaults to 0.0
"""
def __init__(self, input_size, output_size, hidden_sizes, activation=nn.ReLU(), dropout_rate=0.0):
super(MLP, self).__init__()

# Validate inputs
if not isinstance(hidden_sizes, list) or len(hidden_sizes) == 0:
raise ValueError("hidden_sizes must be a non-empty list of integers")

# Build the layers
layers = []

# Input layer
layers.append(nn.Linear(input_size, hidden_sizes[0]))
layers.append(activation)
if dropout_rate > 0:
layers.append(nn.Dropout(dropout_rate))

# Hidden layers
for i in range(len(hidden_sizes) - 1):
layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
layers.append(activation)
if dropout_rate > 0:
layers.append(nn.Dropout(dropout_rate))

# Output layer
layers.append(nn.Linear(hidden_sizes[-1], output_size))

# Combine all layers
self.network = nn.Sequential(*layers)

def forward(self, x):
"""Forward pass through the MLP."""
return self.network(x)


class RevNet_Block(nn.Module):
def __init__(self, in_features, hidden_size, depth=1):
super(RevNet_Block, self).__init__()
self.half_in_features=in_features//2

self.F = MLP(self.half_in_features, self.half_in_features, [hidden_size]*depth)
self.G = MLP(self.half_in_features, self.half_in_features, [hidden_size]*depth)

def forward(self, x):
x_1 = x[:,:self.half_in_features]
x_2 = x[:,self.half_in_features:]
F_O = self.F(x_2)
y_1 = x_1 + F_O
G_O = self.G(y_1)
y_2 = x_2 + G_O
y = torch.cat((y_1, y_2), dim=1)
return y

def inverse(self, y):
y_1 = y[:,:self.half_in_features]
y_2 = y[:,self.half_in_features:]
G_O = self.G(y_1)
x_2 = y_2 - G_O
F_O = self.F(x_2)
x_1 = y_1 - F_O
x = torch.cat((x_1, x_2), dim=1)
return x


class RevNet(nn.Module):
def __init__(self, number_blocks, in_features, hidden_size, depth=1):
super(RevNet, self).__init__()
Model_Layers = []
for i in range(number_blocks):
Model_Layers.append(RevNet_Block(in_features, hidden_size, depth))
self.Model_Layers = nn.ModuleList(Model_Layers)

def forward(self, x):
for ac_layer in self.Model_Layers:
x = ac_layer(x)
return x

def inverse(self, y):
"""Applies inverse transformation with high precision."""
for ac_layer in reversed(self.Model_Layers):
y = ac_layer.inverse(y)
return y
5 changes: 3 additions & 2 deletions neural/model_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pyvene as pv

from neural.featurizers import Featurizer, SubspaceFeaturizer
from neural.featurizers import Featurizer, SubspaceFeaturizer, SubspaceFeaturizerRevNet


class ComponentIndexer:
Expand Down Expand Up @@ -250,4 +250,5 @@ def set_layer(self, layer: int):
self.component.layer = layer

def get_layer(self):
return self.component.layer
return self.component.layer