From 304b8b232d0d0181c058cd8b6d9088d5e3d562aa Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Thu, 30 Apr 2026 11:37:52 -0700 Subject: [PATCH] Speed up analyze_traps_bundle by reusing first-pass artifacts --- tests/test_trap_bundles.py | 78 ++++++++++++ weightwatcher/__init__.py | 3 +- weightwatcher/trap_analysis.py | 22 ++++ weightwatcher/trap_bundles.py | 222 +++++++++++++++++++++++++++++++++ 4 files changed, 323 insertions(+), 2 deletions(-) create mode 100644 tests/test_trap_bundles.py create mode 100644 weightwatcher/trap_bundles.py diff --git a/tests/test_trap_bundles.py b/tests/test_trap_bundles.py new file mode 100644 index 0000000..50e29ea --- /dev/null +++ b/tests/test_trap_bundles.py @@ -0,0 +1,78 @@ +import numpy as np +import pandas as pd +import pytest + +import weightwatcher as ww +from weightwatcher.trap_bundles import analyze_traps_bundle, load_trap_bundle, remove_single_trap_from_bundle + +try: + import torch + import torch.nn as nn +except Exception: + torch = None + nn = None + + +pytestmark = pytest.mark.skipif(torch is None, reason="torch required") + + +if torch is not None: + class TinyTrapNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(16, 12, bias=False) + with torch.no_grad(): + u = torch.linspace(1.0, 2.0, steps=12) + v = torch.linspace(-2.0, 1.0, steps=16) + self.fc1.weight.copy_(35.0 * torch.outer(u, v)) + + +def test_analyze_traps_bundle_roundtrip(tmp_path): + model = TinyTrapNet() + watcher = ww.WeightWatcher(model=model) + trap_df, bundles = analyze_traps_bundle(watcher, save_bundle=True, bundle_dir=str(tmp_path), checkpoint_id="1") + assert isinstance(trap_df, pd.DataFrame) + assert len(bundles) >= 1 + if len(trap_df) == 0: + pytest.skip("no traps") + row = trap_df.iloc[0] + b = bundles[int(row.layer_id)] + assert row.permute_fingerprint == b.permute_fingerprint + assert "bundle_path" in trap_df.columns + loaded = load_trap_bundle(row.bundle_path) + assert loaded.permute_fingerprint == b.permute_fingerprint + + +def test_remove_single_trap_from_bundle_independent_of_seed(): + model = TinyTrapNet() + watcher = ww.WeightWatcher(model=model) + trap_df, bundles = analyze_traps_bundle(watcher, seed=111) + if len(trap_df) == 0: + pytest.skip("no traps") + row = trap_df.iloc[0] + b = bundles[int(row.layer_id)] + + m1, meta1 = remove_single_trap_from_bundle(model, b, row) + # changing seed in fresh analysis should not affect prior bundle-based ablation + _df2, _bundles2 = analyze_traps_bundle(ww.WeightWatcher(model=TinyTrapNet()), seed=999) + m2, meta2 = remove_single_trap_from_bundle(model, b, row) + assert meta1["permute_fingerprint"] == meta2["permute_fingerprint"] + + +def test_remove_single_trap_mismatch_raises(): + model = TinyTrapNet() + trap_df, bundles = analyze_traps_bundle(ww.WeightWatcher(model=model), seed=123) + if len(trap_df) == 0: + pytest.skip("no traps") + row = trap_df.iloc[0].copy() + b = bundles[int(row.layer_id)] + row["permute_fingerprint"] = "bad" + with pytest.raises(ValueError, match="permute_fingerprint mismatch"): + remove_single_trap_from_bundle(model, b, row) + + +def test_legacy_api_still_works(): + model = TinyTrapNet() + watcher = ww.WeightWatcher(model=model) + df = watcher.analyze_traps(plot=False, savefig=False, rng=1337) + assert isinstance(df, pd.DataFrame) diff --git a/weightwatcher/__init__.py b/weightwatcher/__init__.py index adf20c3..a4f9739 100644 --- a/weightwatcher/__init__.py +++ b/weightwatcher/__init__.py @@ -19,7 +19,7 @@ __name__ = "weightwatcher" -__version__ = "0.8.3" +__version__ = "0.8.4" __license__ = "Apache License, Version 2.0" __description__ = "Diagnostic Tool for Deep Neural Networks" @@ -30,4 +30,3 @@ __all__ = ["__name__", "__version__", "__license__", "__description__", "__url__", "__author__", "__email__", "__copyright__"] - diff --git a/weightwatcher/trap_analysis.py b/weightwatcher/trap_analysis.py index 8a6e5f4..d7fbd5f 100644 --- a/weightwatcher/trap_analysis.py +++ b/weightwatcher/trap_analysis.py @@ -32,6 +32,7 @@ def analyze_traps( trap_burden=False, trap_burden_variant="top5", top_sector_l=1, + bundle_collector=None, ): """Externalized implementation for WeightWatcher.analyze_traps().""" if layers is None: @@ -136,6 +137,27 @@ def analyze_traps( for row in layer_rows: row.pop("T_orig", None) trap_rows.extend(layer_rows) + if bundle_collector is not None and len(layer_rows) > 0: + try: + from .RMT_Util import svd_full + W_perm = np.asarray(ww_layer.Wmats[0]).copy() + U_perm, S_perm, Vh_perm = svd_full(W_perm) + permute_ids = np.asarray(ww_layer.permute_ids[0]).copy() if len(ww_layer.permute_ids) > 0 else None + bundle_collector[int(ww_layer.layer_id)] = { + "layer_id": int(ww_layer.layer_id), + "name": str(getattr(ww_layer, "name", "")), + "longname": str(getattr(ww_layer, "longname", "")), + "W_orig": np.asarray(getattr(ww_layer, "W_orig", ww_layer.Wmats[0])).copy(), + "W_perm": W_perm, + "permute_ids": permute_ids, + "U_perm": U_perm, + "S_perm": S_perm, + "Vh_perm": Vh_perm, + "mp_bulk_max": float(layer_rows[0].get("mp_bulk_max", np.nan)), + "rows": [dict(r) for r in layer_rows], + } + except Exception: + wwcore.logger.exception("Failed to collect trap bundle artifacts for layer %s", ww_layer.layer_id) if len(trap_rows) > 0: details = pd.DataFrame.from_records(trap_rows) diff --git a/weightwatcher/trap_bundles.py b/weightwatcher/trap_bundles.py new file mode 100644 index 0000000..9b30d8f --- /dev/null +++ b/weightwatcher/trap_bundles.py @@ -0,0 +1,222 @@ +import copy +import hashlib +import logging +import os +import pickle +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from . import remove_traps as remove_traps_ops +from . import trap_analysis as trap_analysis_ops +from .RMT_Util import unpermute_matrix +from .constants import DEFAULT_PARAMS, DEFAULT_START_ID, FAST_SVD, LAYERS, PEFT, PLOT, POOL, SVD_METHOD, WW_NAME +from .weightwatcher import WeightWatcher + +logger = logging.getLogger(WW_NAME) + + +@dataclass +class TrapAnalysisBundle: + checkpoint_id: Optional[str] + layer_id: int + layer_name: str + layer_longname: str + W_orig: np.ndarray + W_perm: np.ndarray + permute_ids: np.ndarray + permute_mode: str + seed: Optional[int] + rng_state: Optional[dict] + permute_fingerprint: str + U_perm: np.ndarray + S_perm: np.ndarray + Vh_perm: np.ndarray + trap_metrics: pd.DataFrame + trap_mode_map: Dict[int, int] + mp_bulk_edge: Dict[str, float] = field(default_factory=dict) + bundle_path: Optional[str] = None + + +def _fingerprint_perm(permute_ids: np.ndarray) -> str: + return hashlib.sha1(np.asarray(permute_ids).tobytes()).hexdigest() + + +def _bundle_filename(bundle: TrapAnalysisBundle) -> str: + ck = str(bundle.checkpoint_id or "na") + seed = "none" if bundle.seed is None else str(bundle.seed) + return f"trap_bundle_step_{ck}_layer_{bundle.layer_id}_{bundle.permute_fingerprint}_seed_{seed}.pkl" + + +def save_trap_bundle(bundle: TrapAnalysisBundle, bundle_dir: str) -> str: + os.makedirs(bundle_dir, exist_ok=True) + path = os.path.join(bundle_dir, _bundle_filename(bundle)) + with open(path, "wb") as f: + pickle.dump(bundle, f) + bundle.bundle_path = path + return path + + +def load_trap_bundle(path: str) -> TrapAnalysisBundle: + with open(path, "rb") as f: + return pickle.load(f) + + +def analyze_traps_bundle(model_or_watcher, layers=None, save_bundle=False, bundle_dir=None, return_bundle=True, checkpoint_id=None, seed=None, rng=None, plot=False, pool=True): + watcher = model_or_watcher if isinstance(model_or_watcher, WeightWatcher) else WeightWatcher(model=model_or_watcher) + model = watcher.model if isinstance(model_or_watcher, WeightWatcher) else model_or_watcher + params = DEFAULT_PARAMS.copy() + params[POOL] = pool + params[PLOT] = plot + params[LAYERS] = [] if layers is None else layers + params[SVD_METHOD] = FAST_SVD + params[PEFT] = params.get(PEFT) + params["seed"] = seed + params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng, seed=seed) + params = watcher.normalize_params(params) + + bundle_collector = {} + trap_df = trap_analysis_ops.analyze_traps( + watcher, + model=model, + layers=layers or [], + plot=plot, + pool=pool, + rng=rng if rng is not None else seed, + bundle_collector=bundle_collector, + ) + bundles = {} + rows = [] + for layer_id, artifacts in bundle_collector.items(): + layer_traps = trap_df[trap_df["layer_id"].astype(int) == int(layer_id)].copy() + if len(layer_traps) == 0: + continue + W_perm = np.asarray(artifacts["W_perm"]).copy() + permute_ids = np.asarray(artifacts["permute_ids"]).copy() + fp = _fingerprint_perm(permute_ids) + U, S, Vh = np.asarray(artifacts["U_perm"]), np.asarray(artifacts["S_perm"]), np.asarray(artifacts["Vh_perm"]) + + layer_traps["permute_fingerprint"] = fp + layer_traps["bundle_id"] = f"{checkpoint_id}:{layer_id}:{fp}" + trap_mode_map = {int(r.trap_index): int(r.trap_mode_index) for r in layer_traps.itertuples()} + bundle = TrapAnalysisBundle( + checkpoint_id=str(checkpoint_id) if checkpoint_id is not None else None, + layer_id=int(layer_id), + layer_name=str(artifacts.get("name", "")), + layer_longname=str(artifacts.get("longname", "")), + W_orig=np.asarray(artifacts["W_orig"]).copy(), + W_perm=W_perm, + permute_ids=permute_ids, + permute_mode="shuffle", + seed=seed, + rng_state=params["rng"].get_state() if params.get("rng") is not None else None, + permute_fingerprint=fp, + U_perm=U, + S_perm=S, + Vh_perm=Vh, + trap_metrics=layer_traps.copy(), + trap_mode_map=trap_mode_map, + mp_bulk_edge={"mp_bulk_max": float(artifacts.get("mp_bulk_max", np.nan))}, + ) + if save_bundle: + path = save_trap_bundle(bundle, bundle_dir or "trap_bundles") + layer_traps["bundle_path"] = path + bundle.bundle_path = path + bundles[int(layer_id)] = bundle + rows.append(layer_traps) + logger.info(f"trap bundle layer_id={layer_id} traps={len(layer_traps)} permute_fingerprint={fp} bundle_path={bundle.bundle_path}") + + out_df = pd.concat(rows, ignore_index=True) if rows else trap_df.copy() + return (out_df, bundles) if return_bundle else (out_df, None) + + +def remove_single_trap_from_bundle(model, bundle: TrapAnalysisBundle, trap_row, inplace=False, allow_model_mismatch=False, atol=1e-8, rtol=1e-5): + row = trap_row if isinstance(trap_row, dict) else trap_row.to_dict() + if int(row["layer_id"]) != int(bundle.layer_id): + raise ValueError("layer_id mismatch between trap row and bundle") + if str(row.get("permute_fingerprint")) != str(bundle.permute_fingerprint): + raise ValueError("permute_fingerprint mismatch between trap row and bundle") + tmi = int(row["trap_mode_index"]) + if tmi < 0 or tmi >= len(bundle.S_perm): + raise ValueError("trap_mode_index not found in bundle decomposition") + if "sigma_perm" in row and not np.isclose(float(row["sigma_perm"]), float(bundle.S_perm[tmi]), rtol=rtol, atol=atol): + raise ValueError("sigma_perm mismatch between trap row and bundle decomposition") + + watcher = WeightWatcher(model=model) + params = watcher.normalize_params(DEFAULT_PARAMS.copy()) + target_layer = None + for ww_layer in watcher.make_layer_iterator(model=watcher.model, layers=[bundle.layer_id], params=params, base_model=watcher.base_model): + if int(ww_layer.layer_id) == int(bundle.layer_id): + target_layer = ww_layer + break + if target_layer is None: + raise ValueError("target layer not found in model") + current_W = np.asarray(target_layer.Wmats[0]) + if current_W.shape != bundle.W_orig.shape: + raise ValueError("current model layer shape mismatch vs bundle") + if (not allow_model_mismatch) and (not np.allclose(current_W, bundle.W_orig, rtol=rtol, atol=atol)): + raise ValueError("current model layer weights do not match bundle original weights") + + u = bundle.U_perm[:, tmi] + v = bundle.Vh_perm[tmi, :] + T_perm = bundle.S_perm[tmi] * np.outer(u, v) + W_perm_abl = bundle.W_perm - T_perm + W_abl = unpermute_matrix(W_perm_abl, bundle.permute_ids) + + out_model = model if inplace else copy.deepcopy(model) + out_watcher = WeightWatcher(model=out_model) + out_params = out_watcher.normalize_params(DEFAULT_PARAMS.copy()) + for l in out_watcher.make_layer_iterator(model=out_watcher.model, layers=[bundle.layer_id], params=out_params, base_model=out_watcher.base_model): + if int(l.layer_id) == int(bundle.layer_id): + out_watcher.replace_layer_weights(l.layer_id, l.framework_layer, W_abl) + break + meta = {"ok": True, "layer_id": bundle.layer_id, "trap_index": int(row.get("trap_index", -1)), "trap_mode_index": tmi, "permute_fingerprint": bundle.permute_fingerprint} + logger.info(f"bundle ablation checkpoint={bundle.checkpoint_id} layer_id={bundle.layer_id} trap_index={meta['trap_index']} trap_mode_index={tmi} permute_fingerprint={bundle.permute_fingerprint} verification passed") + return out_model, meta + + +def remove_traps_from_bundle(model, bundle: TrapAnalysisBundle, trap_indices=None, trap_mode_indices=None, inplace=False, allow_model_mismatch=False): + rows = bundle.trap_metrics.copy() + if trap_indices is not None: + rows = rows[rows["trap_index"].isin(list(trap_indices))] + if trap_mode_indices is not None: + rows = rows[rows["trap_mode_index"].isin(list(trap_mode_indices))] + if len(rows) == 0: + raise ValueError("No traps selected from bundle") + out_model = model + metas = [] + for _, row in rows.iterrows(): + out_model, meta = remove_single_trap_from_bundle(out_model, bundle, row, inplace=True, allow_model_mismatch=allow_model_mismatch) + metas.append(meta) + return (out_model if inplace else copy.deepcopy(out_model)), pd.DataFrame.from_records(metas) + + +def run_trap_bundle_ablation_experiment(model, checkpoint_step, checkpoint_path, layers, evaluate_fn, bulk_baseline_fn=None, bundle_dir=None, save_bundles=True): + trap_df, bundles = analyze_traps_bundle(model, layers=layers, save_bundle=save_bundles, bundle_dir=bundle_dir, checkpoint_id=checkpoint_step, return_bundle=True, plot=False) + base = evaluate_fn(model) + out_rows = [] + for _, row in trap_df.iterrows(): + layer_id = int(row["layer_id"]) + bundle = bundles[layer_id] + rec = dict(row) + rec.update({"checkpoint_step": checkpoint_step, "checkpoint_path": checkpoint_path, "ok": False, "error": None, "base_train_accuracy": base.get("train_accuracy"), "base_test_accuracy": base.get("test_accuracy")}) + try: + ablated_model, meta = remove_single_trap_from_bundle(model, bundle, row, inplace=False) + res = evaluate_fn(ablated_model) + rec["trap_train_accuracy"] = res.get("train_accuracy") + rec["trap_test_accuracy"] = res.get("test_accuracy") + rec["trap_delta_train_accuracy"] = rec["trap_train_accuracy"] - rec["base_train_accuracy"] + rec["trap_delta_test_accuracy"] = rec["trap_test_accuracy"] - rec["base_test_accuracy"] + if bulk_baseline_fn is not None: + bulk = bulk_baseline_fn(model, bundle, row) + rec.update(bulk) + if "bulk_delta_test_accuracy_mean" in rec: + rec["trap_damage_excess_vs_bulk"] = rec["trap_delta_test_accuracy"] - rec["bulk_delta_test_accuracy_mean"] + rec["ok"] = bool(meta.get("ok", False)) + rec["bundle_path"] = bundle.bundle_path + except Exception as exc: + rec["error"] = str(exc) + out_rows.append(rec) + return pd.DataFrame.from_records(out_rows)