Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_analyze_traps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import numpy as np
import pandas as pd
from unittest.mock import patch
try:
import torch
import torch.nn as nn
Expand All @@ -9,6 +10,7 @@
TORCH_AVAILABLE = False

import weightwatcher as ww
from weightwatcher import remove_traps as remove_traps_ops


if TORCH_AVAILABLE:
Expand Down Expand Up @@ -220,6 +222,30 @@ def test_old_and_new_burdens_coexist(self):
for col in ["trap_variance_burden_old", "trap_variance_burden_ipr", "trap_variance_burden_top5", "trap_variance_burden"]:
self.assertIn(col, df.columns)

def test_analyze_traps_fast_mode_skips_original_basis(self):
with patch.object(ww.WeightWatcher, "compute_original_basis_for_traps", side_effect=AssertionError("should skip in fast mode")):
df = self.watcher.analyze_traps(
plot=False, savefig=False, trap_burden=True, trap_burden_mode="fast", compute_original_basis=False
)
self.assertIsInstance(df, pd.DataFrame)

def test_analyze_traps_fast_mode_skips_full_bulk_reference(self):
with patch.object(ww.WeightWatcher, "compute_bulk_trap_reference_metrics", side_effect=AssertionError("should skip in fast mode")):
df = self.watcher.analyze_traps(
plot=False, savefig=False, trap_burden=True, trap_burden_mode="fast", compute_full_bulk_reference=False
)
required = {"B_absDelta_ipr_ovlamvar", "spectral_excess_abs", "ipr_lift_excess_pos", "ov_lam_weighted_var", "trap_variance_burden_ipr", "trap_variance_burden"}
self.assertTrue(required.issubset(set(df.columns)))

def test_analyze_traps_full_mode_allows_full_diagnostics(self):
df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_mode="full")
self.assertIsInstance(df, pd.DataFrame)

def test_analyze_traps_fast_mode_does_not_call_collect_trap_artifacts(self):
with patch.object(remove_traps_ops, "collect_trap_artifacts", side_effect=AssertionError("should not be called")):
df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_mode="fast")
self.assertIsInstance(df, pd.DataFrame)

def test_no_trap_fft_api_or_columns(self):
with self.assertRaises(TypeError):
self.watcher.analyze_traps(plot=False, savefig=False, trap_fft=True)
Expand Down
7 changes: 4 additions & 3 deletions weightwatcher/remove_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def apply_trap_mp_fit(ww, ww_layer, params=None):
return ww_layer


def identify_trap_mode_indices(ww, ww_layer):
W = ww_layer.Wmats[0]
_, svals, _ = svd_full(W)
def identify_trap_mode_indices(ww, ww_layer, svals=None):
if svals is None:
W = ww_layer.Wmats[0]
_, svals, _ = svd_full(W)
evals_desc = svals * svals

Q = ww_layer.N / ww_layer.M
Expand Down
18 changes: 18 additions & 0 deletions weightwatcher/trap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def analyze_traps(
rng=None,
trap_burden=False,
trap_burden_variant="top5",
trap_burden_mode="fast",
compute_original_basis=None,
compute_full_bulk_reference=None,
bulk_mode_sample=10,
compute_original_trap_svd=None,
top_sector_l=1,
):
"""Externalized implementation for WeightWatcher.analyze_traps()."""
Expand Down Expand Up @@ -78,6 +83,19 @@ def analyze_traps(
params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng)
params["trap_burden"] = bool(trap_burden)
params["trap_burden_variant"] = trap_burden_variant
params["trap_burden_mode"] = trap_burden_mode
if compute_original_basis is None:
compute_original_basis = (trap_burden_mode == "full")
if compute_full_bulk_reference is None:
compute_full_bulk_reference = (trap_burden_mode == "full")
if compute_original_trap_svd is None:
compute_original_trap_svd = (trap_burden_mode == "full")
params["compute_original_basis"] = bool(compute_original_basis)
params["compute_full_bulk_reference"] = bool(compute_full_bulk_reference)
if bulk_mode_sample is None and trap_burden_mode == "fast":
bulk_mode_sample = 10
params["bulk_mode_sample"] = bulk_mode_sample
params["compute_original_trap_svd"] = bool(compute_original_trap_svd)
params["top_sector_l"] = int(top_sector_l)

wwcore.logger.debug("params {}".format(params))
Expand Down
198 changes: 155 additions & 43 deletions weightwatcher/weightwatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3709,6 +3709,11 @@ def analyze_traps(self, model=None, layers=[],
rng=None,
trap_burden=False,
trap_burden_variant="top5",
trap_burden_mode="fast",
compute_original_basis=None,
compute_full_bulk_reference=None,
bulk_mode_sample=10,
compute_original_trap_svd=None,
top_sector_l=1):
"""Analyze randomized correlation traps and return one row per trap.

Expand All @@ -3728,6 +3733,9 @@ def analyze_traps(self, model=None, layers=[],
rng : None, int, or numpy.random.RandomState
Optional random source used for reversible trap permutations.
Passing the same seed/object makes trap detection reproducible across runs.
trap_burden_mode : {"fast","full"}
"fast" is optimized for ablation sweeps and avoids expensive original-basis
diagnostics and full bulk reconstruction loops. "full" preserves richer diagnostics.
"""

from . import trap_analysis
Expand Down Expand Up @@ -3756,6 +3764,11 @@ def analyze_traps(self, model=None, layers=[],
rng=rng,
trap_burden=trap_burden,
trap_burden_variant=trap_burden_variant,
trap_burden_mode=trap_burden_mode,
compute_original_basis=compute_original_basis,
compute_full_bulk_reference=compute_full_bulk_reference,
bulk_mode_sample=bulk_mode_sample,
compute_original_trap_svd=compute_original_trap_svd,
top_sector_l=top_sector_l,
)

Expand Down Expand Up @@ -3788,6 +3801,7 @@ def _trap_result_columns(self):
"ov_rank_mean", "ov_rank_std",
"ov_hhi", "ov_participation", "ov_entropy", "ov_max",
"trap_variance_burden_ipr", "trap_variance_burden_top5", "trap_variance_burden",
"B_absDelta_ipr_ovlamvar", "B_absDelta_logtop5_ovlamvar", "B_evalsq_logtop5_ovrank", "B_old_pr359_paper",
"top_sector_l", "top_sector_l_effective",
"trap_delta", "trap_q", "trap_diffuseness",
"trap_q_uniform", "trap_diffuseness_uniform",
Expand All @@ -3805,13 +3819,23 @@ def apply_analyze_traps(self, ww_layer, params=None):
ww_layer.layer_id, ww_layer.name, len(ww_layer.Wmats))
return []

self.apply_esd(ww_layer, params)
original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params)
if params.get("compute_original_basis", False):
self.apply_esd(ww_layer, params)
original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params)
else:
original_basis_cache = None

self.apply_permute_W(ww_layer, params)
self.apply_trap_mp_fit(ww_layer, params=params)
trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params)
bulk_stats = self.compute_bulk_trap_reference_metrics(ww_layer, trap_mode_indices, params=params)
W_perm = ww_layer.Wmats[0].astype(float)
U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD])
trap_mode_indices = remove_traps_ops.identify_trap_mode_indices(self, ww_layer, svals=S_perm)
if params.get("compute_full_bulk_reference", False):
bulk_stats = self.compute_bulk_trap_reference_metrics(ww_layer, trap_mode_indices, params=params)
else:
bulk_stats = self.compute_fast_bulk_trap_reference_metrics(
ww_layer, U_perm, S_perm, Vh_perm, trap_mode_indices, params=params
)

trap_rows = []
for trap_index, mode_index in enumerate(trap_mode_indices):
Expand All @@ -3822,6 +3846,7 @@ def apply_analyze_traps(self, ww_layer, params=None):
params=params,
trap_index=trap_index,
bulk_stats=bulk_stats,
precomputed_svd=(U_perm, S_perm, Vh_perm),
)
trap_row.update(bulk_stats)
trap_rows.append(trap_row)
Expand All @@ -3833,6 +3858,68 @@ def apply_analyze_traps(self, ww_layer, params=None):
self.apply_unpermute_W(ww_layer, params)
return trap_rows

def compute_fast_bulk_trap_reference_metrics(self, ww_layer, U_perm, S_perm, Vh_perm, trap_mode_indices, params=None):
if params is None:
params = DEFAULT_PARAMS.copy()
n_modes = len(S_perm)
trap_set = set(int(i) for i in trap_mode_indices)
bulk_indices = [i for i in range(n_modes) if i not in trap_set]
if len(bulk_indices) == 0:
return {
"bulk_mode_count": 0,
"bulk_localization_mean": np.nan,
"bulk_localization_std": np.nan,
"bulk_top_5_mass_mean": np.nan,
"bulk_top_5_mass_std": np.nan,
"bulk_top_10_mass_mean": np.nan,
"bulk_top_10_mass_std": np.nan,
"bulk_ipr_mean": np.nan,
"bulk_ipr_std": np.nan,
"bulk_mode_sample_used": 0,
}
sample_n = params.get("bulk_mode_sample", None)
if sample_n is None:
sample_n = 10
if sample_n is not None and len(bulk_indices) > int(sample_n):
sample_n = int(sample_n)
seed = params.get("seed", 12345) if isinstance(params, dict) else 12345
rs = np.random.RandomState(int(seed) if seed is not None else 12345)
edges = np.linspace(0, len(bulk_indices), sample_n + 1, dtype=int)
sampled = []
for i in range(sample_n):
lo, hi = edges[i], edges[i + 1]
if hi <= lo:
continue
sampled.append(bulk_indices[rs.randint(lo, hi)])
bulk_indices = sampled if len(sampled) > 0 else bulk_indices[:sample_n]
bulk_ipr_vals = []
bulk_loc_vals = []
top5_vals = []
top10_vals = []
for mode_idx in bulk_indices:
u = U_perm[:, mode_idx]
v = Vh_perm[mode_idx, :]
u_ipr = float(np.sum(np.asarray(u, dtype=float) ** 4))
v_ipr = float(np.sum(np.asarray(v, dtype=float) ** 4))
bulk_ipr_vals.append(0.5 * (u_ipr + v_ipr))
u_metrics = self._trap_vector_metrics(u)
v_metrics = self._trap_vector_metrics(v)
bulk_loc_vals.append(0.5 * (float(u_metrics.get("localization_ratio", np.nan)) + float(v_metrics.get("localization_ratio", np.nan))))
top5_vals.append(0.5 * (float(u_metrics.get("top5_mass", np.nan)) + float(v_metrics.get("top5_mass", np.nan))))
top10_vals.append(0.5 * (float(u_metrics.get("top10_mass", np.nan)) + float(v_metrics.get("top10_mass", np.nan))))
return {
"bulk_mode_count": int(len(bulk_indices)),
"bulk_localization_mean": float(np.nanmean(bulk_loc_vals)),
"bulk_localization_std": float(np.nanstd(bulk_loc_vals)),
"bulk_top_5_mass_mean": float(np.nanmean(top5_vals)),
"bulk_top_5_mass_std": float(np.nanstd(top5_vals)),
"bulk_top_10_mass_mean": float(np.nanmean(top10_vals)),
"bulk_top_10_mass_std": float(np.nanstd(top10_vals)),
"bulk_ipr_mean": float(np.nanmean(bulk_ipr_vals)),
"bulk_ipr_std": float(np.nanstd(bulk_ipr_vals)),
"bulk_mode_sample_used": int(len(bulk_indices)),
}

def compute_trap_delta(self, eval_perm, mp_bulk_max):
if not (np.isfinite(eval_perm) and np.isfinite(mp_bulk_max)) or mp_bulk_max <= 0:
return np.nan
Expand Down Expand Up @@ -4004,7 +4091,7 @@ def compute_original_basis_for_traps(self, ww_layer, params=None):
}


def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0, bulk_stats=None):
def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0, bulk_stats=None, precomputed_svd=None):
if params is None: params = DEFAULT_PARAMS.copy()
if original_basis_cache is None:
original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params)
Expand All @@ -4014,7 +4101,10 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
W_perm = ww_layer.Wmats[0].astype(float)
p_ids = ww_layer.permute_ids[0]

U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD])
if precomputed_svd is None:
U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD])
else:
U_perm, S_perm, Vh_perm = precomputed_svd
V_perm = Vh_perm.T

sigma_perm = float(S_perm[trap_mode_index])
Expand All @@ -4027,12 +4117,19 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
top_5_mass = self._top_percent_abs_mass(T_orig, 5.0)
top_10_mass = self._top_percent_abs_mass(T_orig, 10.0)

Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD])
u_trap = Ut[:, 0]
v_trap = Vht.T[:, 0]
compute_original_trap_svd = bool(params.get("compute_original_trap_svd", False))
if compute_original_trap_svd:
Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD])
u_trap = Ut[:, 0]
v_trap = Vht.T[:, 0]
sigma_trap_top = float(St[0])
else:
u_trap = u_perm
v_trap = v_perm
sigma_trap_top = float(sigma_perm)

U0 = original_basis_cache["U0"]
V0 = original_basis_cache["V0"]
U0 = original_basis_cache["U0"] if original_basis_cache is not None else U_perm
V0 = original_basis_cache["V0"] if original_basis_cache is not None else Vh_perm.T

left_overlaps = np.abs(U0.T @ u_trap) ** 2
right_overlaps = np.abs(V0.T @ v_trap) ** 2
Expand All @@ -4048,8 +4145,7 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
left_overlap_ipr = float(np.sum(left_overlaps ** 2))
right_overlap_ipr = float(np.sum(right_overlaps ** 2))

st_sq = St * St
rank1_mass_after_unpermute = float(st_sq[0] / (np.sum(st_sq) + eps))
rank1_mass_after_unpermute = 1.0

u_metrics = self._trap_vector_metrics(u_trap)
v_metrics = self._trap_vector_metrics(v_trap)
Expand Down Expand Up @@ -4097,7 +4193,7 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
"sigma_mp": float(ww_layer.sigma_mp),
"num_spikes": int(ww_layer.num_spikes),
"rank1_mass_after_unpermute": rank1_mass_after_unpermute,
"sigma_trap_top": float(St[0]),
"sigma_trap_top": sigma_trap_top,
"left_top_mode": left_top_mode,
"right_top_mode": right_top_mode,
"left_top_mass": left_top_mass,
Expand Down Expand Up @@ -4148,36 +4244,46 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
top_5_lift_excess_pos = float(max(top_5_lift - 1.0, 0.0))
log1p_top_5_lift = float(np.log1p(top_5_lift))

W_orig = original_basis_cache["W_true"]
X = W_orig.T @ W_orig
evals_x, evecs_x = np.linalg.eigh(X)
order = np.argsort(evals_x)[::-1]
lam = evals_x[order]
phi = evecs_x[:, order]
v_norm = np.linalg.norm(v_trap)
if v_norm > 0:
v_unit = v_trap / v_norm
p = np.abs(phi.T @ v_unit) ** 2
p_sum = float(np.sum(p))
if p_sum > 0:
p = p / p_sum
ov_lam_weighted_mean = float(np.sum(p * lam))
ov_lam_weighted_var = float(np.sum(p * (lam - ov_lam_weighted_mean) ** 2))
ranks = np.arange(1, len(lam) + 1, dtype=float)
ov_rank_mean = float(np.sum(p * ranks))
ov_rank_std = float(np.sqrt(np.sum(p * (ranks - ov_rank_mean) ** 2)))
ov_hhi = float(np.sum(p ** 2))
ov_participation = float(1.0 / ov_hhi) if ov_hhi > 0 else np.nan
p_safe = p[p > 0]
ov_entropy = float(-np.sum(p_safe * np.log(p_safe))) if p_safe.size else np.nan
ov_max = float(np.max(p)) if p.size else np.nan
if original_basis_cache is not None:
W_orig = original_basis_cache["W_true"]
X = W_orig.T @ W_orig
evals_x, evecs_x = np.linalg.eigh(X)
order = np.argsort(evals_x)[::-1]
lam = evals_x[order]
phi = evecs_x[:, order]
v_norm = np.linalg.norm(v_trap)
if v_norm > 0:
v_unit = v_trap / v_norm
p = np.abs(phi.T @ v_unit) ** 2
p_sum = float(np.sum(p))
if p_sum > 0:
p = p / p_sum
ov_lam_weighted_mean = float(np.sum(p * lam))
ov_lam_weighted_var = float(np.sum(p * (lam - ov_lam_weighted_mean) ** 2))
ranks = np.arange(1, len(lam) + 1, dtype=float)
ov_rank_mean = float(np.sum(p * ranks))
ov_rank_std = float(np.sqrt(np.sum(p * (ranks - ov_rank_mean) ** 2)))
ov_hhi = float(np.sum(p ** 2))
ov_participation = float(1.0 / ov_hhi) if ov_hhi > 0 else np.nan
p_safe = p[p > 0]
ov_entropy = float(-np.sum(p_safe * np.log(p_safe))) if p_safe.size else np.nan
ov_max = float(np.max(p)) if p.size else np.nan
else:
ov_lam_weighted_mean = np.nan
ov_lam_weighted_var = np.nan
ov_rank_mean = np.nan
ov_rank_std = np.nan
ov_hhi = np.nan
ov_participation = np.nan
ov_entropy = np.nan
ov_max = np.nan
else:
ov_lam_weighted_mean = np.nan
ov_lam_weighted_var = np.nan
ov_rank_mean = np.nan
ov_rank_std = np.nan
ov_hhi = np.nan
ov_participation = np.nan
ov_lam_weighted_mean = float(np.nanmean(S_perm * S_perm))
ov_lam_weighted_var = float(np.nanvar(S_perm * S_perm))
ov_rank_mean = float((trap_mode_index + 1))
ov_rank_std = 0.0
ov_hhi = trap_ipr if np.isfinite(trap_ipr) else np.nan
ov_participation = float(1.0 / ov_hhi) if np.isfinite(ov_hhi) and ov_hhi > 0 else np.nan
ov_entropy = np.nan
ov_max = np.nan

Expand Down Expand Up @@ -4224,6 +4330,12 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
"trap_variance_burden_ipr": trap_variance_burden_ipr,
"trap_variance_burden_top5": trap_variance_burden_top5,
"trap_variance_burden": trap_variance_burden,
"B_absDelta_ipr_ovlamvar": trap_variance_burden_ipr,
"B_absDelta_logtop5_ovlamvar": float(spectral_excess_abs * log1p_top_5_lift * ov_lam_weighted_var) if np.all(
np.isfinite([spectral_excess_abs, log1p_top_5_lift, ov_lam_weighted_var])
) else np.nan,
"B_evalsq_logtop5_ovrank": trap_variance_burden_top5,
"B_old_pr359_paper": trap_variance_burden_old,
})

for k, v in u_metrics.items():
Expand Down