diff --git a/.gitignore b/.gitignore
index f8d1a64..c17a082 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,6 +4,7 @@
wandb/
.data/
.checkpoints/
+tests/saes_for_tests/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/pyproject.toml b/pyproject.toml
index 14f97e7..7498d1f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,7 @@ dependencies = [
"zstandard~=0.22.0",
"matplotlib>=3.5.3",
"eindex-callum@git+https://github.com/callummcdougall/eindex",
- "sae_vis@git+https://github.com/callummcdougall/sae_vis.git@b28a0f7c7e936f4bea05528d952dfcd438533cce"
+ "sae_vis@git+https://github.com/callummcdougall/sae_vis"
]
[project.urls]
@@ -98,4 +98,4 @@ reportPrivateImportUsage = false
filterwarnings = [
# https://github.com/google/python-fire/pull/447
"ignore::DeprecationWarning:fire:59",
-]
\ No newline at end of file
+]
diff --git a/sparsify/scripts/dashboards.yaml b/sparsify/scripts/dashboards.yaml
index c3f44fd..0e4bcb9 100644
--- a/sparsify/scripts/dashboards.yaml
+++ b/sparsify/scripts/dashboards.yaml
@@ -1,22 +1,28 @@
pretrained_sae_paths: null # Paths of the pretrained SAEs to load. Should be a path to a .pt file, or a list of them. Can also be provided as a second argument in the command line.
sae_config_path: null # Path to the config file used to train the SAEs (if null, we'll assume it's at pretrained_sae_paths[0].parent / "config.yaml")
-n_samples: 10_000
-batch_size: 64
-minibatch_size_features: 128 # Num features in each batch of calculations. Lower to avoid OOM errors
+n_samples: 3000
+batch_size: 16
+minibatch_size_features: 100 # Num features in each batch of calculations. Lower to avoid OOM errors
data: # DatasetConfig for the data which will be used to generate the dashboards
- dataset_name: 'apollo-research/roneneldan-TinyStories-tokenizer-gpt2'
+ dataset_name: 'apollo-research/Skylion007-openwebtext-tokenizer-gpt2'
is_tokenized: True
tokenizer_name: 'gpt2'
- split: "validation"
- n_ctx: 512
+ split: "train"
+ n_ctx: 1024
save_dir: null # The directory for saving the HTML feature dashboard files
+save_json_data: false
sae_positions: null # The names of the SAE positions to generate dashboards for. e.g.'blocks.2.hook_resid_post'. If None, then all positions will be generated
feature_indices: null # The features for which to generate dashboards on each SAE. If none, then we'll generate dashbaords for every feature.
prompt_centric: # Used to generate prompt-centric (rather than feature-centric) dashboards. Feature-centric dashboards will also be generated for every feature appaearing in these
- n_random_prompt_dashboards: 50 # The number of random prompts to generate prompt-centric dashboards for.
+ n_random_prompt_dashboards: 10 # The number of random prompts to generate prompt-centric dashboards for.
data: null # "DatasetConfig for getting random prompts. If None, then non-prompt-centric data will be used
prompts: # Specific prompts on which to generate prompt-centric feature dashboards. A feature-centric dashboard will be generated for every token position in each prompt.
- - "Sally met Mike at the show. She brought popcorn for him."
- str_score: "loss_effect" # The ordering metric for which features are most important in prompt-centric dashboards. Can be one of 'act_size', 'act_quantile', or 'loss_effect'
+ - "Sally met Mike at the show. She brought popcorn for him. They ate it together"
+ - 'Lily asked, "Mommy, can I go on the slide?"'
+ - "It was time for the lecture to begin."
+ - "A man was taken to hospital after the crash"
+ - "CAMPAIGN The campaign will focus on three core goals:"
+ - "new_list = [n**2 for n in numbers if n%2==0]"
+ str_score: "act_quantile" # The ordering metric for which features are most important in prompt-centric dashboards. Can be one of 'act_size', 'act_quantile', or 'loss_effect'
num_top_features: 10 # How many of the most relevant features to show for each prompt in the prompt-centric dashboards
seed: 0
\ No newline at end of file
diff --git a/sparsify/scripts/generate_dashboards.py b/sparsify/scripts/generate_dashboards.py
index 3d7c85b..4bb6fe3 100644
--- a/sparsify/scripts/generate_dashboards.py
+++ b/sparsify/scripts/generate_dashboards.py
@@ -28,37 +28,36 @@
https://github.com/callummcdougall/sae_vis/commit/b28a0f7c7e936f4bea05528d952dfcd438533cce
"""
import math
-from collections.abc import Iterable
+from copy import deepcopy
from pathlib import Path
-from typing import Annotated, Literal
+from typing import Annotated, Any, Literal
import fire
import numpy as np
import torch
-from eindex import eindex
-from einops import einsum, rearrange
+from einops import einsum
from jaxtyping import Float, Int
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt, PositiveInt
-from sae_vis.data_fetching_fns import get_sequences_data
-from sae_vis.data_storing_fns import (
- FeatureData,
- FeatureVisParams,
- HistogramData,
- MiddlePlotsData,
- MultiFeatureData,
- MultiPromptData,
- PromptData,
- SequenceData,
- SequenceMultiGroupData,
+from sae_vis.data_config_classes import (
+ ActsHistogramConfig,
+ Column,
+ LogitsHistogramConfig,
+ LogitsTableConfig,
+ PromptConfig,
+ SaeVisConfig,
+ SaeVisLayoutConfig,
+ SequencesConfig,
)
-from sae_vis.utils_fns import QuantileCalculator, TopK, process_str_tok
+from sae_vis.data_fetching_fns import parse_feature_data, parse_prompt_data
+from sae_vis.data_storing_fns import SaeVisData
+from sae_vis.html_fns import HTML
+from sae_vis.utils_fns import get_decode_html_safe_fn
from torch import Tensor
from torch.utils.data.dataset import IterableDataset
from tqdm import tqdm
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
- PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
@@ -68,7 +67,28 @@
from sparsify.models.transformers import SAETransformer
from sparsify.scripts.train_tlens_saes.run_train_tlens_saes import Config
from sparsify.types import RootPath
-from sparsify.utils import filter_names, load_config, to_numpy
+from sparsify.utils import filter_names, load_config
+
+LAYOUT_FEATURE_VIS = SaeVisLayoutConfig(
+ columns=[
+ Column(ActsHistogramConfig(), LogitsTableConfig(), LogitsHistogramConfig()),
+ Column(SequencesConfig(stack_mode="stack-none")),
+ ],
+ height=750,
+)
+
+LAYOUT_PROMPT_VIS = SaeVisLayoutConfig(
+ columns=[
+ Column(
+ PromptConfig(),
+ ActsHistogramConfig(),
+ LogitsTableConfig(n_rows=5),
+ SequencesConfig(top_acts_group_size=10, n_quantiles=0),
+ width=450,
+ ),
+ ],
+ height=1000,
+)
FeatureIndicesType = dict[str, list[int]] | dict[str, Int[Tensor, "some_feats"]] # noqa: F821 (jaxtyping/pyright doesn't like single dimensions)
StrScoreType = Literal["act_size", "act_quantile", "loss_effect"]
@@ -77,9 +97,9 @@
class PromptDashboardsConfig(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
n_random_prompt_dashboards: NonNegativeInt = Field(
- default=50,
+ default=10,
description="The number of random prompts to generate prompt-centric dashboards for."
- "A feature-centric dashboard will be generated for random token positions in each prompt.",
+ "Feature-centric dashboards will be generated for each prompt.",
)
data: DatasetConfig | None = Field(
default=None,
@@ -125,7 +145,12 @@ class DashboardsConfig(BaseModel):
)
save_dir: RootPath | None = Field(
default=None,
- description="The directory for saving the HTML feature dashboard files",
+ description="The directory for saving the HTML feature dashboard files. If none, they "
+ "will be saved in pretrained_sae_paths[0].parent",
+ )
+ save_json_data: bool = Field(
+ default=False,
+ description="Whether to save JSON data which can be used to re-generate the HTML dashboards",
)
sae_positions: Annotated[
list[str] | None, BeforeValidator(lambda x: [x] if isinstance(x, str) else x)
@@ -188,7 +213,7 @@ def compute_feature_acts(
feature_acts[hook_name] = feature_acts[hook_name].to("cpu")
if feature_indices is not None:
feature_acts[hook_name] = feature_acts[hook_name][..., feature_indices[hook_name]]
- return feature_acts, final_resid_acts
+ return feature_acts, final_resid_acts.to("cpu")
def compute_feature_acts_on_distribution(
@@ -221,16 +246,16 @@ def compute_feature_acts_on_distribution(
- The residual stream activations of the model at the final layer (or at stop_at_layer)
- The tokens used as input to the model
"""
- data_loader, _ = create_data_loader(
- dataset_config, batch_size=batch_size, buffer_size=batch_size
- )
+ data_loader, _ = create_data_loader(dataset_config, batch_size, buffer_size=batch_size)
if raw_sae_positions is None:
raw_sae_positions = model.raw_sae_positions
assert raw_sae_positions is not None
device = model.saes[raw_sae_positions[0].replace(".", "-")].device
if n_samples is None:
# If streaming (i.e. if the dataset is an IterableDataset), we don't know the length
- n_batches = None if isinstance(data_loader.dataset, IterableDataset) else len(data_loader)
+ n_batches = (
+ float("inf") if isinstance(data_loader.dataset, IterableDataset) else len(data_loader)
+ )
else:
n_batches = math.ceil(n_samples / batch_size)
if not isinstance(data_loader.dataset, IterableDataset):
@@ -245,8 +270,8 @@ def compute_feature_acts_on_distribution(
for batch in tqdm(data_loader, total=n_batches, desc="Computing feature acts"):
batch_tokens: Int[Tensor, "..."] = batch[dataset_config.column_name].to(device=device)
batch_feature_acts, batch_final_resid_acts = compute_feature_acts(
- model=model,
- tokens=batch_tokens,
+ model,
+ batch_tokens,
raw_sae_positions=raw_sae_positions,
feature_indices=feature_indices,
stop_at_layer=stop_at_layer,
@@ -255,127 +280,17 @@ def compute_feature_acts_on_distribution(
feature_acts_lists[sae_name].append(batch_feature_acts[sae_name])
final_resid_acts_list.append(batch_final_resid_acts)
tokens_list.append(batch_tokens)
- total_samples += batch_tokens.shape[0]
if n_samples is not None and total_samples > n_samples:
break
+ total_samples += batch_tokens.shape[0]
final_resid_acts: Float[Tensor, "... d_resid"] = torch.cat(final_resid_acts_list, dim=0)
tokens: Int[Tensor, "..."] = torch.cat(tokens_list, dim=0)
feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {}
for sae_name in raw_sae_positions:
- feature_acts[sae_name] = torch.cat(tensors=feature_acts_lists[sae_name], dim=0)
+ feature_acts[sae_name] = torch.cat(feature_acts_lists[sae_name], dim=0)
return feature_acts, final_resid_acts, tokens
-def create_vocab_dict(tokenizer: PreTrainedTokenizerBase) -> dict[int, str]:
- """
- Creates a vocab dict suitable for dashboards by replacing all the special tokens with their
- HTML representations. This function is adapted from sae_vis.create_vocab_dict()
- """
- vocab_dict: dict[str, int] = tokenizer.get_vocab()
- vocab_dict_processed: dict[int, str] = {v: process_str_tok(k) for k, v in vocab_dict.items()}
- return vocab_dict_processed
-
-
-@torch.inference_mode()
-def parse_activation_data(
- tokens: Int[Tensor, "batch pos"],
- feature_acts: Float[Tensor, "... some_feats"],
- final_resid_acts: Float[Tensor, "... d_resid"],
- feature_resid_dirs: Float[Tensor, "some_feats dim"],
- feature_indices_list: Iterable[int],
- W_U: Float[Tensor, "dim d_vocab"],
- vocab_dict: dict[int, str],
- fvp: FeatureVisParams,
-) -> MultiFeatureData:
- """Convert generic activation data into a MultiFeatureData object, which can be used to create
- the feature-centric visualisation.
- Adapted from sae_vis.data_fetching_fns._get_feature_data()
- final_resid_acts + W_U are used for the logit lens.
-
- Args:
- tokens: The inputs to the model
- feature_acts: The activations values of the features
- final_resid_acts: The activations of the final layer of the model
- feature_resid_dirs: The directions that each feature writes to the logit output
- feature_indices_list: The indices of the features we're interested in
- W_U: The unembed weights for the logit lens
- vocab_dict: A dictionary mapping vocab indices to strings
- fvp: FeatureVisParams, containing a bunch of settings. See the FeatureVisParams docstring in
- sae_vis for more information.
-
- Returns:
- A MultiFeatureData containing data for creating each feature's visualization,
- as well as data for rank-ordering the feature visualizations when it comes time
- to make the prompt-centric view (the `feature_act_quantiles` attribute).
- Use MultiFeatureData[feature_idx].get_html() to generate the HTML dashboard for a
- particular feature (returns a string of HTML).
-
- """
- device = W_U.device
- feature_acts.to(device)
- sequence_data_dict: dict[int, SequenceMultiGroupData] = {}
- middle_plots_data_dict: dict[int, MiddlePlotsData] = {}
- features_data: dict[int, FeatureData] = {}
- # Calculate all data for the right-hand visualisations, i.e. the sequences
- for i, feat in enumerate(feature_indices_list):
- # Add this feature's sequence data to the list
- sequence_data_dict[feat] = get_sequences_data(
- tokens=tokens,
- feat_acts=feature_acts[..., i],
- resid_post=final_resid_acts,
- feature_resid_dir=feature_resid_dirs[i],
- W_U=W_U,
- fvp=fvp,
- )
-
- # Get the logits of all features (i.e. the directions this feature writes to the logit output)
- logits = einsum(
- feature_resid_dirs,
- W_U,
- "feats d_model, d_model d_vocab -> feats d_vocab",
- )
- for i, (feat, logit) in enumerate(zip(feature_indices_list, logits, strict=True)):
- # Get data for logits (the histogram, and the table)
- logits_histogram_data = HistogramData(logit, 40, "5 ticks")
- top10_logits = TopK(logit, k=15, largest=True)
- bottom10_logits = TopK(logit, k=15, largest=False)
-
- # Get data for feature activations histogram (the title, and the histogram)
- feat_acts = feature_acts[..., i]
- nonzero_feat_acts = feat_acts[feat_acts > 0]
- frac_nonzero = nonzero_feat_acts.numel() / feat_acts.numel()
- freq_histogram_data = HistogramData(nonzero_feat_acts, 40, "ints")
-
- # Create a MiddlePlotsData object from this, and add it to the dict
- middle_plots_data_dict[feat] = MiddlePlotsData(
- bottom10_logits=bottom10_logits,
- top10_logits=top10_logits,
- logits_histogram_data=logits_histogram_data,
- freq_histogram_data=freq_histogram_data,
- frac_nonzero=frac_nonzero,
- )
-
- # Return the output, as a dict of FeatureData items
- for i, feat in enumerate(feature_indices_list):
- features_data[feat] = FeatureData(
- # Data-containing inputs (for the feature-centric visualisation)
- sequence_data=sequence_data_dict[feat],
- middle_plots_data=middle_plots_data_dict[feat],
- left_tables_data=None,
- # Non data-containing inputs
- feature_idx=feat,
- vocab_dict=vocab_dict,
- fvp=fvp,
- )
-
- # Also get the quantiles, which will be useful for the prompt-centric visualisation
-
- feature_act_quantiles = QuantileCalculator(
- data=rearrange(feature_acts, "... feats -> feats (...)")
- )
- return MultiFeatureData(features_data, feature_act_quantiles)
-
-
def feature_indices_to_tensordict(
feature_indices_in: FeatureIndicesType | list[int] | None,
raw_sae_positions: list[str],
@@ -412,9 +327,8 @@ def get_dashboards_data(
n_samples: PositiveInt | None = None,
batch_size: PositiveInt | None = None,
minibatch_size_features: PositiveInt | None = None,
- fvp: FeatureVisParams | None = None,
- vocab_dict: dict[int, str] | None = None,
-) -> dict[str, MultiFeatureData]:
+ cfg: SaeVisConfig | None = None,
+) -> dict[str, SaeVisData]:
"""Gets data that needed to create the sequences in the feature-centric HTML visualisation
Adapted from sae_vis.data_fetching_fns._get_feature_data()
@@ -442,30 +356,18 @@ def get_dashboards_data(
using dataset_config
minibatch_size_features:
Num features in each batch of calculations (break up the features to avoid OOM errors).
- fvp:
+ cfg:
Feature visualization parameters, containing a bunch of other stuff. See the
- FeatureVisParams docstring in sae_vis for more information.
- vocab_dict:
- vocab dict suitable for dashboards with all the special tokens replaced with their
- HTML representations. If None then it will be created using create_vocab_dict(tokenizer)
+ SaeVisConfig docstring in sae_vis for more information.
Returns:
- A dict of [sae_position_name: MultiFeatureData]. Each MultiFeatureData contains data for
+ A dict of [sae_position_name: SaeVisData]. Each SaeVisData contains data for
creating each feature's visualization, as well as data for rank-ordering the feature
visualizations when it comes time to make the prompt-centric view
(the `feature_act_quantiles` attribute).
Use dashboards_data[sae_name][feature_idx].get_html() to generate the HTML
dashboard for a particular feature (returns a string of HTML)
"""
- # Get the vocab dict, which we'll use at the end
- if vocab_dict is None:
- assert (
- model.tlens_model.tokenizer is not None
- ), "If voacab_dict is not supplied, the model must have a tokenizer"
- vocab_dict = create_vocab_dict(model.tlens_model.tokenizer)
-
- if fvp is None:
- fvp = FeatureVisParams(include_left_tables=False)
if sae_positions is None:
raw_sae_positions: list[str] = model.raw_sae_positions
@@ -475,9 +377,9 @@ def get_dashboards_data(
)
# If we haven't supplied any feature indicies, assume that we want all of them
feature_indices_tensors = feature_indices_to_tensordict(
- feature_indices_in=feature_indices,
- raw_sae_positions=raw_sae_positions,
- model=model,
+ feature_indices,
+ raw_sae_positions,
+ model,
)
for sae_name in raw_sae_positions:
assert (
@@ -493,9 +395,9 @@ def get_dashboards_data(
batch_size is not None
), "If no tokens are supplied, then a batch_size must be supplied"
feature_acts, final_resid_acts, tokens = compute_feature_acts_on_distribution(
- model=model,
- dataset_config=dataset_config,
- batch_size=batch_size,
+ model,
+ dataset_config,
+ batch_size,
raw_sae_positions=raw_sae_positions,
feature_indices=feature_indices_tensors,
n_samples=n_samples,
@@ -503,8 +405,8 @@ def get_dashboards_data(
else:
tokens.to(device)
feature_acts, final_resid_acts = compute_feature_acts(
- model=model,
- tokens=tokens,
+ model,
+ tokens,
raw_sae_positions=raw_sae_positions,
feature_indices=feature_indices_tensors,
)
@@ -516,11 +418,18 @@ def get_dashboards_data(
feature_indices_tensors[sae_name] = feature_indices_tensors[sae_name][acts_sum > 0]
del acts_sum
- dashboards_data: dict[str, MultiFeatureData] = {
- name: MultiFeatureData() for name in raw_sae_positions
- }
+ dashboards_data: dict[str, SaeVisData] = {name: SaeVisData() for name in raw_sae_positions}
for sae_name in raw_sae_positions:
+ if cfg is None:
+ cfg = SaeVisConfig(
+ hook_point=sae_name,
+ features=feature_indices_tensors[sae_name].tolist(),
+ feature_centric_layout=LAYOUT_FEATURE_VIS,
+ )
+ dashboards_data[sae_name].cfg = cfg
+ dashboards_data[sae_name].model = model.tlens_model
+
sae = model.saes[sae_name.replace(".", "-")]
W_dec: Float[Tensor, "feats dim"] = sae.decoder.weight.T
feature_resid_dirs: Float[Tensor, "some_feats dim"] = W_dec[
@@ -540,222 +449,29 @@ def get_dashboards_data(
]
feature_resid_dir_batches = feature_resid_dirs.split(minibatch_size_features)
for i in tqdm(iterable=range(len(feature_batches)), desc="Parsing activation data"):
- new_feature_data = parse_activation_data(
- tokens=tokens,
- feature_acts=feature_acts_batches[i].to_dense().to(device),
- final_resid_acts=final_resid_acts,
- feature_resid_dirs=feature_resid_dir_batches[i],
- feature_indices_list=feature_batches[i],
- W_U=W_U,
- vocab_dict=vocab_dict,
- fvp=fvp,
+ new_feature_data, _ = parse_feature_data(
+ tokens,
+ feature_batches[i],
+ feature_acts_batches[i].to_dense().to(device),
+ feature_resid_dir_batches[i].to(device),
+ final_resid_acts.to(device),
+ W_U.to(device),
+ cfg,
)
dashboards_data[sae_name].update(new_feature_data)
return dashboards_data
-@torch.inference_mode()
-def parse_prompt_data(
- tokens: Int[Tensor, "batch pos"],
- str_tokens: list[str],
- features_data: MultiFeatureData,
- feature_acts: Float[Tensor, "seq some_feats"],
- final_resid_acts: Float[Tensor, "seq d_resid"],
- feature_resid_dirs: Float[Tensor, "some_feats dim"],
- feature_indices_list: list[int],
- W_U: Float[Tensor, "dim d_vocab"],
- num_top_features: int = 10,
-) -> MultiPromptData:
- """Gets data needed to create the sequences in the prompt-centric HTML visualisation.
-
- This visualization displays dashboards for the most relevant features on a prompt.
- Adapted from sae_vis.data_fetching_fns.get_prompt_data().
-
- Args:
- tokens: The input prompt to the model as tokens
- str_tokens: The input prompt to the model as a list of strings (one string per token)
- features_data: A MultiFeatureData containing information required to plot the features.
- feature_acts: The activations values of the features
- final_resid_acts: The activations of the final layer of the model
- feature_resid_dirs: The directions that each feature writes to the logit output
- feature_indices_list: The indices of the features we're interested in
- W_U: The unembed weights for the logit lens
- num_top_features: The number of top features to display in this view, for any given metric.
- Returns:
- A MultiPromptData object containing data for visualizing the most relevant features
- given the prompt.
-
- Similar to parse_feature_data, except it just gets the data relevant for a particular
- sequence (i.e. a custom one that the user inputs on their own).
-
- The ordering metric for relevant features is set by the str_score parameter in the
- MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect"
- """
- torch.cuda.empty_cache()
- device = W_U.device
- n_feats = len(feature_indices_list)
- batch, seq_len = tokens.shape
- feats_contribution_to_loss = torch.empty(size=(n_feats, seq_len - 1), device=device)
-
- # Some logit computations which we only need to do once
- correct_token_unembeddings = W_U[:, tokens[0, 1:]] # [d_model seq]
- orig_logits = (
- final_resid_acts / final_resid_acts.std(dim=-1, keepdim=True)
- ) @ W_U # [seq d_vocab]
-
- sequence_data_dict: dict[int, SequenceData] = {}
-
- for i, feat in enumerate(feature_indices_list):
- # Calculate all data for the sequences
- # (this is the only truly 'new' bit of calculation we need to do)
-
- # Get this feature's output vector, using an outer product over feature acts for all tokens
- final_resid_acts_feature_effect = einsum(
- feature_acts[..., i].to_dense().to(device),
- feature_resid_dirs[i],
- "seq, d_model -> seq d_model",
- )
-
- # Ablate the output vector from the residual stream, and get logits post-ablation
- new_final_resid_acts = final_resid_acts - final_resid_acts_feature_effect
- new_logits = (new_final_resid_acts / new_final_resid_acts.std(dim=-1, keepdim=True)) @ W_U
-
- # Get the top5 & bottom5 changes in logits
- contribution_to_logprobs = orig_logits.log_softmax(dim=-1) - new_logits.log_softmax(dim=-1)
- top5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5)
- bottom5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5, largest=False)
-
- # Get the change in loss (which is negative of change of logprobs for correct token)
- contribution_to_loss = eindex(-contribution_to_logprobs[:-1], tokens[0, 1:], "seq [seq]")
- feats_contribution_to_loss[i, :] = contribution_to_loss
-
- # Store the sequence data
- sequence_data_dict[feat] = SequenceData(
- token_ids=tokens.squeeze(0).tolist(),
- feat_acts=feature_acts[..., i].tolist(),
- contribution_to_loss=[0.0] + contribution_to_loss.tolist(),
- top5_token_ids=top5_contribution_to_logits.indices.tolist(),
- top5_logit_contributions=top5_contribution_to_logits.values.tolist(),
- bottom5_token_ids=bottom5_contribution_to_logits.indices.tolist(),
- bottom5_logit_contributions=bottom5_contribution_to_logits.values.tolist(),
- )
-
- # Get the logits for the correct tokens
- logits_for_correct_tokens = einsum(
- feature_resid_dirs[i], correct_token_unembeddings, "d_model, d_model seq -> seq"
- )
-
- # Add the annotations data (feature activations and logit effect) to the histograms
- freq_line_posn = feature_acts[..., i].tolist()
- freq_line_text = [
- f"\\'{str_tok}\\'
{act:.3f}"
- for str_tok, act in zip(str_tokens[1:], freq_line_posn, strict=False)
- ]
- middle_plots_data = features_data[feat].middle_plots_data
- assert middle_plots_data is not None
- middle_plots_data.freq_histogram_data.line_posn = freq_line_posn
- middle_plots_data.freq_histogram_data.line_text = freq_line_text # type: ignore (due to typing bug in sae_vis)
- logits_line_posn = logits_for_correct_tokens.tolist()
- logits_line_text = [
- f"\\'{str_tok}\\'
{logits:.3f}"
- for str_tok, logits in zip(str_tokens[1:], logits_line_posn, strict=False)
- ]
- middle_plots_data.logits_histogram_data.line_posn = logits_line_posn
- middle_plots_data.logits_histogram_data.line_text = logits_line_text # type: ignore (due to typing bug in sae_vis)
-
- # Lastly, use the criteria (act size, act quantile, loss effect) to find top-scoring features
-
- # Construct a scores dict, which maps from things like ("act_quantile", seq_pos)
- # to a list of the top-scoring features
- scores_dict: dict[tuple[str, str], tuple[TopK, list[str]]] = {}
-
- for seq_pos in range(len(str_tokens)):
- # Filter the feature activations, since we only need the ones that are non-zero
- feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos] > 0)
- feat_acts_nonzero_locations = np.nonzero(feat_acts_nonzero_filter)[0].tolist()
- _feature_acts = (
- feature_acts[seq_pos, feat_acts_nonzero_filter].to_dense().to(device)
- ) # [feats_filtered,]
- _feature_indices_list = np.array(feature_indices_list)[feat_acts_nonzero_filter]
-
- if feat_acts_nonzero_filter.sum() > 0:
- k = min(num_top_features, _feature_acts.numel())
-
- # Get the "act_size" scores (we return it as a TopK object)
- act_size_topk = TopK(_feature_acts, k=k, largest=True)
- # Replace the indices with feature indices (these are different when
- # feature_indices_list argument is not [0, 1, 2, ...])
- act_size_topk.indices[:] = _feature_indices_list[act_size_topk.indices]
- scores_dict[("act_size", seq_pos)] = (act_size_topk, ".3f") # type: ignore (due to typing bug in sae_vis)
-
- # Get the "act_quantile" scores, which is just the fraction of cached feat acts that it
- # is larger than
- act_quantile, act_precision = features_data.feature_act_quantiles.get_quantile(
- _feature_acts, feat_acts_nonzero_locations
- )
- act_quantile_topk = TopK(act_quantile, k=k, largest=True)
- act_formatting_topk = [f".{act_precision[i]-2}%" for i in act_quantile_topk.indices]
- # Replace the indices with feature indices (these are different when
- # feature_indices_list argument is not [0, 1, 2, ...])
- act_quantile_topk.indices[:] = _feature_indices_list[act_quantile_topk.indices]
- scores_dict[("act_quantile", seq_pos)] = (act_quantile_topk, act_formatting_topk) # type: ignore (due to typing bug in sae_vis)
-
- # We don't measure loss effect on the first token
- if seq_pos == 0:
- continue
-
- # Filter the loss effects, since we only need the ones which have non-zero feature acts on
- # the tokens before them
- prev_feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos - 1] > 0)
- _contribution_to_loss = feats_contribution_to_loss[
- prev_feat_acts_nonzero_filter, seq_pos - 1
- ] # [feats_filtered,]
- _feature_indices_list_prev = np.array(feature_indices_list)[prev_feat_acts_nonzero_filter]
-
- if prev_feat_acts_nonzero_filter.sum() > 0:
- k = min(num_top_features, _contribution_to_loss.numel())
-
- # Get the "loss_effect" scores, which are just the min of features' contributions to
- # loss (min because we're looking for helpful features, not harmful ones)
- contribution_to_loss_topk = TopK(_contribution_to_loss, k=k, largest=False)
- # Replace the indices with feature indices (these are different when
- # feature_indices_list argument is not [0, 1, 2, ...])
- contribution_to_loss_topk.indices[:] = _feature_indices_list_prev[
- contribution_to_loss_topk.indices
- ]
- scores_dict[("loss_effect", seq_pos)] = (contribution_to_loss_topk, ".3f") # type: ignore (due to typing bug in sae_vis)
-
- # Get all the features which are required (i.e. all the sequence position indices)
- feature_indices_list_required = set()
- for score_topk, _ in scores_dict.values():
- feature_indices_list_required.update(set(score_topk.indices.tolist()))
- prompt_data_dict = {}
- for feat in feature_indices_list_required:
- middle_plots_data = features_data[feat].middle_plots_data
- assert middle_plots_data is not None
- prompt_data_dict[feat] = PromptData(
- prompt_data=sequence_data_dict[feat],
- sequence_data=features_data[feat].sequence_data[0],
- middle_plots_data=middle_plots_data,
- )
-
- return MultiPromptData(
- prompt_str_toks=str_tokens,
- prompt_data_dict=prompt_data_dict,
- scores_dict=scores_dict,
- )
-
-
@torch.inference_mode()
def get_prompt_data(
model: SAETransformer,
tokens: Int[Tensor, "batch pos"],
str_tokens: list[str],
- dashboards_data: dict[str, MultiFeatureData],
+ dashboards_data: dict[str, SaeVisData],
sae_positions: list[str] | None = None,
num_top_features: PositiveInt = 10,
-) -> dict[str, MultiPromptData]:
+) -> tuple[dict[str, SaeVisData], dict[str, dict[str, tuple[list[int], list[str]]]]]:
"""Gets data needed to create the sequences in the prompt-centric HTML visualisation.
This visualization displays dashboards for the most relevant features on a prompt.
@@ -769,7 +485,7 @@ def get_prompt_data(
str_tokens:
The input prompt to the model as a list of strings (one string per token)
dashboards_data:
- For each SAE, a MultiFeatureData containing information required to plot its features.
+ For each SAE, a SaeVisData containing information required to plot its features.
sae_positions:
The names of the SAEs we want to find relevant features in.
eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them.
@@ -777,13 +493,13 @@ def get_prompt_data(
The number of top features to display in this view, for any given metric.
Returns:
- A dict of [sae_position_name: MultiPromptData]. Each MultiPromptData contains data for
+ A dict of [sae_position_name: SaeVisData]. Each SaeVisData contains data for
visualizing the most relevant features in that SAE given the prompt.
Similar to get_feature_data, except it just gets the data relevant for a particular
sequence (i.e. a custom one that the user inputs on their own).
The ordering metric for relevant features is set by the str_score parameter in the
- MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect"
+ SaeVisData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect"
"""
assert tokens.shape[-1] == len(
str_tokens
@@ -794,21 +510,24 @@ def get_prompt_data(
raw_sae_positions: list[str] = filter_names(
list(model.tlens_model.hook_dict.keys()), sae_positions
)
+ device = model.saes[raw_sae_positions[0].replace(".", "-")].device
+ tokens = tokens.to(device)
feature_indices: dict[str, list[int]] = {}
for sae_name in raw_sae_positions:
feature_indices[sae_name] = list(dashboards_data[sae_name].feature_data_dict.keys())
feature_acts, final_resid_acts = compute_feature_acts(
- model=model,
- tokens=tokens,
+ model,
+ tokens,
raw_sae_positions=raw_sae_positions,
feature_indices=feature_indices,
)
final_resid_acts = final_resid_acts.squeeze(dim=0)
- prompt_data: dict[str, MultiPromptData] = {}
+ scores_dicts: dict[str, dict[str, tuple[list[int], list[str]]]] = {}
for sae_name in raw_sae_positions:
+ dashboards_data[sae_name].model = model.tlens_model
sae = model.saes[sae_name.replace(".", "-")]
feature_act_dir: Float[Tensor, "dim some_feats"] = sae.encoder[0].weight.T[
:, feature_indices[sae_name]
@@ -822,45 +541,59 @@ def get_prompt_data(
== (len(feature_indices[sae_name]), sae.input_size)
)
- prompt_data[sae_name] = parse_prompt_data(
- tokens=tokens,
- str_tokens=str_tokens,
- features_data=dashboards_data[sae_name],
- feature_acts=feature_acts[sae_name].squeeze(dim=0),
- final_resid_acts=final_resid_acts,
- feature_resid_dirs=feature_resid_dirs,
- feature_indices_list=feature_indices[sae_name],
- W_U=model.tlens_model.W_U,
+ scores_dicts[sae_name] = parse_prompt_data(
+ tokens,
+ str_tokens,
+ dashboards_data[sae_name],
+ feature_acts[sae_name].squeeze(dim=0).to(device),
+ feature_resid_dirs.to(device),
+ final_resid_acts.to(device),
+ model.tlens_model.W_U.to(device),
+ feature_idx=feature_indices[sae_name],
num_top_features=num_top_features,
)
- return prompt_data
+ return dashboards_data, scores_dicts
@torch.inference_mode()
def generate_feature_dashboard_html_files(
- dashboards_data: dict[str, MultiFeatureData],
- feature_indices: FeatureIndicesType | dict[str, set[int]] | None,
+ dashboards_data: dict[str, SaeVisData],
+ minibatch_size_features: PositiveInt | None,
save_dir: str | Path = "",
):
"""Generates viewable HTML dashboards for every feature in every SAE in dashboards_data"""
- if feature_indices is None:
- feature_indices = {name: dashboards_data[name].keys() for name in dashboards_data}
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
- for sae_name in feature_indices:
+ for sae_name in dashboards_data:
logger.info(f"Saving HTML feature dashboards for the SAE at {sae_name}:")
folder = save_dir / Path(f"dashboards_{sae_name}")
folder.mkdir(parents=True, exist_ok=True)
- for feature_idx in tqdm(feature_indices[sae_name], desc="Dashboard HTML files"):
- feature_idx = (
- int(feature_idx.item()) if isinstance(feature_idx, Tensor) else feature_idx
+ model = dashboards_data[sae_name].model
+ assert model is not None
+ feature_ids = sorted(list(dashboards_data[sae_name].feature_data_dict.keys()))
+ if minibatch_size_features is None:
+ feature_ids_split = [feature_ids]
+ else:
+
+ def split_list(lst: list[Any], chunk_size: int) -> list[list[Any]]:
+ chunks = [[] for _ in range((len(lst) + chunk_size - 1) // chunk_size)]
+ for i, item in enumerate(lst):
+ chunks[i // chunk_size].append(item)
+ return chunks
+
+ feature_ids_split = split_list(feature_ids, minibatch_size_features)
+ for batch_feature_ids in tqdm(feature_ids_split, desc="Dashboard HTML files"):
+ batch_feature_data_dict = {
+ i: dashboards_data[sae_name].feature_data_dict[i] for i in batch_feature_ids
+ }
+ batch_dashboards_data = SaeVisData(
+ feature_data_dict=batch_feature_data_dict,
+ cfg=dashboards_data[sae_name].cfg,
+ model=dashboards_data[sae_name].model,
+ )
+ batch_dashboards_data.save_feature_centric_vis(
+ filename=folder / f"features-{batch_feature_ids[0]}-to-{batch_feature_ids[-1]}.html"
)
- if feature_idx in dashboards_data[sae_name].keys():
- html_str = dashboards_data[sae_name][feature_idx].get_html()
- filepath = folder / Path(f"feature-{feature_idx}.html")
- with open(filepath, "w") as f:
- f.write(html_str)
- logger.info(f"Saved HTML feature dashboards in {folder}")
@torch.inference_mode()
@@ -868,33 +601,38 @@ def generate_prompt_dashboard_html_files(
model: SAETransformer,
tokens: Int[Tensor, "batch pos"],
str_tokens: list[str],
- dashboards_data: dict[str, MultiFeatureData],
- seq_pos: int | list[int] | None = None,
- vocab_dict: dict[int, str] | None = None,
- str_score: StrScoreType = "loss_effect",
+ dashboards_data: dict[str, SaeVisData],
+ seq_pos: list[int] | int | None = None,
save_dir: str | Path = "",
) -> dict[str, set[int]]:
"""Generates viewable HTML dashboards for the most relevant features (measured by str_score) for
every SAE in dashboards_data.
+ This function is adapted from sae_vis.data_storing_functions.SaeVisData.save_prompt_centric_vis
+
Returns the set of feature indices which were active"""
+
assert tokens.shape[-1] == len(
str_tokens
), "Error: the number of tokens does not equal the number of str_tokens"
- str_tokens = [s.replace("Ġ", " ") for s in str_tokens]
+
if isinstance(seq_pos, int):
seq_pos = [seq_pos]
if seq_pos is None: # Generate a dashboard for every position if none is specified
seq_pos = list(range(2, len(str_tokens) - 2))
- if vocab_dict is None:
- assert (
- model.tlens_model.tokenizer is not None
- ), "If voacab_dict is not supplied, the model must have a tokenizer"
- vocab_dict = create_vocab_dict(model.tlens_model.tokenizer)
- prompt_data = get_prompt_data(
- model=model, tokens=tokens, str_tokens=str_tokens, dashboards_data=dashboards_data
- )
- prompt = "".join(str_tokens)
+
+ str_toks = [t.replace("|", "│") for t in str_tokens] # vertical line -> pipe
+ str_toks_list = [f"{t!r} ({i})" for i, t in enumerate(str_toks)]
+ metric_list = ["act_quantile", "act_size", "loss_effect"]
+
+ # Get default values for dropdowns
+ first_metric = "act_quantile"
+ first_seq_pos = str_toks_list[seq_pos[0]]
+ first_key = f"{first_metric}|{first_seq_pos}"
+
+ # Get tokenize function (we only need to define it once)
+ decode_fn = get_decode_html_safe_fn(model.tlens_model.tokenizer)
+
# Use the beginning of the prompt for the filename, but make sure that it's safe for a filename
str_tokens_safe_for_filenames = [
"".join(c for c in token if c.isalpha() or c.isdigit() or c == " ")
@@ -906,43 +644,94 @@ def generate_prompt_dashboard_html_files(
filename_from_prompt = filename_from_prompt[:50]
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
+
+ METRIC_TITLES = {
+ "act_size": "Activation Size",
+ "act_quantile": "Activation Quantile",
+ "loss_effect": "Loss Effect",
+ }
+
+ # Run forward passes on our prompt, and store the data within each FeatureData object
+ # as `self.prompt_data` as well as returning the scores_dict (which maps from score hash to a
+ # list of feature indices & formatted scores)
+ dashboards_data, scores_dicts = get_prompt_data(
+ model,
+ tokens,
+ str_tokens,
+ dashboards_data=dashboards_data,
+ )
used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data}
+
for sae_name in dashboards_data:
- seq_pos_with_scores: set[int] = {
- int(x[1])
- for x in prompt_data["blocks.1.hook_resid_post"].scores_dict
- if x[0] == str_score
- }
- for seq_pos_i in seq_pos_with_scores.intersection(seq_pos):
- # Find the most relevant features (by {str_score}) for the token
- # '{str_tokens[seq_pos_i]}' in the prompt '{prompt}:
- html_str = prompt_data[sae_name].get_html(seq_pos_i, str_score, vocab_dict)
- # Insert a title
- title: str = (
- f"