diff --git a/delphi/__main__.py b/delphi/__main__.py index d69d7b10..60c6855c 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -10,7 +10,7 @@ from simple_parsing import ArgumentParser from torch import Tensor from transformers import ( - AutoModel, + AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, @@ -27,7 +27,12 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator +from delphi.scorers import ( + DetectionScorer, + FuzzingScorer, + OpenAISimulator, + SurprisalInterventionScorer, +) from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -40,7 +45,7 @@ def load_artifacts(run_cfg: RunConfig): else: dtype = "auto" - model = AutoModel.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( run_cfg.model, device_map={"": "cuda"}, quantization_config=( @@ -118,6 +123,8 @@ async def process_cache( hookpoints: list[str], tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, latent_range: Tensor | None, + model, + hookpoint_to_sparse_encode, ): """ Converts SAE latent activations in on-disk cache in the `latents_path` directory @@ -219,6 +226,12 @@ def none_postprocessor(result): ) ) + def custom_serializer(obj): + """A custom serializer for orjson to handle specific types.""" + if isinstance(obj, Tensor): + return obj.tolist() + raise TypeError + # Builds the record from result returned by the pipeline def scorer_preprocess(result): if isinstance(result, list): @@ -230,11 +243,18 @@ def scorer_preprocess(result): return record # Saves the score to a file - def scorer_postprocess(result, score_dir): + # In your __main__.py file + + def scorer_postprocess(result, score_dir, scorer_name=None): + if isinstance(result, list): + if not result: + return + result = result[0] + safe_latent_name = str(result.record.latent).replace("/", "--") with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) + f.write(orjson.dumps(result.score, default=custom_serializer)) scorers = [] for scorer_name in run_cfg.scorers: @@ -265,6 +285,16 @@ def scorer_postprocess(result, score_dir): verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, ) + + elif scorer_name == "surprisal_intervention": + scorer = SurprisalInterventionScorer( + model, + hookpoint_to_sparse_encode, + hookpoints=run_cfg.hookpoints, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) else: raise ValueError(f"Scorer {scorer_name} not supported") @@ -404,6 +434,8 @@ async def run( hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg) tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token) + model.tokenizer = tokenizer + nrh = assert_type( dict, non_redundant_hookpoints( @@ -420,7 +452,6 @@ async def run( transcode, ) - del model, hookpoint_to_sparse_encode if run_cfg.constructor_cfg.non_activating_source == "neighbours": nrh = assert_type( list, @@ -453,10 +484,20 @@ async def run( nrh, tokenizer, latent_range, + model, + hookpoint_to_sparse_encode, ) + del model, hookpoint_to_sparse_encode + if run_cfg.verbose: - log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers) + log_results( + scores_path, + visualize_path, + run_cfg.hookpoints, + run_cfg.scorers, + model_name=run_cfg.model, + ) if __name__ == "__main__": diff --git a/delphi/config.py b/delphi/config.py index de806157..11321684 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -148,18 +148,14 @@ class RunConfig(Serializable): the default single token explainer, and 'none' for no explanation generation.""" scorers: list[str] = list_field( - choices=[ - "fuzz", - "detection", - "simulation", - ], + choices=["fuzz", "detection", "simulation", "surprisal_intervention"], default=[ "fuzz", "detection", ], ) - """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', and - 'simulation'.""" + """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', + 'simulation' and 'surprisal_intervention'.""" fuzz_type: Literal["default", "active"] = "default" """Type of fuzzing to use for the fuzz scorer. Default uses non-activating examples and highlights n_incorrect tokens. Active uses activating examples diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 0f4ff94d..ca08ffaa 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -157,6 +157,13 @@ class LatentRecord: """Frequency of the latent. Number of activations in a context per total number of contexts.""" + @property + def feature_id(self) -> int: + """ + Returns the unique feature index for this latent. + """ + return self.latent.latent_index + @property def max_activation(self) -> float: """ diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 9937bd96..8afd4388 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional import orjson import pandas as pd @@ -9,12 +8,143 @@ from sklearn.metrics import roc_auc_score, roc_curve -def plot_firing_vs_f1( - latent_df: pd.DataFrame, num_tokens: int, out_dir: Path, run_label: str -) -> None: +def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label: str): + """ + Replicates the Scatter Plot from the paper (Figure 3/Appendix G). + Plots Fuzz Score vs. Intervention Score for the same latents. + """ + + # Extract Fuzz Scores (using F1 or Accuracy as the metric) + fuzz_df = latent_df[latent_df["score_type"] == "fuzz"].copy() + if fuzz_df.empty: + return + + # Calculate per-latent F1 for fuzzing + fuzz_metrics = ( + fuzz_df.groupby(["module", "latent_idx"]) + .apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ) + .reset_index(name="fuzz_score") + ) + + # Extract Intervention Scores + int_df = latent_df[latent_df["score_type"] == "surprisal_intervention"].copy() + if int_df.empty: + return + + int_metrics = int_df.drop_duplicates(subset=["module", "latent_idx"])[ + ["module", "latent_idx", "avg_kl_divergence", "final_score"] + ] + + merged = pd.merge(fuzz_metrics, int_metrics, on=["module", "latent_idx"]) + + if merged.empty: + print("Could not merge Fuzz and Intervention scores (no matching latents).") + return + + # Plot 1: KL vs Fuzz (Causal Impact vs Correlational Quality) + fig_kl = px.scatter( + merged, + x="fuzz_score", + y="avg_kl_divergence", + hover_data=["latent_idx"], + title=f"Correlation vs. Causation (KL) - {run_label}", + labels={ + "fuzz_score": "Fuzzing Score (Correlation)", + "avg_kl_divergence": "Intervention KL (Causation)", + }, + trendline="ols", # Adds a regression line to show the negative/zero correlation + ) + fig_kl.write_image(out_dir / "scatter_fuzz_vs_kl.pdf") + + # Plot 2: Score vs Fuzz (Original Paper Metric) + fig_score = px.scatter( + merged, + x="fuzz_score", + y="final_score", + hover_data=["latent_idx"], + title=f"Correlation vs. Causation (Score) - {run_label}", + labels={ + "fuzz_score": "Fuzzing Score (Correlation)", + "final_score": "Intervention Score (Surprisal)", + }, + trendline="ols", + ) + fig_score.write_image(out_dir / "scatter_fuzz_vs_score.pdf") + print("Generated Fuzz vs. Intervention scatter plots.") + + +def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): + """ + Improved histograms. Plots two versions: + 1. All Features (Log Scale) - to show the dead features. + 2. Live Features Only - to show the distribution of the ones that work. + """ + out_dir.mkdir(exist_ok=True, parents=True) + display_name = model_name.split("/")[-1] if "/" in model_name else model_name + + # 1. Live/Dead Split Bar Chart + threshold = 0.01 + df["status"] = df["avg_kl_divergence"].apply( + lambda x: "Decoder-Live" if x > threshold else "Decoder-Dead" + ) + counts = df["status"].value_counts().reset_index() + counts.columns = ["Status", "Count"] + + total = counts["Count"].sum() + live = ( + counts[counts["Status"] == "Decoder-Live"]["Count"].sum() + if "Decoder-Live" in counts["Status"].values + else 0 + ) + pct = (live / total * 100) if total > 0 else 0 + + fig_bar = px.bar( + counts, + x="Status", + y="Count", + color="Status", + text="Count", + title=f"Causal Relevance: {pct:.1f}% Live ({display_name})", + color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"}, + ) + fig_bar.write_image(out_dir / "intervention_live_dead_split.pdf") + + # 2. "Live Features Only" Histogram + live_df = df[df["avg_kl_divergence"] > threshold] + if not live_df.empty: + fig_live = px.histogram( + live_df, + x="avg_kl_divergence", + nbins=20, + title=f"Distribution of LIVE Features Only ({display_name})", + labels={"avg_kl_divergence": "KL Divergence (Causal Effect)"}, + ) + fig_live.update_layout(showlegend=False) + fig_live.write_image(out_dir / "intervention_kl_dist_LIVE_ONLY.pdf") + + # 3. All Features Histogram (Log Scale) + fig_all = px.histogram( + df, + x="avg_kl_divergence", + nbins=50, + title=f"Distribution of All Features ({display_name})", + labels={"avg_kl_divergence": "KL Divergence"}, + log_y=True, # Log scale to handle the massive spike at 0 + ) + fig_all.write_image(out_dir / "intervention_kl_dist_log_scale.pdf") + + +def plot_firing_vs_f1(latent_df, num_tokens, out_dir, run_label): out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): - module_df = module_df.copy() + if "firing_count" not in module_df.columns: + continue + module_df = module_df[module_df["f1_score"].notna()] + if module_df.empty: + continue + module_df["firing_rate"] = module_df["firing_count"] / num_tokens fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) fig.update_layout( @@ -24,48 +154,33 @@ def plot_firing_vs_f1( def import_plotly(): - """Import plotly with mitigiation for MathJax bug.""" try: - import plotly.express as px # type: ignore - import plotly.io as pio # type: ignore + import plotly.express as px + import plotly.io as pio except ImportError: - raise ImportError( - "Plotly is not installed.\n" - "Please install it using `pip install plotly`, " - "or install the `[visualize]` extra." - ) - pio.kaleido.scope.mathjax = None # https://github.com/plotly/plotly.py/issues/3469 + raise ImportError("Install plotly: pip install plotly") + pio.kaleido.scope.mathjax = None return px -def compute_auc(df: pd.DataFrame) -> float | None: - if not df.probability.nunique(): - return None - - valid_df = df[df.probability.notna()] - - return roc_auc_score(valid_df.activating, valid_df.probability) # type: ignore - - -def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): +def plot_accuracy_hist(df, out_dir): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): + if label == "surprisal_intervention": + continue fig = px.histogram( df[df["score_type"] == label], x="accuracy", nbins=100, - title=f"Accuracy distribution: {label}", + title=f"Accuracy: {label}", ) fig.write_image(out_dir / f"{label}_accuracy.pdf") -def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - if not df.probability.nunique(): - return - - # filter out NANs +def plot_roc_curve(df, out_dir): valid_df = df[df.probability.notna()] - + if valid_df.empty or valid_df.activating.nunique() <= 1: + return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) auc = roc_auc_score(valid_df.activating, valid_df.probability) fig = go.Figure( @@ -74,310 +189,188 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path): go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), ] ) - fig.update_layout( - title="ROC Curve", - xaxis_title="FPR", - yaxis_title="TPR", - ) + fig.update_layout(title="ROC Curve", xaxis_title="FPR", yaxis_title="TPR") out_dir.mkdir(exist_ok=True, parents=True) fig.write_image(out_dir / "roc_curve.pdf") -def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: +def compute_confusion(df, threshold=0.5): df_valid = df[df["prediction"].notna()] + if df_valid.empty: + return dict( + true_positives=0, + true_negatives=0, + false_positives=0, + false_negatives=0, + total_positives=0, + total_negatives=0, + ) act = df_valid["activating"].astype(bool) - - total = len(df_valid) - pos = act.sum() - neg = total - pos - - tp = ((df_valid.prediction >= threshold) & act).sum() - tn = ((df_valid.prediction < threshold) & ~act).sum() - fp = ((df_valid.prediction >= threshold) & ~act).sum() - fn = ((df_valid.prediction < threshold) & act).sum() - - assert fp <= neg and tn <= neg and tp <= pos and fn <= pos - + pred = df_valid["prediction"] >= threshold + tp, tn = (pred & act).sum(), (~pred & ~act).sum() + fp, fn = (pred & ~act).sum(), (~pred & act).sum() return dict( true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, - total_examples=total, - total_positives=pos, - total_negatives=neg, - failed_count=len(df_valid) - total, + total_positives=act.sum(), + total_negatives=(~act).sum(), ) -def compute_classification_metrics(conf: dict) -> dict: - tp = conf["true_positives"] - tn = conf["true_negatives"] - fp = conf["false_positives"] - fn = conf["false_negatives"] - total = conf["total_examples"] - pos = conf["total_positives"] - neg = conf["total_negatives"] - - assert pos + neg == total, "pos + neg must equal total" - - # accuracy = (tp + tn) / total if total > 0 else 0 - balanced_accuracy = ( - (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) - ) / 2 - - precision = tp / (tp + fp) if tp + fp > 0 else 0 - recall = tp / pos if pos > 0 else 0 - f1 = ( - 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 - ) - - return dict( - precision=precision, - recall=recall, - f1_score=f1, - accuracy=balanced_accuracy, - true_positive_rate=tp / pos if pos > 0 else 0, - true_negative_rate=tn / neg if neg > 0 else 0, - false_positive_rate=fp / neg if neg > 0 else 0, - false_negative_rate=fn / pos if pos > 0 else 0, - total_examples=total, - total_positives=pos, - total_negatives=neg, - positive_class_ratio=pos / total if total > 0 else 0, - negative_class_ratio=neg / total if total > 0 else 0, +def compute_classification_metrics(conf): + tp, tn, fp, _ = ( + conf["true_positives"], + conf["true_negatives"], + conf["false_positives"], + conf["false_negatives"], ) + pos, neg = conf["total_positives"], conf["total_negatives"] + acc = ((tp / pos if pos else 0) + (tn / neg if neg else 0)) / 2 + prec = tp / (tp + fp) if (tp + fp) else 0 + rec = tp / pos if pos else 0 + f1 = 2 * (prec * rec) / (prec + rec) if (prec + rec) else 0 + return dict(accuracy=acc, precision=prec, recall=rec, f1_score=f1) -def load_data(scores_path: Path, modules: list[str]): - """Load all on-disk data into a single DataFrame.""" - - def parse_score_file(path: Path) -> pd.DataFrame: - """ - Load a score file and return a raw DataFrame - """ - try: - data = orjson.loads(path.read_bytes()) - except orjson.JSONDecodeError: - print(f"Error decoding JSON from {path}. Skipping file.") - return pd.DataFrame() - - latent_idx = int(path.stem.split("latent")[-1]) - - return pd.DataFrame( - [ - { - "text": "".join(ex["str_tokens"]), - "distance": ex["distance"], - "activating": ex["activating"], - "prediction": ex["prediction"], - "probability": ex["probability"], - "correct": ex["correct"], - "activations": ex["activations"], - "latent_idx": latent_idx, - } - for ex in data - ] - ) - - counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" - counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} - if not all(module in counts for module in modules): - print("Missing firing counts for some modules, setting counts to None.") - print(f"Missing modules: {[m for m in modules if m not in counts]}") - counts = None - - # Collect per-latent data - latent_dfs = [] - for score_type_dir in scores_path.iterdir(): - if not score_type_dir.is_dir(): - continue - for module in modules: - for file in score_type_dir.glob(f"*{module}*"): - latent_idx = int(file.stem.split("latent")[-1]) - - latent_df = parse_score_file(file) - latent_df["score_type"] = score_type_dir.name - latent_df["module"] = module - latent_df["latent_idx"] = latent_idx - if counts: - latent_df["firing_count"] = ( - counts[module][latent_idx].item() - if latent_idx in counts[module] - else None - ) - - latent_dfs.append(latent_df) - - return pd.concat(latent_dfs, ignore_index=True), counts +def compute_auc(df): + valid = df[df.probability.notna()] + if valid.probability.nunique() <= 1: + return None + return roc_auc_score(valid.activating, valid.probability) -def frequency_weighted_f1( - df: pd.DataFrame, counts: dict[str, torch.Tensor] -) -> float | None: +def get_agg_metrics(df): rows = [] - for (module, latent_idx), grp in df.groupby(["module", "latent_idx"]): - f1 = compute_classification_metrics(compute_confusion(grp))["f1_score"] - fire = counts[module][latent_idx].item() + for scorer, group in df.groupby("score_type"): + if scorer == "surprisal_intervention": + continue + conf = compute_confusion(group) rows.append( { - "module": module, - "latent_idx": latent_idx, - "f1_score": f1, - "firing_count": fire, + "score_type": scorer, + **conf, + **compute_classification_metrics(conf), + "auc": compute_auc(group), } ) + return pd.DataFrame(rows) - latent_df = pd.DataFrame(rows) - per_module_f1 = [] - for module in latent_df["module"].unique(): - module_df = latent_df[latent_df["module"] == module] - - firing_weights = counts[module][module_df["latent_idx"]].float() - total_weight = firing_weights.sum() - if total_weight == 0: - continue - - f1_tensor = torch.as_tensor(module_df["f1_score"].values, dtype=torch.float32) - module_f1 = (f1_tensor * firing_weights).sum() / firing_weights.sum() - per_module_f1.append(module_f1) - - overall_frequency_weighted_f1 = torch.stack(per_module_f1).mean() - return ( - overall_frequency_weighted_f1.item() - if not overall_frequency_weighted_f1.isnan() - else None +def add_latent_f1(df): + f1s = ( + df.groupby(["module", "latent_idx"]) + .apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ) + .reset_index(name="f1_score") ) + return df.merge(f1s, on=["module", "latent_idx"]) -def get_agg_metrics( - latent_df: pd.DataFrame, counts: Optional[dict[str, torch.Tensor]] -) -> pd.DataFrame: - processed_rows = [] - for score_type, group_df in latent_df.groupby("score_type"): - conf = compute_confusion(group_df) - class_m = compute_classification_metrics(conf) - auc = compute_auc(group_df) - f1_w = frequency_weighted_f1(group_df, counts) if counts else None - - row = { - "score_type": score_type, - **conf, - **class_m, - "auc": auc, - "weighted_f1": f1_w, - } - processed_rows.append(row) +def load_data(scores_path, modules): + def parse_file(path): + try: + data = orjson.loads(path.read_bytes()) + if not isinstance(data, list): + return pd.DataFrame() + latent_idx = int(path.stem.split("latent")[-1]) + return pd.DataFrame( + [ + { + "text": "".join(ex.get("str_tokens", [])), + "activating": ex.get("activating"), + "prediction": ex.get("prediction"), + "probability": ex.get("probability"), + "final_score": ex.get("final_score"), + "avg_kl_divergence": ex.get("avg_kl_divergence"), + "latent_idx": latent_idx, + } + for ex in data + ] + ) + except Exception: + return pd.DataFrame() - return pd.DataFrame(processed_rows) + counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" + counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} + dfs = [] + for scorer_dir in scores_path.iterdir(): + if not scorer_dir.is_dir(): + continue + for module in modules: + for f in scorer_dir.glob(f"*{module}*"): + df = parse_file(f) + if df.empty: + continue + df["score_type"] = scorer_dir.name + df["module"] = module + if module in counts: + idx = df["latent_idx"].iloc[0] + if idx < len(counts[module]): + df["firing_count"] = counts[module][idx].item() + dfs.append(df) -def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: - f1s = ( - latent_df.groupby(["module", "latent_idx"]) - .apply( - lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] - ) - .reset_index(name="f1_score") # <- naive (un-weighted) F1 - ) - return latent_df.merge(f1s, on=["module", "latent_idx"]) + return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts def log_results( - scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str] + scores_path: Path, + viz_path: Path, + modules: list[str], + scorer_names: list[str], + model_name: str = "Unknown", ): import_plotly() latent_df, counts = load_data(scores_path, modules) - latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - latent_df = add_latent_f1(latent_df) - - plot_firing_vs_f1( - latent_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name - ) - if latent_df.empty: - print("No data found") + print("No data found.") return - dead = sum((counts[m] == 0).sum().item() for m in modules) - print(f"Number of dead features: {dead}") - print(f"Number of interpreted live features: {len(latent_df)}") + print(f"Generating report for: {latent_df['score_type'].unique()}") - # Load constructor config for run - with open(scores_path.parent / "run_config.json", "r") as f: - run_cfg = orjson.loads(f.read()) - constructor_cfg = run_cfg.get("constructor_cfg", {}) - min_examples = constructor_cfg.get("min_examples", None) - print("min examples", min_examples) + # Split Data + class_mask = latent_df["score_type"] != "surprisal_intervention" + class_df = latent_df[class_mask] + int_df = latent_df[~class_mask] - if min_examples is not None: - uninterpretable_features = sum( - [(counts[m] < min_examples).sum() for m in modules] - ) - print( - f"Number of features below the interpretation firing" - f" count threshold: {uninterpretable_features}" - ) + # 1. Handle Classification (Fuzz/Detection) + if not class_df.empty: + class_df = add_latent_f1(class_df) + if counts: + plot_firing_vs_f1(class_df, 10_000_000, viz_path, scores_path.name) + plot_roc_curve(class_df, viz_path) - plot_roc_curve(latent_df, viz_path) + agg_df = get_agg_metrics(class_df) + plot_accuracy_hist(agg_df, viz_path) - processed_df = get_agg_metrics(latent_df, counts) + for _, row in agg_df.iterrows(): + print(f"\n[ {row['score_type'].title()} ]") + print(f"Accuracy: {row['accuracy']:.3f}") + print(f"F1 Score: {row['f1_score']:.3f}") - plot_accuracy_hist(processed_df, viz_path) + # 2. Handle Intervention + if not int_df.empty: + unique_latents = int_df.drop_duplicates(subset=["module", "latent_idx"]).copy() - for score_type in processed_df.score_type.unique(): - score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] - print(f"\n--- {score_type.title()} Metrics ---") - print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") - print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") - print( - "Note: the frequency-weighted F1 score is computed over each" - " hookpoint and averaged" - ) - print(f"Precision: {score_type_summary['precision']:.3f}") - print(f"Recall: {score_type_summary['recall']:.3f}") - # Only print AUC if unbalanced AUC is not -1. - if score_type_summary["auc"] is not None: - print(f"AUC: {score_type_summary['auc']:.3f}") - else: - print("Logits not available.") - - fractions_failed = [ - score_type_summary["failed_count"] - / ( - ( - score_type_summary["total_examples"] - + score_type_summary["failed_count"] - ) - ) - ] - print( - f"""Average fraction of failed examples: \ -{sum(fractions_failed) / len(fractions_failed)}""" - ) + avg_score = unique_latents["final_score"].mean() + avg_kl = unique_latents["avg_kl_divergence"].mean() - print("\nConfusion Matrix:") - print( - f"True Positive Rate: {score_type_summary['true_positive_rate']:.3f} " - f"({score_type_summary['true_positives'].sum()})" - ) - print( - f"True Negative Rate: {score_type_summary['true_negative_rate']:.3f} " - f"({score_type_summary['true_negatives'].sum()})" - ) - print( - f"False Positive Rate: {score_type_summary['false_positive_rate']:.3f} " - f"({score_type_summary['false_positives'].sum()})" - ) - print( - f"False Negative Rate: {score_type_summary['false_negative_rate']:.3f} " - f"({score_type_summary['false_negatives'].sum()})" - ) + threshold = 0.01 + n_total = len(unique_latents) + n_live = len(unique_latents[unique_latents["avg_kl_divergence"] > threshold]) + pct = (n_live / n_total * 100) if n_total > 0 else 0 + + print("\n--- Surprisal Intervention Analysis ---") + print(f"Avg Normalized Score: {avg_score:.3f}") + print(f"Avg KL Divergence: {avg_kl:.3f}") + print(f"Decoder-Live %: {pct:.2f}%") + + plot_intervention_stats(unique_latents, viz_path, model_name) - print("\nClass Distribution:") - print(f"""Positives: {score_type_summary['total_positives'].sum():.0f}""") - print(f"""Negatives: {score_type_summary['total_negatives'].sum():.0f}""") - print(f"Total: {score_type_summary['total_examples'].sum():.0f}") + # 3. Generate Scatter Plot (Fuzz vs. Intervention) + if not class_df.empty and not int_df.empty: + plot_fuzz_vs_intervention(latent_df, viz_path, scores_path.name) diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 782430f1..4e520493 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -3,6 +3,7 @@ from .classifier.intruder import IntruderScorer from .embedding.embedding import EmbeddingScorer from .embedding.example_embedding import ExampleEmbeddingScorer +from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer from .scorer import Scorer from .simulator.simulation.oai_simulator import ( RefactoredOpenAISimulator as OpenAISimulator, @@ -18,4 +19,5 @@ "EmbeddingScorer", "IntruderScorer", "ExampleEmbeddingScorer", + "SurprisalInterventionScorer", ] diff --git a/delphi/scorers/intervention/__init__.py b/delphi/scorers/intervention/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py new file mode 100644 index 00000000..f2ae6179 --- /dev/null +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -0,0 +1,665 @@ +import copy +import functools +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer + +from ...latents import LatentRecord +from ..scorer import Scorer, ScorerResult + + +@dataclass +class SurprisalInterventionResult: + """ + Detailed results from the SurprisalInterventionScorer. + + Attributes: + score: The final computed score. + avg_kl: The average KL divergence between clean & intervened + next-token distributions. + explanation: The explanation string that was scored. + """ + + score: float + avg_kl: float + explanation: str + tuned_strength: float + + +class SurprisalInterventionScorer(Scorer): + """ + Implements the Surprisal / Log-Probability Intervention Scorer. + + This scorer evaluates an explanation for a model's latent feature by measuring + how much an intervention in the feature's direction increases the model's belief + (log-probability) in the explanation. The change in log-probability is normalized + by the intervention's strength, measured by the KL divergence between the clean + and intervened next-token distributions. + + Reference: Paulo et al., "Automatically Interpreting Millions of Features in LLMs" + (https://arxiv.org/pdf/2410.13928), Section 3.3.5[cite: 206, 207]. + + Pipeline: + 1. For a small set of activating prompts: + a. Generate a continuation and get the next-token distribution ("clean"). + b. Add directional vector for the feature to the activations ("intervened"). + 2. Compute the log-probability of the explanation conditioned on both the clean + and intervened generated texts: log P(explanation | text). + 3. Compute KL divergence between the clean & intervened next-token distributions. + 4. The final score is the mean change in explanation log-prob, divided by the + mean KL divergence: + score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε). + """ + + name = "surprisal_intervention" + + def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): + """ + Args: + subject_model: The language model to generate from and score with. + explainer_model: A model (e.g., an SAE) used to get feature directions. + **kwargs: Configuration options. + strength (float): The magnitude of the intervention. Default: 5.0. + num_prompts (int): Number of activating examples to test. Default: 3. + max_new_tokens (int): Max tokens to generate for continuations. + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') + for the intervention. + """ + self.subject_model = subject_model + self.explainer_model = explainer_model + self.strength = float(kwargs.get("strength", 5.0)) + self.num_prompts = int(kwargs.get("num_prompts", 3)) + self.max_new_tokens = int(kwargs.get("max_new_tokens", 8)) + self.hookpoints = kwargs.get("hookpoints") + + self.target_kl = float(kwargs.get("target_kl", 1.0)) + self.kl_tolerance = float(kwargs.get("kl_tolerance", 0.1)) + self.max_search_steps = int(kwargs.get("max_search_steps", 15)) + + if len(self.hookpoints): + self.hookpoint_str = self.hookpoints[0] + + if hasattr(subject_model, "tokenizer"): + self.tokenizer = subject_model.tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained("gpt2") + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id + + def _get_device(self) -> torch.device: + """Safely gets the device of the subject model.""" + try: + return next(self.subject_model.parameters()).device + except StopIteration: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def _find_layer(self, model: Any, name: str) -> torch.nn.Module: + """Resolves a module by its dotted path name.""" + if name is None: + raise ValueError("Hookpoint name is not configured.") + current = model + for part in name.split("."): + if part.isdigit(): + current = current[int(part)] + else: + current = getattr(current, part) + return current + + def _get_full_hookpoint_path(self, hookpoint_str: str) -> str: + """ + Heuristically finds the model's prefix and constructs the full hookpoint + path string. + e.g., 'layers.6.mlp' -> 'model.layers.6.mlp' + """ + # Heuristically find the model prefix. + prefix = None + for p in ["gpt_neox", "transformer", "model"]: + if hasattr(self.subject_model, p): + candidate_body = getattr(self.subject_model, p) + if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"): + prefix = p + break + + return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str + + def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: + """ + Finds and returns the actual module object for a given hookpoint string. + """ + full_path = self._get_full_hookpoint_path(hookpoint_str) + try: + return self._find_layer(model, full_path) + except AttributeError as e: + raise AttributeError( + f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. + Original error: {e}""" + ) + + def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + """ + Function used for formatting results to run smoothly in the delphi pipeline + """ + sanitized = [] + for ex in examples: + if hasattr(ex, "str_tokens") and ex.str_tokens is not None: + sanitized.append({"str_tokens": ex.str_tokens}) + + elif isinstance(ex, dict) and "str_tokens" in ex: + sanitized.append(ex) + + elif isinstance(ex, str): + sanitized.append({"str_tokens": [ex]}) + + elif isinstance(ex, (list, tuple)): + sanitized.append({"str_tokens": [str(t) for t in ex]}) + + else: + sanitized.append({"str_tokens": [str(ex)]}) + + return sanitized + + def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor: + """ + Calculates the feature's decoder vector, subtracting the decoder bias. + """ + + d_latent = sae.encoder.out_features + sae_device = sae.encoder.weight.device + + # Create a one-hot activation for our single feature. + one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device) + + if feature_id >= d_latent: + print( + f"""DEBUG: ERROR - Feature ID {feature_id} is out of bounds + for d_latent {d_latent}""" + ) + return torch.zeros(1) + + one_hot_activation[0, 0, feature_id] = 1.0 + + # Create the corresponding indices needed for the decode method. + indices = torch.tensor([[[feature_id]]], device=sae_device, dtype=torch.long) + + with torch.no_grad(): + try: + decoded_zero = sae.decode(torch.zeros_like(one_hot_activation), indices) + vector_before_sub = sae.decode(one_hot_activation, indices) + except Exception as e: + print(f"DEBUG: ERROR during sae.decode: {e}") + return torch.zeros(1) + + decoder_vector = vector_before_sub - decoded_zero + + final_norm = decoder_vector.norm().item() + + # --- MODIFIED DEBUG BLOCK --- + # Only print if the feature is "decoder-live" + if final_norm > 1e-6: + print(f"\n--- DEBUG: 'Decoder-Live' Feature Found: {feature_id} ---") + print(f"DEBUG: sae.encoder.out_features (d_latent): {d_latent}") + print(f"DEBUG: sae.encoder.weight.device (sae_device): {sae_device}") + print(f"DEBUG: Norm of decoded_zero: {decoded_zero.norm().item()}") + print( + f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}" + ) + print(f"DEBUG: Feature {feature_id}, FINAL Vector Norm: {final_norm}") + print("--- END DEBUG ---\n") + # --- END MODIFIED BLOCK --- + + return decoder_vector.squeeze() + + async def __call__(self, record: LatentRecord) -> ScorerResult: + + record_copy = copy.deepcopy(record) + + raw_examples = getattr(record_copy, "test", []) or [] + + if not raw_examples: + result = SurprisalInterventionResult( + score=0.0, avg_kl=0.0, explanation=record_copy.explanation + ) + return ScorerResult(record=record, score=[result.__dict__]) + + examples = self._sanitize_examples(raw_examples) + + prompts = ["".join(ex["str_tokens"]) for ex in examples[: self.num_prompts]] + + # Step 1 - Truncate prompts before tuning or scoring. + truncated_prompts = [ + await self._truncate_prompt(p, record_copy) for p in prompts + ] + + # Step 2 - Tune intervention strength to match target KL. + hookpoint_str = self.hookpoint_str or getattr(record_copy, "hookpoint", None) + sae = self._get_sae_for_hookpoint(hookpoint_str, record_copy) + if not sae: + raise ValueError(f"Could not find SAE for hookpoint {hookpoint_str}") + + intervention_vector = self._get_intervention_vector(sae, record_copy.feature_id) + + tuned_strength, initial_kl = await self._tune_strength( + truncated_prompts, record_copy, intervention_vector + ) + + total_diff = 0.0 + total_kl = 0.0 + n = 0 + + for prompt in truncated_prompts: + clean_text, clean_logp_dist = await self._generate_with_intervention( + prompt, + record_copy, + strength=0.0, + intervention_vector=intervention_vector, + get_logp_dist=True, + ) + int_text, int_logp_dist = await self._generate_with_intervention( + prompt, + record_copy, + strength=tuned_strength, + intervention_vector=intervention_vector, + get_logp_dist=True, + ) + + logp_clean = await self._score_explanation( + clean_text, record_copy.explanation + ) + logp_int = await self._score_explanation(int_text, record_copy.explanation) + + p_clean = torch.exp(clean_logp_dist) + kl_div = F.kl_div( + int_logp_dist, p_clean, reduction="sum", log_target=False + ).item() + + total_diff += logp_int - logp_clean + total_kl += kl_div + n += 1 + + avg_diff = total_diff / n if n > 0 else 0.0 + avg_kl = total_kl / n if n > 0 else 0.0 + + # Final score is the average difference, not normalized by KL. + final_score = avg_diff + + final_output_list = [] + for i, ex in enumerate(examples[: self.num_prompts]): + final_output_list.append( + { + "str_tokens": ex["str_tokens"], + "truncated_prompt": truncated_prompts[i], + "final_score": final_score, + "avg_kl_divergence": avg_kl, + "tuned_strength": tuned_strength, + "target_kl": self.target_kl, + # Placeholder keys + "distance": None, + "activating": None, + "prediction": None, + "correct": None, + "probability": None, + "activations": None, + } + ) + return ScorerResult(record=record_copy, score=final_output_list) + + async def _get_latent_activations( + self, prompt: str, record: LatentRecord + ) -> torch.Tensor: + """ + Runs a forward pass to get the SAE's latent activations for a prompt. + """ + device = self._get_device() + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + sae = self._get_sae_for_hookpoint(hookpoint_str, record) + if not sae: + return torch.empty(0) # Return empty tensor if no SAE to encode with + + captured_hidden_states = [] + + def capture_hook(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + captured_hidden_states.append(hidden_states.detach().cpu()) + + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + handle = layer_to_hook.register_forward_hook(capture_hook) + + try: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) + with torch.no_grad(): + self.subject_model(input_ids) + finally: + handle.remove() + + if not captured_hidden_states: + return torch.empty(0) + + hidden_states = captured_hidden_states[0].to(device) + + encoding_result = sae.encode(hidden_states) + feature_acts = encoding_result[2] + + return feature_acts[0, :, record.feature_id].cpu() + + async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: + """ + Truncates prompt to end just before the first token where latent activates. + """ + activations = await self._get_latent_activations(prompt, record) + if activations.numel() == 0: + return prompt # Cannot truncate if no activations found + + # Find the index of the first token with non-zero activation + # Get ALL non-zero indices first + all_activation_indices = (activations > 1e-6).nonzero(as_tuple=True)[0] + + # Filter out activations at position 0 (BOS) + first_activation_idx = all_activation_indices[all_activation_indices > 0] + + if first_activation_idx.numel() > 0: + truncation_point = first_activation_idx[0].item() + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0] + truncated_ids = input_ids[: truncation_point + 1] + return self.tokenizer.decode(truncated_ids, skip_special_tokens=True) + + return prompt + + async def _tune_strength( + self, + prompts: List[str], + record: LatentRecord, + intervention_vector: torch.Tensor, + ) -> Tuple[float, float]: + """ + Performs a binary search to find intervention strength that matches target_kl. + """ + low_strength, high_strength = 0.0, 40.0 # Heuristic search range + best_strength = self.target_kl # Default to target_kl if search fails + + for _ in range(self.max_search_steps): + mid_strength = (low_strength + high_strength) / 2 + + # Estimate KL at mid_strength + total_kl = 0.0 + n = 0 + for prompt in prompts: + _, clean_logp = await self._generate_with_intervention( + prompt, record, 0.0, intervention_vector, True + ) + _, int_logp = await self._generate_with_intervention( + prompt, record, mid_strength, intervention_vector, True + ) + + p_clean = torch.exp(clean_logp) + kl_div = F.kl_div( + int_logp, p_clean, reduction="sum", log_target=False + ).item() + total_kl += kl_div + n += 1 + + current_kl = total_kl / n if n > 0 else 0.0 + + if abs(current_kl - self.target_kl) < self.kl_tolerance: + return mid_strength, current_kl + + if current_kl < self.target_kl: + low_strength = mid_strength + else: + high_strength = mid_strength + + best_strength = mid_strength + + # Return the best found strength and the corresponding KL + final_kl = await self._calculate_avg_kl( + prompts, record, best_strength, intervention_vector + ) + return best_strength, final_kl + + async def _calculate_avg_kl( + self, + prompts: List[str], + record: LatentRecord, + strength: float, + intervention_vector: torch.Tensor, + ) -> float: + total_kl = 0.0 + n = 0 + for prompt in prompts: + _, clean_logp = await self._generate_with_intervention( + prompt, record, 0.0, intervention_vector, True + ) + _, int_logp = await self._generate_with_intervention( + prompt, record, strength, intervention_vector, True + ) + p_clean = torch.exp(clean_logp) + kl_div = F.kl_div( + int_logp, p_clean, reduction="sum", log_target=False + ).item() + total_kl += kl_div + n += 1 + return total_kl / n if n > 0 else 0.0 + + async def _generate_with_intervention( + self, + prompt: str, + record: LatentRecord, + strength: float, + intervention_vector: torch.Tensor, + get_logp_dist: bool = False, + ) -> Tuple[str, torch.Tensor]: + device = self._get_device() + enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) + input_ids = enc["input_ids"].to(device) + attention_mask = enc["attention_mask"].to(device) + + prompt_length = input_ids.shape[1] + delta = strength * intervention_vector + + hooks = [] + if strength > 0: + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + if hookpoint_str is None: + raise ValueError("No hookpoint string specified for intervention.") + + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + sae = self._get_sae_for_hookpoint(hookpoint_str, record) + if not sae: + raise ValueError( + f"Couldn't find a valid SAE for hookpoint {hookpoint_str}" + ) + + def hook_fn(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + original_dtype = hidden_states.dtype + + current_seq_len = hidden_states.shape[1] + new_hiddens = hidden_states.detach().clone() + + intervention_start_index = prompt_length - 1 + + if current_seq_len >= prompt_length: + new_hiddens[:, intervention_start_index:, :] += delta.to( + original_dtype + ) + + return ( + (new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens + ) + + hooks.append(layer_to_hook.register_forward_hook(hook_fn)) + + try: + with torch.no_grad(): + outputs = self.subject_model(input_ids, attention_mask=attention_mask) + next_token_logits = outputs.logits[0, -1, :] + log_probs_next_token = ( + F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None + ) + + gen_ids = self.subject_model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id, + ) + generated_text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) + finally: + for h in hooks: + h.remove() + + return generated_text, ( + log_probs_next_token.cpu() if get_logp_dist else torch.empty(0) + ) + + async def _score_explanation(self, generated_text: str, explanation: str) -> float: + """ + Computes log P(explanation | generated_text) using the paper's + prompt format. + """ + device = self._get_device() + + # Build the prompt from Appendix G.1 + prompt_template = ( + "\n" + f"{generated_text}\n" + 'The above passage contains an amplified amount of "' + ) + explanation_suffix = f'{explanation}"' + + # Tokenize the parts + context_enc = self.tokenizer(prompt_template, return_tensors="pt") + explanation_enc = self.tokenizer(explanation_suffix, return_tensors="pt") + + full_input_ids = torch.cat( + [context_enc.input_ids, explanation_enc.input_ids], dim=1 + ).to(device) + + with torch.no_grad(): + outputs = self.subject_model(full_input_ids) + logits = outputs.logits + + # We only need to score the explanation part + context_len = context_enc.input_ids.shape[1] + + # Get logits for positions that predict the explanation tokens + # Shape: [batch_size, explanation_len, vocab_size] + explanation_logits = logits[:, context_len - 1 : -1, :] + + # Get the target token IDs for the explanation + # Shape: [batch_size, explanation_len] + target_ids = explanation_enc.input_ids.to(device) + + log_probs = F.log_softmax(explanation_logits, dim=-1) + + # Gather the log-probabilities of the actual explanation tokens + token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) + + # Return the sum of log-probs for the explanation + return token_log_probs.sum().item() + + def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any: + """ + Retrieves the correct SAE model, handling the specific functools.partial + wrapper provided by the Delphi framework. + """ + candidate = None + + if hasattr(record, "sae") and record.sae: + candidate = record.sae + elif self.explainer_model and isinstance(self.explainer_model, dict): + full_key = self._get_full_hookpoint_path(hookpoint_str) + short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp" + + for key in [hookpoint_str, full_key, short_key]: + if self.explainer_model.get(key) is not None: + candidate = self.explainer_model.get(key) + break + + if candidate is None: + # This will raise an error if the key isn't found + raise ValueError( + f"ERROR: Surprisal scorer could not find an SAE " + f"for hookpoint '{hookpoint_str}' in self.explainer_model" + ) + + if isinstance(candidate, functools.partial): + # As shown in load_sparsify.py, the SAE is in the 'sae' keyword. + if candidate.keywords and "sae" in candidate.keywords: + return candidate.keywords["sae"] # Unwrapped successfully + else: + # This will raise an error if the partial is missing the keyword + raise ValueError( + f"""ERROR: Found a partial for + {hookpoint_str} but could not + find the 'sae' keyword. + func: {candidate.func} + args: {candidate.args} + keywords: {candidate.keywords}""" + ) + + # This will raise an error if the candidate isn't a partial + raise ValueError( + f"""ERROR: Candidate for {hookpoint_str} was not a partial + object, which was not expected. Type: {type(candidate)}""" + ) + + def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + + sae = self._get_sae_for_hookpoint(hookpoint_str, record) + + if sae and hasattr(sae, "get_feature_vector"): + direction = sae.get_feature_vector(record.feature_id) + if not isinstance(direction, torch.Tensor): + direction = torch.tensor(direction, dtype=torch.float32) + direction = direction.squeeze() + return F.normalize(direction, p=2, dim=0) + + return self._estimate_direction_from_examples(record) + + def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: + """Estimates an intervention direction by averaging activations.""" + device = self._get_device() + + examples = self._sanitize_examples(getattr(record, "test", []) or []) + if not examples: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + captured_activations = [] + + def capture_hook(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + + captured_activations.append(hidden_states[:, -1, :].detach().cpu()) + + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + handle = layer_to_hook.register_forward_hook(capture_hook) + + try: + for ex in examples[: min(8, self.num_prompts)]: + prompt = "".join(ex["str_tokens"]) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( + device + ) + with torch.no_grad(): + self.subject_model(input_ids) + finally: + handle.remove() + + if not captured_activations: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + activations = torch.cat(captured_activations, dim=0).to(device) + direction = activations.mean(dim=0) + + return F.normalize(direction, p=2, dim=0) diff --git a/delphi/temp.py b/delphi/temp.py new file mode 100644 index 00000000..19c29522 --- /dev/null +++ b/delphi/temp.py @@ -0,0 +1,14 @@ +# Create a file named run_analysis.py with these contents +from pathlib import Path + +from delphi.log.result_analysis import log_results + +# Adjust the path to your results folder +scores_path = Path("results/pythia_100_test/scores") +viz_path = Path("results/pythia_100_test/visualize") +modules = ["layers.6.mlp"] +scorer_names = ["fuzz", "detection", "surprisal_intervention"] + +log_results( + scores_path, viz_path, modules, scorer_names, model_name="EleutherAI/pythia-160m" +)