diff --git a/.gitignore b/.gitignore index f8d1a64..c17a082 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ wandb/ .data/ .checkpoints/ +tests/saes_for_tests/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/pyproject.toml b/pyproject.toml index 14f97e7..7498d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "zstandard~=0.22.0", "matplotlib>=3.5.3", "eindex-callum@git+https://github.com/callummcdougall/eindex", - "sae_vis@git+https://github.com/callummcdougall/sae_vis.git@b28a0f7c7e936f4bea05528d952dfcd438533cce" + "sae_vis@git+https://github.com/callummcdougall/sae_vis" ] [project.urls] @@ -98,4 +98,4 @@ reportPrivateImportUsage = false filterwarnings = [ # https://github.com/google/python-fire/pull/447 "ignore::DeprecationWarning:fire:59", -] \ No newline at end of file +] diff --git a/sparsify/scripts/dashboards.yaml b/sparsify/scripts/dashboards.yaml index c3f44fd..0e4bcb9 100644 --- a/sparsify/scripts/dashboards.yaml +++ b/sparsify/scripts/dashboards.yaml @@ -1,22 +1,28 @@ pretrained_sae_paths: null # Paths of the pretrained SAEs to load. Should be a path to a .pt file, or a list of them. Can also be provided as a second argument in the command line. sae_config_path: null # Path to the config file used to train the SAEs (if null, we'll assume it's at pretrained_sae_paths[0].parent / "config.yaml") -n_samples: 10_000 -batch_size: 64 -minibatch_size_features: 128 # Num features in each batch of calculations. Lower to avoid OOM errors +n_samples: 3000 +batch_size: 16 +minibatch_size_features: 100 # Num features in each batch of calculations. Lower to avoid OOM errors data: # DatasetConfig for the data which will be used to generate the dashboards - dataset_name: 'apollo-research/roneneldan-TinyStories-tokenizer-gpt2' + dataset_name: 'apollo-research/Skylion007-openwebtext-tokenizer-gpt2' is_tokenized: True tokenizer_name: 'gpt2' - split: "validation" - n_ctx: 512 + split: "train" + n_ctx: 1024 save_dir: null # The directory for saving the HTML feature dashboard files +save_json_data: false sae_positions: null # The names of the SAE positions to generate dashboards for. e.g.'blocks.2.hook_resid_post'. If None, then all positions will be generated feature_indices: null # The features for which to generate dashboards on each SAE. If none, then we'll generate dashbaords for every feature. prompt_centric: # Used to generate prompt-centric (rather than feature-centric) dashboards. Feature-centric dashboards will also be generated for every feature appaearing in these - n_random_prompt_dashboards: 50 # The number of random prompts to generate prompt-centric dashboards for. + n_random_prompt_dashboards: 10 # The number of random prompts to generate prompt-centric dashboards for. data: null # "DatasetConfig for getting random prompts. If None, then non-prompt-centric data will be used prompts: # Specific prompts on which to generate prompt-centric feature dashboards. A feature-centric dashboard will be generated for every token position in each prompt. - - "Sally met Mike at the show. She brought popcorn for him." - str_score: "loss_effect" # The ordering metric for which features are most important in prompt-centric dashboards. Can be one of 'act_size', 'act_quantile', or 'loss_effect' + - "Sally met Mike at the show. She brought popcorn for him. They ate it together" + - 'Lily asked, "Mommy, can I go on the slide?"' + - "It was time for the lecture to begin." + - "A man was taken to hospital after the crash" + - "CAMPAIGN The campaign will focus on three core goals:" + - "new_list = [n**2 for n in numbers if n%2==0]" + str_score: "act_quantile" # The ordering metric for which features are most important in prompt-centric dashboards. Can be one of 'act_size', 'act_quantile', or 'loss_effect' num_top_features: 10 # How many of the most relevant features to show for each prompt in the prompt-centric dashboards seed: 0 \ No newline at end of file diff --git a/sparsify/scripts/generate_dashboards.py b/sparsify/scripts/generate_dashboards.py index 3d7c85b..4bb6fe3 100644 --- a/sparsify/scripts/generate_dashboards.py +++ b/sparsify/scripts/generate_dashboards.py @@ -28,37 +28,36 @@ https://github.com/callummcdougall/sae_vis/commit/b28a0f7c7e936f4bea05528d952dfcd438533cce """ import math -from collections.abc import Iterable +from copy import deepcopy from pathlib import Path -from typing import Annotated, Literal +from typing import Annotated, Any, Literal import fire import numpy as np import torch -from eindex import eindex -from einops import einsum, rearrange +from einops import einsum from jaxtyping import Float, Int from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt, PositiveInt -from sae_vis.data_fetching_fns import get_sequences_data -from sae_vis.data_storing_fns import ( - FeatureData, - FeatureVisParams, - HistogramData, - MiddlePlotsData, - MultiFeatureData, - MultiPromptData, - PromptData, - SequenceData, - SequenceMultiGroupData, +from sae_vis.data_config_classes import ( + ActsHistogramConfig, + Column, + LogitsHistogramConfig, + LogitsTableConfig, + PromptConfig, + SaeVisConfig, + SaeVisLayoutConfig, + SequencesConfig, ) -from sae_vis.utils_fns import QuantileCalculator, TopK, process_str_tok +from sae_vis.data_fetching_fns import parse_feature_data, parse_prompt_data +from sae_vis.data_storing_fns import SaeVisData +from sae_vis.html_fns import HTML +from sae_vis.utils_fns import get_decode_html_safe_fn from torch import Tensor from torch.utils.data.dataset import IterableDataset from tqdm import tqdm from transformers import ( AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerBase, PreTrainedTokenizerFast, ) @@ -68,7 +67,28 @@ from sparsify.models.transformers import SAETransformer from sparsify.scripts.train_tlens_saes.run_train_tlens_saes import Config from sparsify.types import RootPath -from sparsify.utils import filter_names, load_config, to_numpy +from sparsify.utils import filter_names, load_config + +LAYOUT_FEATURE_VIS = SaeVisLayoutConfig( + columns=[ + Column(ActsHistogramConfig(), LogitsTableConfig(), LogitsHistogramConfig()), + Column(SequencesConfig(stack_mode="stack-none")), + ], + height=750, +) + +LAYOUT_PROMPT_VIS = SaeVisLayoutConfig( + columns=[ + Column( + PromptConfig(), + ActsHistogramConfig(), + LogitsTableConfig(n_rows=5), + SequencesConfig(top_acts_group_size=10, n_quantiles=0), + width=450, + ), + ], + height=1000, +) FeatureIndicesType = dict[str, list[int]] | dict[str, Int[Tensor, "some_feats"]] # noqa: F821 (jaxtyping/pyright doesn't like single dimensions) StrScoreType = Literal["act_size", "act_quantile", "loss_effect"] @@ -77,9 +97,9 @@ class PromptDashboardsConfig(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) n_random_prompt_dashboards: NonNegativeInt = Field( - default=50, + default=10, description="The number of random prompts to generate prompt-centric dashboards for." - "A feature-centric dashboard will be generated for random token positions in each prompt.", + "Feature-centric dashboards will be generated for each prompt.", ) data: DatasetConfig | None = Field( default=None, @@ -125,7 +145,12 @@ class DashboardsConfig(BaseModel): ) save_dir: RootPath | None = Field( default=None, - description="The directory for saving the HTML feature dashboard files", + description="The directory for saving the HTML feature dashboard files. If none, they " + "will be saved in pretrained_sae_paths[0].parent", + ) + save_json_data: bool = Field( + default=False, + description="Whether to save JSON data which can be used to re-generate the HTML dashboards", ) sae_positions: Annotated[ list[str] | None, BeforeValidator(lambda x: [x] if isinstance(x, str) else x) @@ -188,7 +213,7 @@ def compute_feature_acts( feature_acts[hook_name] = feature_acts[hook_name].to("cpu") if feature_indices is not None: feature_acts[hook_name] = feature_acts[hook_name][..., feature_indices[hook_name]] - return feature_acts, final_resid_acts + return feature_acts, final_resid_acts.to("cpu") def compute_feature_acts_on_distribution( @@ -221,16 +246,16 @@ def compute_feature_acts_on_distribution( - The residual stream activations of the model at the final layer (or at stop_at_layer) - The tokens used as input to the model """ - data_loader, _ = create_data_loader( - dataset_config, batch_size=batch_size, buffer_size=batch_size - ) + data_loader, _ = create_data_loader(dataset_config, batch_size, buffer_size=batch_size) if raw_sae_positions is None: raw_sae_positions = model.raw_sae_positions assert raw_sae_positions is not None device = model.saes[raw_sae_positions[0].replace(".", "-")].device if n_samples is None: # If streaming (i.e. if the dataset is an IterableDataset), we don't know the length - n_batches = None if isinstance(data_loader.dataset, IterableDataset) else len(data_loader) + n_batches = ( + float("inf") if isinstance(data_loader.dataset, IterableDataset) else len(data_loader) + ) else: n_batches = math.ceil(n_samples / batch_size) if not isinstance(data_loader.dataset, IterableDataset): @@ -245,8 +270,8 @@ def compute_feature_acts_on_distribution( for batch in tqdm(data_loader, total=n_batches, desc="Computing feature acts"): batch_tokens: Int[Tensor, "..."] = batch[dataset_config.column_name].to(device=device) batch_feature_acts, batch_final_resid_acts = compute_feature_acts( - model=model, - tokens=batch_tokens, + model, + batch_tokens, raw_sae_positions=raw_sae_positions, feature_indices=feature_indices, stop_at_layer=stop_at_layer, @@ -255,127 +280,17 @@ def compute_feature_acts_on_distribution( feature_acts_lists[sae_name].append(batch_feature_acts[sae_name]) final_resid_acts_list.append(batch_final_resid_acts) tokens_list.append(batch_tokens) - total_samples += batch_tokens.shape[0] if n_samples is not None and total_samples > n_samples: break + total_samples += batch_tokens.shape[0] final_resid_acts: Float[Tensor, "... d_resid"] = torch.cat(final_resid_acts_list, dim=0) tokens: Int[Tensor, "..."] = torch.cat(tokens_list, dim=0) feature_acts: dict[str, Float[Tensor, "... some_feats"]] = {} for sae_name in raw_sae_positions: - feature_acts[sae_name] = torch.cat(tensors=feature_acts_lists[sae_name], dim=0) + feature_acts[sae_name] = torch.cat(feature_acts_lists[sae_name], dim=0) return feature_acts, final_resid_acts, tokens -def create_vocab_dict(tokenizer: PreTrainedTokenizerBase) -> dict[int, str]: - """ - Creates a vocab dict suitable for dashboards by replacing all the special tokens with their - HTML representations. This function is adapted from sae_vis.create_vocab_dict() - """ - vocab_dict: dict[str, int] = tokenizer.get_vocab() - vocab_dict_processed: dict[int, str] = {v: process_str_tok(k) for k, v in vocab_dict.items()} - return vocab_dict_processed - - -@torch.inference_mode() -def parse_activation_data( - tokens: Int[Tensor, "batch pos"], - feature_acts: Float[Tensor, "... some_feats"], - final_resid_acts: Float[Tensor, "... d_resid"], - feature_resid_dirs: Float[Tensor, "some_feats dim"], - feature_indices_list: Iterable[int], - W_U: Float[Tensor, "dim d_vocab"], - vocab_dict: dict[int, str], - fvp: FeatureVisParams, -) -> MultiFeatureData: - """Convert generic activation data into a MultiFeatureData object, which can be used to create - the feature-centric visualisation. - Adapted from sae_vis.data_fetching_fns._get_feature_data() - final_resid_acts + W_U are used for the logit lens. - - Args: - tokens: The inputs to the model - feature_acts: The activations values of the features - final_resid_acts: The activations of the final layer of the model - feature_resid_dirs: The directions that each feature writes to the logit output - feature_indices_list: The indices of the features we're interested in - W_U: The unembed weights for the logit lens - vocab_dict: A dictionary mapping vocab indices to strings - fvp: FeatureVisParams, containing a bunch of settings. See the FeatureVisParams docstring in - sae_vis for more information. - - Returns: - A MultiFeatureData containing data for creating each feature's visualization, - as well as data for rank-ordering the feature visualizations when it comes time - to make the prompt-centric view (the `feature_act_quantiles` attribute). - Use MultiFeatureData[feature_idx].get_html() to generate the HTML dashboard for a - particular feature (returns a string of HTML). - - """ - device = W_U.device - feature_acts.to(device) - sequence_data_dict: dict[int, SequenceMultiGroupData] = {} - middle_plots_data_dict: dict[int, MiddlePlotsData] = {} - features_data: dict[int, FeatureData] = {} - # Calculate all data for the right-hand visualisations, i.e. the sequences - for i, feat in enumerate(feature_indices_list): - # Add this feature's sequence data to the list - sequence_data_dict[feat] = get_sequences_data( - tokens=tokens, - feat_acts=feature_acts[..., i], - resid_post=final_resid_acts, - feature_resid_dir=feature_resid_dirs[i], - W_U=W_U, - fvp=fvp, - ) - - # Get the logits of all features (i.e. the directions this feature writes to the logit output) - logits = einsum( - feature_resid_dirs, - W_U, - "feats d_model, d_model d_vocab -> feats d_vocab", - ) - for i, (feat, logit) in enumerate(zip(feature_indices_list, logits, strict=True)): - # Get data for logits (the histogram, and the table) - logits_histogram_data = HistogramData(logit, 40, "5 ticks") - top10_logits = TopK(logit, k=15, largest=True) - bottom10_logits = TopK(logit, k=15, largest=False) - - # Get data for feature activations histogram (the title, and the histogram) - feat_acts = feature_acts[..., i] - nonzero_feat_acts = feat_acts[feat_acts > 0] - frac_nonzero = nonzero_feat_acts.numel() / feat_acts.numel() - freq_histogram_data = HistogramData(nonzero_feat_acts, 40, "ints") - - # Create a MiddlePlotsData object from this, and add it to the dict - middle_plots_data_dict[feat] = MiddlePlotsData( - bottom10_logits=bottom10_logits, - top10_logits=top10_logits, - logits_histogram_data=logits_histogram_data, - freq_histogram_data=freq_histogram_data, - frac_nonzero=frac_nonzero, - ) - - # Return the output, as a dict of FeatureData items - for i, feat in enumerate(feature_indices_list): - features_data[feat] = FeatureData( - # Data-containing inputs (for the feature-centric visualisation) - sequence_data=sequence_data_dict[feat], - middle_plots_data=middle_plots_data_dict[feat], - left_tables_data=None, - # Non data-containing inputs - feature_idx=feat, - vocab_dict=vocab_dict, - fvp=fvp, - ) - - # Also get the quantiles, which will be useful for the prompt-centric visualisation - - feature_act_quantiles = QuantileCalculator( - data=rearrange(feature_acts, "... feats -> feats (...)") - ) - return MultiFeatureData(features_data, feature_act_quantiles) - - def feature_indices_to_tensordict( feature_indices_in: FeatureIndicesType | list[int] | None, raw_sae_positions: list[str], @@ -412,9 +327,8 @@ def get_dashboards_data( n_samples: PositiveInt | None = None, batch_size: PositiveInt | None = None, minibatch_size_features: PositiveInt | None = None, - fvp: FeatureVisParams | None = None, - vocab_dict: dict[int, str] | None = None, -) -> dict[str, MultiFeatureData]: + cfg: SaeVisConfig | None = None, +) -> dict[str, SaeVisData]: """Gets data that needed to create the sequences in the feature-centric HTML visualisation Adapted from sae_vis.data_fetching_fns._get_feature_data() @@ -442,30 +356,18 @@ def get_dashboards_data( using dataset_config minibatch_size_features: Num features in each batch of calculations (break up the features to avoid OOM errors). - fvp: + cfg: Feature visualization parameters, containing a bunch of other stuff. See the - FeatureVisParams docstring in sae_vis for more information. - vocab_dict: - vocab dict suitable for dashboards with all the special tokens replaced with their - HTML representations. If None then it will be created using create_vocab_dict(tokenizer) + SaeVisConfig docstring in sae_vis for more information. Returns: - A dict of [sae_position_name: MultiFeatureData]. Each MultiFeatureData contains data for + A dict of [sae_position_name: SaeVisData]. Each SaeVisData contains data for creating each feature's visualization, as well as data for rank-ordering the feature visualizations when it comes time to make the prompt-centric view (the `feature_act_quantiles` attribute). Use dashboards_data[sae_name][feature_idx].get_html() to generate the HTML dashboard for a particular feature (returns a string of HTML) """ - # Get the vocab dict, which we'll use at the end - if vocab_dict is None: - assert ( - model.tlens_model.tokenizer is not None - ), "If voacab_dict is not supplied, the model must have a tokenizer" - vocab_dict = create_vocab_dict(model.tlens_model.tokenizer) - - if fvp is None: - fvp = FeatureVisParams(include_left_tables=False) if sae_positions is None: raw_sae_positions: list[str] = model.raw_sae_positions @@ -475,9 +377,9 @@ def get_dashboards_data( ) # If we haven't supplied any feature indicies, assume that we want all of them feature_indices_tensors = feature_indices_to_tensordict( - feature_indices_in=feature_indices, - raw_sae_positions=raw_sae_positions, - model=model, + feature_indices, + raw_sae_positions, + model, ) for sae_name in raw_sae_positions: assert ( @@ -493,9 +395,9 @@ def get_dashboards_data( batch_size is not None ), "If no tokens are supplied, then a batch_size must be supplied" feature_acts, final_resid_acts, tokens = compute_feature_acts_on_distribution( - model=model, - dataset_config=dataset_config, - batch_size=batch_size, + model, + dataset_config, + batch_size, raw_sae_positions=raw_sae_positions, feature_indices=feature_indices_tensors, n_samples=n_samples, @@ -503,8 +405,8 @@ def get_dashboards_data( else: tokens.to(device) feature_acts, final_resid_acts = compute_feature_acts( - model=model, - tokens=tokens, + model, + tokens, raw_sae_positions=raw_sae_positions, feature_indices=feature_indices_tensors, ) @@ -516,11 +418,18 @@ def get_dashboards_data( feature_indices_tensors[sae_name] = feature_indices_tensors[sae_name][acts_sum > 0] del acts_sum - dashboards_data: dict[str, MultiFeatureData] = { - name: MultiFeatureData() for name in raw_sae_positions - } + dashboards_data: dict[str, SaeVisData] = {name: SaeVisData() for name in raw_sae_positions} for sae_name in raw_sae_positions: + if cfg is None: + cfg = SaeVisConfig( + hook_point=sae_name, + features=feature_indices_tensors[sae_name].tolist(), + feature_centric_layout=LAYOUT_FEATURE_VIS, + ) + dashboards_data[sae_name].cfg = cfg + dashboards_data[sae_name].model = model.tlens_model + sae = model.saes[sae_name.replace(".", "-")] W_dec: Float[Tensor, "feats dim"] = sae.decoder.weight.T feature_resid_dirs: Float[Tensor, "some_feats dim"] = W_dec[ @@ -540,222 +449,29 @@ def get_dashboards_data( ] feature_resid_dir_batches = feature_resid_dirs.split(minibatch_size_features) for i in tqdm(iterable=range(len(feature_batches)), desc="Parsing activation data"): - new_feature_data = parse_activation_data( - tokens=tokens, - feature_acts=feature_acts_batches[i].to_dense().to(device), - final_resid_acts=final_resid_acts, - feature_resid_dirs=feature_resid_dir_batches[i], - feature_indices_list=feature_batches[i], - W_U=W_U, - vocab_dict=vocab_dict, - fvp=fvp, + new_feature_data, _ = parse_feature_data( + tokens, + feature_batches[i], + feature_acts_batches[i].to_dense().to(device), + feature_resid_dir_batches[i].to(device), + final_resid_acts.to(device), + W_U.to(device), + cfg, ) dashboards_data[sae_name].update(new_feature_data) return dashboards_data -@torch.inference_mode() -def parse_prompt_data( - tokens: Int[Tensor, "batch pos"], - str_tokens: list[str], - features_data: MultiFeatureData, - feature_acts: Float[Tensor, "seq some_feats"], - final_resid_acts: Float[Tensor, "seq d_resid"], - feature_resid_dirs: Float[Tensor, "some_feats dim"], - feature_indices_list: list[int], - W_U: Float[Tensor, "dim d_vocab"], - num_top_features: int = 10, -) -> MultiPromptData: - """Gets data needed to create the sequences in the prompt-centric HTML visualisation. - - This visualization displays dashboards for the most relevant features on a prompt. - Adapted from sae_vis.data_fetching_fns.get_prompt_data(). - - Args: - tokens: The input prompt to the model as tokens - str_tokens: The input prompt to the model as a list of strings (one string per token) - features_data: A MultiFeatureData containing information required to plot the features. - feature_acts: The activations values of the features - final_resid_acts: The activations of the final layer of the model - feature_resid_dirs: The directions that each feature writes to the logit output - feature_indices_list: The indices of the features we're interested in - W_U: The unembed weights for the logit lens - num_top_features: The number of top features to display in this view, for any given metric. - Returns: - A MultiPromptData object containing data for visualizing the most relevant features - given the prompt. - - Similar to parse_feature_data, except it just gets the data relevant for a particular - sequence (i.e. a custom one that the user inputs on their own). - - The ordering metric for relevant features is set by the str_score parameter in the - MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect" - """ - torch.cuda.empty_cache() - device = W_U.device - n_feats = len(feature_indices_list) - batch, seq_len = tokens.shape - feats_contribution_to_loss = torch.empty(size=(n_feats, seq_len - 1), device=device) - - # Some logit computations which we only need to do once - correct_token_unembeddings = W_U[:, tokens[0, 1:]] # [d_model seq] - orig_logits = ( - final_resid_acts / final_resid_acts.std(dim=-1, keepdim=True) - ) @ W_U # [seq d_vocab] - - sequence_data_dict: dict[int, SequenceData] = {} - - for i, feat in enumerate(feature_indices_list): - # Calculate all data for the sequences - # (this is the only truly 'new' bit of calculation we need to do) - - # Get this feature's output vector, using an outer product over feature acts for all tokens - final_resid_acts_feature_effect = einsum( - feature_acts[..., i].to_dense().to(device), - feature_resid_dirs[i], - "seq, d_model -> seq d_model", - ) - - # Ablate the output vector from the residual stream, and get logits post-ablation - new_final_resid_acts = final_resid_acts - final_resid_acts_feature_effect - new_logits = (new_final_resid_acts / new_final_resid_acts.std(dim=-1, keepdim=True)) @ W_U - - # Get the top5 & bottom5 changes in logits - contribution_to_logprobs = orig_logits.log_softmax(dim=-1) - new_logits.log_softmax(dim=-1) - top5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5) - bottom5_contribution_to_logits = TopK(contribution_to_logprobs[:-1], k=5, largest=False) - - # Get the change in loss (which is negative of change of logprobs for correct token) - contribution_to_loss = eindex(-contribution_to_logprobs[:-1], tokens[0, 1:], "seq [seq]") - feats_contribution_to_loss[i, :] = contribution_to_loss - - # Store the sequence data - sequence_data_dict[feat] = SequenceData( - token_ids=tokens.squeeze(0).tolist(), - feat_acts=feature_acts[..., i].tolist(), - contribution_to_loss=[0.0] + contribution_to_loss.tolist(), - top5_token_ids=top5_contribution_to_logits.indices.tolist(), - top5_logit_contributions=top5_contribution_to_logits.values.tolist(), - bottom5_token_ids=bottom5_contribution_to_logits.indices.tolist(), - bottom5_logit_contributions=bottom5_contribution_to_logits.values.tolist(), - ) - - # Get the logits for the correct tokens - logits_for_correct_tokens = einsum( - feature_resid_dirs[i], correct_token_unembeddings, "d_model, d_model seq -> seq" - ) - - # Add the annotations data (feature activations and logit effect) to the histograms - freq_line_posn = feature_acts[..., i].tolist() - freq_line_text = [ - f"\\'{str_tok}\\'
{act:.3f}" - for str_tok, act in zip(str_tokens[1:], freq_line_posn, strict=False) - ] - middle_plots_data = features_data[feat].middle_plots_data - assert middle_plots_data is not None - middle_plots_data.freq_histogram_data.line_posn = freq_line_posn - middle_plots_data.freq_histogram_data.line_text = freq_line_text # type: ignore (due to typing bug in sae_vis) - logits_line_posn = logits_for_correct_tokens.tolist() - logits_line_text = [ - f"\\'{str_tok}\\'
{logits:.3f}" - for str_tok, logits in zip(str_tokens[1:], logits_line_posn, strict=False) - ] - middle_plots_data.logits_histogram_data.line_posn = logits_line_posn - middle_plots_data.logits_histogram_data.line_text = logits_line_text # type: ignore (due to typing bug in sae_vis) - - # Lastly, use the criteria (act size, act quantile, loss effect) to find top-scoring features - - # Construct a scores dict, which maps from things like ("act_quantile", seq_pos) - # to a list of the top-scoring features - scores_dict: dict[tuple[str, str], tuple[TopK, list[str]]] = {} - - for seq_pos in range(len(str_tokens)): - # Filter the feature activations, since we only need the ones that are non-zero - feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos] > 0) - feat_acts_nonzero_locations = np.nonzero(feat_acts_nonzero_filter)[0].tolist() - _feature_acts = ( - feature_acts[seq_pos, feat_acts_nonzero_filter].to_dense().to(device) - ) # [feats_filtered,] - _feature_indices_list = np.array(feature_indices_list)[feat_acts_nonzero_filter] - - if feat_acts_nonzero_filter.sum() > 0: - k = min(num_top_features, _feature_acts.numel()) - - # Get the "act_size" scores (we return it as a TopK object) - act_size_topk = TopK(_feature_acts, k=k, largest=True) - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - act_size_topk.indices[:] = _feature_indices_list[act_size_topk.indices] - scores_dict[("act_size", seq_pos)] = (act_size_topk, ".3f") # type: ignore (due to typing bug in sae_vis) - - # Get the "act_quantile" scores, which is just the fraction of cached feat acts that it - # is larger than - act_quantile, act_precision = features_data.feature_act_quantiles.get_quantile( - _feature_acts, feat_acts_nonzero_locations - ) - act_quantile_topk = TopK(act_quantile, k=k, largest=True) - act_formatting_topk = [f".{act_precision[i]-2}%" for i in act_quantile_topk.indices] - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - act_quantile_topk.indices[:] = _feature_indices_list[act_quantile_topk.indices] - scores_dict[("act_quantile", seq_pos)] = (act_quantile_topk, act_formatting_topk) # type: ignore (due to typing bug in sae_vis) - - # We don't measure loss effect on the first token - if seq_pos == 0: - continue - - # Filter the loss effects, since we only need the ones which have non-zero feature acts on - # the tokens before them - prev_feat_acts_nonzero_filter = to_numpy(feature_acts[seq_pos - 1] > 0) - _contribution_to_loss = feats_contribution_to_loss[ - prev_feat_acts_nonzero_filter, seq_pos - 1 - ] # [feats_filtered,] - _feature_indices_list_prev = np.array(feature_indices_list)[prev_feat_acts_nonzero_filter] - - if prev_feat_acts_nonzero_filter.sum() > 0: - k = min(num_top_features, _contribution_to_loss.numel()) - - # Get the "loss_effect" scores, which are just the min of features' contributions to - # loss (min because we're looking for helpful features, not harmful ones) - contribution_to_loss_topk = TopK(_contribution_to_loss, k=k, largest=False) - # Replace the indices with feature indices (these are different when - # feature_indices_list argument is not [0, 1, 2, ...]) - contribution_to_loss_topk.indices[:] = _feature_indices_list_prev[ - contribution_to_loss_topk.indices - ] - scores_dict[("loss_effect", seq_pos)] = (contribution_to_loss_topk, ".3f") # type: ignore (due to typing bug in sae_vis) - - # Get all the features which are required (i.e. all the sequence position indices) - feature_indices_list_required = set() - for score_topk, _ in scores_dict.values(): - feature_indices_list_required.update(set(score_topk.indices.tolist())) - prompt_data_dict = {} - for feat in feature_indices_list_required: - middle_plots_data = features_data[feat].middle_plots_data - assert middle_plots_data is not None - prompt_data_dict[feat] = PromptData( - prompt_data=sequence_data_dict[feat], - sequence_data=features_data[feat].sequence_data[0], - middle_plots_data=middle_plots_data, - ) - - return MultiPromptData( - prompt_str_toks=str_tokens, - prompt_data_dict=prompt_data_dict, - scores_dict=scores_dict, - ) - - @torch.inference_mode() def get_prompt_data( model: SAETransformer, tokens: Int[Tensor, "batch pos"], str_tokens: list[str], - dashboards_data: dict[str, MultiFeatureData], + dashboards_data: dict[str, SaeVisData], sae_positions: list[str] | None = None, num_top_features: PositiveInt = 10, -) -> dict[str, MultiPromptData]: +) -> tuple[dict[str, SaeVisData], dict[str, dict[str, tuple[list[int], list[str]]]]]: """Gets data needed to create the sequences in the prompt-centric HTML visualisation. This visualization displays dashboards for the most relevant features on a prompt. @@ -769,7 +485,7 @@ def get_prompt_data( str_tokens: The input prompt to the model as a list of strings (one string per token) dashboards_data: - For each SAE, a MultiFeatureData containing information required to plot its features. + For each SAE, a SaeVisData containing information required to plot its features. sae_positions: The names of the SAEs we want to find relevant features in. eg. ['blocks.0.hook_resid_pre']. If none, then we'll do all of them. @@ -777,13 +493,13 @@ def get_prompt_data( The number of top features to display in this view, for any given metric. Returns: - A dict of [sae_position_name: MultiPromptData]. Each MultiPromptData contains data for + A dict of [sae_position_name: SaeVisData]. Each SaeVisData contains data for visualizing the most relevant features in that SAE given the prompt. Similar to get_feature_data, except it just gets the data relevant for a particular sequence (i.e. a custom one that the user inputs on their own). The ordering metric for relevant features is set by the str_score parameter in the - MultiPromptData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect" + SaeVisData.get_html() method: it can be "act_size", "act_quantile", or "loss_effect" """ assert tokens.shape[-1] == len( str_tokens @@ -794,21 +510,24 @@ def get_prompt_data( raw_sae_positions: list[str] = filter_names( list(model.tlens_model.hook_dict.keys()), sae_positions ) + device = model.saes[raw_sae_positions[0].replace(".", "-")].device + tokens = tokens.to(device) feature_indices: dict[str, list[int]] = {} for sae_name in raw_sae_positions: feature_indices[sae_name] = list(dashboards_data[sae_name].feature_data_dict.keys()) feature_acts, final_resid_acts = compute_feature_acts( - model=model, - tokens=tokens, + model, + tokens, raw_sae_positions=raw_sae_positions, feature_indices=feature_indices, ) final_resid_acts = final_resid_acts.squeeze(dim=0) - prompt_data: dict[str, MultiPromptData] = {} + scores_dicts: dict[str, dict[str, tuple[list[int], list[str]]]] = {} for sae_name in raw_sae_positions: + dashboards_data[sae_name].model = model.tlens_model sae = model.saes[sae_name.replace(".", "-")] feature_act_dir: Float[Tensor, "dim some_feats"] = sae.encoder[0].weight.T[ :, feature_indices[sae_name] @@ -822,45 +541,59 @@ def get_prompt_data( == (len(feature_indices[sae_name]), sae.input_size) ) - prompt_data[sae_name] = parse_prompt_data( - tokens=tokens, - str_tokens=str_tokens, - features_data=dashboards_data[sae_name], - feature_acts=feature_acts[sae_name].squeeze(dim=0), - final_resid_acts=final_resid_acts, - feature_resid_dirs=feature_resid_dirs, - feature_indices_list=feature_indices[sae_name], - W_U=model.tlens_model.W_U, + scores_dicts[sae_name] = parse_prompt_data( + tokens, + str_tokens, + dashboards_data[sae_name], + feature_acts[sae_name].squeeze(dim=0).to(device), + feature_resid_dirs.to(device), + final_resid_acts.to(device), + model.tlens_model.W_U.to(device), + feature_idx=feature_indices[sae_name], num_top_features=num_top_features, ) - return prompt_data + return dashboards_data, scores_dicts @torch.inference_mode() def generate_feature_dashboard_html_files( - dashboards_data: dict[str, MultiFeatureData], - feature_indices: FeatureIndicesType | dict[str, set[int]] | None, + dashboards_data: dict[str, SaeVisData], + minibatch_size_features: PositiveInt | None, save_dir: str | Path = "", ): """Generates viewable HTML dashboards for every feature in every SAE in dashboards_data""" - if feature_indices is None: - feature_indices = {name: dashboards_data[name].keys() for name in dashboards_data} save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) - for sae_name in feature_indices: + for sae_name in dashboards_data: logger.info(f"Saving HTML feature dashboards for the SAE at {sae_name}:") folder = save_dir / Path(f"dashboards_{sae_name}") folder.mkdir(parents=True, exist_ok=True) - for feature_idx in tqdm(feature_indices[sae_name], desc="Dashboard HTML files"): - feature_idx = ( - int(feature_idx.item()) if isinstance(feature_idx, Tensor) else feature_idx + model = dashboards_data[sae_name].model + assert model is not None + feature_ids = sorted(list(dashboards_data[sae_name].feature_data_dict.keys())) + if minibatch_size_features is None: + feature_ids_split = [feature_ids] + else: + + def split_list(lst: list[Any], chunk_size: int) -> list[list[Any]]: + chunks = [[] for _ in range((len(lst) + chunk_size - 1) // chunk_size)] + for i, item in enumerate(lst): + chunks[i // chunk_size].append(item) + return chunks + + feature_ids_split = split_list(feature_ids, minibatch_size_features) + for batch_feature_ids in tqdm(feature_ids_split, desc="Dashboard HTML files"): + batch_feature_data_dict = { + i: dashboards_data[sae_name].feature_data_dict[i] for i in batch_feature_ids + } + batch_dashboards_data = SaeVisData( + feature_data_dict=batch_feature_data_dict, + cfg=dashboards_data[sae_name].cfg, + model=dashboards_data[sae_name].model, + ) + batch_dashboards_data.save_feature_centric_vis( + filename=folder / f"features-{batch_feature_ids[0]}-to-{batch_feature_ids[-1]}.html" ) - if feature_idx in dashboards_data[sae_name].keys(): - html_str = dashboards_data[sae_name][feature_idx].get_html() - filepath = folder / Path(f"feature-{feature_idx}.html") - with open(filepath, "w") as f: - f.write(html_str) - logger.info(f"Saved HTML feature dashboards in {folder}") @torch.inference_mode() @@ -868,33 +601,38 @@ def generate_prompt_dashboard_html_files( model: SAETransformer, tokens: Int[Tensor, "batch pos"], str_tokens: list[str], - dashboards_data: dict[str, MultiFeatureData], - seq_pos: int | list[int] | None = None, - vocab_dict: dict[int, str] | None = None, - str_score: StrScoreType = "loss_effect", + dashboards_data: dict[str, SaeVisData], + seq_pos: list[int] | int | None = None, save_dir: str | Path = "", ) -> dict[str, set[int]]: """Generates viewable HTML dashboards for the most relevant features (measured by str_score) for every SAE in dashboards_data. + This function is adapted from sae_vis.data_storing_functions.SaeVisData.save_prompt_centric_vis + Returns the set of feature indices which were active""" + assert tokens.shape[-1] == len( str_tokens ), "Error: the number of tokens does not equal the number of str_tokens" - str_tokens = [s.replace("Ġ", " ") for s in str_tokens] + if isinstance(seq_pos, int): seq_pos = [seq_pos] if seq_pos is None: # Generate a dashboard for every position if none is specified seq_pos = list(range(2, len(str_tokens) - 2)) - if vocab_dict is None: - assert ( - model.tlens_model.tokenizer is not None - ), "If voacab_dict is not supplied, the model must have a tokenizer" - vocab_dict = create_vocab_dict(model.tlens_model.tokenizer) - prompt_data = get_prompt_data( - model=model, tokens=tokens, str_tokens=str_tokens, dashboards_data=dashboards_data - ) - prompt = "".join(str_tokens) + + str_toks = [t.replace("|", "│") for t in str_tokens] # vertical line -> pipe + str_toks_list = [f"{t!r} ({i})" for i, t in enumerate(str_toks)] + metric_list = ["act_quantile", "act_size", "loss_effect"] + + # Get default values for dropdowns + first_metric = "act_quantile" + first_seq_pos = str_toks_list[seq_pos[0]] + first_key = f"{first_metric}|{first_seq_pos}" + + # Get tokenize function (we only need to define it once) + decode_fn = get_decode_html_safe_fn(model.tlens_model.tokenizer) + # Use the beginning of the prompt for the filename, but make sure that it's safe for a filename str_tokens_safe_for_filenames = [ "".join(c for c in token if c.isalpha() or c.isdigit() or c == " ") @@ -906,43 +644,94 @@ def generate_prompt_dashboard_html_files( filename_from_prompt = filename_from_prompt[:50] save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) + + METRIC_TITLES = { + "act_size": "Activation Size", + "act_quantile": "Activation Quantile", + "loss_effect": "Loss Effect", + } + + # Run forward passes on our prompt, and store the data within each FeatureData object + # as `self.prompt_data` as well as returning the scores_dict (which maps from score hash to a + # list of feature indices & formatted scores) + dashboards_data, scores_dicts = get_prompt_data( + model, + tokens, + str_tokens, + dashboards_data=dashboards_data, + ) used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} + for sae_name in dashboards_data: - seq_pos_with_scores: set[int] = { - int(x[1]) - for x in prompt_data["blocks.1.hook_resid_post"].scores_dict - if x[0] == str_score - } - for seq_pos_i in seq_pos_with_scores.intersection(seq_pos): - # Find the most relevant features (by {str_score}) for the token - # '{str_tokens[seq_pos_i]}' in the prompt '{prompt}: - html_str = prompt_data[sae_name].get_html(seq_pos_i, str_score, vocab_dict) - # Insert a title - title: str = ( - f"

  The most relevant features from {sae_name},
  " - f"measured by {str_score} on the '{str_tokens[seq_pos_i].replace('Ġ',' ')}' " - f"token (token number {seq_pos_i}) in the prompt '{prompt}':

" + for _metric in metric_list: + # Initialize the object we'll eventually get_html from + HTML_OBJ = HTML() + + # For each (metric, seqpos) object, we merge the prompt-centric views of each of the top + # features, then we merge + # these all together into our HTML_OBJ + for _seq_pos in seq_pos: + # Create the key for this given combination of metric & seqpos, and get our + # top features & scores + key = f"{_metric}|{str_toks_list[_seq_pos]}" + if key not in scores_dicts[sae_name]: + continue + feature_idx_list, scores_formatted = scores_dicts[sae_name][key] + used_features[sae_name] = used_features[sae_name].union(feature_idx_list) + + # Create HTML object, to store each feature column for all the top features for + # this particular key + html_obj = HTML() + + for i, (feature_idx, score_formatted) in enumerate( + zip(feature_idx_list, scores_formatted, strict=True) + ): + # Get HTML object at this column (which includes JavaScript to set the title) + html_obj += ( + dashboards_data[sae_name] + .feature_data_dict[feature_idx] + ._get_html_data_prompt_centric( + layout=LAYOUT_PROMPT_VIS, + decode_fn=decode_fn, + column_idx=i, + bold_idx=_seq_pos, + title=f"

#{feature_idx}
{METRIC_TITLES[_metric]}" + f" = {score_formatted}


", + ) + ) + + # Add the JavaScript (which includes the titles for each column) + HTML_OBJ.js_data[key] = deepcopy(html_obj.js_data) + + # Set the HTML data to be the one with the most columns + if len(HTML_OBJ.html_data) < len(html_obj.html_data): + HTML_OBJ.html_data = deepcopy(html_obj.html_data) + + # Check our first key is in the scores_dict (if not, we should pick a different key) + assert first_key in scores_dicts[sae_name], ( + f"Key {first_key} not found in " + "{scores_dicts[sae_name].keys()=}. Have you tried " + "computing your initial data with more features and/or tokens, " + "to make sure you have enough positive examples?" ) - substr = "
" - html_str = html_str.replace( - substr, "
" + title + "
\n" + substr + + filename = save_dir / Path( + f"prompt-{filename_from_prompt}_{_metric}_sae-{sae_name}.html" ) - filepath = save_dir / Path( - f"prompt-{filename_from_prompt}_token-{seq_pos_i}-" - f"{str_tokens_safe_for_filenames[seq_pos_i]}_-{str_score.replace('_','-')}_" - f"sae-{sae_name}.html" + + # Save our full HTML + HTML_OBJ.get_html( + LAYOUT_PROMPT_VIS, + filename, + first_key, ) - with open(filepath, "w") as f: - f.write(html_str) - scores = prompt_data[sae_name].scores_dict[(str_score, seq_pos_i)][0] # type: ignore - used_features[sae_name] = used_features[sae_name].union(scores.indices.tolist()) return used_features @torch.inference_mode() def generate_random_prompt_dashboards( model: SAETransformer, - dashboards_data: dict[str, MultiFeatureData], + dashboards_data: dict[str, SaeVisData], dashboards_config: DashboardsConfig, use_model_tokenizer: bool = False, save_dir: RootPath | None = None, @@ -951,7 +740,9 @@ def generate_random_prompt_dashboards( A data_loader is created using the dashboards_config.prompt_centric.data if it exists, otherwise using the dashboards_config.data config. - For each random prompt, dashboards are generated for three consecutive sequence positions.""" + For each random prompt, dashboards are generated for three consecutive sequence positions. + + Returns the set of feature indices which were active""" np.random.seed(dashboards_config.seed) if save_dir is None: save_dir = dashboards_config.save_dir @@ -975,7 +766,6 @@ def generate_random_prompt_dashboards( assert isinstance(tokenizer, PreTrainedTokenizer | PreTrainedTokenizerFast) else: tokenizer = AutoTokenizer.from_pretrained(dashboards_config.data.tokenizer_name) - vocab_dict = create_vocab_dict(tokenizer) if dashboards_config.sae_positions is None: raw_sae_positions: list[str] = model.raw_sae_positions else: @@ -985,7 +775,7 @@ def generate_random_prompt_dashboards( used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} device = model.saes[raw_sae_positions[0].replace(".", "-")].device - n_prompts = (dashboards_config.prompt_centric.n_random_prompt_dashboards + 2) // 3 + n_prompts = dashboards_config.prompt_centric.n_random_prompt_dashboards for prompt_idx, batch in tqdm( enumerate(data_loader), total=n_prompts, @@ -999,25 +789,27 @@ def generate_random_prompt_dashboards( bos_inds = torch.argwhere(batch_tokens == tokenizer.bos_token_id)[:, 1] if len(bos_inds) > 1: batch_tokens = batch_tokens[:, bos_inds[0] : bos_inds[1]] + if batch_tokens.shape[1] > 50: + batch_tokens = batch_tokens[:, :50] str_tokens = tokenizer.convert_ids_to_tokens(batch_tokens.squeeze(dim=0).tolist()) assert isinstance(str_tokens, list) seq_len: int = batch_tokens.shape[1] # Generate dashboards for three consecutive positions in the prompt, chosen randomly + seq_pos = None if seq_len > 4: # Ensure the prompt is long enough for three positions + next token effect seq_pos_c = np.random.randint(1, seq_len - 3) seq_pos = [seq_pos_c - 1, seq_pos_c, seq_pos_c + 1] - used_features_now = generate_prompt_dashboard_html_files( - model=model, - tokens=batch_tokens, - str_tokens=str_tokens, - dashboards_data=dashboards_data, - seq_pos=seq_pos, - vocab_dict=vocab_dict, - str_score=dashboards_config.prompt_centric.str_score, - save_dir=save_dir, - ) - for sae_name in used_features: - used_features[sae_name] = used_features[sae_name].union(used_features_now[sae_name]) + # Generate dashboards for three consecutive positions in the prompt, chosen randomly + used_features_now = generate_prompt_dashboard_html_files( + model, + batch_tokens, + str_tokens, + dashboards_data, + seq_pos=seq_pos, + save_dir=save_dir, + ) + for sae_name in used_features: + used_features[sae_name] = used_features[sae_name].union(used_features_now[sae_name]) if prompt_idx > n_prompts: break @@ -1043,6 +835,7 @@ def generate_dashboards( "generate_dashboards() saves HTML files, but no save_dir was specified in the" + " dashboards_config or given as input" ) + save_dir.mkdir(parents=True, exist_ok=True) # Deal with the possible input typles of sae_positions if dashboards_config.sae_positions is None: raw_sae_positions = model.raw_sae_positions @@ -1056,8 +849,8 @@ def generate_dashboards( ) # Get the data used in the dashboards - dashboards_data: dict[str, MultiFeatureData] = get_dashboards_data( - model=model, + dashboards_data: dict[str, SaeVisData] = get_dashboards_data( + model, dataset_config=dashboards_config.data, sae_positions=raw_sae_positions, # We need data for every feature if we're generating prompt-centric dashboards: @@ -1067,6 +860,23 @@ def generate_dashboards( minibatch_size_features=dashboards_config.minibatch_size_features, ) + if dashboards_config.save_json_data: + logger.info(f"Saving dashboards data json files to {save_dir}:") + for sae_name in dashboards_data: + dashboards_data[sae_name].save_json( + str(save_dir / Path(f"dashboards_data_{sae_name.replace('.','-')}.json")) + ) + logger.info("Saved.") + + # Generate the viewable HTML feature dashboard files + dashboard_html_saving_folder = save_dir / Path("feature-dashboards") + dashboard_html_saving_folder.mkdir(parents=True, exist_ok=True) + generate_feature_dashboard_html_files( + dashboards_data, + minibatch_size_features=dashboards_config.minibatch_size_features, + save_dir=dashboard_html_saving_folder, + ) + # Generate the prompt-centric dashboards and record which features were active on them used_features: dict[str, set[int]] = {sae_name: set() for sae_name in dashboards_data} if dashboards_config.prompt_centric: @@ -1075,9 +885,9 @@ def generate_dashboards( # Generate random prompt-centric dashboards if dashboards_config.prompt_centric.n_random_prompt_dashboards > 0: used_features_now = generate_random_prompt_dashboards( - model=model, - dashboards_data=dashboards_data, - dashboards_config=dashboards_config, + model, + dashboards_data, + dashboards_config, save_dir=prompt_dashboard_saving_folder, ) for sae_name in used_features: @@ -1086,7 +896,6 @@ def generate_dashboards( # Generate dashboards for specific prompts if dashboards_config.prompt_centric.prompts is not None: tokenizer = AutoTokenizer.from_pretrained(dashboards_config.data.tokenizer_name) - vocab_dict = create_vocab_dict(tokenizer) for prompt in dashboards_config.prompt_centric.prompts: tokens = tokenizer(prompt)["input_ids"] list_tokens = tokens.tolist() if isinstance(tokens, Tensor) else tokens @@ -1094,12 +903,10 @@ def generate_dashboards( str_tokens = tokenizer.convert_ids_to_tokens(list_tokens) assert isinstance(str_tokens, list) used_features_now = generate_prompt_dashboard_html_files( - model=model, - tokens=torch.Tensor(tokens).to(dtype=torch.int).unsqueeze(dim=0), - str_tokens=str_tokens, - dashboards_data=dashboards_data, - str_score=dashboards_config.prompt_centric.str_score, - vocab_dict=vocab_dict, + model, + torch.Tensor(tokens).to(dtype=torch.int).unsqueeze(dim=0), + str_tokens, + dashboards_data, save_dir=prompt_dashboard_saving_folder, ) for sae_name in used_features: @@ -1112,15 +919,6 @@ def generate_dashboards( set(feature_indices[sae_name].tolist()) ) - # Generate the viewable HTML feature dashboard files - dashboard_html_saving_folder = save_dir / Path("feature-dashboards") - dashboard_html_saving_folder.mkdir(parents=True, exist_ok=True) - generate_feature_dashboard_html_files( - dashboards_data=dashboards_data, - feature_indices=used_features if dashboards_config.prompt_centric else feature_indices, - save_dir=dashboard_html_saving_folder, - ) - # Load the saved SAEs and the corresponding model def load_SAETransformer_from_saes_paths( @@ -1154,9 +952,7 @@ def load_SAETransformer_from_saes_paths( assert pretrained_sae_paths is not None, "pretrained_sae_paths must be given or in config" logger.info(config) - tlens_model = load_tlens_model( - tlens_model_name=config.tlens_model_name, tlens_model_path=config.tlens_model_path - ) + tlens_model = load_tlens_model(config.tlens_model_name, config.tlens_model_path) assert tlens_model is not None if sae_positions is None: @@ -1164,17 +960,17 @@ def load_SAETransformer_from_saes_paths( raw_sae_positions = filter_names(list(tlens_model.hook_dict.keys()), sae_positions) model = SAETransformer( - tlens_model=tlens_model, - raw_sae_positions=raw_sae_positions, - dict_size_to_input_ratio=config.saes.dict_size_to_input_ratio, + tlens_model, + raw_sae_positions, + config.saes.dict_size_to_input_ratio, ).to(device=device) all_param_names = [name for name, _ in model.saes.named_parameters()] trainable_param_names = load_pretrained_saes( - saes=model.saes, - pretrained_sae_paths=pretrained_sae_paths, - all_param_names=all_param_names, - retrain_saes=config.saes.retrain_saes, + model.saes, + pretrained_sae_paths, + all_param_names, + config.saes.retrain_saes, ) return model, config, trainable_param_names @@ -1183,7 +979,7 @@ def main( config_path_or_obj: Path | str | DashboardsConfig, pretrained_sae_paths: Path | str | list[Path] | list[str] | None, ) -> None: - dashboards_config = load_config(config_path_or_obj, config_model=DashboardsConfig) + dashboards_config = load_config(config_path_or_obj, DashboardsConfig) logger.info(dashboards_config) if pretrained_sae_paths is None: @@ -1207,6 +1003,7 @@ def main( save_dir = dashboards_config.save_dir or Path(pretrained_sae_paths[0]).parent logger.info(f"The HTML dashboards will be saved in {save_dir}") generate_dashboards(model, dashboards_config, save_dir=save_dir) + logger.info("Finished.") if __name__ == "__main__": diff --git a/tests/test_dashboards.py b/tests/test_dashboards.py index 3cdf26a..96348c3 100644 --- a/tests/test_dashboards.py +++ b/tests/test_dashboards.py @@ -10,11 +10,10 @@ from sparsify.scripts.generate_dashboards import ( DashboardsConfig, compute_feature_acts, - create_vocab_dict, generate_dashboards, ) from sparsify.utils import set_seed -from tests.utils import get_tinystories_config +from tests.utils import TINYSTORIES_CONFIG, get_tinystories_config Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast @@ -52,17 +51,6 @@ def test_compute_feature_acts(tinystories_model: SAETransformer): assert acts.shape[2] == 7 # feature_indices -def test_create_vocab_dict(tinystories_model: SAETransformer): - tokenizer = tinystories_model.tlens_model.tokenizer - assert tokenizer is not None - vocab_dict = create_vocab_dict(tokenizer) - assert isinstance(tokenizer, PreTrainedTokenizerFast) - assert len(vocab_dict) == len(tokenizer.vocab) - for token_id, token_str in vocab_dict.items(): - assert isinstance(token_id, int) - assert isinstance(token_str, str) - - def check_valid_feature_dashboard_htmls(folder: Path): assert folder.exists() for html_file in folder.iterdir(): @@ -72,42 +60,25 @@ def check_valid_feature_dashboard_htmls(folder: Path): html_content = f.read() assert isinstance(html_content, str) assert len(html_content) > 100 - assert "Plotly.newPlot('histogram-acts'" in html_content - assert '
'" in html_content - assert "Feature #" in html_content - assert '
" in html_content -@pytest.mark.skip("Currently only works with a GPU") +@pytest.mark.slow() def test_generate_dashboards(tinystories_model: SAETransformer, tmp_path: Path): # This function also tests compute_feature_acts_on_distribution() set_seed(0) dashboards_config = DashboardsConfig( n_samples=10, - batch_size=2, - minibatch_size_features=5, + batch_size=10, + minibatch_size_features=100, save_dir=Path(tmp_path), + sae_config_path=Path(TINYSTORIES_CONFIG), sae_positions=["blocks.2.hook_resid_post"], pretrained_sae_paths=None, feature_indices=list(range(5)), @@ -120,5 +91,5 @@ def test_generate_dashboards(tinystories_model: SAETransformer, tmp_path: Path): ) generate_dashboards(tinystories_model, dashboards_config) check_valid_feature_dashboard_htmls( - tmp_path / "feature-dashboards" / "dashboards_blocks.2.hook_resid_post" + tmp_path / Path("feature-dashboards") / Path("dashboards_blocks.2.hook_resid_post") )