Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions tests/test_trap_bundles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pandas as pd
import pytest

import weightwatcher as ww
from weightwatcher.trap_bundles import analyze_traps_bundle, load_trap_bundle, remove_single_trap_from_bundle

try:
import torch
import torch.nn as nn
except Exception:
torch = None
nn = None


pytestmark = pytest.mark.skipif(torch is None, reason="torch required")


if torch is not None:
class TinyTrapNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(16, 12, bias=False)
with torch.no_grad():
u = torch.linspace(1.0, 2.0, steps=12)
v = torch.linspace(-2.0, 1.0, steps=16)
self.fc1.weight.copy_(35.0 * torch.outer(u, v))


def test_analyze_traps_bundle_roundtrip(tmp_path):
model = TinyTrapNet()
watcher = ww.WeightWatcher(model=model)
trap_df, bundles = analyze_traps_bundle(watcher, save_bundle=True, bundle_dir=str(tmp_path), checkpoint_id="1")
assert isinstance(trap_df, pd.DataFrame)
assert len(bundles) >= 1
if len(trap_df) == 0:
pytest.skip("no traps")
row = trap_df.iloc[0]
b = bundles[int(row.layer_id)]
assert row.permute_fingerprint == b.permute_fingerprint
assert "bundle_path" in trap_df.columns
loaded = load_trap_bundle(row.bundle_path)
assert loaded.permute_fingerprint == b.permute_fingerprint


def test_remove_single_trap_from_bundle_independent_of_seed():
model = TinyTrapNet()
watcher = ww.WeightWatcher(model=model)
trap_df, bundles = analyze_traps_bundle(watcher, seed=111)
if len(trap_df) == 0:
pytest.skip("no traps")
row = trap_df.iloc[0]
b = bundles[int(row.layer_id)]

m1, meta1 = remove_single_trap_from_bundle(model, b, row)
# changing seed in fresh analysis should not affect prior bundle-based ablation
_df2, _bundles2 = analyze_traps_bundle(ww.WeightWatcher(model=TinyTrapNet()), seed=999)
m2, meta2 = remove_single_trap_from_bundle(model, b, row)
assert meta1["permute_fingerprint"] == meta2["permute_fingerprint"]


def test_remove_single_trap_mismatch_raises():
model = TinyTrapNet()
trap_df, bundles = analyze_traps_bundle(ww.WeightWatcher(model=model), seed=123)
if len(trap_df) == 0:
pytest.skip("no traps")
row = trap_df.iloc[0].copy()
b = bundles[int(row.layer_id)]
row["permute_fingerprint"] = "bad"
with pytest.raises(ValueError, match="permute_fingerprint mismatch"):
remove_single_trap_from_bundle(model, b, row)


def test_legacy_api_still_works():
model = TinyTrapNet()
watcher = ww.WeightWatcher(model=model)
df = watcher.analyze_traps(plot=False, savefig=False, rng=1337)
assert isinstance(df, pd.DataFrame)
3 changes: 1 addition & 2 deletions weightwatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__name__ = "weightwatcher"

__version__ = "0.8.3"
__version__ = "0.8.4"

__license__ = "Apache License, Version 2.0"
__description__ = "Diagnostic Tool for Deep Neural Networks"
Expand All @@ -30,4 +30,3 @@

__all__ = ["__name__", "__version__", "__license__", "__description__",
"__url__", "__author__", "__email__", "__copyright__"]

22 changes: 22 additions & 0 deletions weightwatcher/trap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def analyze_traps(
trap_burden=False,
trap_burden_variant="top5",
top_sector_l=1,
bundle_collector=None,
):
"""Externalized implementation for WeightWatcher.analyze_traps()."""
if layers is None:
Expand Down Expand Up @@ -136,6 +137,27 @@ def analyze_traps(
for row in layer_rows:
row.pop("T_orig", None)
trap_rows.extend(layer_rows)
if bundle_collector is not None and len(layer_rows) > 0:
try:
from .RMT_Util import svd_full
W_perm = np.asarray(ww_layer.Wmats[0]).copy()
U_perm, S_perm, Vh_perm = svd_full(W_perm)
permute_ids = np.asarray(ww_layer.permute_ids[0]).copy() if len(ww_layer.permute_ids) > 0 else None
bundle_collector[int(ww_layer.layer_id)] = {
"layer_id": int(ww_layer.layer_id),
"name": str(getattr(ww_layer, "name", "")),
"longname": str(getattr(ww_layer, "longname", "")),
"W_orig": np.asarray(getattr(ww_layer, "W_orig", ww_layer.Wmats[0])).copy(),
"W_perm": W_perm,
"permute_ids": permute_ids,
"U_perm": U_perm,
"S_perm": S_perm,
"Vh_perm": Vh_perm,
"mp_bulk_max": float(layer_rows[0].get("mp_bulk_max", np.nan)),
"rows": [dict(r) for r in layer_rows],
}
except Exception:
wwcore.logger.exception("Failed to collect trap bundle artifacts for layer %s", ww_layer.layer_id)

if len(trap_rows) > 0:
details = pd.DataFrame.from_records(trap_rows)
Expand Down
222 changes: 222 additions & 0 deletions weightwatcher/trap_bundles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import copy
import hashlib
import logging
import os
import pickle
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

from . import remove_traps as remove_traps_ops
from . import trap_analysis as trap_analysis_ops
from .RMT_Util import unpermute_matrix
from .constants import DEFAULT_PARAMS, DEFAULT_START_ID, FAST_SVD, LAYERS, PEFT, PLOT, POOL, SVD_METHOD, WW_NAME
from .weightwatcher import WeightWatcher

logger = logging.getLogger(WW_NAME)


@dataclass
class TrapAnalysisBundle:
checkpoint_id: Optional[str]
layer_id: int
layer_name: str
layer_longname: str
W_orig: np.ndarray
W_perm: np.ndarray
permute_ids: np.ndarray
permute_mode: str
seed: Optional[int]
rng_state: Optional[dict]
permute_fingerprint: str
U_perm: np.ndarray
S_perm: np.ndarray
Vh_perm: np.ndarray
trap_metrics: pd.DataFrame
trap_mode_map: Dict[int, int]
mp_bulk_edge: Dict[str, float] = field(default_factory=dict)
bundle_path: Optional[str] = None


def _fingerprint_perm(permute_ids: np.ndarray) -> str:
return hashlib.sha1(np.asarray(permute_ids).tobytes()).hexdigest()


def _bundle_filename(bundle: TrapAnalysisBundle) -> str:
ck = str(bundle.checkpoint_id or "na")
seed = "none" if bundle.seed is None else str(bundle.seed)
return f"trap_bundle_step_{ck}_layer_{bundle.layer_id}_{bundle.permute_fingerprint}_seed_{seed}.pkl"


def save_trap_bundle(bundle: TrapAnalysisBundle, bundle_dir: str) -> str:
os.makedirs(bundle_dir, exist_ok=True)
path = os.path.join(bundle_dir, _bundle_filename(bundle))
with open(path, "wb") as f:
pickle.dump(bundle, f)
bundle.bundle_path = path
return path


def load_trap_bundle(path: str) -> TrapAnalysisBundle:
with open(path, "rb") as f:
return pickle.load(f)


def analyze_traps_bundle(model_or_watcher, layers=None, save_bundle=False, bundle_dir=None, return_bundle=True, checkpoint_id=None, seed=None, rng=None, plot=False, pool=True):
watcher = model_or_watcher if isinstance(model_or_watcher, WeightWatcher) else WeightWatcher(model=model_or_watcher)
model = watcher.model if isinstance(model_or_watcher, WeightWatcher) else model_or_watcher
params = DEFAULT_PARAMS.copy()
params[POOL] = pool
params[PLOT] = plot
params[LAYERS] = [] if layers is None else layers
params[SVD_METHOD] = FAST_SVD
params[PEFT] = params.get(PEFT)
params["seed"] = seed
params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng, seed=seed)
params = watcher.normalize_params(params)

bundle_collector = {}
trap_df = trap_analysis_ops.analyze_traps(
watcher,
model=model,
layers=layers or [],
plot=plot,
pool=pool,
rng=rng if rng is not None else seed,
bundle_collector=bundle_collector,
)
bundles = {}
rows = []
for layer_id, artifacts in bundle_collector.items():
layer_traps = trap_df[trap_df["layer_id"].astype(int) == int(layer_id)].copy()
if len(layer_traps) == 0:
continue
W_perm = np.asarray(artifacts["W_perm"]).copy()
permute_ids = np.asarray(artifacts["permute_ids"]).copy()
fp = _fingerprint_perm(permute_ids)
U, S, Vh = np.asarray(artifacts["U_perm"]), np.asarray(artifacts["S_perm"]), np.asarray(artifacts["Vh_perm"])

layer_traps["permute_fingerprint"] = fp
layer_traps["bundle_id"] = f"{checkpoint_id}:{layer_id}:{fp}"
trap_mode_map = {int(r.trap_index): int(r.trap_mode_index) for r in layer_traps.itertuples()}
bundle = TrapAnalysisBundle(
checkpoint_id=str(checkpoint_id) if checkpoint_id is not None else None,
layer_id=int(layer_id),
layer_name=str(artifacts.get("name", "")),
layer_longname=str(artifacts.get("longname", "")),
W_orig=np.asarray(artifacts["W_orig"]).copy(),
W_perm=W_perm,
permute_ids=permute_ids,
permute_mode="shuffle",
seed=seed,
rng_state=params["rng"].get_state() if params.get("rng") is not None else None,
permute_fingerprint=fp,
U_perm=U,
S_perm=S,
Vh_perm=Vh,
trap_metrics=layer_traps.copy(),
trap_mode_map=trap_mode_map,
mp_bulk_edge={"mp_bulk_max": float(artifacts.get("mp_bulk_max", np.nan))},
)
if save_bundle:
path = save_trap_bundle(bundle, bundle_dir or "trap_bundles")
layer_traps["bundle_path"] = path
bundle.bundle_path = path
bundles[int(layer_id)] = bundle
rows.append(layer_traps)
logger.info(f"trap bundle layer_id={layer_id} traps={len(layer_traps)} permute_fingerprint={fp} bundle_path={bundle.bundle_path}")

out_df = pd.concat(rows, ignore_index=True) if rows else trap_df.copy()
return (out_df, bundles) if return_bundle else (out_df, None)


def remove_single_trap_from_bundle(model, bundle: TrapAnalysisBundle, trap_row, inplace=False, allow_model_mismatch=False, atol=1e-8, rtol=1e-5):
row = trap_row if isinstance(trap_row, dict) else trap_row.to_dict()
if int(row["layer_id"]) != int(bundle.layer_id):
raise ValueError("layer_id mismatch between trap row and bundle")
if str(row.get("permute_fingerprint")) != str(bundle.permute_fingerprint):
raise ValueError("permute_fingerprint mismatch between trap row and bundle")
tmi = int(row["trap_mode_index"])
if tmi < 0 or tmi >= len(bundle.S_perm):
raise ValueError("trap_mode_index not found in bundle decomposition")
if "sigma_perm" in row and not np.isclose(float(row["sigma_perm"]), float(bundle.S_perm[tmi]), rtol=rtol, atol=atol):
raise ValueError("sigma_perm mismatch between trap row and bundle decomposition")

watcher = WeightWatcher(model=model)
params = watcher.normalize_params(DEFAULT_PARAMS.copy())
target_layer = None
for ww_layer in watcher.make_layer_iterator(model=watcher.model, layers=[bundle.layer_id], params=params, base_model=watcher.base_model):
if int(ww_layer.layer_id) == int(bundle.layer_id):
target_layer = ww_layer
break
if target_layer is None:
raise ValueError("target layer not found in model")
current_W = np.asarray(target_layer.Wmats[0])
if current_W.shape != bundle.W_orig.shape:
raise ValueError("current model layer shape mismatch vs bundle")
if (not allow_model_mismatch) and (not np.allclose(current_W, bundle.W_orig, rtol=rtol, atol=atol)):
raise ValueError("current model layer weights do not match bundle original weights")

u = bundle.U_perm[:, tmi]
v = bundle.Vh_perm[tmi, :]
T_perm = bundle.S_perm[tmi] * np.outer(u, v)
W_perm_abl = bundle.W_perm - T_perm
W_abl = unpermute_matrix(W_perm_abl, bundle.permute_ids)

out_model = model if inplace else copy.deepcopy(model)
out_watcher = WeightWatcher(model=out_model)
out_params = out_watcher.normalize_params(DEFAULT_PARAMS.copy())
for l in out_watcher.make_layer_iterator(model=out_watcher.model, layers=[bundle.layer_id], params=out_params, base_model=out_watcher.base_model):
if int(l.layer_id) == int(bundle.layer_id):
out_watcher.replace_layer_weights(l.layer_id, l.framework_layer, W_abl)
break
meta = {"ok": True, "layer_id": bundle.layer_id, "trap_index": int(row.get("trap_index", -1)), "trap_mode_index": tmi, "permute_fingerprint": bundle.permute_fingerprint}
logger.info(f"bundle ablation checkpoint={bundle.checkpoint_id} layer_id={bundle.layer_id} trap_index={meta['trap_index']} trap_mode_index={tmi} permute_fingerprint={bundle.permute_fingerprint} verification passed")
return out_model, meta


def remove_traps_from_bundle(model, bundle: TrapAnalysisBundle, trap_indices=None, trap_mode_indices=None, inplace=False, allow_model_mismatch=False):
rows = bundle.trap_metrics.copy()
if trap_indices is not None:
rows = rows[rows["trap_index"].isin(list(trap_indices))]
if trap_mode_indices is not None:
rows = rows[rows["trap_mode_index"].isin(list(trap_mode_indices))]
if len(rows) == 0:
raise ValueError("No traps selected from bundle")
out_model = model
metas = []
for _, row in rows.iterrows():
out_model, meta = remove_single_trap_from_bundle(out_model, bundle, row, inplace=True, allow_model_mismatch=allow_model_mismatch)
metas.append(meta)
return (out_model if inplace else copy.deepcopy(out_model)), pd.DataFrame.from_records(metas)


def run_trap_bundle_ablation_experiment(model, checkpoint_step, checkpoint_path, layers, evaluate_fn, bulk_baseline_fn=None, bundle_dir=None, save_bundles=True):
trap_df, bundles = analyze_traps_bundle(model, layers=layers, save_bundle=save_bundles, bundle_dir=bundle_dir, checkpoint_id=checkpoint_step, return_bundle=True, plot=False)
base = evaluate_fn(model)
out_rows = []
for _, row in trap_df.iterrows():
layer_id = int(row["layer_id"])
bundle = bundles[layer_id]
rec = dict(row)
rec.update({"checkpoint_step": checkpoint_step, "checkpoint_path": checkpoint_path, "ok": False, "error": None, "base_train_accuracy": base.get("train_accuracy"), "base_test_accuracy": base.get("test_accuracy")})
try:
ablated_model, meta = remove_single_trap_from_bundle(model, bundle, row, inplace=False)
res = evaluate_fn(ablated_model)
rec["trap_train_accuracy"] = res.get("train_accuracy")
rec["trap_test_accuracy"] = res.get("test_accuracy")
rec["trap_delta_train_accuracy"] = rec["trap_train_accuracy"] - rec["base_train_accuracy"]
rec["trap_delta_test_accuracy"] = rec["trap_test_accuracy"] - rec["base_test_accuracy"]
if bulk_baseline_fn is not None:
bulk = bulk_baseline_fn(model, bundle, row)
rec.update(bulk)
if "bulk_delta_test_accuracy_mean" in rec:
rec["trap_damage_excess_vs_bulk"] = rec["trap_delta_test_accuracy"] - rec["bulk_delta_test_accuracy_mean"]
rec["ok"] = bool(meta.get("ok", False))
rec["bundle_path"] = bundle.bundle_path
except Exception as exc:
rec["error"] = str(exc)
out_rows.append(rec)
return pd.DataFrame.from_records(out_rows)