diff --git a/Python-3.11.3.tgz b/Python-3.11.3.tgz new file mode 100644 index 0000000..9de33e6 Binary files /dev/null and b/Python-3.11.3.tgz differ diff --git a/activation_dataset.py b/activation_dataset.py index f495af1..6cfcb5f 100644 --- a/activation_dataset.py +++ b/activation_dataset.py @@ -10,7 +10,7 @@ from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, Literal import numpy as np import numpy.typing as npt @@ -29,6 +29,7 @@ from transformer_lens import HookedTransformer from transformer_lens.loading_from_pretrained import get_official_model_name, convert_hf_model_config from transformers import GPT2Tokenizer, PreTrainedTokenizerBase +from transformers import AutoModelForCausalLM, AutoTokenizer from utils import * @@ -54,6 +55,7 @@ def get_activation_size(model_name: str, layer_loc: str): "residual", "mlp", "attn", + "attn_concat", "mlpout", ], f"Layer location {layer_loc} not supported" model_cfg = convert_hf_model_config(model_name) @@ -65,6 +67,8 @@ def get_activation_size(model_name: str, layer_loc: str): return model_cfg["d_head"] * model_cfg["n_heads"] elif layer_loc == "mlpout": return model_cfg["d_model"] + elif layer_loc == "attn_concat": + return model_cfg["d_head"] * model_cfg["n_heads"] def check_transformerlens_model(model_name: str): @@ -81,6 +85,7 @@ def make_tensor_name(layer: int, layer_loc: str, model_name: str) -> str: "residual", "mlp", "attn", + "attn_concat", "mlpout", ], f"Layer location {layer_loc} not supported" if layer_loc == "residual": @@ -88,6 +93,11 @@ def make_tensor_name(layer: int, layer_loc: str, model_name: str) -> str: tensor_name = f"blocks.{layer}.hook_resid_post" else: raise NotImplementedError(f"Model {model_name} not supported for residual stream") + elif layer_loc == "attn_concat": + if check_transformerlens_model(model_name): + tensor_name = f"blocks.{layer}.attn.hook_z" + else: + raise NotImplementedError(f"Model {model_name} not supported for attention output") elif layer_loc == "mlp": if check_transformerlens_model(model_name): tensor_name = f"blocks.{layer}.mlp.hook_post" @@ -324,7 +334,7 @@ def make_activation_dataset( print(f"Saved undersized chunk {n_saved_chunks} of activations, total size: {batch_idx * activation_size} ") -def make_activation_dataset_hf( +def make_activation_dataset_tl( sentence_dataset: DataLoader, model: HookedTransformer, activation_width: int, @@ -366,7 +376,10 @@ def make_activation_dataset_hf( for layer in layers: tensor_name = make_tensor_name(layer, tensor_loc, model.cfg.model_name) activation_data = cache[tensor_name].to(torch.float16) - activation_data = rearrange(activation_data, "b s n -> (b s) n") + if tensor_loc == "attn_concat": + activation_data = rearrange(activation_data, "b s n d -> (b s) (n d)") + else: + activation_data = rearrange(activation_data, "b s n -> (b s) n") if layer == layers[0]: n_activations += activation_data.shape[0] datasets[layer].append(activation_data) @@ -392,12 +405,156 @@ def make_activation_dataset_hf( #return ((chunk_means, chunk_stds) if center_dataset else None, n_activations) return n_activations +def make_activation_dataset_hf( + sentence_dataset: Dataset, + model: AutoModelForCausalLM, + tensor_names: List[str], + chunk_size: int, + n_chunks: int, + output_folder: str = "activation_data", + skip_chunks: int = 0, + device: Optional[torch.device] = torch.device("cuda:0"), + max_length: int = 2048, + model_batch_size: int = 4, + precision: Literal["float16", "float32"] = "float16", + shuffle_seed: Optional[int] = None, +): + with torch.no_grad(): + model.eval() + + dtype = None + if precision == "float16": + dtype = torch.float16 + elif precision == "float32": + dtype = torch.float32 + else: + raise ValueError(f"Invalid precision '{precision}'") + + dataset_iterator = iter(sentence_dataset) + chunk_batches = chunk_size // (model_batch_size * max_length) + batches_to_skip = skip_chunks * chunk_batches + + if shuffle_seed is not None: + torch.manual_seed(shuffle_seed) + + dataloader = DataLoader( + sentence_dataset, + batch_size=model_batch_size, + shuffle=shuffle_seed is not None, + ) + + dataloader_iter = iter(dataloader) + + for _ in range(batches_to_skip): + dataloader_iter.__next__() + + # configure hooks for the model + tensor_buffer: Dict[str, Any] = {} + + hook_handles = [] + + for tensor_name in tensor_names: + tensor_buffer[tensor_name] = [] + + def hook(module, output, tensor_name=tensor_name): + if type(output) == tuple: + out = output[0] + else: + out = output + tensor_buffer[tensor_name].append(rearrange(out, "b l ... -> (b l) (...)").to(dtype=dtype).cpu()) + return output + + for name, module in model.named_modules(): + if name == tensor_name: + handle = module.register_forward_hook(hook) + hook_handles.append(handle) + + def reset_buffers(): + for tensor_name in tensor_names: + tensor_buffer[tensor_name] = [] + + reset_buffers() + + chunk_idx = 0 + + progress_bar = tqdm(total=chunk_size * n_chunks) + + for batch_idx, batch in enumerate(dataloader_iter): + batch = batch["input_ids"].to(device) + + _ = model(batch) + + progress_bar.update(model_batch_size) + + if batch_idx+1 % chunk_batches == 0: + for tensor_name in tensor_names: + save_activation_chunk(tensor_buffer[tensor_name], chunk_idx, os.path.join(output_folder, tensor_name)) + + n_act = batch_idx * model_batch_size * max_length + print(f"Saved chunk {chunk_idx} of activations, total size: {n_act / 1e6:.2f}M activations") + + chunk_idx += 1 + + reset_buffers() + if chunk_idx >= n_chunks: + break + + # undersized final chunk + if chunk_idx < n_chunks: + for tensor_name in tensor_names: + save_activation_chunk(tensor_buffer[tensor_name], chunk_idx, os.path.join(output_folder, tensor_name)) + + n_act = batch_idx * model_batch_size * max_length + print(f"Saved undersized chunk {chunk_idx} of activations, total size: {n_act / 1e6:.2f}M activations") + + for hook_handle in hook_handles: + hook_handle.remove() + + def save_activation_chunk(dataset, n_saved_chunks, dataset_folder): dataset_t = torch.cat(dataset, dim=0).to("cpu") os.makedirs(dataset_folder, exist_ok=True) with open(dataset_folder + "/" + str(n_saved_chunks) + ".pt", "wb") as f: torch.save(dataset_t, f) +def setup_data_new( + model_name: str, + dataset_name: str, + output_folder: str, + tensor_names: List[str], + chunk_size: int, + n_chunks: int, + skip_chunks: int = 0, + device: Optional[torch.device] = torch.device("cuda:0"), + max_length: int = 2048, + model_batch_size: int = 4, + precision: Literal["float16", "float32"] = "float16", + shuffle_seed: Optional[int] = None, +): + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name).to(device=device) + + # weak upper bound on number of lines + max_lines = int((chunk_size * (n_chunks + skip_chunks)) / max_length) * 2 + + print(f"Processing first {max_lines} lines of dataset...") + + sentence_dataset = make_sentence_dataset(dataset_name, max_lines=max_lines) + tokenized_sentence_dataset, _ = chunk_and_tokenize(sentence_dataset, tokenizer, max_length=max_length) + make_activation_dataset_hf( + tokenized_sentence_dataset, + model, + tensor_names, + chunk_size, + n_chunks, + output_folder=output_folder, + skip_chunks=skip_chunks, + device=device, + max_length=max_length, + model_batch_size=model_batch_size, + precision=precision, + shuffle_seed=shuffle_seed, + ) def setup_data( tokenizer, @@ -444,7 +601,7 @@ def setup_data( ) else: dataset_folder = [dataset_folder] if isinstance(dataset_folder, str) else dataset_folder - n_datapoints = make_activation_dataset_hf( + n_datapoints = make_activation_dataset_tl( sentence_dataset=token_loader, model=model, activation_width=activation_width, diff --git a/autoencoders/__init__.py b/autoencoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/autoencoders/direct_coef_search.py b/autoencoders/direct_coef_search.py deleted file mode 100644 index c269349..0000000 --- a/autoencoders/direct_coef_search.py +++ /dev/null @@ -1,92 +0,0 @@ -import optree -import torch -import torch.nn.functional as F - -import optimizers.sgdm -from autoencoders.learned_dict import LearnedDict - -N_ITERS_OPT = 100 - - -class DirectCoefOptimizer: - @staticmethod - def init(d_activation, n_features, l1_alpha, lr=1e-3, dtype=torch.float32): - params = {} - params["decoder"] = torch.randn(n_features, d_activation, dtype=dtype) - - buffers = {} - buffers["l1_alpha"] = torch.tensor(l1_alpha, dtype=dtype) - buffers["lr"] = torch.tensor(lr, dtype=dtype) - - return params, buffers - - @staticmethod - def objective(c, normed_dict, batch, l1_alpha): - x_hat = torch.einsum("ij,bi->bj", normed_dict, c) - - l_reconstruction = (x_hat - batch).pow(2).mean() - l_sparsity = l1_alpha * torch.norm(c, 1, dim=-1).mean() - - losses = { - "loss": l_reconstruction + l_sparsity, - "l_reconstruction": l_reconstruction, - "l_l1": l_sparsity, - } - - aux = {"c": c} - - return l_reconstruction + l_sparsity, (losses, aux) - - @staticmethod - def basis_pursuit(params, buffers, batch, normed_dict=None): - if normed_dict is None: - decoder_norms = torch.norm(params["decoder"], 2, dim=-1) - normed_dict = params["decoder"] / torch.clamp(decoder_norms, 1e-8)[:, None] - - # hate this - c = torch.zeros_like(torch.einsum("ij,bj->bi", normed_dict, batch)) - - optimizer = optimizers.sgdm.SGDM(buffers["lr"], 0.9) - optim_state = optimizer.init(c) - - for _ in range(N_ITERS_OPT): - grads, _ = torch.func.grad(DirectCoefOptimizer.objective, has_aux=True)(c, normed_dict, batch, buffers["l1_alpha"]) - updates, optim_state = optimizer.update(grads, optim_state) - c += updates - c = F.relu(c) - - return c - - @staticmethod - def loss(params, buffers, batch): - decoder_norms = torch.norm(params["decoder"], 2, dim=-1) - normed_dict = params["decoder"] / torch.clamp(decoder_norms, 1e-8)[:, None] - - with torch.no_grad(): - c = DirectCoefOptimizer.basis_pursuit(params, buffers, batch, normed_dict=normed_dict) - - x_hat = torch.einsum("ij,bi->bj", normed_dict, c) - l_reconstruction = (x_hat - batch).pow(2).mean() - - return l_reconstruction, ({"loss": l_reconstruction}, {"c": c}) - - @staticmethod - def to_learned_dict(params, buffers): - return DirectCoefSearch(params, buffers) - - -class DirectCoefSearch(LearnedDict): - def __init__(self, params, buffers): - self.params = params - self.buffers = buffers - - def encode(self, x): - return DirectCoefOptimizer.basis_pursuit(self.params, self.buffers, x) - - def get_learned_dict(self): - decoder_norms = torch.norm(self.params["decoder"], 2, dim=-1) - return self.params["decoder"] / torch.clamp(decoder_norms, 1e-8)[:, None] - - def to_device(self, device): - self.params = optree.tree_map(lambda t: t.to(device), self.params) - self.buffers = optree.tree_map(lambda t: t.to(device), self.buffers) diff --git a/autoencoders/ica.py b/autoencoders/ica.py index 902004a..4cd1c6d 100644 --- a/autoencoders/ica.py +++ b/autoencoders/ica.py @@ -6,11 +6,14 @@ from sklearn.preprocessing import StandardScaler from torchtyping import TensorType +from typing import Tuple + from autoencoders.learned_dict import LearnedDict from autoencoders.topk_encoder import TopKLearnedDict -_n_samples, _activation_size = None, None - +_n_samples, _activation_size = ( + None, None +) # type: Tuple[None, None] class ICAEncoder(LearnedDict): def __init__(self, activation_size, n_components: int = 0): @@ -51,3 +54,28 @@ def to_topk_dict(self, sparsity): negatives = -positives components = np.concatenate([positives, negatives], axis=0) return TopKLearnedDict(components, sparsity) + + def to_nneg_dict(self): + return NNegICAEncoder(self.activation_size, self.ica) + +class NNegICAEncoder(LearnedDict): + def __init__(self, activation_size, ica): + self.activation_size = activation_size + self.ica = ica + + def to_device(self, device): + pass + + def encode(self, x): + assert x.shape[1] == self.activation_size + x_standardized = self.scaler.transform(x.cpu().numpy().astype(np.float64)) + c = self.ica.transform(x_standardized) + c_neg = -c + c = np.clamp(c, min=0) + c_neg = np.clamp(c_neg, min=0) + return torch.cat([torch.tensor(c, device=x.device), torch.tensor(c_neg, device=x.device)], dim=-1) + + def get_learned_dict(self): + components = torch.tensor(self.ica.components_, dtype=torch.float32) + components = torch.cat([components, -components], dim=0) + return components / torch.norm(components, dim=-1, keepdim=True) \ No newline at end of file diff --git a/autoencoders/learned_dict.py b/autoencoders/learned_dict.py index a1d1177..a7028c0 100644 --- a/autoencoders/learned_dict.py +++ b/autoencoders/learned_dict.py @@ -5,10 +5,13 @@ from torch import nn from torchtyping import TensorType -from autoencoders.ensemble import DictSignature +from typing import Tuple -_n_dict_components, _activation_size, _batch_size = None, None, None +from autoencoders.ensemble import DictSignature +_n_dict_components, _activation_size, _batch_size = ( + None, None, None + ) # type: Tuple[None, None, None] class LearnedDict(ABC): n_feats: int @@ -51,19 +54,34 @@ def n_dict_components(self): class Identity(LearnedDict): - def __init__(self, activation_size): + def __init__(self, activation_size, device=None): self.n_feats = activation_size self.activation_size = activation_size + self.device = "cpu" if device is None else device def get_learned_dict(self): - return torch.eye(self.n_feats) + return torch.eye(self.n_feats, device=self.device) def encode(self, batch): return batch def to_device(self, device): - pass + self.device = device + +class IdentityPositive(LearnedDict): + def __init__(self, activation_size, device=None): + self.n_feats = activation_size + self.activation_size = activation_size + self.device = "cpu" if device is None else device + + def get_learned_dict(self): + return torch.cat([torch.eye(self.n_feats, device=self.device), -torch.eye(self.n_feats, device=self.device)], dim=0) + def encode(self, batch): + return torch.clamp(torch.cat([batch, -batch], dim=-1), min=0.0) + + def to_device(self, device): + self.device = device class IdentityReLU(LearnedDict): def __init__(self, activation_size, bias: Optional[torch.Tensor] = None): @@ -132,7 +150,7 @@ def encode(self, batch): class TiedSAE(LearnedDict): - def __init__(self, encoder, encoder_bias, centering=(None, None, None), norm_encoder=False): + def __init__(self, encoder, encoder_bias, centering=(None, None, None), norm_encoder=True): self.encoder = encoder self.encoder_bias = encoder_bias self.norm_encoder = norm_encoder @@ -145,6 +163,7 @@ def __init__(self, encoder, encoder_bias, centering=(None, None, None), norm_enc if center_rot is None: center_rot = torch.eye(self.activation_size) + print(center_rot) if center_scale is None: center_scale = torch.ones(self.activation_size) diff --git a/autoencoders/nmf.py b/autoencoders/nmf.py index 936803e..175edab 100644 --- a/autoencoders/nmf.py +++ b/autoencoders/nmf.py @@ -17,10 +17,14 @@ from sklearn.decomposition import NMF from torchtyping import TensorType +from typing import Tuple + from autoencoders.learned_dict import LearnedDict from autoencoders.topk_encoder import TopKLearnedDict -_n_samples, _activation_size = None, None +_n_samples, _activation_size = ( + None, None +) # type: Tuple[None, None] class NMFEncoder(LearnedDict): diff --git a/autoencoders/pca.py b/autoencoders/pca.py index 286c1ef..fbaf2c6 100644 --- a/autoencoders/pca.py +++ b/autoencoders/pca.py @@ -1,6 +1,6 @@ import torch -from autoencoders.learned_dict import LearnedDict, Rotation +from autoencoders.learned_dict import LearnedDict, Rotation, TiedSAE from autoencoders.topk_encoder import TopKLearnedDict def calc_pca(activations, batch_size=512, device="cuda:0"): @@ -97,9 +97,18 @@ def to_topk_dict(self, sparsity): eigvecs_ = torch.cat([eigvecs, -eigvecs], dim=0) return TopKLearnedDict(eigvecs_, sparsity) - def to_rotation_dict(self, n_components): + def to_rotation_dict(self, n_components=None): + if n_components is None: + n_components = self.n_dims return Rotation(self.get_dict()[:n_components]) + def to_pve_rotation_dict(self, n_components=None): + if n_components is None: + n_components = self.n_dims + dirs = self.get_dict()[:n_components] + dirs_ = torch.cat([dirs, -dirs], dim=0) + return TiedSAE(dirs_, torch.zeros(2 * n_components), centering=(self.get_mean(), None, None), norm_encoder=True) + class PCAEncoder(LearnedDict): def __init__(self, pca_dict, sparsity): diff --git a/basic_l1_sweep.py b/basic_l1_sweep.py index 0e91058..3c81775 100644 --- a/basic_l1_sweep.py +++ b/basic_l1_sweep.py @@ -122,7 +122,9 @@ def basic_l1_sweep( args = parser.parse_args() - l1_values = np.logspace(args.l1_value_min, args.l1_value_max, args.l1_value_n) + #l1_values = list(np.logspace(args.l1_value_min, args.l1_value_max, args.l1_value_n)) + + l1_values = [0, 1e-3, 3e-4, 1e-4] basic_l1_sweep( args.dataset_dir, args.output_dir, diff --git a/big_sweep_experiments.py b/big_sweep_experiments.py index 3874067..b82fa43 100644 --- a/big_sweep_experiments.py +++ b/big_sweep_experiments.py @@ -24,6 +24,8 @@ from cluster_runs import dispatch_job_on_chunk from utils import dotdict +from typing import Optional + # an example function that builds a list of ensembles to run # you could this as a template for other experiments @@ -34,7 +36,7 @@ # - a list of hyperparameters that vary between models in the same ensemble # - a dict of hyperparameter ranges -DICT_RATIO = None +DICT_RATIO: Optional[int] = None def tied_vs_not_experiment(cfg: dotdict): l1_values = list(np.logspace(-3.5, -2, 4)) @@ -92,7 +94,6 @@ def tied_vs_not_experiment(cfg: dotdict): cfg.activation_width, cfg.activation_width * 8, l1_alpha, - bias_decay=bias_decay, dtype=cfg.dtype, ) for l1_alpha, bias_decay in cfgs @@ -140,7 +141,6 @@ def tied_vs_not_experiment(cfg: dotdict): cfg.activation_width, cfg.activation_width * 4, l1_alpha, - bias_decay=bias_decay, dtype=cfg.dtype, ) for l1_alpha, bias_decay in cfgs @@ -188,7 +188,6 @@ def tied_vs_not_experiment(cfg: dotdict): cfg.activation_width, cfg.activation_width * 2, l1_alpha, - bias_decay=bias_decay, dtype=cfg.dtype, ) for l1_alpha, bias_decay in cfgs @@ -305,7 +304,6 @@ def dense_l1_range_experiment(cfg: dotdict): cfg.activation_width, dict_size, l1_alpha, - bias_decay=0.0, dtype=cfg.dtype, ) for l1_alpha in cfgs @@ -507,7 +505,6 @@ def zero_l1_baseline(cfg: dotdict): cfg.activation_width, dict_size, l1_alpha, - bias_decay=0.0, dtype=cfg.dtype, ) for l1_alpha in cfgs @@ -874,7 +871,7 @@ def pythia_1_4_b_dict(cfg: dotdict): }, device=device ) - args = {"batch_size": cfg.batch_size, "device": device, "dict_size": dict_sizes[i]} + args = {"batch_size": cfg.batch_size, "device": device, "dict_size": dict_size} name = f"l1_{i}" ensembles.append((ensemble, args, name)) @@ -926,7 +923,6 @@ def run_zeros_only(cfg: dotdict): cfg.activation_width, dict_size, l1_alpha, - bias_decay=0.0, dtype=cfg.dtype, ) for l1_alpha in l1_values @@ -974,7 +970,6 @@ def long_mlp_sweep(cfg: dotdict): cfg.activation_width, dict_size, l1_alpha, - bias_decay=0.0, dtype=cfg.dtype, ) for l1_alpha in l1_values @@ -1112,7 +1107,6 @@ def simple_setoff(cfg: dotdict) -> Tuple[List[Tuple[FunctionalEnsemble, dict, st cfg.activation_width, dict_size, l1_alpha, - bias_decay=0.0, dtype=cfg.dtype, ) for l1_alpha in l1_values diff --git a/do_ioi_multiple_layers.sh b/do_ioi_multiple_layers.sh new file mode 100644 index 0000000..54f6a2a --- /dev/null +++ b/do_ioi_multiple_layers.sh @@ -0,0 +1,4 @@ +python generate_test_data.py --model="EleutherAI/pythia-410m-deduped" --n_chunks=30 --layers 3 + +python basic_l1_sweep.py --dataset_dir="activation_data/layer_3" --output_dir="dicts_l3" --ratio=4 +python ioi_feature_ident.py 3 \ No newline at end of file diff --git a/ensemble_training_example.py b/ensemble_training_example.py deleted file mode 100644 index b290265..0000000 --- a/ensemble_training_example.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchopt - -from autoencoders.sae_ensemble import SAE, SAEEnsemble -from sc_datasets.random_dataset import RandomDatasetGenerator - -# we calculate gradients functionally so disable autograd for memory -torch.set_grad_enabled(False) - -l1_exp_base = 10 ** (1 / 4) -n_features = 1024 -d_activation = 512 -n_dict_components = 2048 -batch_size = 256 -dataset = RandomDatasetGenerator(d_activation, n_features, batch_size, 5, 0.99, True, "cuda") - - -def mmcs(truth, dict): - # truth: [n_features, d_activation] - # dict: [n_dict_components, d_activation] - - cosine_sim = truth @ dict.T - max_cosine_sim, _ = torch.max(cosine_sim, dim=0) - return max_cosine_sim.mean() - - -l1_coefs = [1 * l1_exp_base**i for i in range(-16, -11)] - -models = [SAE(d_activation, n_dict_components, l1_coef=l1_coef).to("cuda") for l1_coef in l1_coefs] -ensemble = SAEEnsemble(models, torchopt.adam(lr=1e-3)) - -for i in range(1000): - minibatch = dataset.__next__().unsqueeze(0).expand(len(models), -1, -1) - # minibatch = torch.randn(len(models), batch_size, d_activation, device="cuda", requires_grad=True) - losses = ensemble.step_batch(minibatch) - - if i % 100 == 0: - mmcss = torch.vmap(lambda y: mmcs(dataset.feats, y))(ensemble.params["decoder"]) - - print(f"Step {i}") - print(f" Losses: {losses}") - print(f" MMCS: {mmcss}") diff --git a/erasure.py b/erasure.py deleted file mode 100644 index 3978ccd..0000000 --- a/erasure.py +++ /dev/null @@ -1,1243 +0,0 @@ -from functools import partial -from itertools import product -from typing import List, Tuple, Union, Any, Dict, Literal, Optional, Callable - -from datasets import load_dataset -from einops import rearrange -import matplotlib.pyplot as plt -import matplotlib -import numpy as np -from PIL import Image -from sklearn.cluster import KMeans -from sklearn.manifold import TSNE -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader -from torchtyping import TensorType - -import tqdm - -from transformer_lens import HookedTransformer - -from autoencoders.learned_dict import LearnedDict, RandomDict -from autoencoders.pca import BatchedPCA - -from activation_dataset import setup_data - -import standard_metrics - -import copy - -from test_datasets.ioi import generate_ioi_dataset -from test_datasets.gender import generate_gender_dataset, generate_pronoun_dataset -from test_datasets.winobias import generate_winobias_dataset - -from concept_erasure import LeaceFitter, LeaceEraser - -from sklearn.metrics import roc_auc_score - -from dataclasses import dataclass - -import os - -from sklearn.linear_model import LogisticRegression, LinearRegression, RidgeClassifier - -class NullspaceProjector: - def __init__(self, nullspace): - self.d_activation = nullspace.shape[0] - self.nullspace = nullspace.detach().clone() / torch.linalg.norm(nullspace) - - def project(self, tensor: TensorType["batch", "d_activation"]) -> TensorType["batch", "d_activation"]: - return tensor - (torch.einsum("bd,d->b", tensor, self.nullspace)[..., None] * self.nullspace) - - @staticmethod - def class_means( - activations: TensorType["batch", "d_activation"], - class_labels: TensorType["batch"], - ) -> "NullspaceProjector": - class_means = torch.stack([ - activations[class_labels == i].mean(dim=0) - for i in range(activations.shape[0]) - ], dim=0) - - class_means_diff = class_means[1] - class_means[0] - - return NullspaceProjector(class_means_diff) - -def resample_ablation_hook( - lens: LearnedDict, - features_to_ablate: List[int], - corrupted_codes: Optional[TensorType["batch", "sequence", "n_dict_components"]] = None, - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - handicap: Optional[TensorType["batch", "sequence", "d_activation"]] = None, - ablation_rank: Literal["full", "partial"] = "partial", - ablation_mask: Optional[TensorType["batch", "sequence"]] = None, -): - if corrupted_codes is None: - corrupted_codes_ = None - else: - corrupted_codes_ = corrupted_codes.reshape(-1, corrupted_codes.shape[-1]) - - activation_dict = {"output": None} - - def reconstruction_intervention(tensor, hook=None): - nonlocal activation_dict - B, L, D = tensor.shape - code = lens.encode(tensor.reshape(-1, D)) - - if corrupted_codes_ is None: - code[:, features_to_ablate] = 0.0 - else: - code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - - reconstr = lens.decode(code).reshape(tensor.shape) - - if handicap is not None: - output = reconstr + handicap - else: - output = reconstr - - if ablation_mask is not None: - output[~ablation_mask] = tensor[~ablation_mask] - - activation_dict["output"] = output.clone() - return output - - def partial_ablation_intervention(tensor, hook=None): - nonlocal activation_dict - B, L, D = tensor.shape - code = lens.encode(tensor.reshape(-1, D)) - - ablation_code = torch.zeros_like(code) - - if corrupted_codes_ is None: - ablation_code[:, features_to_ablate] = -code[:, features_to_ablate] - else: - ablation_code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - code[:, features_to_ablate] - - ablation = lens.decode(ablation_code).reshape(tensor.shape) - - if handicap is not None: - output = tensor + ablation + handicap - else: - output = tensor + ablation - - if ablation_mask is not None: - output[~ablation_mask] = tensor[~ablation_mask] - - activation_dict["output"] = output.clone() - return output - - def full_ablation_intervention(tensor, hook=None): - nonlocal activation_dict - B, L, D = tensor.shape - code = torch.einsum("bd,nd->bn", tensor.reshape(-1,D), lens.get_learned_dict()) - - ablation_code = torch.zeros_like(code) - - if corrupted_codes_ is None: - ablation_code[:, features_to_ablate] = -code[:, features_to_ablate] - else: - ablation_code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - code[:, features_to_ablate] - - ablation = torch.einsum("bn,nd->bd", ablation_code, lens.get_learned_dict()).reshape(tensor.shape) - output = tensor + ablation - - if ablation_mask is not None: - output[~ablation_mask] = tensor[~ablation_mask] - - activation_dict["output"] = output.clone() - return tensor + ablation - - ablation_func = None - if ablation_type == "reconstruction": - ablation_func = reconstruction_intervention - elif ablation_type == "ablation" and ablation_rank == "partial": - ablation_func = partial_ablation_intervention - elif ablation_type == "ablation" and ablation_rank == "full": - ablation_func = full_ablation_intervention - else: - raise ValueError(f"Unknown ablation type '{ablation_type}' with rank '{ablation_rank}'") - - return ablation_func, activation_dict - -def resample_ablation( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - clean_tokens: TensorType["batch", "sequence"], - features_to_ablate: List[int], - corrupted_codes: Optional[TensorType["batch", "sequence", "n_dict_components"]] = None, - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - handicap: Optional[TensorType["batch", "sequence", "d_activation"]] = None, - ablation_rank: Literal["full", "partial"] = "partial", - ablation_mask: Optional[TensorType["batch", "sequence"]] = None, - **kwargs, -) -> Tuple[Any, TensorType["batch", "sequence", "d_activation"]]: - ablation_func, activation_dict = resample_ablation_hook( - lens, - features_to_ablate, - corrupted_codes=corrupted_codes, - ablation_type=ablation_type, - handicap=handicap, - ablation_rank=ablation_rank, - ablation_mask=ablation_mask, - ) - - logits = model.run_with_hooks( - clean_tokens, - fwd_hooks=[( - standard_metrics.get_model_tensor_name(location), - ablation_func, - )], - **kwargs, - ) - - return logits, activation_dict["output"] - -def save_dataset_activations( - model: HookedTransformer, - dataset: TensorType["batch", "sequence"], - location: standard_metrics.Location, - n_classes: int, - classes: TensorType["batch"], - sequence_lengths: Optional[TensorType["batch"]] = None, - batch_size: int = 32, - skip_tokens: int = 0, - filename: str = "activation_data_erasure.pt", -): - # {filename} is a tuple of (activations, classes, sequence_positions) - - if skip_tokens is None: - skip_tokens = 0 - - if sequence_lengths is None: - sequence_lengths = torch.tensor([dataset.shape[1]]*dataset.shape[0], dtype=torch.long, device=dataset.device) - - max_seq_len = dataset.shape[1] - - saved_activations = [] - saved_class_labels = [] - saved_sequence_lengths = [] - - with torch.no_grad(): - for i in tqdm.tqdm(range(0, dataset.shape[0], batch_size)): - j = min(i+batch_size, dataset.shape[0]) - batch = dataset[i:j] - batch_lengths = sequence_lengths[i:j] - batch_classes = classes[i:j] - - logits, activations = model.run_with_cache( - batch, - names_filter=lambda name: name == standard_metrics.get_model_tensor_name(location), - return_type="logits", - stop_at_layer=location[0] + 1, - ) - activations = activations[standard_metrics.get_model_tensor_name(location)] - - for k in range(batch.shape[0]): - class_id = batch_classes[k].item() - seq_len = batch_lengths[k].item() - - activation = activations[k] - activation[seq_len:] = 0.0 - - saved_activations.append(activation) - saved_class_labels.append(class_id) - saved_sequence_lengths.append(seq_len) - - saved_activations = torch.stack(saved_activations, dim=0) - saved_class_labels = torch.tensor(saved_class_labels, dtype=torch.long) - saved_sequence_lengths = torch.tensor(saved_sequence_lengths, dtype=torch.long) - - torch.save((saved_activations, saved_class_labels, saved_sequence_lengths, skip_tokens), filename) - -def ce_distance(clean_activation, activation): - return torch.linalg.norm(clean_activation - activation, dim=-1) - -def ablation_mask_from_seq_lengths( - seq_lengths: TensorType["batch"], - max_length: int, -) -> TensorType["batch", "sequence"]: - B = seq_lengths.shape[0] - mask = torch.zeros((B, max_length), dtype=torch.bool) - for i in range(B): - mask[i, :seq_lengths[i]] = True - return mask - -def approx_feature_erasure( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - dataset: TensorType["batch", "sequence"], - class_labels: TensorType["batch"], - sequence_lengths: TensorType["batch"], - scoring_function: Callable[[TensorType["batch", "sequence", "vocab_size"], TensorType["batch"], TensorType["batch"]], TensorType["batch"]], - directions_filter: Optional[List[int]] = None, - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - ablation_rank: Literal["full", "partial"] = "partial", - test_batch_size: int = 32, -) -> List[Tuple[int, float]]: - """Try ablations with directions and see which ones are best""" - if directions_filter is None: - directions_filter = list(range(lens.get_learned_dict().shape[0])) - - scores = [] - - for i in tqdm.tqdm(directions_filter): - batch_idxs = np.random.choice(dataset.shape[0], size=test_batch_size, replace=False) - batch = dataset[batch_idxs] - batch_classes = class_labels[batch_idxs] - batch_lengths = sequence_lengths[batch_idxs] - - batch_logits, _ = resample_ablation( - model, - lens, - location, - batch, - [i], - ablation_type=ablation_type, - ablation_rank=ablation_rank, - return_type="logits", - ) - - score = scoring_function(batch_logits, batch_classes, batch_lengths).mean().item() - scores.append((i, score)) - - return sorted(scores, key=lambda x: x[1]) - -def filter_activation_threshold( - lens: LearnedDict, - dataset: TensorType["batch", "sequence", "d_activation"], - sequence_lengths: TensorType["batch"], - activation_proportion_threshold: float = 0.05, - batch_size: int = 32, - last_position_only: bool = False, -) -> List[int]: - if last_position_only: - zero_mask = torch.zeros((dataset.shape[0], dataset.shape[1]), dtype=torch.bool) - zero_mask[torch.arange(sequence_lengths.shape[0]), sequence_lengths-1] = True - else: - zero_mask = ablation_mask_from_seq_lengths(sequence_lengths, dataset.shape[1]) - - - feat_activation_count = torch.zeros(lens.get_learned_dict().shape[0], dtype=torch.long, device=dataset.device) - total_activations = 0 - - for i in tqdm.tqdm(range(0, dataset.shape[0], batch_size)): - j = min(i+batch_size, dataset.shape[0]) - batch = dataset[i:j] - - encoded_batch = lens.encode(batch.reshape(-1, batch.shape[-1])).reshape(batch.shape[0], batch.shape[1], -1) - batch_nz = (encoded_batch != 0.0).long() - - batch_nz[~zero_mask[i:j]] = 0 - - feat_activation_count += batch_nz.sum(dim=(0, 1)) - - if last_position_only: - total_activations += batch.shape[0] - else: - total_activations += sequence_lengths[i:j].sum().item() - - feat_activation_proportions = feat_activation_count.float() / total_activations - - return torch.where(feat_activation_proportions > activation_proportion_threshold)[0].tolist() - -def eval_hook( - model: HookedTransformer, - hook_func: Callable[[TensorType["batch", "sequence", "d_activation"], TensorType["batch"], TensorType["batch"], Any], TensorType["batch", "sequence", "d_activation"]], - dataset: TensorType["batch", "sequence"], - class_labels: TensorType["batch"], - sequence_lengths: TensorType["batch"], - location: standard_metrics.Location, - task_score_func: Callable[[TensorType["batch", "sequence", "vocab_size"], TensorType["batch"], TensorType["batch"]], TensorType["batch"]], - activation_dist_func: Callable[[TensorType["batch", "sequence", "d_activation"], TensorType["batch", "sequence", "d_activation"]], TensorType["batch"]] = ce_distance, - batch_size: int = 4, - last_position_only: bool = False, - device: torch.device = torch.device("cpu"), -) -> Tuple[float, float, float]: - # returns (task_score, activation_dist) - - model.eval() - - mean_activation_dist = 0.0 - mean_task_score = 0.0 - - for i in range(0, dataset.shape[0], batch_size): - j = min(i+batch_size, dataset.shape[0]) - batch = dataset[i:j].to(device) - batch_lengths = sequence_lengths[i:j].to(device) - batch_classes = class_labels[i:j].to(device) - - activation_dist = None - - def hook_func_wrapper(tensor, hook=None): - nonlocal activation_dist - _, L, D = tensor.shape - uneditied = tensor.clone() - if last_position_only: - edited = tensor.clone() - edited[torch.arange(batch_lengths.shape[0]), batch_lengths-1] = hook_func( - tensor[torch.arange(batch_lengths.shape[0]), batch_lengths-1], - batch_classes, - batch_lengths, - hook=hook - ) - activation_dist = activation_dist_func( - uneditied[torch.arange(batch_lengths.shape[0]), batch_lengths-1], - edited[torch.arange(batch_lengths.shape[0]), batch_lengths-1], - ) - else: - edited = hook_func(tensor, batch_classes, batch_lengths, hook=hook) - activation_dist = activation_dist_func(uneditied, edited) - - return edited - - logits = model.run_with_hooks( - batch, - fwd_hooks=[( - standard_metrics.get_model_tensor_name(location), - hook_func_wrapper, - )], - return_type="logits", - ) - - mean_task_score += task_score_func(logits, batch_classes, batch_lengths).sum().item() - - mean_activation_dist += activation_dist.sum().item() - - mean_activation_dist /= dataset.shape[0] - mean_task_score /= dataset.shape[0] - - return mean_task_score, mean_activation_dist - -def generate_activation_data(cfg): - model_name = cfg.model_name - device = cfg.device - - model = HookedTransformer.from_pretrained(model_name) - model.to(device) - model.eval() - model.requires_grad_(False) - - prompts, classes, _, sequence_lengths, skip_tokens = generate_gender_dataset( - model_name, - count_cutoff=cfg.count_cutoff, - sample_n=cfg.unique_names, - prompts_per_name=cfg.prompts_per_name, - n_few_shot=cfg.k_shot, - randomise=False, - ) - - prompts = prompts.to(device) - classes = classes.to(device) - sequence_lengths = sequence_lengths.to(device) - - save_dataset_activations( - model, - prompts, - (cfg.layer, "residual"), - 2, - classes, - sequence_lengths=sequence_lengths, - batch_size=32, - skip_tokens=skip_tokens, - filename=cfg.activation_filename - ) - -def gen_pca_simplification(cfg): - device = cfg.device - - activations, class_labels, sequence_lengths, skip_tokens = torch.load(cfg.activation_filename) - - B, L, D = activations.shape - - activations = activations[:, skip_tokens:] - - pca_components = torch.empty((L-skip_tokens, 2, D), dtype=torch.float, device=device) - - optimal_activations = torch.empty((B, L-skip_tokens, D), dtype=torch.float, device=device) - optimal_activations_proj = torch.empty((B, L-skip_tokens, 2), dtype=torch.float, device=device) - - for i in tqdm.tqdm(range(L-skip_tokens)): - u, s, v = torch.linalg.svd(activations[:, i]) - pca_components[i] = v[:2] - optimal_eraser = LeaceEraser.fit( - activations[:, i], - class_labels, - ) - optimal_activations[:, i] = optimal_eraser(activations[:, i]) - - optimal_activations_proj = torch.einsum("bld,lnd->bln", optimal_activations, pca_components) - - projected_activations = torch.einsum("bld,lnd->bln", activations, pca_components) - #projected_activations = projected_activations.reshape(B, L-skip_tokens, 2) - - leace_eraser = torch.load(f"{cfg.output_folder}/leace_eraser_layer_{cfg.layer}.pt") - - erased_activations = leace_eraser(activations) - erased_activations_proj = torch.einsum("bld,lnd->bln", erased_activations, pca_components) - #erased_activations = erased_activations.reshape(B, L-skip_tokens, 2) - - from contextlib import nullcontext - - import matplotlib.pyplot as plt - import matplotlib - - # have a color for every combination of class and sequence length - - male_projected = projected_activations[class_labels == 0].detach().cpu().numpy() - male_erased = erased_activations_proj[class_labels == 0].detach().cpu().numpy() - male_optimal = optimal_activations_proj[class_labels == 0].detach().cpu().numpy() - - female_projected = projected_activations[class_labels == 1].detach().cpu().numpy() - female_erased = erased_activations_proj[class_labels == 1].detach().cpu().numpy() - female_optimal = optimal_activations_proj[class_labels == 1].detach().cpu().numpy() - - male_cmap = matplotlib.cm.get_cmap("Blues") - female_cmap = matplotlib.cm.get_cmap("Reds") - - token_positions = np.arange(L-skip_tokens) - hues = np.linspace(0.3, 0.7, L-skip_tokens) - - os.makedirs(f"{cfg.output_folder}/pca_img", exist_ok=True) - - for t in token_positions: - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30, 10)) - - ax1.scatter(male_projected[:, t, 0], male_projected[:, t, 1], color=male_cmap(hues[t])) - ax2.scatter(male_erased[:, t, 0], male_erased[:, t, 1], color=male_cmap(hues[t])) - ax3.scatter(male_optimal[:, t, 0], male_optimal[:, t, 1], color=male_cmap(hues[t])) - - ax1.scatter(female_projected[:, t, 0], female_projected[:, t, 1], color=female_cmap(hues[t])) - ax2.scatter(female_erased[:, t, 0], female_erased[:, t, 1], color=female_cmap(hues[t])) - ax3.scatter(female_optimal[:, t, 0], female_optimal[:, t, 1], color=female_cmap(hues[t])) - - ax1.set_title("Original") - ax2.set_title("Erased") - ax3.set_title("Optimal") - - plt.savefig(f"{cfg.output_folder}/pca_img/pca_simplification_layer_{cfg.layer}_pos_{t}.png") - - plt.close() - - overall_pca = torch.empty((2, D), dtype=torch.float, device=device) - - with nullcontext(): - u, s, v = torch.linalg.svd(activations.reshape(-1, D)) - overall_pca = v[:2] - - optimal_activations_proj = torch.einsum("bld,nd->bln", optimal_activations, overall_pca) - erased_activations_proj = torch.einsum("bld,nd->bln", erased_activations, overall_pca) - projected_activations = torch.einsum("bld,nd->bln", activations, overall_pca) - - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30, 10)) - - male_projected = projected_activations[class_labels == 0].detach().cpu().numpy() - male_erased = erased_activations_proj[class_labels == 0].detach().cpu().numpy() - male_optimal = optimal_activations_proj[class_labels == 0].detach().cpu().numpy() - - female_projected = projected_activations[class_labels == 1].detach().cpu().numpy() - female_erased = erased_activations_proj[class_labels == 1].detach().cpu().numpy() - female_optimal = optimal_activations_proj[class_labels == 1].detach().cpu().numpy() - - for t in token_positions: - ax1.scatter(male_projected[:, t, 0], male_projected[:, t, 1], color=male_cmap(hues[t])) - ax2.scatter(male_erased[:, t, 0], male_erased[:, t, 1], color=male_cmap(hues[t])) - ax3.scatter(male_optimal[:, t, 0], male_optimal[:, t, 1], color=male_cmap(hues[t])) - - ax1.scatter(female_projected[:, t, 0], female_projected[:, t, 1], color=female_cmap(hues[t])) - ax2.scatter(female_erased[:, t, 0], female_erased[:, t, 1], color=female_cmap(hues[t])) - ax3.scatter(female_optimal[:, t, 0], female_optimal[:, t, 1], color=female_cmap(hues[t])) - - ax1.set_title("Original") - ax2.set_title("Erased") - ax3.set_title("Optimal") - - plt.savefig(f"{cfg.output_folder}/pca_img/pca_simplification_layer_{cfg.layer}_overall.png") - - plt.close() - -def fit_leace_eraser(cfg): - device = cfg.device - - activations, class_labels, sequence_lengths, skip_tokens = torch.load(cfg.activation_filename) - - B, L, D = activations.shape - - if cfg.last_position_only: - eraser = LeaceEraser.fit( - activations[torch.arange(sequence_lengths.shape[0]), sequence_lengths-1], - class_labels, - ) - else: - mask = ablation_mask_from_seq_lengths(sequence_lengths, L-skip_tokens) - - activations = activations[:, skip_tokens:][mask] - class_labels = class_labels.unsqueeze(1).expand(-1, L-skip_tokens)[mask] - - eraser = LeaceEraser.fit( - activations, - class_labels, - ) - - torch.save(eraser, f"{cfg.output_folder}/leace_eraser_layer_{cfg.layer}.pt") - -def fit_means_eraser(cfg): - device = cfg.device - - activations, class_labels, sequence_lengths, skip_tokens = torch.load(cfg.activation_filename) - - B, L, D = activations.shape - - if cfg.last_position_only: - projector = NullspaceProjector.class_means( - activations[torch.arange(sequence_lengths.shape[0]), sequence_lengths-1], - class_labels, - ) - else: - mask = ablation_mask_from_seq_lengths(sequence_lengths, L-skip_tokens) - - activations = activations[:, skip_tokens:][mask] - class_labels = class_labels.unsqueeze(1).expand(-1, L-skip_tokens)[mask] - - projector = NullspaceProjector.class_means( - activations, - class_labels, - ) - - torch.save(projector, f"{cfg.output_folder}/means_eraser_layer_{cfg.layer}.pt") - -def gender_prediction(class_tokens): - def go(logits, class_labels, sequence_lengths): - preds = logits[torch.arange(sequence_lengths.shape[0]), sequence_lengths-1] - preds = F.softmax(preds[:, [class_tokens[0], class_tokens[1]]], dim=-1) - labels_one_hot = F.one_hot(class_labels, num_classes=2).float() - return torch.einsum("bc,bc->b",preds,labels_one_hot) - return go - -def skip_tokens_distance(skip_tokens): - def go(unedited, edited): - return torch.linalg.norm(unedited[:, skip_tokens:] - edited[:, skip_tokens:], dim=(-1, -2)) - - return go - -def eval_features_classification_positive( - features: TensorType["n_features", "d_activation"], - activations: TensorType["batch", "sequence", "d_activation"], - class_labels: TensorType["batch"], - sequence_lengths: TensorType["batch"], - skip_tokens: int = 0, -): - N, D = features.shape - - selection_mask = ablation_mask_from_seq_lengths(sequence_lengths, activations.shape[1]-skip_tokens) - - expanded_class_labels = class_labels.unsqueeze(1).expand(-1, activations.shape[1]-skip_tokens)[selection_mask] - expanded_class_labels = expanded_class_labels.detach().cpu() - activations = activations[:, skip_tokens:][selection_mask] - - scores = [] - - for feature_idx in tqdm.tqdm(range(N)): - projected_activations = torch.einsum("bd,d->b", activations, features[feature_idx]).reshape(-1, 1) - projected_activations = projected_activations.detach().cpu() - # train a classifier on this - model = LogisticRegression() - model.fit(projected_activations, expanded_class_labels) - scores.append((feature_idx, model.score(projected_activations, expanded_class_labels))) - - return sorted(scores, key=lambda x: x[1], reverse=True) - -class LogisticRegression(torch.nn.Module): - def __init__(self, input_dim): - super(LogisticRegression, self).__init__() - self.linear = torch.nn.Linear(input_dim, 1) - - def forward(self, x): - outputs = torch.sigmoid(self.linear(x).reshape(-1)) - return outputs - - @staticmethod - def fit(x, y, iters=100): - torch.autograd.set_grad_enabled(True) - - model = LogisticRegression(x.shape[1]).to(x.device) - criterion = torch.nn.BCELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - - for epoch in range(iters): - optimizer.zero_grad() - #idxs = np.random.choice(x.shape[0], size=batch_size, replace=False) - outputs = model(x) - #loss = -(outputs * y + (1-outputs) * (1-y)).mean() - loss = criterion(outputs, y) - loss.backward() - optimizer.step() - - torch.autograd.set_grad_enabled(False) - - return model - - def score(self, x, y): - outputs = self.forward(x) - return (outputs * y + (1-outputs) * (1-y)).mean().item() - -def eval_features_classification_negative( - features: TensorType["n_features", "d_activation"], - activations: TensorType["batch", "sequence", "d_activation"], - class_labels: TensorType["batch"], - sequence_lengths: TensorType["batch"], - skip_tokens: int = 0, -): - N, D = features.shape - - selection_mask = ablation_mask_from_seq_lengths(sequence_lengths, activations.shape[1]-skip_tokens) - - expanded_class_labels = class_labels.unsqueeze(1).expand(-1, activations.shape[1]-skip_tokens)[selection_mask] - expanded_class_labels = expanded_class_labels.to(dtype=torch.float32, device=activations.device) - activations = activations[:, skip_tokens:][selection_mask] - - scores = [] - - for feature_idx in tqdm.tqdm(range(N)): - projection = NullspaceProjector(features[feature_idx]) - projected_activations = projection.project(activations) - - # scale and shift activations to unit gaussian - - projected_activations = (projected_activations - projected_activations.mean(dim=0)) / projected_activations.std(dim=0) - - # linear classifier, not logistic regression - model = LogisticRegression.fit(projected_activations, expanded_class_labels) - scores.append((feature_idx, model.score(projected_activations, expanded_class_labels))) - - #print(scores[-1]) - - scores = sorted(scores, key=lambda x: x[1], reverse=False) - print(scores[:10]) - return scores - -def rank_dict_features_classifier(cfg): - activations, class_labels, sequence_lengths, skip_tokens = torch.load(cfg.activation_filename) - dicts = torch.load(cfg.dict_filename.format(layer=cfg.layer)) - - target_l1 = cfg.target_l1 - target_dict_size = cfg.dict_size - best_dist = None - best_dict = None - - for dict, hyperparams in dicts: - if hyperparams["dict_size"] == target_dict_size: - dist = abs(hyperparams["l1_alpha"] - target_l1) - if best_dist is None or dist < best_dist: - best_dist = dist - best_dict = dict - - best_dict.to_device(cfg.device) - features = best_dict.get_learned_dict() - - random_dict = RandomDict(features.shape[1], features.shape[0]) - random_dict.to_device(cfg.device) - random_features = random_dict.get_learned_dict() - - erasure_scores = eval_features_classification_negative(features, activations, class_labels, sequence_lengths, skip_tokens) - random_erasure_scores = eval_features_classification_negative(random_features, activations, class_labels, sequence_lengths, skip_tokens) - - erasure_scores = sorted(erasure_scores, key=lambda x: x[1]) - random_erasure_scores = sorted(random_erasure_scores, key=lambda x: x[1]) - - filtered_idxs = [idx for idx, _ in erasure_scores[:cfg.test_n_scores]] - random_filtered_idxs = [idx for idx, _ in random_erasure_scores[:cfg.test_n_scores]] - - del activations, class_labels, sequence_lengths, skip_tokens - - prompts, class_labels, class_tokens, sequence_lengths, skip_tokens = generate_gender_dataset( - cfg.model_name, - count_cutoff=cfg.count_cutoff, - sample_n=cfg.unique_names, - n_few_shot=cfg.k_shot, - prompts_per_name=1, # max name diversity - ) - - model = HookedTransformer.from_pretrained(cfg.model_name) - model.to(cfg.device) - model.eval() - model.requires_grad_(False) - - base_logits = model(prompts, return_type="logits") - - scores = [] - - for idx in tqdm.tqdm(filtered_idxs): - feature = features[idx].to(cfg.device) - - projector = NullspaceProjector(feature) - - def hook(tensor, class_labels, seq_lengths, hook=None): - if cfg.last_position_only: - return projector.project(tensor) - else: - B, L, D = tensor.shape - tensor[:, skip_tokens:] = projector.project(tensor[:, skip_tokens:].reshape(-1, D)).reshape(B, L-skip_tokens, D) - return tensor - - task_score, _ = eval_hook( - model, - hook, - prompts, - class_labels, - sequence_lengths, - (cfg.layer, "residual"), - task_score_func=gender_prediction(class_tokens), - batch_size=cfg.batch_size, - last_position_only=cfg.last_position_only, - device=cfg.device, - activation_dist_func=skip_tokens_distance(skip_tokens), - ) - - scores.append((idx, task_score)) - - scores = sorted(scores, key=lambda x: x[1]) - - random_scores = [] - - for idx in tqdm.tqdm(random_filtered_idxs): - feature = random_features[idx].to(cfg.device) - - projector = NullspaceProjector(feature) - - def hook(tensor, class_labels, seq_lengths, hook=None): - if cfg.last_position_only: - return projector.project(tensor) - else: - B, L, D = tensor.shape - tensor[:, skip_tokens:] = projector.project(tensor[:, skip_tokens:].reshape(-1, D)).reshape(B, L-skip_tokens, D) - return tensor - - task_score, _ = eval_hook( - model, - hook, - prompts, - class_labels, - sequence_lengths, - (cfg.layer, "residual"), - task_score_func=gender_prediction(class_tokens), - batch_size=cfg.batch_size, - last_position_only=cfg.last_position_only, - device=cfg.device, - activation_dist_func=skip_tokens_distance(skip_tokens), - ) - - random_scores.append((idx, task_score)) - - torch.save(scores, f"{cfg.output_folder}/dict_feature_scores_layer_{cfg.layer}.pt") - torch.save(best_dict, f"{cfg.output_folder}/best_dict_layer_{cfg.layer}.pt") - - torch.save(random_scores, f"{cfg.output_folder}/random_dict_feature_scores_layer_{cfg.layer}.pt") - torch.save(random_dict, f"{cfg.output_folder}/random_dict_layer_{cfg.layer}.pt") - -def rank_dict_features_expensive(cfg): - model_name = cfg.model_name - device = cfg.device - - activations, _, sequence_lengths, _ = torch.load(cfg.activation_filename) - - dicts = torch.load(cfg.dict_filename.format(layer=cfg.layer)) - - target_l1 = cfg.target_l1 - target_dict_size = cfg.dict_size - best_dist = None - best_dict = None - - for dict, hyperparams in dicts: - if hyperparams["dict_size"] == target_dict_size: - dist = abs(hyperparams["l1_alpha"] - target_l1) - if best_dist is None or dist < best_dist: - best_dist = dist - best_dict = dict - - best_dict.to_device(device) - - filtered_idxs = filter_activation_threshold( - best_dict, - activations, - sequence_lengths, - batch_size=32, - activation_proportion_threshold=cfg.feature_freq_threshold, - last_position_only=cfg.last_position_only, - ) - - del activations, sequence_lengths - - prompts, class_labels, class_tokens, sequence_lengths, skip_tokens = generate_gender_dataset( - model_name, - count_cutoff=cfg.count_cutoff, - sample_n=cfg.unique_names, - n_few_shot=cfg.k_shot, - prompts_per_name=1, # max name diversity - ) - - #prompts = prompts.to(device) - #class_labels = class_labels.to(device) - #sequence_lengths = sequence_lengths.to(device) - - model = HookedTransformer.from_pretrained(model_name) - model.to(device) - model.eval() - model.requires_grad_(False) - - scores = [] - - base_logits = model(prompts, return_type="logits") - - for idx in tqdm.tqdm(filtered_idxs): - feature = best_dict.get_learned_dict()[idx].to(device) - - projector = NullspaceProjector(feature) - - def hook(tensor, class_labels, seq_lengths, hook=None): - if cfg.last_position_only: - return projector.project(tensor) - else: - B, L, D = tensor.shape - tensor[:, skip_tokens:] = projector.project(tensor[:, skip_tokens:].reshape(-1, D)).reshape(B, L-skip_tokens, D) - return tensor - - task_score, _ = eval_hook( - model, - hook, - prompts, - class_labels, - sequence_lengths, - (cfg.layer, "residual"), - task_score_func=gender_prediction(class_tokens), - batch_size=cfg.batch_size, - last_position_only=cfg.last_position_only, - device=cfg.device, - activation_dist_func=skip_tokens_distance(skip_tokens), - ) - - scores.append((idx, task_score)) - - scores = sorted(scores, key=lambda x: x[1]) - - torch.save(scores, f"{cfg.output_folder}/dict_feature_scores_layer_{cfg.layer}.pt") - torch.save(best_dict, f"{cfg.output_folder}/best_dict_layer_{cfg.layer}.pt") - -def evaluate_interventions(cfg, dataset_fn, dataset_name): - torch.autograd.set_grad_enabled(False) - - device = cfg.device - model_name = cfg.model_name - layer = cfg.layer - - model = HookedTransformer.from_pretrained(model_name) - model.to(device) - model.eval() - model.requires_grad_(False) - - prompts, class_labels, class_tokens, sequence_lengths, skip_tokens = dataset_fn( - model_name, - count_cutoff=cfg.count_cutoff, - sample_n=cfg.unique_names, - prompts_per_name=cfg.prompts_per_name, - n_few_shot=cfg.k_shot, - randomise=False, - ) - - dict_scores = [] - - #base_logits = model(prompts, return_type="logits") - - sum_base_performance = 0.0 - - for i in tqdm.tqdm(range(0, prompts.shape[0], cfg.batch_size)): - j = min(i+cfg.batch_size, prompts.shape[0]) - batch = prompts[i:j].to(device) - batch_lengths = sequence_lengths[i:j].to(device) - batch_classes = class_labels[i:j].to(device) - - batch_logits = model(batch, return_type="logits") - - sum_base_performance += gender_prediction(class_tokens)(batch_logits, batch_classes, batch_lengths).sum().item() - - base_score = sum_base_performance / prompts.shape[0] - print(f"base score: {base_score}") - - best_dict = torch.load(f"{cfg.output_folder}/best_dict_layer_{cfg.layer}.pt") - best_dict.to_device(device) - best_dict_scores = torch.load(f"{cfg.output_folder}/dict_feature_scores_layer_{cfg.layer}.pt") - #best_dict_scores = best_dict_scores[:cfg.test_n_scores] - best_dict_scores = best_dict_scores[:1] - - for feat_idx, _ in best_dict_scores: - feature = best_dict.get_learned_dict()[feat_idx].to(device) - - projector = NullspaceProjector(feature) - - def hook(tensor, class_labels, seq_lengths, hook=None): - if cfg.last_position_only: - return projector.project(tensor) - else: - B, L, D = tensor.shape - tensor[:, skip_tokens:] = projector.project(tensor[:, skip_tokens:].reshape(-1, D)).reshape(B, L-skip_tokens, D) - return tensor - - task_score, activation_dist = eval_hook( - model, - hook, - prompts, - class_labels, - sequence_lengths, - (layer, "residual"), - task_score_func=gender_prediction(class_tokens), - batch_size=cfg.batch_size, - last_position_only=cfg.last_position_only, - device=cfg.device, - activation_dist_func=skip_tokens_distance(skip_tokens), - ) - - dict_scores.append((feat_idx, task_score, activation_dist)) - print(f"feat: {feat_idx}, score: {task_score}, dist: {activation_dist}") - - random_dict = torch.load(f"{cfg.output_folder}/random_dict_layer_{cfg.layer}.pt") - random_dict.to_device(device) - random_dict_scores = torch.load(f"{cfg.output_folder}/random_dict_feature_scores_layer_{cfg.layer}.pt") - #best_dict_scores = best_dict_scores[:cfg.test_n_scores] - random_dict_scores = random_dict_scores[:1] - - random_dict_scores_out = [] - - for feat_idx, _ in random_dict_scores: - feature = random_dict.get_learned_dict()[feat_idx].to(device) - - projector = NullspaceProjector(feature) - - def hook(tensor, class_labels, seq_lengths, hook=None): - if cfg.last_position_only: - return projector.project(tensor) - else: - B, L, D = tensor.shape - tensor[:, skip_tokens:] = projector.project(tensor[:, skip_tokens:].reshape(-1, D)).reshape(B, L-skip_tokens, D) - return tensor - - task_score, activation_dist = eval_hook( - model, - hook, - prompts, - class_labels, - sequence_lengths, - (layer, "residual"), - task_score_func=gender_prediction(class_tokens), - batch_size=cfg.batch_size, - last_position_only=cfg.last_position_only, - device=cfg.device, - activation_dist_func=skip_tokens_distance(skip_tokens), - ) - - random_dict_scores_out.append((feat_idx, task_score, activation_dist)) - print(f"random feat: {feat_idx}, score: {task_score}, dist: {activation_dist}") - - leace_eraser = torch.load(f"{cfg.output_folder}/leace_eraser_layer_{cfg.layer}.pt", map_location=device) - - def leace_hook(tensor, class_labels, seq_lengths, hook=None): - if cfg.last_position_only: - return projector.project(tensor) - else: - B, L, D = tensor.shape - tensor[:, skip_tokens:] = leace_eraser(tensor[:, skip_tokens:].reshape(-1, D)).reshape(B, L-skip_tokens, D) - return tensor - - leace_score, leace_dist = eval_hook( - model, - leace_hook, - prompts, - class_labels, - sequence_lengths, - (layer, "residual"), - task_score_func=gender_prediction(class_tokens), - batch_size=cfg.batch_size, - last_position_only=cfg.last_position_only, - device=cfg.device, - activation_dist_func=skip_tokens_distance(skip_tokens), - ) - - print(f"leace score: {leace_score}, dist: {leace_dist}") - - means_eraser = torch.load(f"{cfg.output_folder}/means_eraser_layer_{cfg.layer}.pt", map_location=device) - - def means_hook(tensor, class_labels, seq_lengths, hook=None): - if cfg.last_position_only: - return projector.project(tensor) - else: - B, L, D = tensor.shape - tensor[:, skip_tokens:] = means_eraser.project(tensor[:, skip_tokens:].reshape(-1, D)).reshape(B, L-skip_tokens, D) - return tensor - - means_score, means_dist = eval_hook( - model, - means_hook, - prompts, - class_labels, - sequence_lengths, - (layer, "residual"), - task_score_func=gender_prediction(class_tokens), - batch_size=cfg.batch_size, - last_position_only=cfg.last_position_only, - device=cfg.device, - activation_dist_func=skip_tokens_distance(skip_tokens), - ) - - print(f"means score: {means_score}, dist: {means_dist}") - - torch.save({ - "leace": (leace_score, leace_dist), - "means": (means_score, means_dist), - "dict": dict_scores, - "random": random_dict_scores_out, - "base": base_score}, - f"{cfg.output_folder}/eval_layer_{cfg.layer}_{dataset_name}.pt") - -def gender_prediction_everything(layer, device, done_flag=None): - from utils import dotdict - - cfg = dotdict({ - "model_name": "EleutherAI/pythia-410m-deduped", - "device": device, - "layer": layer, - "count_cutoff": 100000, - "k_shot": 3, - "unique_names": 100, - "prompts_per_name": 5, - "output_folder": "output_erasure_410m", - "activation_filename": f"activation_data_erasure_410m_l{layer}.pt", - "dict_filename": f"/mnt/ssd-cluster/pythia410/tied_residual_l{layer}_r4/_79/learned_dicts.pt", - "target_l1": 8e-4, - "dict_size": 4096, - "feature_freq_threshold": 0.05, - "test_n_scores": 4, - "estimation_sample_n": 16, - "last_position_only": False, - "batch_size": 32, - }) - - #layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22] - - os.makedirs(cfg.output_folder, exist_ok=True) - - generate_activation_data(cfg) - fit_leace_eraser(cfg) - fit_means_eraser(cfg) - #rank_dict_features_expensive(cfg) - rank_dict_features_classifier(cfg) - evaluate_interventions(cfg, generate_gender_dataset, "gender") - evaluate_interventions(cfg, generate_pronoun_dataset, "pronoun") - - if done_flag is not None: - done_flag.value = 1 - -def gender_prediction_everything_multigpu(): - layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22] - free_gpus = ["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5"] - - import torch.multiprocessing as mp - import time - mp.set_start_method("spawn") - - processes = [] - - # while some gpus are still free - while True: - new_processes = [] - for process, gpu, done_flag in processes: - if done_flag.value == 1: - process.join() - free_gpus.append(gpu) - print(f"finished layer {process} on gpu {gpu}") - else: - new_processes.append((process, gpu, done_flag)) - - processes = new_processes - - if len(processes) == 0 and len(layers) == 0: - break - - if len(free_gpus) == 0: - time.sleep(0.1) - continue - - if len(layers) == 0: - time.sleep(0.1) - continue - - layer = layers.pop(0) - gpu = free_gpus.pop(0) - - print(f"starting layer {layer} on gpu {gpu}") - - done_flag = mp.Value("i", 0) - - process = mp.Process( - target=gender_prediction_everything, - args=(layer, gpu, done_flag), - ) - - process.start() - - processes.append((process, gpu, done_flag)) - -def winobias_prediction_everything(): - from utils import dotdict - - cfg = dotdict({ - "model_name": "EleutherAI/pythia-70m-deduped", - "device": "cuda:4", - "layer": None, - "count_cutoff": 10000, - "output_folder": "output_erasure_pca", - "activation_filename": "activation_data_erasure.pt", - "dict_filename": "/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r4/_9/learned_dicts.pt", - "target_l1": 8e-4, - "dict_size": 2048, - "feature_freq_threshold": 0.05, - "test_n_scores": 10, - }) - - layers = [0, 1, 2, 3, 4, 5] - - os.makedirs(cfg.output_folder, exist_ok=True) - - for layer in layers: - cfg.layer = layer - eval_on_winobias(cfg) - -if __name__ == "__main__": - from sys import argv - - if argv[1] == "gender": - gender_prediction_everything_multigpu() - elif argv[1] == "winobias": - winobias_prediction_everything() - elif argv[1] == "pca": - from utils import dotdict - - cfg = dotdict({ - "model_name": "EleutherAI/pythia-70m-deduped", - "device": "cuda:4", - "layer": 5, - "count_cutoff": 10000, - "output_folder": "output_erasure_pca", - "activation_filename": "activation_data_erasure.pt", - "dict_filename": "/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r4/_9/learned_dicts.pt", - "target_l1": 8e-4, - "dict_size": 2048, - "feature_freq_threshold": 0.05, - "test_n_scores": 10, - "last_position_only": False, - }) - - generate_activation_data(cfg) - fit_leace_eraser(cfg) - gen_pca_simplification(cfg) \ No newline at end of file diff --git a/experiments/ablate_test.py b/experiments/ablate_test.py index 87edb1f..b16a514 100644 --- a/experiments/ablate_test.py +++ b/experiments/ablate_test.py @@ -33,11 +33,13 @@ from neuron_explainer.explanations.simulator import ExplanationNeuronSimulator REPLACEMENT_CHAR = "�" -MAX_CONCURRENT = None +MAX_CONCURRENT = None # type: None EXPLAINER_MODEL_NAME = "gpt-4" # "gpt-3.5-turbo" SIMULATOR_MODEL_NAME = "text-davinci-003" -_n_dict_components, _n_sentences, _fragment_len = None, None, None +_n_dict_components, _n_sentences, _fragment_len = ( + None, None, None +) # type: Tuple[None, None, None] def run_ablating_model_directions( diff --git a/experiments/bottleneck_test.py b/experiments/bottleneck_test.py index 04c15da..265c4eb 100644 --- a/experiments/bottleneck_test.py +++ b/experiments/bottleneck_test.py @@ -1,1099 +1,770 @@ -import sys -import os - -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) - -import copy -from functools import partial -from itertools import product -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union - -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn.functional as F -import tqdm -from concept_erasure import LeaceEraser -from datasets import load_dataset -from einops import rearrange -from PIL import Image -from sklearn.cluster import KMeans -from sklearn.manifold import TSNE -from torch.utils.data import DataLoader -from torchtyping import TensorType -from transformer_lens import HookedTransformer - -import standard_metrics -from activation_dataset import setup_data -from autoencoders.learned_dict import LearnedDict -from autoencoders.pca import BatchedPCA -from test_datasets.gender import generate_gender_dataset -from test_datasets.ioi import generate_ioi_dataset - -_batch, _sequence, _n_dict_components, _d_activation, _vocab_size = ( - None, - None, - None, - None, - None, -) - -BASE_FOLDER = "~/sparse_coding_aidan" - - -def logits_under_ablation( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - ablated_directions: List[int], - tokens: TensorType["_batch", "_sequence"], - calc_fvu: bool = False, -) -> Tuple[TensorType["_batch", "_sequence"], Optional[TensorType["_batch", "_sequence"]]]: - fvu = None - - def intervention(tensor, hook=None): - B, L, D = tensor.shape - tensor = tensor.reshape(-1, D) - codes = lens.encode(tensor) - ablation = torch.einsum( - "be,ed->bd", - codes[:, ablated_directions], - lens.get_learned_dict()[ablated_directions], - ) - ablated = tensor - ablation - - if calc_fvu: - nonlocal fvu - fvu = (ablation**2).sum() / (tensor**2).sum() - - return ablated.reshape(B, L, D) - - logits = model.run_with_hooks( - tokens, - return_type="logits", - fwd_hooks=[ - ( - standard_metrics.get_model_tensor_name(location), - intervention, - ) - ], - ) - - return logits, fvu - - -def logits_under_reconstruction( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - ablated_directions: List[int], - tokens: TensorType["_batch", "_sequence"], - calc_fvu: bool = False, - resample: Optional[TensorType["_batch", "_sequence", "_n_dict_components"]] = None, -) -> Tuple[TensorType["_batch", "_sequence"], Optional[TensorType["_batch", "_sequence"]]]: - fvu = None - - def intervention(tensor, hook=None): - B, L, D = tensor.shape - code = lens.encode(tensor.reshape(-1, D)) - if resample is not None: - code[:, ablated_directions] = resample.reshape(-1, code.shape[-1])[:, ablated_directions] - else: - code[:, ablated_directions] = 0.0 - reconstruction = lens.decode(code).reshape(B, L, D) - - if calc_fvu: - nonlocal fvu - residuals = reconstruction - tensor - fvu = (residuals**2).sum() / (tensor**2).sum() - - return reconstruction - - logits = model.run_with_hooks( - tokens, - return_type="logits", - fwd_hooks=[ - ( - standard_metrics.get_model_tensor_name(location), - intervention, - ) - ], - ) - - return logits, fvu - - -def bottleneck_test( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - tokens: TensorType["_batch", "_sequence"], - logit_metric: Callable[[TensorType["_batch", "_sequence"]], TensorType["_batch"]], - calc_fvu: bool = False, - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - feature_sample_size: Optional[int] = None, -) -> List[Tuple[int, Optional[float], float]]: - # iteratively ablate away the least useful directions in the bottleneck - - remaining_directions = list(range(lens.n_dict_components())) - - results = [] - ablated_directions: List[int] = [] - - for i in tqdm.tqdm(range(lens.n_dict_components())): - min_score = None - min_direction = -1 - min_fvu = None - - features_to_test: List[int] = [] - - if feature_sample_size is not None: - if feature_sample_size < len(remaining_directions): - features_to_test = list(np.random.choice(remaining_directions, size=feature_sample_size, replace=False)) - else: - features_to_test = remaining_directions - else: - features_to_test = remaining_directions - - for direction in features_to_test: - if ablation_type == "ablation": - logits, fvu = logits_under_ablation( - model, - lens, - location, - ablated_directions + [direction], - tokens, - calc_fvu=calc_fvu, - ) - elif ablation_type == "reconstruction": - logits, fvu = logits_under_reconstruction( - model, - lens, - location, - ablated_directions + [direction], - tokens, - calc_fvu=calc_fvu, - ) - else: - raise ValueError(f"Unknown ablation type '{ablation_type}'") - - score = logit_metric(logits).item() - - if calc_fvu: - assert fvu is not None - fvu_item: float = fvu.item() - - if min_score is None or score < min_score: - min_score = score - min_direction = direction - min_fvu = fvu_item - - assert min_direction != -1 - assert min_score is not None - results.append((min_direction, min_fvu, min_score)) - ablated_directions.append(min_direction) - remaining_directions.remove(min_direction) - - return results - - -def resample_ablation_hook( - lens: LearnedDict, - features_to_ablate: List[int], - corrupted_codes: Optional[TensorType["_batch", "_sequence", "_n_dict_components"]] = None, - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - handicap: Optional[TensorType["_batch", "_sequence", "_d_activation"]] = None, - ablation_rank: Literal["full", "partial"] = "partial", - ablation_mask: Optional[TensorType["_batch", "_sequence"]] = None, -): - if corrupted_codes is None: - corrupted_codes_ = None - else: - corrupted_codes_ = corrupted_codes.reshape(-1, corrupted_codes.shape[-1]) - - activation_dict = {"output": None} - - def reconstruction_intervention(tensor, hook=None): - nonlocal activation_dict - B, L, D = tensor.shape - code = lens.encode(tensor.reshape(-1, D)) - - if corrupted_codes_ is None: - code[:, features_to_ablate] = 0.0 - else: - code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - - reconstr = lens.decode(code).reshape(tensor.shape) - - if handicap is not None: - output = reconstr + handicap - else: - output = reconstr - - if ablation_mask is not None: - output[~ablation_mask] = tensor[~ablation_mask] - - activation_dict["output"] = output.clone() - return output - - def partial_ablation_intervention(tensor, hook=None): - nonlocal activation_dict - B, L, D = tensor.shape - code = lens.encode(tensor.reshape(-1, D)) - - ablation_code = torch.zeros_like(code) - - if corrupted_codes_ is None: - ablation_code[:, features_to_ablate] = -code[:, features_to_ablate] - else: - ablation_code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - code[:, features_to_ablate] - - ablation = lens.decode(ablation_code).reshape(tensor.shape) - - if handicap is not None: - output = tensor + ablation + handicap - else: - output = tensor + ablation - - if ablation_mask is not None: - output[~ablation_mask] = tensor[~ablation_mask] - - activation_dict["output"] = output.clone() - return output - - def full_ablation_intervention(tensor, hook=None): - nonlocal activation_dict - B, L, D = tensor.shape - code = torch.einsum("bd,nd->bn", tensor.reshape(-1, D), lens.get_learned_dict()) - - ablation_code = torch.zeros_like(code) - - if corrupted_codes_ is None: - ablation_code[:, features_to_ablate] = -code[:, features_to_ablate] - else: - ablation_code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - code[:, features_to_ablate] - - ablation = torch.einsum("bn,nd->bd", ablation_code, lens.get_learned_dict()).reshape(tensor.shape) - output = tensor + ablation - - if ablation_mask is not None: - output[~ablation_mask] = tensor[~ablation_mask] - - activation_dict["output"] = output.clone() - return tensor + ablation - - ablation_func = None - if ablation_type == "reconstruction": - ablation_func = reconstruction_intervention - elif ablation_type == "ablation" and ablation_rank == "partial": - ablation_func = partial_ablation_intervention - elif ablation_type == "ablation" and ablation_rank == "full": - ablation_func = full_ablation_intervention - else: - raise ValueError(f"Unknown ablation type '{ablation_type}' with rank '{ablation_rank}'") - - return ablation_func, activation_dict - - -def resample_ablation( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - clean_tokens: TensorType["_batch", "_sequence"], - features_to_ablate: List[int], - corrupted_codes: Optional[TensorType["_batch", "_sequence", "_n_dict_components"]] = None, - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - handicap: Optional[TensorType["_batch", "_sequence", "_d_activation"]] = None, - ablation_rank: Literal["full", "partial"] = "partial", - ablation_mask: Optional[TensorType["_batch", "_sequence"]] = None, - **kwargs, -) -> Tuple[Any, TensorType["_batch", "_sequence", "_d_activation"]]: - ablation_func, activation_dict = resample_ablation_hook( - lens, - features_to_ablate, - corrupted_codes=corrupted_codes, - ablation_type=ablation_type, - handicap=handicap, - ablation_rank=ablation_rank, - ablation_mask=ablation_mask, - ) - - logits = model.run_with_hooks( - clean_tokens, - fwd_hooks=[ - ( - standard_metrics.get_model_tensor_name(location), - ablation_func, - ) - ], - **kwargs, - ) - - return logits, activation_dict["output"] - - -def activation_info( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - tokens: TensorType["_batch", "_sequence"], - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - replacement_residuals: Optional[TensorType["_batch", "_sequence", "_d_activation"]] = None, -) -> Tuple[ - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_n_dict_components"], - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_vocab_size"], -]: - residuals = None - codes = None - activations = None - logits = None - - def intervention(tensor, hook=None): - nonlocal residuals, codes, activations - B, L, D = tensor.shape - activations = tensor.clone() - code = lens.encode(tensor.reshape(-1, D)) - codes = code.reshape(B, L, -1).clone() - output = lens.decode(code).reshape(tensor.shape) - residuals = tensor - output - - if ablation_type == "reconstruction": - return output - else: - if replacement_residuals is not None: - return output + replacement_residuals - else: - return tensor - - logits = model.run_with_hooks( - tokens, - fwd_hooks=[ - ( - standard_metrics.get_model_tensor_name(location), - intervention, - ) - ], - return_type="logits", - ) - - return residuals, codes, activations, logits - - -def scaled_distance_to_clean(clean_activation, corrupted_activation, activation): - total_dist = torch.norm(clean_activation - corrupted_activation, dim=(-1, -2)) - dist = torch.norm(clean_activation - activation, dim=(-1, -2)) - return dist / total_dist - - -def dot_difference_metric(clean_activation, corrupted_activation, activation): - dataset_diff_vector = corrupted_activation - clean_activation - diff_vector = activation - clean_activation - return torch.einsum("bld,bld->b", diff_vector, dataset_diff_vector) / torch.norm(dataset_diff_vector, dim=(-1, -2)) ** 2 - - -def acdc_test( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - clean_tokens: TensorType["_batch", "_sequence"], - corrupted_tokens: TensorType["_batch", "_sequence"], - logit_metric: Callable[ - [ - TensorType["_batch", "_sequence", "_vocab_size"], - TensorType["_batch", "_sequence", "_vocab_size"], - ], - float, - ], - thresholds: List[float] = [0.05], - base_logits: Optional[TensorType["_batch", "_sequence", "_vocab_size"]] = None, - ablation_type: Literal["ablation", "reconstruction"] = "reconstruction", - ablation_handicap: bool = False, - distance_metric: Callable[ - [ - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_d_activation"], - ], - TensorType["_batch"], - ] = scaled_distance_to_clean, - initial_directions: Optional[List[int]] = None, -) -> List[Tuple[List[int], float, float]]: - if initial_directions is None: - initial_directions = list(range(lens.n_dict_components())) - - ablated_directions = [x for x in range(lens.n_dict_components()) if x not in initial_directions] - remaining_directions = list(initial_directions) - - corrupted_residuals, corrupted_codes, corrupted_activation, _ = activation_info( - model, lens, location, corrupted_tokens, ablation_type=ablation_type - ) - - clean_residuals, _, clean_activation, _ = activation_info( - model, - lens, - location, - clean_tokens, - ablation_type=ablation_type, - replacement_residuals=corrupted_residuals, - ) - - handicap = None - if ablation_handicap: - handicap = corrupted_residuals - clean_residuals - - reconstruction_logits, _ = resample_ablation( - model, - lens, - location, - clean_tokens, - corrupted_codes=corrupted_codes, - features_to_ablate=ablated_directions, - return_type="logits", - ablation_type=ablation_type, - handicap=handicap, - ) - - if base_logits is None: - base_logits = reconstruction_logits - - prev_divergence = logit_metric(reconstruction_logits, base_logits) - - scores = [] - - #print(ablated_directions, remaining_directions) - - for tau in sorted(thresholds): - if len(remaining_directions) > 0: - activation = None - - assert len(ablated_directions) + len(remaining_directions) == lens.n_dict_components() - - for i in tqdm.tqdm(remaining_directions.copy()): - logits, activation = resample_ablation( - model, - lens, - location, - clean_tokens, - corrupted_codes=corrupted_codes, - features_to_ablate=ablated_directions + [i], - return_type="logits", - ablation_type=ablation_type, - handicap=handicap, - ) - - divergence = logit_metric(logits, base_logits) - - if divergence - prev_divergence < tau: - prev_divergence = divergence - ablated_directions.append(i) - remaining_directions.remove(i) - - distance = distance_metric(clean_activation, corrupted_activation, activation) - scores.append((remaining_directions.copy(), prev_divergence, distance.mean().item())) - - print(f"graph size: {len(remaining_directions)} div: {prev_divergence} edit: {distance.mean().item()}") - - zero_logits, zero_activation = resample_ablation( - model, - lens, - location, - clean_tokens, - corrupted_codes=corrupted_codes, - features_to_ablate=list(range(lens.n_dict_components())), - return_type="logits", - ablation_type=ablation_type, - handicap=handicap, - ) - - zero_divergence = logit_metric(zero_logits, base_logits) - zero_distance = distance_metric(clean_activation, corrupted_activation, zero_activation) - - scores.append(([], zero_divergence, zero_distance.mean().item())) - - full_logits, full_activation = resample_ablation( - model, - lens, - location, - clean_tokens, - corrupted_codes=corrupted_codes, - features_to_ablate=[], - return_type="logits", - ablation_type=ablation_type, - handicap=handicap, - ) - - full_divergence = logit_metric(full_logits, base_logits) - full_distance = distance_metric(clean_activation, corrupted_activation, full_activation) - - scores.append((list(range(lens.n_dict_components())), full_divergence, full_distance.mean().item())) - - return scores - -def diff_mean_activation_editing( - model: HookedTransformer, - location: standard_metrics.Location, - clean_tokens: TensorType["_batch", "_sequence"], - corrupted_tokens: TensorType["_batch", "_sequence"], - logit_metric: Callable[ - [ - TensorType["_batch", "_sequence", "_vocab_size"], - TensorType["_batch", "_sequence", "_vocab_size"], - ], - float, - ], - scale_range: Tuple[float, float] = (0.0, 1.0), - n_points: int = 10, - distance_metric: Callable[ - [ - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_d_activation"], - ], - TensorType["_batch"], - ] = scaled_distance_to_clean, -) -> List[Tuple[float, float, float]]: - clean_logits, activation_cache = model.run_with_cache( - clean_tokens, - # names_filter=[standard_metrics.get_model_tensor_name(location)], - return_type="logits", - ) - clean_activation = activation_cache[standard_metrics.get_model_tensor_name(location)] - - _, activation_cache = model.run_with_cache( - corrupted_tokens, - # names_filter=[standard_metrics.get_model_tensor_name(location)], - return_type="logits", - ) - corrupted_activation = activation_cache[standard_metrics.get_model_tensor_name(location)] - - diff_means_vector = clean_activation.mean(dim=0) - corrupted_activation.mean(dim=0) - - scales = torch.linspace(*scale_range, n_points) - - scores = [] - for scale in tqdm.tqdm(scales): - activation = None - - def intervention(tensor, hook): - nonlocal activation - activation = tensor + scale * diff_means_vector - return activation - - logits = model.run_with_hooks( - corrupted_tokens, - fwd_hooks=[ - ( - standard_metrics.get_model_tensor_name(location), - intervention, - ) - ], - return_type="logits", - ) - - distance = distance_metric(clean_activation, corrupted_activation, activation).mean().item() - logit_score = logit_metric(logits, clean_logits) - scores.append((scale, distance, logit_score)) - - return scores - - -def ce_distance(clean_activation, activation): - return torch.linalg.norm(clean_activation - activation, dim=(-1, -2)) - - -def ablation_mask_from_seq_lengths( - seq_lengths: TensorType["_batch"], - max_length: int, -) -> TensorType["_batch", "_sequence"]: - B = seq_lengths.shape[0] - mask = torch.zeros((B, max_length), dtype=torch.bool) - for i in range(B): - mask[i, : seq_lengths[i]] = True - return mask - - -def concept_ablation( - model: HookedTransformer, - lens: LearnedDict, - location: standard_metrics.Location, - dataset: TensorType["_batch", "_sequence"], - scoring_function: Callable[[TensorType["_batch", "_sequence", "_vocab_size"]], float], - max_features_removed: int = 10, - distance_metric: Callable[ - [ - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_d_activation"], - ], - TensorType["_batch"], - ] = ce_distance, - ablation_type: Literal["ablation", "reconstruction"] = "ablation", - ablation_rank: Literal["full", "partial"] = "partial", - sequence_lengths: Optional[TensorType["_batch"]] = None, - scale_by_magnitude: bool = False, - min_perf_decrease: float = 1.0, # to stop scale_by_magnitude from removing unimportant features -) -> List[Tuple[List[int], float, float]]: - """Try and add as much data back as possible while keeping a specific concept erased""" - if sequence_lengths is not None: - ablation_mask = ablation_mask_from_seq_lengths(sequence_lengths, dataset.shape[1]) - else: - ablation_mask = None - - ablated_directions: List[int] = [] - remaining_directions = list(range(lens.n_dict_components())) - - _, activation_cache = model.run_with_cache( - dataset, - names_filter=lambda name: name == standard_metrics.get_model_tensor_name(location), - return_type="logits", - ) - clean_activation = activation_cache[standard_metrics.get_model_tensor_name(location)] - - logits, activation = resample_ablation( - model, - lens, - location, - dataset, - corrupted_codes=None, - features_to_ablate=ablated_directions, - return_type="logits", - ablation_type=ablation_type, - ablation_rank=ablation_rank, - ablation_mask=ablation_mask, - ) - - scores = [] - - prev_score = float("inf") - for iteration in range(max_features_removed): - min_weighted_score = float("inf") - min_score = float("inf") - min_idx = None - min_activation_dist = float("inf") - - for i in tqdm.tqdm(range(lens.n_dict_components())): - logits, activation = resample_ablation( - model, - lens, - location, - dataset, - corrupted_codes=None, - features_to_ablate=ablated_directions + [i], - return_type="logits", - ablation_type=ablation_type, - ablation_rank=ablation_rank, - ablation_mask=ablation_mask, - ) - - score = scoring_function(logits) - - if scale_by_magnitude: - weighted_score = score * distance_metric(clean_activation, activation).mean().item() - else: - weighted_score = score - - if weighted_score < min_weighted_score and score < prev_score * min_perf_decrease: - min_weighted_score = weighted_score - min_score = score - min_idx = i - min_activation_dist = distance_metric(clean_activation, activation).mean().item() - - if min_idx is None: - print("Early stopped at iteration", iteration, "with score", prev_score) - break - - ablated_directions.append(min_idx) - remaining_directions.remove(min_idx) - prev_score = min_score - - print(f"Removed {min_idx} with score {min_score} and activation distance {min_activation_dist}") - - scores.append((ablated_directions.copy(), min_score, min_activation_dist)) - - return scores - - -def least_squares_erasure( - model: HookedTransformer, - location: standard_metrics.Location, - dataset: TensorType["_batch", "_sequence"], - classes: TensorType["_batch"], - scoring_function: Callable[[TensorType["_batch", "_sequence", "_vocab_size"]], float], - distance_metric: Callable[ - [ - TensorType["_batch", "_sequence", "_d_activation"], - TensorType["_batch", "_sequence", "_d_activation"], - ], - TensorType["_batch"], - ] = ce_distance, - sequence_lengths: Optional[TensorType["_batch"]] = None, -) -> Tuple[float, float, Any]: - if sequence_lengths is not None: - ablation_mask = ablation_mask_from_seq_lengths(sequence_lengths, dataset.shape[1]) - else: - ablation_mask = None - - _, activation_cache = model.run_with_cache( - dataset, - names_filter=lambda name: name == standard_metrics.get_model_tensor_name(location), - return_type="logits", - ) - - if ablation_mask is None: - ablation_mask = torch.ones_like(dataset, dtype=torch.bool) - - B, L, D = activation_cache[standard_metrics.get_model_tensor_name(location)].shape - - activations_flattened = activation_cache[standard_metrics.get_model_tensor_name(location)].reshape(B * L, D) - classes_flattened = classes.repeat_interleave(L) - mask_flattened = ablation_mask.reshape(B * L) - - activations = activations_flattened[mask_flattened] - classes_ = classes_flattened[mask_flattened] - - print(activations.shape, classes_.shape) - - eraser = LeaceEraser.fit(activations, classes_) # type: ignore - - distance = None - - def erasure(tensor, hook): - nonlocal distance - erased = eraser(tensor.reshape(B * L, D)).reshape(B, L, D) - - if ablation_mask is not None: - erased[~ablation_mask] = tensor[~ablation_mask] - - distance = distance_metric(tensor, erased) - return erased - - logits = model.run_with_hooks( - dataset, - fwd_hooks=[ - ( - standard_metrics.get_model_tensor_name(location), - erasure, - ) - ], - return_type="logits", - ) - - score = scoring_function(logits) - assert distance is not None - - return score, distance.mean().item(), eraser - -def clean_logits_and_activations( - model: HookedTransformer, - location: standard_metrics.Location, - dataset: TensorType["_batch", "_sequence"], -): - base_logits, activation_cache = model.run_with_cache( - dataset, - names_filter=lambda name: name == standard_metrics.get_model_tensor_name(location), - return_type="logits", - ) - return base_logits, activation_cache[standard_metrics.get_model_tensor_name(location)] - -def filter_active_components( - activations: TensorType["_batch", "_sequence", "_d_activation"], - dict: LearnedDict, - threshold: float = 0.01, -): - B, L, D = activations.shape - activations_flattened = activations.reshape(B * L, D) - codes = dict.encode(activations_flattened) - print(codes.shape) - codes_nz = codes.count_nonzero(dim=0).float() / codes.shape[0] - # get indexes of components that are active in at least threshold of the dataset - active_components = torch.where(codes_nz > threshold)[0] - - components = list(active_components.cpu().numpy()) - print(f"Found {len(components)} active components") - - return components - -def new_bottleneck_test(cfg, layer, device, done_flag): - torch.autograd.set_grad_enabled(False) - - # Train PCA - - activation_dataset = torch.load(f"activation_data/layer_{layer}/0.pt") - activation_dataset = activation_dataset.to(device, dtype=torch.float32) - - pca = BatchedPCA(n_dims=activation_dataset.shape[-1], device=device) - batch_size = 2048 - - print("training pca") - for i in tqdm.trange(0, activation_dataset.shape[0], batch_size): - j = min(i + batch_size, activation_dataset.shape[0]) - pca.train_batch(activation_dataset[i:j]) - - pca_dict = pca.to_rotation_dict(activation_dataset.shape[-1]) - - pca_dict.to_device(device) - - del activation_dataset - - # Load model - - model = HookedTransformer.from_pretrained(cfg.model_name) - - model.to(device) - - ioi_clean_full, ioi_corrupted_full = generate_ioi_dataset(model.tokenizer, cfg.dataset_size, cfg.dataset_size) - ioi_clean = ioi_clean_full[:, :-1].to(device) - ioi_corrupted = ioi_corrupted_full[:, :-1].to(device) - ioi_correct = ioi_clean_full[:, -1].to(device) - ioi_incorrect = ioi_corrupted_full[:, -1].to(device) - - base_logits, base_activations = clean_logits_and_activations(model, (layer, "residual"), ioi_clean) - - def divergence_metric(new_logits, base_logits): - B, L, V = base_logits.shape - new_logprobs = F.log_softmax(new_logits[:, -1], dim=-1) - base_logprobs = F.log_softmax(base_logits[:, -1], dim=-1) - return F.kl_div(new_logprobs, base_logprobs, log_target=True, reduction="none").sum(dim=-1).mean().item() - - def logit_diff(new_logits, base_logits): - B, L, V = base_logits.shape - correct = new_logits[:, -1, ioi_correct] - incorrect = new_logits[:, -1, ioi_incorrect] - return -(correct - incorrect).mean().item() - - l1_alphas = [1e-3, 3e-4, 1e-4] - name_fmt = "learned_r{ratio}_{l1_alpha:.0e}" - best_dicts = {} - ratios = [4] - dict_sets = [ - ( - ratio, - torch.load(f"/mnt/ssd-cluster/pythia410/tied_residual_l{layer}_r{ratio}/_79/learned_dicts.pt"), - ) - for ratio in ratios - ] - - print("evaluating dicts") - for l1_alpha in l1_alphas: - for ratio, dicts in tqdm.tqdm(dict_sets): - best_approx_dist = float("inf") - best_dict = None - for dict, hyperparams in dicts: - dist = abs(hyperparams["l1_alpha"] - l1_alpha) - if dist < best_approx_dist: - best_approx_dist = dist - best_dict = (dict, hyperparams) - - best_dicts[name_fmt.format(ratio=ratio, l1_alpha=l1_alpha)] = best_dict - - print("found satisfying dicts:", list(best_dicts.keys())) - - dictionaries = {} - dictionaries["pca"] = (pca_dict, {"pca": True}) - for name, (dict, hyperparams) in best_dicts.items(): - dictionaries[name] = (dict, hyperparams) - - tau_values = list(np.logspace(cfg.tau_min, cfg.tau_max, cfg.tau_n)) - - scores: Dict[str, List] = {} - - for name, (dict, _) in dictionaries.items(): - dict.to_device(device) - print("evaluating", name) - - active_components = filter_active_components(base_activations, dict, threshold=cfg.activity_threshold) - scores[name] = acdc_test( - model, - dict, - (layer, "residual"), - ioi_clean, - ioi_corrupted, - divergence_metric, - thresholds=tau_values, - ablation_type="ablation", - base_logits=base_logits, - ablation_handicap=True, - distance_metric=scaled_distance_to_clean, - initial_directions=active_components, - ) - - torch.save(scores, f"{cfg.output_dir}/dict_scores_layer_{layer}.pt") - torch.save(dictionaries, f"{cfg.output_dir}/dictionaries_layer_{layer}.pt") - - done_flag.value = 1 - - # torch.save(diff_mean_scores, os.path.join(BASE_FOLDER, f"diff_mean_scores_layer_{layer}.pt")) - -def bottleneck_everything_multigpu(cfg): - layers = [4, 6, 8, 10, 12, 14, 16, 18] - free_gpus = ["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"] - - import torch.multiprocessing as mp - import time - mp.set_start_method("spawn") - - processes = [] - - # while some gpus are still free - while True: - new_processes = [] - for process, gpu, done_flag in processes: - if done_flag.value == 1: - process.join() - free_gpus.append(gpu) - print(f"finished layer {process} on gpu {gpu}") - else: - new_processes.append((process, gpu, done_flag)) - - processes = new_processes - - if len(processes) == 0 and len(layers) == 0: - break - - if len(free_gpus) == 0: - time.sleep(0.1) - continue - - if len(layers) == 0: - time.sleep(0.1) - continue - - layer = layers.pop(0) - gpu = free_gpus.pop(0) - - print(f"starting layer {layer} on gpu {gpu}") - - done_flag = mp.Value("i", 0) - - process = mp.Process( - target=new_bottleneck_test, - args=(cfg, layer, gpu, done_flag), - ) - - process.start() - - processes.append((process, gpu, done_flag)) - -def erasure_test(): - torch.autograd.set_grad_enabled(False) - - model_name = "EleutherAI/pythia-410m-deduped" - - model = HookedTransformer.from_pretrained(model_name) - - device = "cuda:1" - - model.to(device) - - prompts, classes, class_tokens, sequence_lengths = generate_gender_dataset(model_name, 100, 100, model.tokenizer.pad_token_id) - prompts = prompts.to(device) - classes = classes.to(device) - sequence_lengths = sequence_lengths.to(device) - print(sequence_lengths) - class_one_hot = F.one_hot(classes, num_classes=2).float() - - def gender_erasure_metric(predictions): - predictions = predictions[torch.arange(sequence_lengths.shape[0]), sequence_lengths - 1] - predictions = predictions[:, [class_tokens[0], class_tokens[1]]] - probs = F.softmax(predictions, dim=-1) - - pred = torch.einsum("bc,bc->b", probs, class_one_hot).mean() - return torch.abs(pred - 0.5).item() + 0.5 - - layer = 12 - activation_dataset = torch.load(os.path.join(BASE_FOLDER, f"activation_data/layer_{layer}/0.pt")) - activation_dataset = activation_dataset.to(device, dtype=torch.float32) - - max_fvu = 0.05 - best_dicts = {} - ratios = [4] - dict_sets = [ - ( - ratio, - #torch.load(f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r{ratio}/_9/learned_dicts.pt"), - torch.load(f"/mnt/ssd-cluster/pythia410/tied_residual_l{layer}_r{ratio}/_79/learned_dicts.pt"), - ) - for ratio in ratios - ] - - print("evaluating dicts") - for ratio, dicts in tqdm.tqdm(dict_sets): - for dict, hyperparams in dicts: - dict.to_device(device) - sample_idxs = np.random.choice(activation_dataset.shape[0], size=50000, replace=False) - fvu = standard_metrics.fraction_variance_unexplained(dict, activation_dataset[sample_idxs]).item() - if fvu < max_fvu: - if hyperparams["dict_size"] not in best_dicts: - best_dicts[hyperparams["dict_size"]] = (fvu, hyperparams, dict) - else: - if fvu > best_dicts[hyperparams["dict_size"]][0]: - best_dicts[hyperparams["dict_size"]] = (fvu, hyperparams, dict) - - del activation_dataset - - dictionaries = {} - for dict_size, (_, hyperparams, dict) in best_dicts.items(): - dictionaries[f"learned_{dict_size}"] = (dict, hyperparams) - - leace_score, leace_edit, leace_eraser = least_squares_erasure( - model, - (layer, "residual"), - prompts, - classes, - scoring_function=gender_erasure_metric, - distance_metric=ce_distance, - sequence_lengths=sequence_lengths, - ) - - print(f"LEACE score: {leace_score:.3e}, LEACE edit: {leace_edit:.2f}") - - base_logits = model(prompts, return_type="logits") - base_score = gender_erasure_metric(base_logits) - - print(f"base score: {base_score:.3e}") - - torch.save( - (leace_score, leace_edit, base_score), - os.path.join(BASE_FOLDER, f"leace_scores_layer_{layer}.pt"), - ) - torch.save(leace_eraser, os.path.join(BASE_FOLDER, f"leace_eraser_layer_{layer}.pt")) - - scores = {} - tau_values = np.logspace(-4, 0, 10) - for name, (dict, _) in dictionaries.items(): - scores[name] = concept_ablation( - model, - dict, - (layer, "residual"), - prompts, - scoring_function=gender_erasure_metric, - scale_by_magnitude=False, - sequence_lengths=sequence_lengths, - ablation_rank="full", - ) - - torch.save(scores, os.path.join(BASE_FOLDER, f"erasure_scores_layer_{layer}.pt")) - torch.save( - dictionaries, - os.path.join(BASE_FOLDER, f"erasure_dictionaries_layer_{layer}.pt"), - ) - - -if __name__ == "__main__": - from utils import dotdict - - cfg = dotdict({ - "model_name": "EleutherAI/pythia-410m-deduped", - "dataset_size": 50, - "tau_min": -6, - "tau_max": -1, - "tau_n": 30, - "activity_threshold": 0.01, - "output_dir": "bottleneck_410m" - }) - - os.makedirs(cfg.output_dir, exist_ok=True) - +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +import copy +from functools import partial +from itertools import product +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +from concept_erasure import LeaceEraser +from datasets import load_dataset +from einops import rearrange +from PIL import Image +from sklearn.cluster import KMeans +from sklearn.manifold import TSNE +from torch.utils.data import DataLoader +from torchtyping import TensorType +from transformer_lens import HookedTransformer + +import standard_metrics +from activation_dataset import setup_data +from autoencoders.learned_dict import LearnedDict +from autoencoders.pca import BatchedPCA +from test_datasets.gender import generate_gender_dataset +from test_datasets.ioi import generate_ioi_dataset + +_batch, _sequence, _n_dict_components, _d_activation, _vocab_size = ( + None, + None, + None, + None, + None, +) # type: Tuple[None, None, None, None, None] + +BASE_FOLDER = "~/sparse_coding_aidan" + + +def logits_under_ablation( + model: HookedTransformer, + lens: LearnedDict, + location: standard_metrics.Location, + ablated_directions: List[int], + tokens: TensorType["_batch", "_sequence"], + calc_fvu: bool = False, +) -> Tuple[TensorType["_batch", "_sequence"], Optional[TensorType["_batch", "_sequence"]]]: + fvu = None + + def intervention(tensor, hook=None): + B, L, D = tensor.shape + tensor = tensor.reshape(-1, D) + codes = lens.encode(tensor) + ablation = torch.einsum( + "be,ed->bd", + codes[:, ablated_directions], + lens.get_learned_dict()[ablated_directions], + ) + ablated = tensor - ablation + + if calc_fvu: + nonlocal fvu + fvu = (ablation**2).sum() / (tensor**2).sum() + + return ablated.reshape(B, L, D) + + logits = model.run_with_hooks( + tokens, + return_type="logits", + fwd_hooks=[ + ( + standard_metrics.get_model_tensor_name(location), + intervention, + ) + ], + ) + + return logits, fvu + + +def logits_under_reconstruction( + model: HookedTransformer, + lens: LearnedDict, + location: standard_metrics.Location, + ablated_directions: List[int], + tokens: TensorType["_batch", "_sequence"], + calc_fvu: bool = False, + resample: Optional[TensorType["_batch", "_sequence", "_n_dict_components"]] = None, +) -> Tuple[TensorType["_batch", "_sequence"], Optional[TensorType["_batch", "_sequence"]]]: + fvu = None + + def intervention(tensor, hook=None): + B, L, D = tensor.shape + code = lens.encode(tensor.reshape(-1, D)) + if resample is not None: + code[:, ablated_directions] = resample.reshape(-1, code.shape[-1])[:, ablated_directions] + else: + code[:, ablated_directions] = 0.0 + reconstruction = lens.decode(code).reshape(B, L, D) + + if calc_fvu: + nonlocal fvu + residuals = reconstruction - tensor + fvu = (residuals**2).sum() / (tensor**2).sum() + + return reconstruction + + logits = model.run_with_hooks( + tokens, + return_type="logits", + fwd_hooks=[ + ( + standard_metrics.get_model_tensor_name(location), + intervention, + ) + ], + ) + + return logits, fvu + + +def bottleneck_test( + model: HookedTransformer, + lens: LearnedDict, + location: standard_metrics.Location, + tokens: TensorType["_batch", "_sequence"], + logit_metric: Callable[[TensorType["_batch", "_sequence"]], TensorType["_batch"]], + calc_fvu: bool = False, + ablation_type: Literal["ablation", "reconstruction"] = "ablation", + feature_sample_size: Optional[int] = None, +) -> List[Tuple[int, Optional[float], float]]: + # iteratively ablate away the least useful directions in the bottleneck + + remaining_directions = list(range(lens.n_dict_components())) + + results = [] + ablated_directions: List[int] = [] + + for i in tqdm.tqdm(range(lens.n_dict_components())): + min_score = None + min_direction = -1 + min_fvu = None + + features_to_test: List[int] = [] + + if feature_sample_size is not None: + if feature_sample_size < len(remaining_directions): + features_to_test = list(np.random.choice(remaining_directions, size=feature_sample_size, replace=False)) + else: + features_to_test = remaining_directions + else: + features_to_test = remaining_directions + + for direction in features_to_test: + if ablation_type == "ablation": + logits, fvu = logits_under_ablation( + model, + lens, + location, + ablated_directions + [direction], + tokens, + calc_fvu=calc_fvu, + ) + elif ablation_type == "reconstruction": + logits, fvu = logits_under_reconstruction( + model, + lens, + location, + ablated_directions + [direction], + tokens, + calc_fvu=calc_fvu, + ) + else: + raise ValueError(f"Unknown ablation type '{ablation_type}'") + + score = logit_metric(logits).item() + + if calc_fvu: + assert fvu is not None + fvu_item: float = fvu.item() + + if min_score is None or score < min_score: + min_score = score + min_direction = direction + min_fvu = fvu_item + + assert min_direction != -1 + assert min_score is not None + results.append((min_direction, min_fvu, min_score)) + ablated_directions.append(min_direction) + remaining_directions.remove(min_direction) + + return results + + +def resample_ablation_hook( + lens: LearnedDict, + features_to_ablate: List[int], + corrupted_codes: Optional[TensorType["_batch", "_sequence", "_n_dict_components"]] = None, + ablation_type: Literal["ablation", "reconstruction"] = "ablation", + handicap: Optional[TensorType["_batch", "_sequence", "_d_activation"]] = None, + ablation_rank: Literal["full", "partial"] = "partial", + ablation_mask: Optional[TensorType["_batch", "_sequence"]] = None, +): + if corrupted_codes is None: + corrupted_codes_ = None + else: + corrupted_codes_ = corrupted_codes.reshape(-1, corrupted_codes.shape[-1]) + + activation_dict = {"output": None} + + def reconstruction_intervention(tensor, hook=None): + nonlocal activation_dict + B, L, D = tensor.shape + code = lens.encode(tensor.reshape(-1, D)) + + if corrupted_codes_ is None: + code[:, features_to_ablate] = 0.0 + else: + code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] + + reconstr = lens.decode(code).reshape(tensor.shape) + + if handicap is not None: + output = reconstr + handicap + else: + output = reconstr + + if ablation_mask is not None: + output[~ablation_mask] = tensor[~ablation_mask] + + activation_dict["output"] = output.clone() + return output + + def partial_ablation_intervention(tensor, hook=None): + nonlocal activation_dict + B, L, D = tensor.shape + code = lens.encode(tensor.reshape(-1, D)) + + ablation_code = torch.zeros_like(code) + + if corrupted_codes_ is None: + ablation_code[:, features_to_ablate] = -code[:, features_to_ablate] + else: + ablation_code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - code[:, features_to_ablate] + + ablation = lens.decode(ablation_code).reshape(tensor.shape) + + if handicap is not None: + output = tensor + ablation + handicap + else: + output = tensor + ablation + + if ablation_mask is not None: + output[~ablation_mask] = tensor[~ablation_mask] + + activation_dict["output"] = output.clone() + return output + + def full_ablation_intervention(tensor, hook=None): + nonlocal activation_dict + B, L, D = tensor.shape + code = torch.einsum("bd,nd->bn", tensor.reshape(-1, D), lens.get_learned_dict()) + + ablation_code = torch.zeros_like(code) + + if corrupted_codes_ is None: + ablation_code[:, features_to_ablate] = -code[:, features_to_ablate] + else: + ablation_code[:, features_to_ablate] = corrupted_codes_[:, features_to_ablate] - code[:, features_to_ablate] + + ablation = torch.einsum("bn,nd->bd", ablation_code, lens.get_learned_dict()).reshape(tensor.shape) + output = tensor + ablation + + if ablation_mask is not None: + output[~ablation_mask] = tensor[~ablation_mask] + + activation_dict["output"] = output.clone() + return tensor + ablation + + ablation_func = None + if ablation_type == "reconstruction": + ablation_func = reconstruction_intervention + elif ablation_type == "ablation" and ablation_rank == "partial": + ablation_func = partial_ablation_intervention + elif ablation_type == "ablation" and ablation_rank == "full": + ablation_func = full_ablation_intervention + else: + raise ValueError(f"Unknown ablation type '{ablation_type}' with rank '{ablation_rank}'") + + return ablation_func, activation_dict + + +def resample_ablation( + model: HookedTransformer, + lens: LearnedDict, + location: standard_metrics.Location, + clean_tokens: TensorType["_batch", "_sequence"], + features_to_ablate: List[int], + corrupted_codes: Optional[TensorType["_batch", "_sequence", "_n_dict_components"]] = None, + ablation_type: Literal["ablation", "reconstruction"] = "ablation", + handicap: Optional[TensorType["_batch", "_sequence", "_d_activation"]] = None, + ablation_rank: Literal["full", "partial"] = "partial", + ablation_mask: Optional[TensorType["_batch", "_sequence"]] = None, + **kwargs, +) -> Tuple[Any, TensorType["_batch", "_sequence", "_d_activation"]]: + ablation_func, activation_dict = resample_ablation_hook( + lens, + features_to_ablate, + corrupted_codes=corrupted_codes, + ablation_type=ablation_type, + handicap=handicap, + ablation_rank=ablation_rank, + ablation_mask=ablation_mask, + ) + + logits = model.run_with_hooks( + clean_tokens, + fwd_hooks=[ + ( + standard_metrics.get_model_tensor_name(location), + ablation_func, + ) + ], + **kwargs, + ) + + return logits, activation_dict["output"] + + +def activation_info( + model: HookedTransformer, + lens: LearnedDict, + location: standard_metrics.Location, + tokens: TensorType["_batch", "_sequence"], + ablation_type: Literal["ablation", "reconstruction"] = "ablation", + replacement_residuals: Optional[TensorType["_batch", "_sequence", "_d_activation"]] = None, +) -> Tuple[ + TensorType["_batch", "_sequence", "_d_activation"], + TensorType["_batch", "_sequence", "_n_dict_components"], + TensorType["_batch", "_sequence", "_d_activation"], + TensorType["_batch", "_sequence", "_vocab_size"], +]: + residuals = None + codes = None + activations = None + logits = None + + def intervention(tensor, hook=None): + nonlocal residuals, codes, activations + B, L, D = tensor.shape + activations = tensor.clone() + code = lens.encode(lens.center(tensor.reshape(-1, D))) + codes = code.reshape(B, L, -1).clone() + output = lens.uncenter(lens.decode(code)).reshape(tensor.shape) + residuals = tensor - output + + if ablation_type == "reconstruction": + return output + else: + if replacement_residuals is not None: + return output + replacement_residuals + else: + return tensor + + logits = model.run_with_hooks( + tokens, + fwd_hooks=[ + ( + standard_metrics.get_model_tensor_name(location), + intervention, + ) + ], + return_type="logits", + ) + + return residuals, codes, activations, logits + + +def scaled_distance_to_clean(clean_activation, corrupted_activation, activation): + total_dist = torch.norm(clean_activation - corrupted_activation, dim=(-1, -2)) + dist = torch.norm(clean_activation - activation, dim=(-1, -2)) + return dist / total_dist + + +def dot_difference_metric(clean_activation, corrupted_activation, activation): + dataset_diff_vector = corrupted_activation - clean_activation + diff_vector = activation - clean_activation + return torch.einsum("bld,bld->b", diff_vector, dataset_diff_vector) / torch.norm(dataset_diff_vector, dim=(-1, -2)) ** 2 + +def acdc_intervention( + model: HookedTransformer, + lens: LearnedDict, + location: standard_metrics.Location, + clean_codes: TensorType["_batch", "_sequence", "_n_dict_components"], + corrupted_tokens: TensorType["_batch", "_sequence"], + ablated_directions: List[int], +): + activation = None + def intervention(tensor, hook=None): + nonlocal activation + B, L, D = tensor.shape + _, _, N = clean_codes.shape + + centered_tensor = lens.center(tensor.reshape(-1, D)) + + corrupted_codes = lens.encode(centered_tensor).reshape(B, L, -1) + + corrupted_codes_to_ablate = torch.zeros_like(corrupted_codes) + corrupted_codes_to_ablate[:, :, ablated_directions] = corrupted_codes[:, :, ablated_directions] + + corrupted_difference = lens.decode(corrupted_codes_to_ablate.reshape(-1, N)) + + clean_codes_to_ablate = torch.zeros_like(clean_codes) + clean_codes_to_ablate[:, :, ablated_directions] = clean_codes[:, :, ablated_directions] + + clean_difference = lens.decode(clean_codes_to_ablate.reshape(-1, N)) + + edited_centered_tensor = centered_tensor - corrupted_difference + clean_difference + activation = lens.uncenter(edited_centered_tensor).reshape(B, L, D) + return activation.clone() + + logits = model.run_with_hooks( + corrupted_tokens, + fwd_hooks=[ + ( + standard_metrics.get_model_tensor_name(location), + intervention, + ) + ], + return_type="logits", + ) + + return logits, activation + +def acdc_test( + model: HookedTransformer, + lens: LearnedDict, + location: standard_metrics.Location, + clean_tokens: TensorType["_batch", "_sequence"], + corrupted_tokens: TensorType["_batch", "_sequence"], + logit_metric: Callable[ + [ + TensorType["_batch", "_sequence", "_vocab_size"], + TensorType["_batch", "_sequence", "_vocab_size"], + ], + float, + ], + thresholds: List[float] = [0.05], + base_logits: Optional[TensorType["_batch", "_sequence", "_vocab_size"]] = None, + ablation_handicap: bool = False, + distance_metric: Callable[ + [ + TensorType["_batch", "_sequence", "_d_activation"], + TensorType["_batch", "_sequence", "_d_activation"], + TensorType["_batch", "_sequence", "_d_activation"], + ], + TensorType["_batch"], + ] = scaled_distance_to_clean, + initial_directions: Optional[List[int]] = None, +) -> List[Tuple[List[int], float, float]]: + ablation_type: Literal["ablation"] = "ablation" + + if initial_directions is None: + initial_directions = list(range(lens.n_dict_components())) + + ablated_directions = [x for x in range(lens.n_dict_components()) if x not in initial_directions] + remaining_directions = list(initial_directions) + + _, corrupted_codes, corrupted_activation, _ = activation_info( + model, + lens, + location, + corrupted_tokens, + ablation_type=ablation_type + ) + + _, clean_codes, clean_activation, _ = activation_info( + model, + lens, + location, + clean_tokens, + ablation_type=ablation_type, + ) + + base_logits = model( + clean_tokens, + return_type="logits", + ) + + scores: List[Any] = [] + + zero_logits, zero_activation = acdc_intervention( + model, + lens, + location, + clean_codes, + corrupted_tokens, + remaining_directions, + ) + + zero_divergence = logit_metric(zero_logits, base_logits) + zero_distance = distance_metric(clean_activation, corrupted_activation, zero_activation) + + scores.append(([], zero_divergence, zero_distance.mean().item())) + + prev_divergence = zero_divergence + + #print(ablated_directions, remaining_directions) + + for tau in sorted(thresholds): + if len(remaining_directions) > 0: + activation = None + + assert len(ablated_directions) + len(remaining_directions) == lens.n_dict_components() + + for i in tqdm.tqdm(remaining_directions.copy()): + #logits, activation = resample_ablation( + # model, + # lens, + # location, + # clean_tokens, + # corrupted_codes=corrupted_codes, + # features_to_ablate=ablated_directions + [i], + # return_type="logits", + # ablation_type=ablation_type, + # handicap=handicap, + #) + logits, activation = acdc_intervention( + model, + lens, + location, + clean_codes, + corrupted_tokens, + [x for x in remaining_directions if x != i], + ) + + divergence = logit_metric(logits, base_logits) + + if divergence - prev_divergence < tau: + prev_divergence = divergence + ablated_directions.append(i) + remaining_directions.remove(i) + + distance = distance_metric(clean_activation, corrupted_activation, activation) + scores.append((remaining_directions.copy(), prev_divergence, distance.mean().item())) + + print(f"graph size: {len(remaining_directions)} div: {prev_divergence} edit: {distance.mean().item()}") + + full_logits, full_activation = acdc_intervention( + model, + lens, + location, + clean_codes, + corrupted_tokens, + [], + ) + + full_divergence = logit_metric(full_logits, base_logits) + full_distance = distance_metric(clean_activation, corrupted_activation, full_activation) + + scores.append((list(range(lens.n_dict_components())), full_divergence, full_distance.mean().item())) + + return scores + +def ce_distance(clean_activation, activation): + return torch.linalg.norm(clean_activation - activation, dim=(-1, -2)) + + +def ablation_mask_from_seq_lengths( + seq_lengths: TensorType["_batch"], + max_length: int, +) -> TensorType["_batch", "_sequence"]: + B = seq_lengths.shape[0] + mask = torch.zeros((B, max_length), dtype=torch.bool) + for i in range(B): + mask[i, : seq_lengths[i]] = True + return mask + +def clean_logits_and_activations( + model: HookedTransformer, + location: standard_metrics.Location, + dataset: TensorType["_batch", "_sequence"], +): + base_logits, activation_cache = model.run_with_cache( + dataset, + names_filter=lambda name: name == standard_metrics.get_model_tensor_name(location), + return_type="logits", + ) + return base_logits, activation_cache[standard_metrics.get_model_tensor_name(location)] + +def new_bottleneck_test(cfg, layer, device, done_flag): + torch.autograd.set_grad_enabled(False) + + # Train PCA + + activation_dataset = torch.load(f"activation_data/layer_{layer}/0.pt") + activation_dataset = activation_dataset.to(device, dtype=torch.float32) + + pca = BatchedPCA(n_dims=activation_dataset.shape[-1], device=device) + batch_size = 2048 + + print("training pca") + for i in tqdm.trange(0, activation_dataset.shape[0], batch_size): + j = min(i + batch_size, activation_dataset.shape[0]) + pca.train_batch(activation_dataset[i:j]) + + pca_dict = pca.to_rotation_dict(activation_dataset.shape[-1]) + #pca_dict_nz = + + pca_dict.to_device(device) + + del activation_dataset + + # Load model + + model = HookedTransformer.from_pretrained(cfg.model_name) + + model.to(device) + + ioi_clean_full, ioi_corrupted_full = generate_ioi_dataset(model.tokenizer, cfg.dataset_size, cfg.dataset_size) + ioi_clean = ioi_clean_full[:, :-1].to(device) + ioi_corrupted = ioi_corrupted_full[:, :-1].to(device) + ioi_correct = ioi_clean_full[:, -1].to(device) + ioi_incorrect = ioi_corrupted_full[:, -1].to(device) + + def divergence_metric(new_logits, base_logits): + B, L, V = base_logits.shape + new_logprobs = F.log_softmax(new_logits[:, -1], dim=-1) + base_logprobs = F.log_softmax(base_logits[:, -1], dim=-1) + return F.kl_div(new_logprobs, base_logprobs, log_target=True, reduction="batchmean").item() + + def logit_diff(new_logits, base_logits): + B, L, V = base_logits.shape + correct = new_logits[:, -1, ioi_correct] + incorrect = new_logits[:, -1, ioi_incorrect] + return -(correct - incorrect).mean().item() + + l1_alphas = [1e-3, 3e-4, 1e-4] + name_fmt = "learned_r{ratio}_{l1_alpha:.0e}" + best_dicts = {} + ratios = [4] + dict_sets = [ + ( + ratio, + #torch.load(f"/mnt/ssd-cluster/pythia410/tied_residual_l{layer}_r{ratio}/_79/learned_dicts.pt"), + torch.load(f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r{ratio}/_9/learned_dicts.pt") + ) + for ratio in ratios + ] + + print("evaluating dicts") + for l1_alpha in l1_alphas: + for ratio, dicts in tqdm.tqdm(dict_sets): + best_approx_dist = float("inf") + best_dict = None + for dict, hyperparams in dicts: + dist = abs(hyperparams["l1_alpha"] - l1_alpha) + if dist < best_approx_dist: + best_approx_dist = dist + best_dict = (dict, hyperparams) + + best_dicts[name_fmt.format(ratio=ratio, l1_alpha=l1_alpha)] = best_dict + + print("found satisfying dicts:", list(best_dicts.keys())) + + dictionaries = {} + dictionaries["pca"] = (pca_dict, {"pca": True}) + for name, (dict, hyperparams) in best_dicts.items(): + dictionaries[name] = (dict, hyperparams) + + tau_values = list(np.linspace(0, np.exp(cfg.tau_lin_max), cfg.tau_n_lin)[1:]) + list(np.logspace(cfg.tau_lin_max, cfg.tau_log_max, cfg.tau_n_log)[1:]) + + scores: Dict[str, List] = {} + + for name, (dict, _) in dictionaries.items(): + dict.to_device(device) + print("evaluating", name) + + #active_components = filter_active_components(base_activations, dict, threshold=cfg.activity_threshold) + scores[name] = acdc_test( + model, + dict, + (layer, "residual"), + ioi_clean, + ioi_corrupted, + logit_metric=divergence_metric, + thresholds=tau_values, + distance_metric=scaled_distance_to_clean, + # initial_directions=active_components, + ) + + torch.save(scores, f"{cfg.output_dir}/dict_scores_layer_{layer}.pt") + torch.save(dictionaries, f"{cfg.output_dir}/dictionaries_layer_{layer}.pt") + + done_flag.value = 1 + + # torch.save(diff_mean_scores, os.path.join(BASE_FOLDER, f"diff_mean_scores_layer_{layer}.pt")) + +def bottleneck_everything_multigpu(cfg): + layers = [3] + free_gpus = ["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"] + + import torch.multiprocessing as mp + import time + mp.set_start_method("spawn") + + processes = [] + + # while some gpus are still free + while True: + new_processes = [] + for process, gpu, done_flag in processes: + if done_flag.value == 1: + process.join() + free_gpus.append(gpu) + print(f"finished layer {process} on gpu {gpu}") + else: + new_processes.append((process, gpu, done_flag)) + + processes = new_processes + + if len(processes) == 0 and len(layers) == 0: + break + + if len(free_gpus) == 0: + time.sleep(0.1) + continue + + if len(layers) == 0: + time.sleep(0.1) + continue + + layer = layers.pop(0) + gpu = free_gpus.pop(0) + + print(f"starting layer {layer} on gpu {gpu}") + + done_flag = mp.Value("i", 0) + + process = mp.Process( + target=new_bottleneck_test, + args=(cfg, layer, gpu, done_flag), + ) + + process.start() + + processes.append((process, gpu, done_flag)) + +if __name__ == "__main__": + from utils import dotdict + + cfg = dotdict({ + "model_name": "EleutherAI/pythia-70m-deduped", + "dataset_size": 25, + "tau_lin_max": -3.5, + "tau_log_max": -2.5, + "tau_n_lin": 2, + "tau_n_log": 10, + "output_dir": "bottleneck_70m" + }) + + os.makedirs(cfg.output_dir, exist_ok=True) + bottleneck_everything_multigpu(cfg) \ No newline at end of file diff --git a/experiments/deep_ae_testing.py b/experiments/deep_ae_testing.py index bf2bce0..dda1399 100644 --- a/experiments/deep_ae_testing.py +++ b/experiments/deep_ae_testing.py @@ -17,7 +17,7 @@ def forward(self, z, x, x_hat): in_vec = torch.cat([z, x, x_hat], dim=-1) h = F.gelu(self.f_in(in_vec)) z = z + self.f_out(h) - return F.softplus(z) + return z class SparseAutoencoder(nn.Module): def __init__(self, d_activation, n_hidden, d_latent): @@ -56,6 +56,41 @@ def losses(self, x, c, x_hat, l1_coef): l1_reg = l1_coef * torch.linalg.norm(c, ord=1, dim=-1).mean() return reconstr + l1_reg, reconstr, l1_reg +class NonlinearSparseAutoencoder(nn.Module): + def __init__(self, d_activation, d_hidden, d_latent): + super().__init__() + self.d_activation = d_activation + self.d_latent = d_latent + + self.encoder = nn.Sequential( + nn.Linear(d_activation, d_hidden), + nn.GELU(), + nn.Linear(d_hidden, d_hidden), + nn.GELU(), + nn.Linear(d_hidden, d_latent), + nn.Softplus(beta=100) + ) + + self.decoder = nn.Sequential( + nn.Linear(d_latent, d_hidden), + nn.GELU(), + nn.Linear(d_hidden, d_hidden), + nn.GELU(), + nn.Linear(d_hidden, d_activation) + ) + + def forward(self, x): + c = self.encoder(x) + # scale so that c has unit norm + c = c / torch.linalg.norm(c, ord=2, dim=-1, keepdim=True) + x_hat = self.decoder(c) + return x_hat, c + + def losses(self, x, c, x_hat, l1_coef): + reconstr = F.mse_loss(x, x_hat) + l1_reg = l1_coef * torch.linalg.norm(c, ord=1, dim=-1).mean() + return reconstr + l1_reg, reconstr, l1_reg + def l1_schedule(max_l1=1e-3, warmup_steps=1000): def schedule(step): if step < warmup_steps: @@ -69,7 +104,7 @@ def schedule(step): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = SparseAutoencoder(512, 4, 2048).to(device) + model = NonlinearSparseAutoencoder(512, 1024, 2048).to(device) secrets = json.load(open("secrets.json")) wandb.login(key=secrets["wandb_key"]) @@ -80,7 +115,7 @@ def schedule(step): entity="sparse_coding", ) - n_epochs = 10 + n_epochs = 100 batch_size = 256 optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5) @@ -104,8 +139,12 @@ def schedule(step): loss, reconstr, l1_reg = model.losses(x, c, x_hat, schedule(steps)) + sparsity = {} + fvu = (x - x_hat).pow(2).sum() / (x - x.mean()).pow(2).sum() - sparsity = (c > 1e-5).float().sum(dim=-1).mean() + sparsity["1e-5"] = (c > 1e-5).float().sum(dim=-1).mean() + sparsity["1e-6"] = (c > 1e-6).float().sum(dim=-1).mean() + sparsity["1e-7"] = (c > 1e-7).float().sum(dim=-1).mean() optimizer.zero_grad() loss.backward() @@ -116,7 +155,9 @@ def schedule(step): "reconstr": reconstr.item(), "l1_reg": l1_reg.item(), "fvu": fvu.item(), - "sparsity": sparsity.item() + "sparsity_1e-5": sparsity["1e-5"].item(), + "sparsity_1e-6": sparsity["1e-6"].item(), + "sparsity_1e-7": sparsity["1e-7"].item(), }) steps += 1 \ No newline at end of file diff --git a/generate_test_data.py b/generate_test_data.py index f4803ad..d69242c 100644 --- a/generate_test_data.py +++ b/generate_test_data.py @@ -14,35 +14,50 @@ parser.add_argument("--n_chunks", type=int, default=1) parser.add_argument("--skip_chunks", type=int, default=0) parser.add_argument("--chunk_size_gb", type=float, default=2) + parser.add_argument("--chunk_size_acts", type=int, default=8192 * 1024) parser.add_argument("--dataset", type=str, default="NeelNanda/pile-10k") parser.add_argument("--layers", type=int, nargs="+", default=[2]) + parser.add_argument("--locations", type=str, nargs="+", default=[]) parser.add_argument("--location", type=str, default="residual") parser.add_argument("--dataset_folder", type=str, default="activation_data") parser.add_argument("--layer_folder_fmt", type=str, default="layer_{layer}") - parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--use_tl", type=bool, default=True) args = parser.parse_args() - model = HookedTransformer.from_pretrained(args.model, device=args.device) - tokenizer = AutoTokenizer.from_pretrained(args.model) - - layer_folders = [args.layer_folder_fmt.format(layer=layer) for layer in args.layers] - - os.makedirs(args.dataset_folder, exist_ok=True) - for layer_folder in layer_folders: - os.makedirs(os.path.join(args.dataset_folder, layer_folder), exist_ok=True) - - dataset_folders = [os.path.join(args.dataset_folder, layer_folder) for layer_folder in layer_folders] - - activation_dataset.setup_data( - tokenizer, - model, - args.dataset, - dataset_folders, - layer=args.layers, - layer_loc=args.location, - n_chunks=args.n_chunks, - chunk_size_gb=args.chunk_size_gb, - device=args.device, - skip_chunks=args.skip_chunks, - ) \ No newline at end of file + if args.use_tl: + model = HookedTransformer.from_pretrained(args.model, device=args.device) + tokenizer = AutoTokenizer.from_pretrained(args.model) + + layer_folders = [args.layer_folder_fmt.format(layer=layer) for layer in args.layers] + + os.makedirs(args.dataset_folder, exist_ok=True) + for layer_folder in layer_folders: + os.makedirs(os.path.join(args.dataset_folder, layer_folder), exist_ok=True) + + dataset_folders = [os.path.join(args.dataset_folder, layer_folder) for layer_folder in layer_folders] + + activation_dataset.setup_data( + tokenizer, + model, + args.dataset, + dataset_folders, + layer=args.layers, + layer_loc=args.location, + n_chunks=args.n_chunks, + chunk_size_gb=args.chunk_size_gb, + device=args.device, + skip_chunks=args.skip_chunks, + ) + else: + activation_dataset.setup_data_new( + args.model, + args.dataset, + args.dataset_folder, + args.locations, + args.chunk_size_act, + args.n_chunks, + skip_chunks=args.skip_chunks, + device=args.device, + ) \ No newline at end of file diff --git a/interpret.py b/interpret.py index 920e986..567a643 100644 --- a/interpret.py +++ b/interpret.py @@ -57,7 +57,7 @@ N_SPLITS = 4 TOTAL_EXAMPLES = OPENAI_EXAMPLES_PER_SPLIT * N_SPLITS REPLACEMENT_CHAR = "�" -MAX_CONCURRENT = None +MAX_CONCURRENT: Any = None BASE_FOLDER = "/mnt/ssd-cluster/sweep_interp" diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..31ea411 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +ignore_missing_imports = True +mypy_path = . +var_annotated = False \ No newline at end of file diff --git a/optimizers/adam.py b/optimizers/adam.py deleted file mode 100644 index c6bde95..0000000 --- a/optimizers/adam.py +++ /dev/null @@ -1,68 +0,0 @@ -import optree -import torch - -# functional adam optimizer -# torchopt.adam doesn't implement -# parameter groups annoyingly, so had -# to reimplement adam - -# ref: -# ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION -# Diederik P. Kingma, Jimmy Lei Ba -# https://arxiv.org/pdf/1412.6980.pdf - - -class Adam: - def __init__(self, lr_groups, betas: tuple[float, float], eps: float): - self.lr_groups = lr_groups - self.betas = betas - self.eps = eps - - def init(self, params): - device = optree.tree_flatten(params)[0][0].device - - first_moments = optree.tree_map(torch.zeros_like, params) - second_moments = optree.tree_map(torch.zeros_like, params) - - return { - "first_moments": first_moments, - "second_moments": second_moments, - "step": torch.tensor(0, dtype=torch.long, device=device), - } - - def update(self, grads, states): - first_moments = states["first_moments"] - second_moments = states["second_moments"] - step = states["step"] + 1 - - updated_first_moments = optree.tree_map( - lambda m, g: self.betas[0] * m + (1.0 - self.betas[0]) * g, - first_moments, - grads, - ) - updated_second_moments = optree.tree_map( - lambda v, g: self.betas[1] * v + (1.0 - self.betas[1]) * (g * g), - second_moments, - grads, - ) - - bias_correction_0 = 1 - self.betas[0] ** step - bias_correction_1 = 1 - self.betas[1] ** step - - lr_scaling = torch.sqrt(bias_correction_1) / bias_correction_0 - - corrected_first_moments = optree.tree_map(lambda m: m / bias_correction_0, updated_first_moments) - corrected_second_moments = optree.tree_map(lambda v: v / bias_correction_1, updated_second_moments) - - updates = optree.tree_map( - lambda lr, m_hat, v_hat: -lr * m_hat / (torch.sqrt(v_hat) + self.eps), - self.lr_groups, - corrected_first_moments, - corrected_second_moments, - ) - - return updates, { - "first_moments": updated_first_moments, - "second_moments": updated_second_moments, - "step": step, - } diff --git a/optimizers/sgdm.py b/optimizers/sgdm.py deleted file mode 100644 index e5c1301..0000000 --- a/optimizers/sgdm.py +++ /dev/null @@ -1,28 +0,0 @@ -import optree -import torch - -# functional SGD + momentum optimizer - - -class SGDM: - def __init__(self, lr_groups, momentum: float): - self.lr_groups = lr_groups - self.momentum = momentum - - def init(self, params): - device = optree.tree_flatten(params)[0][0].device - - momentum = optree.tree_map(torch.zeros_like, params) - - return {"momentum": momentum} - - def update(self, grads, states): - momentum = states["momentum"] - - updated_momentum = optree.tree_map(lambda m, g: self.momentum * m + (1 - m) * g, momentum, grads) - - corrected_momentum = optree.tree_map(lambda m: m / (1 - self.momentum), updated_momentum) - - updates = optree.tree_map(lambda lr, m: -lr * m, self.lr_groups, corrected_momentum) - - return updates, {"momentum": updated_momentum} diff --git a/plotting/__init__.py b/plotting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plotting/bottleneck_plot.py b/plotting/bottleneck_plot.py index f6c4ca8..768b16e 100644 --- a/plotting/bottleneck_plot.py +++ b/plotting/bottleneck_plot.py @@ -1,89 +1,145 @@ -import sys - -sys.path.append("..") - -import os -import shutil - -from itertools import product -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import torch - -def plot_bottleneck_scores(layer): - base_folder = "sparse_coding_aidan" - scores = torch.load(f"/mnt/ssd-cluster/bottleneck_410m/dict_scores_layer_{layer}.pt") - - print(scores.keys()) - - #diff_mean_scores = torch.load("diff_mean_scores_layer_2.pt") - - fig, ax = plt.subplots() - - ax.grid(True, alpha=0.5, linestyle="dashed") - ax.set_axisbelow(True) - - #colors = ["red", "blue", "green", "orange", "purple", "brown", "pink", "gray", "olive", "cyan"] - #markers = ["x", "+", "*", "o", "v", "^", "<", ">", "s", "."] - #styles = ["solid", "dashed", "dashdot", "dotted"] - - xs, ys, keys = [], [], [] - for key, score in scores.items(): - graph, div, corruption = zip(*sorted(score, key=lambda x: len(x[0]))) - graph_size = [len(g) for g in graph] - print(graph_size, div, corruption) - xs.append(graph_size) - ys.append(div) - keys.append(key) - - for key, x, y in zip(keys, xs, ys): - style = "dashed" - if key == "pca": - label = "PCA" - color = "Reds" - style = "dotted" - c = 0.5 - elif key == "learned_r4_1e-03": - label = "Dict. alpha=1e-3" - color = "Blues" - c = 0.7 - elif key == "learned_r4_3e-04": - label = "Dict. alpha=3e-4" - color = "Blues" - c = 0.5 - elif key == "learned_r4_1e-04": - label = "Dict. alpha=1e-4" - color = "Blues" - c = 0.3 - - cmap = plt.get_cmap(color) - c = cmap(c) - - ax.plot(x, y, color=c, linestyle=style, label=label, alpha=1) - - ax.set_xlabel("No. Uncorrupted Features") - ax.set_ylabel("KL-Divergence From Base") - - ax.set_title(f"Precision-Complexity Tradeoff Curve - Layer {layer}") - - ax.set_xlim(0, 1024) - - ax.legend( - loc="upper right", - framealpha=1, - ) - - #shutil.rmtree("graphs", ignore_errors=True) - #os.mkdir("graphs", exist_ok=True) - - plt.savefig(f"graphs/bottleneck_scores_layer_{layer}.png") - - plt.close(fig) - del fig, ax - -if __name__ == "__main__": - layers = [4, 6, 8, 10, 12, 14, 16, 18] - - for layer in layers: - plot_bottleneck_scores(layer) \ No newline at end of file +import sys + +sys.path.append("..") + +import os +import shutil + +from itertools import product +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch + +stylemap = { + "dict_1.00e-03": ("Dict, α=1e-03, R=4", ("dashed", "D"), "Blues", 1.0), + "dict_3.00e-04": ("Dict, α=3e-04, R=4", ("dashed", "D"), "Blues", 0.8), + "dict_1.00e-04": ("Dict, α=1e-04, R=4", ("dashed", "D"), "Blues", 0.6), + "dict_0.00e+00": ("Dict, α=0, R=4", ("dashed", "D"), "Blues", 0.4), + "pca_rot": ("PCA", ("dashdot", "o"), "Reds", 0.3), + "pca_pve": ("Nonneg. PCA", ("dashdot", "o"), "Oranges", 0.7), +} + +def plot_bottleneck_scores(layer, title=False): + base_folder = "sparse_coding_aidan" + scores = torch.load(f"ioi_feat/feat_ident_results_l{layer}.pt") + + xs, ys, zs, keys = [], [], [], [] + for key, score in scores: + graph, div, corruption = zip(*sorted(score, key=lambda x: len(x[0]))) + graph_size = [len(g) for g in graph] + print(graph_size, div, corruption) + xs.append(graph_size) + zs.append(corruption) + ys.append(div) + keys.append(key) + + colors = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys"] + #markers = ["x", "+", "*", "o", "v", "^", "<", ">", "s", "."] + styles = ["solid", "dashed", "dashdot", "dotted"] + markers = ["x", "+", "*", "o"] + + #print(scores.keys()) + + #diff_mean_scores = torch.load("diff_mean_scores_layer_2.pt") + + fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True, figsize=(6.0 * 2 if not title else 4.8 * 2, 4.8)) + + if title: + fig.suptitle(f"Layer {layer}") + + ax1.grid(True, alpha=0.5, linestyle="dashed") + ax1.set_axisbelow(True) + + for key, x, y, (style, color) in zip(keys, xs, ys, product(styles, colors)): + #style = "dashed" + #if key == "pca": + # label = "PCA" + # color = "Reds" + # style = "dotted" + # c = 0.5 + #elif key == "learned_r4_1e-03": + # label = "Dict. alpha=1e-3" + # color = "Blues" + # c = 0.7 + #elif key == "learned_r4_3e-04": + # label = "Dict. alpha=3e-4" + # color = "Blues" + # c = 0.5 + #elif key == "learned_r4_1e-04": + # label = "Dict. alpha=1e-4" + # color = "Blues" + # c = 0.3 + + #c = 0.5 + #label = key + + label, (style, _), color, c = stylemap[key] + + cmap = plt.get_cmap(color) + c = cmap(c) + + ax1.plot(x, y, color=c, linestyle=style, label=label, alpha=1) + + ax1.set_xlabel("Number of Patched Features") + ax1.set_ylabel("KL Divergence From Target") + + #ax.set_xscale("log") + + #ax.set_yscale("log") + + ax1.set_xlim(0, 512) + + #ax.set_ylim(0, 1.2) + + #shutil.rmtree("graphs", ignore_errors=True) + #os.mkdir("graphs", exist_ok=True) + + ax2.grid(True, alpha=0.5, linestyle="dashed") + ax2.set_axisbelow(True) + + for key, x, y, (marker, color) in zip(keys, zs, ys, product(markers, colors)): + #c = 0.5 + #label = key + + label, (style, marker), color, c = stylemap[key] + + cmap = plt.get_cmap(color) + c = cmap(c) + + ax2.plot(x, y, color=c, linestyle=style, marker=matplotlib.markers.MarkerStyle(marker).scaled(0.75), alpha=0.5, label=label) + + ax2.set_xlabel("Mean Edit Magnitude") + #ax2.set_ylabel("KL Divergence From Target") + + #ax.set_xscale("log") + + #ax.set_yscale("log") + + #ax.set_xlim(0.6, 1.2) + + #ax.set_ylim(0, 1.2) + + ax2.legend( + #loc="upper left", + framealpha=1, + ) + + #shutil.rmtree("graphs", ignore_errors=True) + #os.mkdir("graphs", exist_ok=True) + + plt.tight_layout() + + plt.savefig(f"graphs_ioi/feature_ident_curve_l{layer}.png") + + plt.close(fig) + +if __name__ == "__main__": + TITLE = True + plot_bottleneck_scores(3, title=TITLE) + plot_bottleneck_scores(5, title=TITLE) + plot_bottleneck_scores(7, title=TITLE) + plot_bottleneck_scores(11, title=TITLE) + plot_bottleneck_scores(15, title=TITLE) + plot_bottleneck_scores(19, title=TITLE) + plot_bottleneck_scores(23, title=TITLE) \ No newline at end of file diff --git a/plotting/erasure_plot.py b/plotting/erasure_plot.py index 2318b85..a6ab71e 100644 --- a/plotting/erasure_plot.py +++ b/plotting/erasure_plot.py @@ -56,7 +56,6 @@ def plot_bottleneck_scores(): fig.savefig(os.path.join(graphs_folder, "bottleneck_scores.png")) - def plot_erasure_scores(): graphs_folder = os.path.join(BASE_FOLDER, "graphs") shutil.rmtree(graphs_folder, ignore_errors=True) @@ -127,8 +126,77 @@ def plot_erasure_scores(): plt.savefig(os.path.join(graphs_folder, "erasure_by_kl_div.png")) +def plot_leace_scores_across_depth(title="various settings", name="various_settings"): + layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 23] + + files = [ + torch.load(f"output_erasure_410m/general_{layer}_gender.pt") + for layer in layers + ] + + from matplotlib.legend_handler import HandlerTuple + + leace_scores = [files[l]["leace"][0] for l in range(len(layers))] + mean_scores = [files[l]["mean"][0] for l in range(len(layers))] + mean_affine_scores = [files[l]["mean_affine"][0] for l in range(len(layers))] + + base_score = files[0]["base"] + + fig, (ax2, ax1) = plt.subplots(2, 1, sharex=True) + + ax1.grid(True, alpha=0.5, linestyle="dashed") + ax1.set_axisbelow(True) + + ax1.plot(leace_scores, label="LEACE", marker="+") + ax1.plot(mean_scores, label="Mean", marker="x") + ax1.plot(mean_affine_scores, label="Mean, Affine", marker=".") + + ax1.set_xticks(range(len(layers))) + ax1.set_xticklabels(layers) + + ax1.axhline(y=base_score, color="red", linestyle="dashed", label="Base Perf.") + ax1.axhline(y=0.5, color="grey", linestyle="dashed", label="Majority") + + #ax1.set_xlabel("Layer") + ax1.set_ylabel("Model Prediction Ability") + + leace_edits = [files[l]["leace"][1] for l in range(len(layers))] + mean_edits = [files[l]["mean"][1] for l in range(len(layers))] + mean_affine_edits = [files[l]["mean_affine"][1] for l in range(len(layers))] + + ax2.grid(True, alpha=0.5, linestyle="dashed") + ax2.set_axisbelow(True) + + ax2.plot(leace_edits, label="LEACE", marker="+") + ax2.plot(mean_edits, label="Mean", marker="x") + ax2.plot(mean_affine_edits, label="Mean, Affine", marker=".") + + ax2.set_xticks(range(len(layers))) + ax2.set_xticklabels(layers) + + ax2.set_xlabel("Layer") + ax2.set_ylabel("Mean Edit Magnitude") + + #ax.set_yscale("log") + + ax2.set_ylim(bottom=0) + + handles, labels = ax1.get_legend_handles_labels() + ax2.legend( + handles, + labels, + loc='upper center', + facecolor="white", + framealpha=1, + ncol=2, + ) + + fig.suptitle(title) + + plt.savefig(f"graphs/erasure_across_depth_410m_{name}.png") + def plot_scores_across_depth(both_datasets=True): - layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 22] + layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22] files = [ torch.load(f"output_erasure_410m/eval_layer_{layer}_gender.pt") @@ -158,10 +226,10 @@ def do_dataset_plot(files, name, layers, title): ax1.grid(True, alpha=0.5, linestyle="dashed") ax1.set_axisbelow(True) - ax1.plot(leace_scores, label="LEACE", marker="+") - ax1.plot(mean_scores, label="Mean", marker="x") - ax1.plot(max_dict_scores, label="Dict. Feature", marker=".") - ax1.plot(max_rand_scores, label="Rand. Feature", marker=".") + #ax1.plot(leace_scores, label="LEACE", marker="+") + ax1.plot(mean_scores, label="Mean", marker="x", color="orange") + ax1.plot(max_dict_scores, label="Dict. Feature", marker=".", color="green") + ax1.plot(max_rand_scores, label="Rand. Feature", marker=".", color="red") ax1.set_xticks(range(len(layers))) ax1.set_xticklabels(layers) @@ -180,10 +248,10 @@ def do_dataset_plot(files, name, layers, title): ax2.grid(True, alpha=0.5, linestyle="dashed") ax2.set_axisbelow(True) - ax2.plot(leace_edits, label="LEACE", marker="+") - ax2.plot(mean_edits, label="Mean", marker="x") - ax2.plot(max_dict_edits, label="Dict Feature", marker=".") - ax2.plot(max_rand_edits, label="Rand. Feature", marker=".") + #ax2.plot(leace_edits, label="LEACE", marker="+") + ax2.plot(mean_edits, label="Mean", marker="x", color="orange") + ax2.plot(max_dict_edits, label="Dict Feature", marker=".", color="green") + ax2.plot(max_rand_edits, label="Rand. Feature", marker=".", color="red") ax2.set_xticks(range(len(layers))) ax2.set_xticklabels(layers) @@ -270,4 +338,5 @@ def plot_kl_div_across_depth(): if __name__ == "__main__": # plot_bottleneck_scores() plot_scores_across_depth() - plot_kl_div_across_depth() \ No newline at end of file + #plot_kl_div_across_depth() + #plot_leace_scores_across_depth() \ No newline at end of file diff --git a/plotting/fvu_sparsity_plot.py b/plotting/fvu_sparsity_plot.py index 8769c7f..ebf0b0b 100644 --- a/plotting/fvu_sparsity_plot.py +++ b/plotting/fvu_sparsity_plot.py @@ -27,7 +27,7 @@ def score_dict(score, label, hyperparams, learned_dict, dataset, ground_truth=No elif score == "l1": return hyperparams["l1_alpha"] elif score == "neg_log_l1": - return -np.log(hyperparams["l1_alpha"]) + return -np.log10(hyperparams["l1_alpha"]) elif score == "dict_size": return hyperparams["dict_size"] elif score == "top_fvu": @@ -244,7 +244,7 @@ def scores_logy(scores): return scores_ -def plot_scores(scores, settings, xlabel, ylabel, xrange, yrange, title, filename): +def plot_scores(scores, settings, xlabel, ylabel, xrange, yrange, title, filename, logx=False, logy=False, crange=None): fig = plt.figure() ax = fig.add_subplot(111) legend_lines = [] @@ -257,6 +257,10 @@ def plot_scores(scores, settings, xlabel, ylabel, xrange, yrange, title, filenam points = np.array([x, y]).T.reshape(-1, 1, 2) c = np.array(shade) + + if crange is not None: + c = (c - crange[0]) / (crange[1] - crange[0]) + segments = np.concatenate([points[:-1], points[1:]], axis=1) cs = 0.5 * (c[:-1] + c[1:]) @@ -296,6 +300,12 @@ def plot_scores(scores, settings, xlabel, ylabel, xrange, yrange, title, filenam ax.set_xlim(*xrange) ax.set_ylim(*yrange) + if logx: + ax.set_xscale("log") + + if logy: + ax.set_yscale("log") + ax.legend(legend_lines, legend_names) plt.savefig(f"{filename}.png") @@ -324,138 +334,42 @@ def get_limits(scores): shutil.rmtree("graphs", ignore_errors=True) os.makedirs("graphs", exist_ok=True) - colors = ["Purples", "Blues", "Greens", "Oranges"] + colors = ["Reds", "Blues", "Greens", "Purples", "Oranges", "Greys"] styles = ["x", "+", ".", "*"] - # styles = ["solid", "dashed", "dashdot", "dotted"] + styles = ["dashed", "dashdot", "dotted", "solid"] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # labels = ["Linear " + str(256*i) for i in range(16)] - - # ratio_names = [0, 1, 2, 4, 8, 16, 32] - # path_fmt = "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r{ratio}/_{chunk}/learned_dicts.pt" - for _ in range(1): layer = 2 - files = [ - #("Learned Center (zero init)", "outputs_sphere/learned_centered_10.pt"), - #("Learned Center (mean init)", "outputs_sphere/learned_centered_mean_init_10.pt"), - #("Sphered", "outputs_sphere/sphered_7.pt") + files = [ + (f"Dicts", f"output/learned_dicts_epoch_0.pt") ] - # layer = 3 - # files += [ - # # ("Linear L3", f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r0/_9/learned_dicts.pt"), - # # ("Linear L3", f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r1/_9/learned_dicts.pt"), - # ( - # "Linear L3", - # f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r2/_9/learned_dicts.pt", - # ), - # ( - # "Linear L3", - # f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r4/_9/learned_dicts.pt", - # ), - # ( - # "Linear L3", - # f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r8/_9/learned_dicts.pt", - # ), - # ( - # "Linear L3", - # f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r16/_9/learned_dicts.pt", - # ), - # # ("Linear L3", f"/mnt/ssd-cluster/bigrun0308/tied_residual_l{layer}_r32/_9/learned_dicts.pt"), - # ] - - title = "FVU Top-2 vs FVU Rest" - filename = "fvu_top" - - dataset_file = "activation_data_sphere/0.pt" - - scores = generate_scores(files, dataset_file, y_score="top_fvu", group_by="dict_size", device=device) + title = "Tied Residual" + filename = "fvu_sparsity" - print(scores) + dataset_file = "activation_data/layer_9/0.pt" - # for chunk in range(0, 10): - # file = "output_dict_ratio/_" + str(chunk) + "/learned_dicts.pt" - # areas = area_under_fvu_sparsity_curve([("Chunk " + str(chunk), file)], dataset_file=dataset_file) - # derivs = scores_derivative_(areas) - # scores = score_representedness([("Chunk " + str(chunk), file)], generator_file, device="cuda:7") - # scores["Chunk " + str(chunk)] = [(hyperparams["dict_size"], score, -np.log(hyperparams["l1_alpha"])) for hyperparams, score in scores.items()] + scores = generate_scores(files, dataset_file, c_score="neg_log_l1", group_by="dict_size", device=device) - # area_scores[f"Chunk {chunk}"] = [(dict_size, area, chunk / 28) for dict_size, area in areas] - # deriv_scores[f"Chunk {chunk}"] = [(dict_size, deriv, chunk / 28) for dict_size, deriv in derivs] + print(scores) settings = { - label: {"style": style, "color": color, "points": True} + label: {"style": style, "color": color, "points": False} for (style, color), label in zip(itertools.product(styles, colors), scores.keys()) } - # xlim, ylim = get_limits(scores) plot_scores( scores, settings, "sparsity", "top-fvu", - (0, 512), + (0, 768), (0, 1), "Threshold Activation Perf.", f"graphs/{filename}", - ) - - # settings = {f"Layer {layer}": {"style": "solid", "color": "Blues", "points": False}} - - # x_lim, y_lim = get_limits(area_scores) - # plot_scores(area_scores, settings, "dict_size", "area", x_lim, y_lim, title, f"graphs/{filename}.png") - - # title = "Derivative of Area Under FVU-Sparsity Curve" - - # x_lim, y_lim = get_limits(deriv_scores) - # plot_scores(deriv_scores, settings, "dict_size", "d(area)/d(dict_size)", x_lim, y_lim, title, f"graphs/{filename}_deriv.png") - - # file = "output_dict_ratio/_27/learned_dicts.pt" - # fuv_sparsity = generate_scores([("Linear", file)], dataset_file=dataset_file) - - # settings = { - # label: {"style": style, "color": color, "points": False} for (style, color), label in zip(itertools.product(styles, colors), fuv_sparsity.keys()) - # } - - # plot_scores(fuv_sparsity, settings, "sparsity", "fvu", (0, 512), (0, 1), "FVU vs Sparsity", f"graphs/fvu_sparsity_layer_4.png") - - # file_sets = [ - # (chunk, [("Linear", f"output_dict_ratio/_{chunk}/learned_dicts.pt")]) for chunk in range(8) - # ] - - # files = [ - # ("Linear", "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r0/_9/learned_dicts.pt"), - # ("Linear", "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r1/_9/learned_dicts.pt"), - # ("Linear", "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r2/_9/learned_dicts.pt"), - # ("Linear", "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r4/_9/learned_dicts.pt"), - # ("Linear", "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r8/_9/learned_dicts.pt"), - # ("Linear", "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r16/_9/learned_dicts.pt"), - # ("Linear", "/mnt/ssd-cluster/bigrun0308/output_hoagy_dense_sweep_tied_resid_l3_r32/_9/learned_dicts.pt"), - # ] - - # title = "Area Under FVU-Sparsity Curve" - # filename = "sparsity_fvu_area" - - # dataset_file = "activation_data/0.pt" - # generator_file = "output_synthetic_1024_100/generator.pt" - - # area_scores = {} - # for chunk, files in file_sets: - # areas = area_under_fvu_sparsity_curve(files, dataset_file=dataset_file) - # areas = scores_derivative_(areas) - # area_scores["Chunk " + str(chunk)] = [(dict_size, area, 0.5) for dict_size, area in areas] - # area_scores = {"Areas": [(dict_size, area, 0.5) for dict_size, area in areas]} - - # area_settings = { - # label: {"style": style, "color": color, "points": False} for (style, color), label in zip(itertools.product(styles, colors), area_scores.keys()) - # } - # xlim, ylim = get_limits(area_scores) - # plot_scores(area_scores, area_settings, "dict_size", "area under curve", xlim, ylim, title, f"graphs/sparsity_fvu_area.png") - - # scores = generate_scores(files, dataset_file=dataset_file) - # settings = { - # label: {"style": style, "color": color, "points": False} for (style, color), label in zip(itertools.product(styles, colors), scores.keys()) - # } - # plot_scores(scores, settings, "sparsity", "fvu", (0, 512), (0, 1), title, f"graphs/sparsity_fvu.png") + logx=False, + logy=False, + crange=(2, 4) + ) \ No newline at end of file diff --git a/replicate_toy_models.py b/replicate_toy_models.py index 5872feb..1adb80e 100644 --- a/replicate_toy_models.py +++ b/replicate_toy_models.py @@ -24,7 +24,7 @@ from utils import dotdict -n_ground_truth_components, activation_dim, dataset_size = None, None, None +n_ground_truth_components, activation_dim, dataset_size = None, None, None # type: Tuple[None, None, None] @dataclass diff --git a/sc_datasets/__init__.py b/sc_datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sc_datasets/random_dataset.py b/sc_datasets/random_dataset.py index a678e73..e51806c 100644 --- a/sc_datasets/random_dataset.py +++ b/sc_datasets/random_dataset.py @@ -10,7 +10,7 @@ None, None, None, -) # for tensortype vars +) # type: Tuple[None, None, None] @dataclass diff --git a/standard_metrics.py b/standard_metrics.py index aab829a..77c96cb 100644 --- a/standard_metrics.py +++ b/standard_metrics.py @@ -31,7 +31,7 @@ matplotlib.use('Agg') -_batch_size, _activation_size, _n_dict_components, _fragment_len, _n_sentences, _n_dicts = None, None, None, None, None, None +_batch_size, _activation_size, _n_dict_components, _fragment_len, _n_sentences, _n_dicts = None, None, None, None, None, None # type: Tuple[None, None, None, None, None, None] def run_with_model_intervention(transformer: HookedTransformer, model: LearnedDict, tensor_name, tokens, other_hooks=[], **kwargs): def intervention(tensor, hook=None): @@ -60,6 +60,8 @@ def get_model_tensor_name(location: Location) -> str: return f"blocks.{location[0]}.hook_resid_post" elif location[1] == "mlp": return f"blocks.{location[0]}.mlp.hook_post" + elif location[1] == "attn_concat": + return f"blocks.{location[0]}.attn.hook_z" else: raise ValueError(f"Location '{location[1]}' not supported") diff --git a/sweep_baselines.py b/sweep_baselines.py index ed6d950..5139ad5 100644 --- a/sweep_baselines.py +++ b/sweep_baselines.py @@ -13,6 +13,16 @@ from autoencoders.pca import BatchedPCA from standard_metrics import mean_nonzero_activations +def run_ica(chunk, output_file): + chunk = torch.load(chunk, map_location="cpu") + + activation_dim = chunk.shape[1] + + ica = ICAEncoder(activation_size=activation_dim) + print("Training ICA") + ica.train(chunk) + + torch.save(ica, output_file) def run_layer_baselines(args) -> None: layer: int @@ -161,6 +171,5 @@ def run_all() -> None: with mp.Pool(processes=len(layers)) as pool: pool.map(run_layer_baselines, args_list) - if __name__ == "__main__": - run_all() + run_ica("activation_data/layer_12/0.pt", "ica.pt") \ No newline at end of file diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index 53688eb..0000000 --- a/test/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# unittest init file \ No newline at end of file diff --git a/test/test_end_to_end.py b/test/test_end_to_end.py index 6f017a1..4b9f1ed 100644 --- a/test/test_end_to_end.py +++ b/test/test_end_to_end.py @@ -30,7 +30,6 @@ def single_setoff(cfg: dotdict): cfg.activation_width, dict_size, l1_alpha, - bias_decay=0.0, dtype=cfg.dtype, ) for l1_alpha in l1_values diff --git a/test_datasets/__init__.py b/test_datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_datasets/gender.py b/test_datasets/gender.py index d1828b3..39378e9 100644 --- a/test_datasets/gender.py +++ b/test_datasets/gender.py @@ -1,173 +1,173 @@ -import pickle -import random - -import torch -from transformers import AutoTokenizer - -COUNT_CUTOFF = 100000 - -def generate_gender_dataset( - tokenizer_name, - pad_token_id=0, - count_cutoff=COUNT_CUTOFF, - sample_n=10, - prompts_per_name=1, - n_few_shot=3, - randomise=True, -): - prompt = " My name is{name} and I am a{answer}." - prompt_q = " My name is{name} and I am a" - prompt_len = 10 - prompt_q_len = 8 - - codes_map = {"M": 0, "F": 1} - answer_map = {"M": " male", "F": " female"} - - skip_tokens = 2 * n_few_shot * prompt_len + 3 - - random.seed(42) - #if sample_n is not None: - # targets = {"M": sample_n, "F": sample_n} - # counts = {"M": 0, "F": 0} - - with open("gender_dataset.pkl", "rb") as f: - name_toks, entries = pickle.load(f) - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - tokenizer.pad_token_id = pad_token_id - - entries = [entry for entry in entries if int(entry[2]) > count_cutoff] - - if randomise: - random.shuffle(entries) - - # split into male and female names - names_classes = {"M": [], "F": []} - - for (name, gender, _, _) in entries: - if len(names_classes[gender]) >= sample_n: - continue - - names_classes[gender].append(name) - - assert len(names_classes["M"]) == sample_n, len(names_classes["M"]) - assert len(names_classes["F"]) == sample_n, len(names_classes["F"]) - - names_classes_list = [(k, name) for k, names in names_classes.items() for name in names] - - prompts = [] - classes = [] - sequence_lengths = [] - - for gender, name in names_classes_list: - for _ in range(prompts_per_name): - names_few_shot = [] - for k, names in names_classes.items(): - names_ = [name_ for name_ in names if name_ != name] - names_few_shot += [(k, name_) for name_ in random.sample(names_, n_few_shot)] - - random.shuffle(names_few_shot) - - strprompt = [prompt.format(name=" "+name_,answer=answer_map[k]) for k, name_ in names_few_shot] - strprompt = "".join(strprompt) + prompt_q.format(name=" "+name) - - prompt_tokens = tokenizer(strprompt)["input_ids"] - prompts.append(prompt_tokens) - sequence_lengths.append(len(prompt_tokens)) - classes.append(codes_map[gender]) - - completion_tokens = {codes_map[g]: tokenizer(a)["input_ids"][0] for g, a in answer_map.items()} - - prompts = torch.tensor(prompts) - classes = torch.tensor(classes) - sequence_lengths = torch.tensor(sequence_lengths) - - return prompts, classes, completion_tokens, sequence_lengths, skip_tokens - - -def generate_pronoun_dataset( - tokenizer_name, - pad_token_id=0, - count_cutoff=COUNT_CUTOFF, - sample_n=10, - prompts_per_name=1, - n_few_shot=3, - randomise=True, -): - objects = ["cat", "car", "dog", "book", "pen", "mouse", "chair", "table", "phone", "computer"] - - prompt = "{name} went to the store, and{answer} bought a {object}." - prompt_q = "{name} went to the store, and" - prompt_len = 11 - prompt_q_len = 7 - - codes_map = {"M": 0, "F": 1} - answer_map = {"M": " he", "F": " she"} - - skip_tokens = 2 * n_few_shot * prompt_len - - random.seed(42) - #if sample_n is not None: - # targets = {"M": sample_n, "F": sample_n} - # counts = {"M": 0, "F": 0} - - with open("gender_dataset.pkl", "rb") as f: - name_toks, entries = pickle.load(f) - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - tokenizer.pad_token_id = pad_token_id - - entries = [entry for entry in entries if int(entry[2]) > count_cutoff] - - if randomise: - random.shuffle(entries) - - # split into male and female names - names_classes = {"M": [], "F": []} - - for (name, gender, _, _) in entries: - if len(names_classes[gender]) >= sample_n: - continue - - names_classes[gender].append(name) - - assert len(names_classes["M"]) == sample_n, len(names_classes["M"]) - assert len(names_classes["F"]) == sample_n, len(names_classes["F"]) - - names_classes_list = [(k, name) for k, names in names_classes.items() for name in names] - - prompts = [] - classes = [] - sequence_lengths = [] - - for gender, name in names_classes_list: - for _ in range(prompts_per_name): - names_few_shot = [] - for k, names in names_classes.items(): - names_ = [name_ for name_ in names if name_ != name] - names_few_shot += [(k, name_) for name_ in random.sample(names_, n_few_shot)] - - random.shuffle(names_few_shot) - - prompt_objects = random.sample(objects, n_few_shot) - - strprompt = [ - prompt.format(name=" "+name_,answer=answer_map[k], object=object_) - for (k, name_), object_ in zip(names_few_shot, objects) - ] - strprompt = "".join(strprompt) + prompt_q.format(name=" "+name) - - #print(strprompt) - - prompt_tokens = tokenizer(strprompt)["input_ids"] - prompts.append(prompt_tokens) - sequence_lengths.append(len(prompt_tokens)) - classes.append(codes_map[gender]) - - completion_tokens = {codes_map[g]: tokenizer(a)["input_ids"][0] for g, a in answer_map.items()} - - prompts = torch.tensor(prompts) - classes = torch.tensor(classes) - sequence_lengths = torch.tensor(sequence_lengths) - +import pickle +import random + +import torch +from transformers import AutoTokenizer + +COUNT_CUTOFF = 100000 + +def generate_gender_dataset( + tokenizer_name, + pad_token_id=0, + count_cutoff=COUNT_CUTOFF, + sample_n=10, + prompts_per_name=1, + n_few_shot=3, + randomise=True, +): + prompt = " My name is{name} and I am a{answer}." + prompt_q = " My name is{name} and I am a" + prompt_len = 10 + prompt_q_len = 8 + + codes_map = {"M": 0, "F": 1} + answer_map = {"M": " male", "F": " female"} + + skip_tokens = 2 * n_few_shot * prompt_len + 3 + + random.seed(42) + #if sample_n is not None: + # targets = {"M": sample_n, "F": sample_n} + # counts = {"M": 0, "F": 0} + + with open("gender_dataset.pkl", "rb") as f: + name_toks, entries = pickle.load(f) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.pad_token_id = pad_token_id + + entries = [entry for entry in entries if int(entry[2]) > count_cutoff] + + if randomise: + random.shuffle(entries) + + # split into male and female names + names_classes = {"M": [], "F": []} + + for (name, gender, _, _) in entries: + if len(names_classes[gender]) >= sample_n: + continue + + names_classes[gender].append(name) + + assert len(names_classes["M"]) == sample_n, len(names_classes["M"]) + assert len(names_classes["F"]) == sample_n, len(names_classes["F"]) + + names_classes_list = [(k, name) for k, names in names_classes.items() for name in names] + + prompts = [] + classes = [] + sequence_lengths = [] + + for gender, name in names_classes_list: + for _ in range(prompts_per_name): + names_few_shot = [] + for k, names in names_classes.items(): + names_ = [name_ for name_ in names if name_ != name] + names_few_shot += [(k, name_) for name_ in random.sample(names_, n_few_shot)] + + random.shuffle(names_few_shot) + + strprompt = [prompt.format(name=" "+name_,answer=answer_map[k]) for k, name_ in names_few_shot] + strprompt = "".join(strprompt) + prompt_q.format(name=" "+name) + + prompt_tokens = tokenizer(strprompt)["input_ids"] + prompts.append(prompt_tokens) + sequence_lengths.append(len(prompt_tokens)) + classes.append(codes_map[gender]) + + completion_tokens = {codes_map[g]: tokenizer(a)["input_ids"][0] for g, a in answer_map.items()} + + prompts = torch.tensor(prompts) + classes = torch.tensor(classes) + sequence_lengths = torch.tensor(sequence_lengths) + + return prompts, classes, completion_tokens, sequence_lengths, skip_tokens + + +def generate_pronoun_dataset( + tokenizer_name, + pad_token_id=0, + count_cutoff=COUNT_CUTOFF, + sample_n=10, + prompts_per_name=1, + n_few_shot=3, + randomise=True, +): + objects = ["cat", "car", "dog", "book", "pen", "mouse", "chair", "table", "phone", "computer"] + + prompt = "{name} went to the store, and{answer} bought a {object}." + prompt_q = "{name} went to the store, and" + prompt_len = 11 + prompt_q_len = 7 + + codes_map = {"M": 0, "F": 1} + answer_map = {"M": " he", "F": " she"} + + skip_tokens = 2 * n_few_shot * prompt_len + + random.seed(42) + #if sample_n is not None: + # targets = {"M": sample_n, "F": sample_n} + # counts = {"M": 0, "F": 0} + + with open("gender_dataset.pkl", "rb") as f: + name_toks, entries = pickle.load(f) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.pad_token_id = pad_token_id + + entries = [entry for entry in entries if int(entry[2]) > count_cutoff] + + if randomise: + random.shuffle(entries) + + # split into male and female names + names_classes = {"M": [], "F": []} + + for (name, gender, _, _) in entries: + if len(names_classes[gender]) >= sample_n: + continue + + names_classes[gender].append(name) + + assert len(names_classes["M"]) == sample_n, len(names_classes["M"]) + assert len(names_classes["F"]) == sample_n, len(names_classes["F"]) + + names_classes_list = [(k, name) for k, names in names_classes.items() for name in names] + + prompts = [] + classes = [] + sequence_lengths = [] + + for gender, name in names_classes_list: + for _ in range(prompts_per_name): + names_few_shot = [] + for k, names in names_classes.items(): + names_ = [name_ for name_ in names if name_ != name] + names_few_shot += [(k, name_) for name_ in random.sample(names_, n_few_shot)] + + random.shuffle(names_few_shot) + + prompt_objects = random.sample(objects, n_few_shot) + + strprompt = [ + prompt.format(name=" "+name_,answer=answer_map[k], object=object_) + for (k, name_), object_ in zip(names_few_shot, objects) + ] + strprompt = "".join(strprompt) + prompt_q.format(name=" "+name) + + #print(strprompt) + + prompt_tokens = tokenizer(strprompt)["input_ids"] + prompts.append(prompt_tokens) + sequence_lengths.append(len(prompt_tokens)) + classes.append(codes_map[gender]) + + completion_tokens = {codes_map[g]: tokenizer(a)["input_ids"][0] for g, a in answer_map.items()} + + prompts = torch.tensor(prompts) + classes = torch.tensor(classes) + sequence_lengths = torch.tensor(sequence_lengths) + return prompts, classes, completion_tokens, sequence_lengths, skip_tokens \ No newline at end of file diff --git a/test_datasets/ioi.py b/test_datasets/ioi.py index 7dc9561..cb131ba 100644 --- a/test_datasets/ioi.py +++ b/test_datasets/ioi.py @@ -1,67 +1,67 @@ -import numpy as np -import torch - -abb_a_prompt = "Then, {name_a} and {name_b} were working at the {location}. {name_b} decided to give a {object} to {name_a}" -aba_b_prompt = "Then, {name_a} and {name_b} were working at the {location}. {name_a} decided to give a {object} to {name_b}" - -names_ = ['James', 'John', 'Robert', 'Michael', 'William', 'Mary', 'David', 'Joseph', 'Richard', 'Charles', 'Thomas', 'Christopher', 'Daniel', 'Matthew', 'Elizabeth', 'Patricia', 'Jennifer', 'Anthony', 'George', 'Linda', 'Barbara', 'Donald', 'Paul', 'Mark', 'Andrew', 'Steven', 'Kenneth', 'Edward', 'Joshua', 'Margaret', 'Brian', 'Kevin', 'Jessica', 'Sarah', 'Susan', 'Timothy', 'Dorothy', 'Jason', 'Ronald', 'Helen', 'Ryan', 'Jeffrey', 'Karen', 'Nancy', 'Betty', 'Lisa', 'Jacob', 'Nicholas', 'Ashley', 'Eric', 'Frank', 'Gary', 'Anna', 'Stephen', 'Jonathan', 'Sandra', 'Emily', 'Amanda', 'Kimberly', 'Michelle', 'Donna', 'Justin', 'Laura', 'Ruth', 'Carol', 'Brandon', 'Larry', 'Scott', 'Melissa', 'Stephanie', 'Benjamin', 'Raymond', 'Samuel', 'Rebecca', 'Deborah', 'Gregory', 'Sharon', 'Kathleen', 'Amy', 'Alexander', 'Patrick', 'Jack', 'Henry', 'Angela', 'Shirley', 'Emma', 'Catherine', 'Katherine', 'Virginia', 'Nicole', 'Dennis', 'Walter', 'Tyler', 'Peter', 'Aaron', 'Jerry', 'Christine'] -locations = ["plateau", "cafe", "home", "bridge", "station"] -objects = ["feather", "towel", "fins", "ring", "tape", "shorts"] - -def generate_ioi_dataset( - tokenizer, - n_abb_a, - n_abb_b, -): - np.random.seed(42) - # validate dataset lengths - error = False - - # filter names instead - names = [] - for name in names_: - if len(tokenizer(" " + name)["input_ids"]) != 1: - print(f"Name {name} is not a single token, skipping") - else: - names.append(name) - - for location in locations: - if len(tokenizer(" " + location)["input_ids"]) != 1: - print(f"Location {location} is not a single token") - error = True - - for object in objects: - if len(tokenizer(" " + object)["input_ids"]) != 1: - print(f"Object {object} is not a single token") - error = True - - assert not error, "Dataset is not valid" - - clean = [] - corrupted = [] - - for i in range(n_abb_a): - (name_a, name_b), location, object = ( - np.random.choice(names, size=2, replace=False), - np.random.choice(locations), - np.random.choice(objects), - ) - clean.append(abb_a_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) - corrupted.append(aba_b_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) - - for i in range(n_abb_b): - (name_a, name_b), location, object = ( - np.random.choice(names, size=2, replace=False), - np.random.choice(locations), - np.random.choice(objects), - ) - clean.append(aba_b_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) - corrupted.append(abb_a_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) - - # print(clean) - # print(corrupted) - - clean = torch.tensor(tokenizer(clean)["input_ids"]) - corrupted = torch.tensor(tokenizer(corrupted)["input_ids"]) - - return clean, corrupted +import numpy as np +import torch + +abb_a_prompt = "Then, {name_a} and {name_b} were working at the {location}. {name_b} decided to give a {object} to {name_a}" +aba_b_prompt = "Then, {name_a} and {name_b} were working at the {location}. {name_a} decided to give a {object} to {name_b}" + +names_ = ['James', 'John', 'Robert', 'Michael', 'William', 'Mary', 'David', 'Joseph', 'Richard', 'Charles', 'Thomas', 'Christopher', 'Daniel', 'Matthew', 'Elizabeth', 'Patricia', 'Jennifer', 'Anthony', 'George', 'Linda', 'Barbara', 'Donald', 'Paul', 'Mark', 'Andrew', 'Steven', 'Kenneth', 'Edward', 'Joshua', 'Margaret', 'Brian', 'Kevin', 'Jessica', 'Sarah', 'Susan', 'Timothy', 'Dorothy', 'Jason', 'Ronald', 'Helen', 'Ryan', 'Jeffrey', 'Karen', 'Nancy', 'Betty', 'Lisa', 'Jacob', 'Nicholas', 'Ashley', 'Eric', 'Frank', 'Gary', 'Anna', 'Stephen', 'Jonathan', 'Sandra', 'Emily', 'Amanda', 'Kimberly', 'Michelle', 'Donna', 'Justin', 'Laura', 'Ruth', 'Carol', 'Brandon', 'Larry', 'Scott', 'Melissa', 'Stephanie', 'Benjamin', 'Raymond', 'Samuel', 'Rebecca', 'Deborah', 'Gregory', 'Sharon', 'Kathleen', 'Amy', 'Alexander', 'Patrick', 'Jack', 'Henry', 'Angela', 'Shirley', 'Emma', 'Catherine', 'Katherine', 'Virginia', 'Nicole', 'Dennis', 'Walter', 'Tyler', 'Peter', 'Aaron', 'Jerry', 'Christine'] +locations = ["plateau", "cafe", "home", "bridge", "station"] +objects = ["feather", "towel", "fins", "ring", "tape", "shorts"] + +def generate_ioi_dataset( + tokenizer, + n_abb_a, + n_abb_b, +): + np.random.seed(42) + # validate dataset lengths + error = False + + # filter names instead + names = [] + for name in names_: + if len(tokenizer(" " + name)["input_ids"]) != 1: + print(f"Name {name} is not a single token, skipping") + else: + names.append(name) + + for location in locations: + if len(tokenizer(" " + location)["input_ids"]) != 1: + print(f"Location {location} is not a single token") + error = True + + for object in objects: + if len(tokenizer(" " + object)["input_ids"]) != 1: + print(f"Object {object} is not a single token") + error = True + + assert not error, "Dataset is not valid" + + clean = [] + corrupted = [] + + for i in range(n_abb_a): + (name_a, name_b), location, object = ( + np.random.choice(names, size=2, replace=False), + np.random.choice(locations), + np.random.choice(objects), + ) + clean.append(abb_a_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) + corrupted.append(aba_b_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) + + for i in range(n_abb_b): + (name_a, name_b), location, object = ( + np.random.choice(names, size=2, replace=False), + np.random.choice(locations), + np.random.choice(objects), + ) + clean.append(aba_b_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) + corrupted.append(abb_a_prompt.format(name_a=name_a, name_b=name_b, location=location, object=object)) + + # print(clean) + # print(corrupted) + + clean = torch.tensor(tokenizer(clean)["input_ids"]) + corrupted = torch.tensor(tokenizer(corrupted)["input_ids"]) + + return clean, corrupted diff --git a/test_datasets/ioi_counterfact.py b/test_datasets/ioi_counterfact.py new file mode 100644 index 0000000..a2ec589 --- /dev/null +++ b/test_datasets/ioi_counterfact.py @@ -0,0 +1,373 @@ +# code & dataset from: https://github.com/redwoodresearch/Easy-Transformer/blob/main/easy_transformer/ioi_dataset.py + +import io +from logging import warning +from typing import Union, List +from site import PREFIXES +import warnings +import torch +import numpy as np +from tqdm import tqdm +import pandas as pd +from transformers import AutoTokenizer +import random +import re +import matplotlib.pyplot as plt +import random as rd +import copy + +NAMES = [ + "Michael", + "Christopher", + "Jessica", + "Matthew", + "Ashley", + "Jennifer", + "Joshua", + "Amanda", + "Daniel", + "David", + "James", + "Robert", + "John", + "Joseph", + "Andrew", + "Ryan", + "Brandon", + "Jason", + "Justin", + "Sarah", + "William", + "Jonathan", + "Stephanie", + "Brian", + "Nicole", + "Nicholas", + "Anthony", + "Heather", + "Eric", + "Elizabeth", + "Adam", + "Megan", + "Melissa", + "Kevin", + "Steven", + "Thomas", + "Timothy", + "Christina", + "Kyle", + "Rachel", + "Laura", + "Lauren", + "Amber", + "Brittany", + "Danielle", + "Richard", + "Kimberly", + "Jeffrey", + "Amy", + "Crystal", + "Michelle", + "Tiffany", + "Jeremy", + "Benjamin", + "Mark", + "Emily", + "Aaron", + "Charles", + "Rebecca", + "Jacob", + "Stephen", + "Patrick", + "Sean", + "Erin", + "Jamie", + "Kelly", + "Samantha", + "Nathan", + "Sara", + "Dustin", + "Paul", + "Angela", + "Tyler", + "Scott", + "Katherine", + "Andrea", + "Gregory", + "Erica", + "Mary", + "Travis", + "Lisa", + "Kenneth", + "Bryan", + "Lindsey", + "Kristen", + "Jose", + "Alexander", + "Jesse", + "Katie", + "Lindsay", + "Shannon", + "Vanessa", + "Courtney", + "Christine", + "Alicia", + "Cody", + "Allison", + "Bradley", + "Samuel", +] + +ABC_TEMPLATES = [ + "Then, [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]", + "Afterwards [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]", + "When [A], [B] and [C] arrived at the [PLACE], [B] and [C] gave a [OBJECT] to [A]", + "Friends [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]", +] + +BAC_TEMPLATES = [ + template.replace("[B]", "[A]", 1).replace("[A]", "[B]", 1) + for template in ABC_TEMPLATES +] + +BABA_TEMPLATES = [ + "Then, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]", + "Then, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]", + "Then, [B] and [A] had a long argument, and afterwards [B] said to [A]", + "After [B] and [A] went to the [PLACE], [B] gave a [OBJECT] to [A]", + "When [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give it to [A]", + "When [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give the [OBJECT] to [A]", + "While [B] and [A] were working at the [PLACE], [B] gave a [OBJECT] to [A]", + "While [B] and [A] were commuting to the [PLACE], [B] gave a [OBJECT] to [A]", + "After the lunch, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Afterwards, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then, [B] and [A] had a long argument. Afterwards [B] said to [A]", + "The [PLACE] [B] and [A] went to had a [OBJECT]. [B] gave it to [A]", + "Friends [B] and [A] found a [OBJECT] at the [PLACE]. [B] gave it to [A]", +] + +BABA_LONG_TEMPLATES = [ + "Then in the morning, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then in the morning, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then in the morning, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]", + "Then in the morning, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]", + "Then in the morning, [B] and [A] had a long argument, and afterwards [B] said to [A]", + "After taking a long break [B] and [A] went to the [PLACE], [B] gave a [OBJECT] to [A]", + "When soon afterwards [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give it to [A]", + "When soon afterwards [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give the [OBJECT] to [A]", + "While spending time together [B] and [A] were working at the [PLACE], [B] gave a [OBJECT] to [A]", + "While spending time together [B] and [A] were commuting to the [PLACE], [B] gave a [OBJECT] to [A]", + "After the lunch in the afternoon, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Afterwards, while spending time together [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then in the morning afterwards, [B] and [A] had a long argument. Afterwards [B] said to [A]", + "The local big [PLACE] [B] and [A] went to had a [OBJECT]. [B] gave it to [A]", + "Friends separated at birth [B] and [A] found a [OBJECT] at the [PLACE]. [B] gave it to [A]", +] + +BABA_LATE_IOS = [ + "Then, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]", + "Then, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]", + "Then, [B] and [A] had a long argument and after that [B] said to [A]", + "After the lunch, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Afterwards, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", + "Then, [B] and [A] had a long argument. Afterwards [B] said to [A]", +] + +BABA_EARLY_IOS = [ + "Then [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]", + "Then [B] and [A] had a lot of fun at the [PLACE], and [B] gave a [OBJECT] to [A]", + "Then [B] and [A] were working at the [PLACE], and [B] decided to give a [OBJECT] to [A]", + "Then [B] and [A] were thinking about going to the [PLACE], and [B] wanted to give a [OBJECT] to [A]", + "Then [B] and [A] had a long argument, and after that [B] said to [A]", + "After the lunch [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]", + "Afterwards [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]", + "Then [B] and [A] had a long argument, and afterwards [B] said to [A]", +] + +TEMPLATES_VARIED_MIDDLE = [ + "", +] + +# no end of texts, GPT-2 small wasn't trained this way (ask Arthur) +# warnings.warn("Adding end of text prefixes!") +# for TEMPLATES in [BABA_TEMPLATES, BABA_EARLY_IOS, BABA_LATE_IOS]: +# for i in range(len(TEMPLATES)): +# TEMPLATES[i] = "<|endoftext|>" + TEMPLATES[i] + +ABBA_TEMPLATES = BABA_TEMPLATES[:] +ABBA_LATE_IOS = BABA_LATE_IOS[:] +ABBA_EARLY_IOS = BABA_EARLY_IOS[:] + +for TEMPLATES in [ABBA_TEMPLATES, ABBA_LATE_IOS, ABBA_EARLY_IOS]: + for i in range(len(TEMPLATES)): + first_clause = True + for j in range(1, len(TEMPLATES[i]) - 1): + if TEMPLATES[i][j - 1 : j + 2] == "[B]" and first_clause: + TEMPLATES[i] = TEMPLATES[i][:j] + "A" + TEMPLATES[i][j + 1 :] + elif TEMPLATES[i][j - 1 : j + 2] == "[A]" and first_clause: + first_clause = False + TEMPLATES[i] = TEMPLATES[i][:j] + "B" + TEMPLATES[i][j + 1 :] + +VERBS = [" tried", " said", " decided", " wanted", " gave"] +PLACES = [ + "store", + "garden", + "restaurant", + "school", + "hospital", + "office", + "house", + "station", +] +OBJECTS = [ + "ring", + "kiss", + "bone", + "basketball", + "computer", + "necklace", + "drink", + "snack", +] + +ANIMALS = [ + "dog", + "cat", + "snake", + "elephant", + "beetle", + "hippo", + "giraffe", + "tiger", + "husky", + "lion", + "panther", + "whale", + "dolphin", + "beaver", + "rabbit", + "fox", + "lamb", + "ferret", +] + +def multiple_replace(dict, text): + # from: https://stackoverflow.com/questions/15175142/how-can-i-do-multiple-substitutions-using-regex + # Create a regular expression from the dictionary keys + regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys()))) + + # For each match, look-up corresponding value in dictionary + return regex.sub(lambda mo: dict[mo.string[mo.start() : mo.end()]], text) + + +def iter_sample_fast(iterable, samplesize): + results = [] + # Fill in the first samplesize elements: + try: + for _ in range(samplesize): + results.append(next(iterable)) + except StopIteration: + raise ValueError("Sample larger than population.") + random.shuffle(results) # Randomize their positions + + return results + +NOUNS_DICT = NOUNS_DICT = {"[PLACE]": PLACES, "[OBJECT]": OBJECTS} + + +def gen_prompt_counterfact( + tokenizer, + templates, names, nouns_dict, N +): + nb_gen = 0 + ioi_prompts = [] + ioi_prompts_counterfact = [] + for nb_gen in range(N): + temp = rd.choice(templates) + temp_id = templates.index(temp) + name_1 = "" + name_2 = "" + name_3 = "" + while len(set([name_1, name_2, name_3])) < 3: + name_1 = rd.choice(names) + name_2 = rd.choice(names) + name_3 = rd.choice(names) + + for name in [name_1, name_2, name_3]: + if len(tokenizer(" " + name)["input_ids"]) != 1: + name_1 = "" + name_2 = "" + name_3 = "" + break + + assert all([len(tokenizer(" " + name)["input_ids"]) == 1 for name in [name_1, name_2, name_3]]) + + nouns = {} + ioi_prompt = {} + ioi_prompt_counterfact = {} + for k in nouns_dict: + nouns[k] = rd.choice(nouns_dict[k]) + ioi_prompt[k] = nouns[k] + ioi_prompt_counterfact[k] = nouns[k] + prompt = temp + for k in nouns_dict: + prompt = prompt.replace(k, nouns[k]) + + prompt1 = prompt.replace("[A]", name_1) + prompt1 = prompt1.replace("[B]", name_2) + ioi_prompt["text"] = prompt1 + ioi_prompt["IO"] = name_1 + ioi_prompt["S"] = name_2 + ioi_prompt["TEMPLATE_IDX"] = temp_id + ioi_prompts.append(ioi_prompt) + + prompt2 = prompt.replace("[A]", name_3) + prompt2 = prompt2.replace("[B]", name_2) + ioi_prompt_counterfact["text"] = prompt2 + ioi_prompt_counterfact["IO"] = name_3 + ioi_prompt_counterfact["S"] = name_2 + ioi_prompt_counterfact["TEMPLATE_IDX"] = temp_id + ioi_prompts_counterfact.append(ioi_prompt_counterfact) + + return ioi_prompts, ioi_prompts_counterfact + +def gen_ioi_dataset( + tokenizer, + n_prompts, +): + assertion = False + while not assertion: + prompts, prompts_cf = gen_prompt_counterfact( + tokenizer, + ABBA_TEMPLATES + BABA_TEMPLATES, + NAMES, + NOUNS_DICT, + n_prompts, + ) + + prompts = [prompt["text"] for prompt in prompts] + prompts_cf = [prompt["text"] for prompt in prompts_cf] + + # ignore final token (indirect object) + prompts = tokenizer(prompts)["input_ids"] + prompts_cf = tokenizer(prompts_cf)["input_ids"] + + assertion = all([len(prompt) == len(prompt_cf) for prompt, prompt_cf in zip(prompts, prompts_cf)]) + + # calc seq lengths & pad + seq_lengths = torch.tensor([len(prompt)-1 for prompt in prompts]) + max_seq_length = torch.max(seq_lengths) + + prompts = torch.stack( + [torch.tensor(prompt[:-1] + [0]*(max_seq_length - len(prompt[:-1]))) for prompt in prompts] + ) + + prompts_cf = torch.stack( + [torch.tensor(prompt[:-1] + [0]*(max_seq_length - len(prompt[:-1]))) for prompt in prompts_cf] + ) + + return prompts, prompts_cf, seq_lengths \ No newline at end of file diff --git a/test_datasets/preprocess_gender_dataset.py b/test_datasets/preprocess_gender_dataset.py index fecdee3..fea8c1c 100644 --- a/test_datasets/preprocess_gender_dataset.py +++ b/test_datasets/preprocess_gender_dataset.py @@ -1,46 +1,46 @@ -import csv -import os -import pickle -import sys - -import tqdm -from transformers import AutoTokenizer - -# dataset: https://archive.ics.uci.edu/dataset/591/gender+by+name - -max_tok_len = 1 -min_tok_len = 1 -name_fmt = " {name}" - -if __name__ == "__main__": - model = "EleutherAI/pythia-70m-deduped" - - if len(sys.argv) > 1: - min_tok_len = int(sys.argv[1]) - - if len(sys.argv) > 2: - max_tok_len = int(sys.argv[2]) - - if len(sys.argv) > 3: - model = sys.argv[3] - - tokenizer = AutoTokenizer.from_pretrained(model) - - entries = [] - - with open("name_gender_dataset.csv", "r", newline="") as f: - reader = csv.reader(f) - reader.__next__() - for entry in tqdm.tqdm(reader): - name = entry[0] - - t = tokenizer(name_fmt.format(name=name)) - - # filter names that are more than one token - if min_tok_len <= len(t["input_ids"]) <= max_tok_len: - entries.append(entry) - - print(f"Found {len(entries)} entries") - - with open("gender_dataset.pkl", "wb") as f: - pickle.dump((max_tok_len, entries), f) +import csv +import os +import pickle +import sys + +import tqdm +from transformers import AutoTokenizer + +# dataset: https://archive.ics.uci.edu/dataset/591/gender+by+name + +max_tok_len = 1 +min_tok_len = 1 +name_fmt = " {name}" + +if __name__ == "__main__": + model = "EleutherAI/pythia-70m-deduped" + + if len(sys.argv) > 1: + min_tok_len = int(sys.argv[1]) + + if len(sys.argv) > 2: + max_tok_len = int(sys.argv[2]) + + if len(sys.argv) > 3: + model = sys.argv[3] + + tokenizer = AutoTokenizer.from_pretrained(model) + + entries = [] + + with open("name_gender_dataset.csv", "r", newline="") as f: + reader = csv.reader(f) + reader.__next__() + for entry in tqdm.tqdm(reader): + name = entry[0] + + t = tokenizer(name_fmt.format(name=name)) + + # filter names that are more than one token + if min_tok_len <= len(t["input_ids"]) <= max_tok_len: + entries.append(entry) + + print(f"Found {len(entries)} entries") + + with open("gender_dataset.pkl", "wb") as f: + pickle.dump((max_tok_len, entries), f) diff --git a/utils.py b/utils.py index 577f58c..515e70e 100644 --- a/utils.py +++ b/utils.py @@ -9,11 +9,11 @@ from botocore.exceptions import ClientError, NoCredentialsError VAST_NUM = 4 -# DEST_ADDR = f"root@ssh{VAST_NUM}.vast.ai" -DEST_ADDR = "mchorse@198.176.96.64" +DEST_ADDR = f"root@ssh{VAST_NUM}.vast.ai" +# DEST_ADDR = "mchorse@198.176.96.64" SSH_PYTHON = "/opt/conda/bin/python" -PORT = 22 +PORT = 14040 USER = "aidan"