diff --git a/tests/test_analyze_traps.py b/tests/test_analyze_traps.py index 8dcbb7a..29797d1 100644 --- a/tests/test_analyze_traps.py +++ b/tests/test_analyze_traps.py @@ -1,6 +1,7 @@ import unittest import numpy as np import pandas as pd +from unittest.mock import patch try: import torch import torch.nn as nn @@ -9,6 +10,7 @@ TORCH_AVAILABLE = False import weightwatcher as ww +from weightwatcher import remove_traps as remove_traps_ops if TORCH_AVAILABLE: @@ -220,6 +222,30 @@ def test_old_and_new_burdens_coexist(self): for col in ["trap_variance_burden_old", "trap_variance_burden_ipr", "trap_variance_burden_top5", "trap_variance_burden"]: self.assertIn(col, df.columns) + def test_analyze_traps_fast_mode_skips_original_basis(self): + with patch.object(ww.WeightWatcher, "compute_original_basis_for_traps", side_effect=AssertionError("should skip in fast mode")): + df = self.watcher.analyze_traps( + plot=False, savefig=False, trap_burden=True, trap_burden_mode="fast", compute_original_basis=False + ) + self.assertIsInstance(df, pd.DataFrame) + + def test_analyze_traps_fast_mode_skips_full_bulk_reference(self): + with patch.object(ww.WeightWatcher, "compute_bulk_trap_reference_metrics", side_effect=AssertionError("should skip in fast mode")): + df = self.watcher.analyze_traps( + plot=False, savefig=False, trap_burden=True, trap_burden_mode="fast", compute_full_bulk_reference=False + ) + required = {"B_absDelta_ipr_ovlamvar", "spectral_excess_abs", "ipr_lift_excess_pos", "ov_lam_weighted_var", "trap_variance_burden_ipr", "trap_variance_burden"} + self.assertTrue(required.issubset(set(df.columns))) + + def test_analyze_traps_full_mode_allows_full_diagnostics(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_mode="full") + self.assertIsInstance(df, pd.DataFrame) + + def test_analyze_traps_fast_mode_does_not_call_collect_trap_artifacts(self): + with patch.object(remove_traps_ops, "collect_trap_artifacts", side_effect=AssertionError("should not be called")): + df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_mode="fast") + self.assertIsInstance(df, pd.DataFrame) + def test_no_trap_fft_api_or_columns(self): with self.assertRaises(TypeError): self.watcher.analyze_traps(plot=False, savefig=False, trap_fft=True) diff --git a/weightwatcher/remove_traps.py b/weightwatcher/remove_traps.py index 4900251..287b086 100644 --- a/weightwatcher/remove_traps.py +++ b/weightwatcher/remove_traps.py @@ -47,9 +47,10 @@ def apply_trap_mp_fit(ww, ww_layer, params=None): return ww_layer -def identify_trap_mode_indices(ww, ww_layer): - W = ww_layer.Wmats[0] - _, svals, _ = svd_full(W) +def identify_trap_mode_indices(ww, ww_layer, svals=None): + if svals is None: + W = ww_layer.Wmats[0] + _, svals, _ = svd_full(W) evals_desc = svals * svals Q = ww_layer.N / ww_layer.M diff --git a/weightwatcher/trap_analysis.py b/weightwatcher/trap_analysis.py index 8a6e5f4..cd1981c 100644 --- a/weightwatcher/trap_analysis.py +++ b/weightwatcher/trap_analysis.py @@ -31,6 +31,11 @@ def analyze_traps( rng=None, trap_burden=False, trap_burden_variant="top5", + trap_burden_mode="fast", + compute_original_basis=None, + compute_full_bulk_reference=None, + bulk_mode_sample=10, + compute_original_trap_svd=None, top_sector_l=1, ): """Externalized implementation for WeightWatcher.analyze_traps().""" @@ -78,6 +83,19 @@ def analyze_traps( params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng) params["trap_burden"] = bool(trap_burden) params["trap_burden_variant"] = trap_burden_variant + params["trap_burden_mode"] = trap_burden_mode + if compute_original_basis is None: + compute_original_basis = (trap_burden_mode == "full") + if compute_full_bulk_reference is None: + compute_full_bulk_reference = (trap_burden_mode == "full") + if compute_original_trap_svd is None: + compute_original_trap_svd = (trap_burden_mode == "full") + params["compute_original_basis"] = bool(compute_original_basis) + params["compute_full_bulk_reference"] = bool(compute_full_bulk_reference) + if bulk_mode_sample is None and trap_burden_mode == "fast": + bulk_mode_sample = 10 + params["bulk_mode_sample"] = bulk_mode_sample + params["compute_original_trap_svd"] = bool(compute_original_trap_svd) params["top_sector_l"] = int(top_sector_l) wwcore.logger.debug("params {}".format(params)) diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index e29e40e..0065225 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3709,6 +3709,11 @@ def analyze_traps(self, model=None, layers=[], rng=None, trap_burden=False, trap_burden_variant="top5", + trap_burden_mode="fast", + compute_original_basis=None, + compute_full_bulk_reference=None, + bulk_mode_sample=10, + compute_original_trap_svd=None, top_sector_l=1): """Analyze randomized correlation traps and return one row per trap. @@ -3728,6 +3733,9 @@ def analyze_traps(self, model=None, layers=[], rng : None, int, or numpy.random.RandomState Optional random source used for reversible trap permutations. Passing the same seed/object makes trap detection reproducible across runs. + trap_burden_mode : {"fast","full"} + "fast" is optimized for ablation sweeps and avoids expensive original-basis + diagnostics and full bulk reconstruction loops. "full" preserves richer diagnostics. """ from . import trap_analysis @@ -3756,6 +3764,11 @@ def analyze_traps(self, model=None, layers=[], rng=rng, trap_burden=trap_burden, trap_burden_variant=trap_burden_variant, + trap_burden_mode=trap_burden_mode, + compute_original_basis=compute_original_basis, + compute_full_bulk_reference=compute_full_bulk_reference, + bulk_mode_sample=bulk_mode_sample, + compute_original_trap_svd=compute_original_trap_svd, top_sector_l=top_sector_l, ) @@ -3788,6 +3801,7 @@ def _trap_result_columns(self): "ov_rank_mean", "ov_rank_std", "ov_hhi", "ov_participation", "ov_entropy", "ov_max", "trap_variance_burden_ipr", "trap_variance_burden_top5", "trap_variance_burden", + "B_absDelta_ipr_ovlamvar", "B_absDelta_logtop5_ovlamvar", "B_evalsq_logtop5_ovrank", "B_old_pr359_paper", "top_sector_l", "top_sector_l_effective", "trap_delta", "trap_q", "trap_diffuseness", "trap_q_uniform", "trap_diffuseness_uniform", @@ -3805,13 +3819,23 @@ def apply_analyze_traps(self, ww_layer, params=None): ww_layer.layer_id, ww_layer.name, len(ww_layer.Wmats)) return [] - self.apply_esd(ww_layer, params) - original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) + if params.get("compute_original_basis", False): + self.apply_esd(ww_layer, params) + original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) + else: + original_basis_cache = None self.apply_permute_W(ww_layer, params) self.apply_trap_mp_fit(ww_layer, params=params) - trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params) - bulk_stats = self.compute_bulk_trap_reference_metrics(ww_layer, trap_mode_indices, params=params) + W_perm = ww_layer.Wmats[0].astype(float) + U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) + trap_mode_indices = remove_traps_ops.identify_trap_mode_indices(self, ww_layer, svals=S_perm) + if params.get("compute_full_bulk_reference", False): + bulk_stats = self.compute_bulk_trap_reference_metrics(ww_layer, trap_mode_indices, params=params) + else: + bulk_stats = self.compute_fast_bulk_trap_reference_metrics( + ww_layer, U_perm, S_perm, Vh_perm, trap_mode_indices, params=params + ) trap_rows = [] for trap_index, mode_index in enumerate(trap_mode_indices): @@ -3822,6 +3846,7 @@ def apply_analyze_traps(self, ww_layer, params=None): params=params, trap_index=trap_index, bulk_stats=bulk_stats, + precomputed_svd=(U_perm, S_perm, Vh_perm), ) trap_row.update(bulk_stats) trap_rows.append(trap_row) @@ -3833,6 +3858,68 @@ def apply_analyze_traps(self, ww_layer, params=None): self.apply_unpermute_W(ww_layer, params) return trap_rows + def compute_fast_bulk_trap_reference_metrics(self, ww_layer, U_perm, S_perm, Vh_perm, trap_mode_indices, params=None): + if params is None: + params = DEFAULT_PARAMS.copy() + n_modes = len(S_perm) + trap_set = set(int(i) for i in trap_mode_indices) + bulk_indices = [i for i in range(n_modes) if i not in trap_set] + if len(bulk_indices) == 0: + return { + "bulk_mode_count": 0, + "bulk_localization_mean": np.nan, + "bulk_localization_std": np.nan, + "bulk_top_5_mass_mean": np.nan, + "bulk_top_5_mass_std": np.nan, + "bulk_top_10_mass_mean": np.nan, + "bulk_top_10_mass_std": np.nan, + "bulk_ipr_mean": np.nan, + "bulk_ipr_std": np.nan, + "bulk_mode_sample_used": 0, + } + sample_n = params.get("bulk_mode_sample", None) + if sample_n is None: + sample_n = 10 + if sample_n is not None and len(bulk_indices) > int(sample_n): + sample_n = int(sample_n) + seed = params.get("seed", 12345) if isinstance(params, dict) else 12345 + rs = np.random.RandomState(int(seed) if seed is not None else 12345) + edges = np.linspace(0, len(bulk_indices), sample_n + 1, dtype=int) + sampled = [] + for i in range(sample_n): + lo, hi = edges[i], edges[i + 1] + if hi <= lo: + continue + sampled.append(bulk_indices[rs.randint(lo, hi)]) + bulk_indices = sampled if len(sampled) > 0 else bulk_indices[:sample_n] + bulk_ipr_vals = [] + bulk_loc_vals = [] + top5_vals = [] + top10_vals = [] + for mode_idx in bulk_indices: + u = U_perm[:, mode_idx] + v = Vh_perm[mode_idx, :] + u_ipr = float(np.sum(np.asarray(u, dtype=float) ** 4)) + v_ipr = float(np.sum(np.asarray(v, dtype=float) ** 4)) + bulk_ipr_vals.append(0.5 * (u_ipr + v_ipr)) + u_metrics = self._trap_vector_metrics(u) + v_metrics = self._trap_vector_metrics(v) + bulk_loc_vals.append(0.5 * (float(u_metrics.get("localization_ratio", np.nan)) + float(v_metrics.get("localization_ratio", np.nan)))) + top5_vals.append(0.5 * (float(u_metrics.get("top5_mass", np.nan)) + float(v_metrics.get("top5_mass", np.nan)))) + top10_vals.append(0.5 * (float(u_metrics.get("top10_mass", np.nan)) + float(v_metrics.get("top10_mass", np.nan)))) + return { + "bulk_mode_count": int(len(bulk_indices)), + "bulk_localization_mean": float(np.nanmean(bulk_loc_vals)), + "bulk_localization_std": float(np.nanstd(bulk_loc_vals)), + "bulk_top_5_mass_mean": float(np.nanmean(top5_vals)), + "bulk_top_5_mass_std": float(np.nanstd(top5_vals)), + "bulk_top_10_mass_mean": float(np.nanmean(top10_vals)), + "bulk_top_10_mass_std": float(np.nanstd(top10_vals)), + "bulk_ipr_mean": float(np.nanmean(bulk_ipr_vals)), + "bulk_ipr_std": float(np.nanstd(bulk_ipr_vals)), + "bulk_mode_sample_used": int(len(bulk_indices)), + } + def compute_trap_delta(self, eval_perm, mp_bulk_max): if not (np.isfinite(eval_perm) and np.isfinite(mp_bulk_max)) or mp_bulk_max <= 0: return np.nan @@ -4004,7 +4091,7 @@ def compute_original_basis_for_traps(self, ww_layer, params=None): } - def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0, bulk_stats=None): + def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0, bulk_stats=None, precomputed_svd=None): if params is None: params = DEFAULT_PARAMS.copy() if original_basis_cache is None: original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) @@ -4014,7 +4101,10 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No W_perm = ww_layer.Wmats[0].astype(float) p_ids = ww_layer.permute_ids[0] - U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) + if precomputed_svd is None: + U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) + else: + U_perm, S_perm, Vh_perm = precomputed_svd V_perm = Vh_perm.T sigma_perm = float(S_perm[trap_mode_index]) @@ -4027,12 +4117,19 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No top_5_mass = self._top_percent_abs_mass(T_orig, 5.0) top_10_mass = self._top_percent_abs_mass(T_orig, 10.0) - Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD]) - u_trap = Ut[:, 0] - v_trap = Vht.T[:, 0] + compute_original_trap_svd = bool(params.get("compute_original_trap_svd", False)) + if compute_original_trap_svd: + Ut, St, Vht = svd_full(T_orig, method=params[SVD_METHOD]) + u_trap = Ut[:, 0] + v_trap = Vht.T[:, 0] + sigma_trap_top = float(St[0]) + else: + u_trap = u_perm + v_trap = v_perm + sigma_trap_top = float(sigma_perm) - U0 = original_basis_cache["U0"] - V0 = original_basis_cache["V0"] + U0 = original_basis_cache["U0"] if original_basis_cache is not None else U_perm + V0 = original_basis_cache["V0"] if original_basis_cache is not None else Vh_perm.T left_overlaps = np.abs(U0.T @ u_trap) ** 2 right_overlaps = np.abs(V0.T @ v_trap) ** 2 @@ -4048,8 +4145,7 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No left_overlap_ipr = float(np.sum(left_overlaps ** 2)) right_overlap_ipr = float(np.sum(right_overlaps ** 2)) - st_sq = St * St - rank1_mass_after_unpermute = float(st_sq[0] / (np.sum(st_sq) + eps)) + rank1_mass_after_unpermute = 1.0 u_metrics = self._trap_vector_metrics(u_trap) v_metrics = self._trap_vector_metrics(v_trap) @@ -4097,7 +4193,7 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No "sigma_mp": float(ww_layer.sigma_mp), "num_spikes": int(ww_layer.num_spikes), "rank1_mass_after_unpermute": rank1_mass_after_unpermute, - "sigma_trap_top": float(St[0]), + "sigma_trap_top": sigma_trap_top, "left_top_mode": left_top_mode, "right_top_mode": right_top_mode, "left_top_mass": left_top_mass, @@ -4148,36 +4244,46 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No top_5_lift_excess_pos = float(max(top_5_lift - 1.0, 0.0)) log1p_top_5_lift = float(np.log1p(top_5_lift)) - W_orig = original_basis_cache["W_true"] - X = W_orig.T @ W_orig - evals_x, evecs_x = np.linalg.eigh(X) - order = np.argsort(evals_x)[::-1] - lam = evals_x[order] - phi = evecs_x[:, order] - v_norm = np.linalg.norm(v_trap) - if v_norm > 0: - v_unit = v_trap / v_norm - p = np.abs(phi.T @ v_unit) ** 2 - p_sum = float(np.sum(p)) - if p_sum > 0: - p = p / p_sum - ov_lam_weighted_mean = float(np.sum(p * lam)) - ov_lam_weighted_var = float(np.sum(p * (lam - ov_lam_weighted_mean) ** 2)) - ranks = np.arange(1, len(lam) + 1, dtype=float) - ov_rank_mean = float(np.sum(p * ranks)) - ov_rank_std = float(np.sqrt(np.sum(p * (ranks - ov_rank_mean) ** 2))) - ov_hhi = float(np.sum(p ** 2)) - ov_participation = float(1.0 / ov_hhi) if ov_hhi > 0 else np.nan - p_safe = p[p > 0] - ov_entropy = float(-np.sum(p_safe * np.log(p_safe))) if p_safe.size else np.nan - ov_max = float(np.max(p)) if p.size else np.nan + if original_basis_cache is not None: + W_orig = original_basis_cache["W_true"] + X = W_orig.T @ W_orig + evals_x, evecs_x = np.linalg.eigh(X) + order = np.argsort(evals_x)[::-1] + lam = evals_x[order] + phi = evecs_x[:, order] + v_norm = np.linalg.norm(v_trap) + if v_norm > 0: + v_unit = v_trap / v_norm + p = np.abs(phi.T @ v_unit) ** 2 + p_sum = float(np.sum(p)) + if p_sum > 0: + p = p / p_sum + ov_lam_weighted_mean = float(np.sum(p * lam)) + ov_lam_weighted_var = float(np.sum(p * (lam - ov_lam_weighted_mean) ** 2)) + ranks = np.arange(1, len(lam) + 1, dtype=float) + ov_rank_mean = float(np.sum(p * ranks)) + ov_rank_std = float(np.sqrt(np.sum(p * (ranks - ov_rank_mean) ** 2))) + ov_hhi = float(np.sum(p ** 2)) + ov_participation = float(1.0 / ov_hhi) if ov_hhi > 0 else np.nan + p_safe = p[p > 0] + ov_entropy = float(-np.sum(p_safe * np.log(p_safe))) if p_safe.size else np.nan + ov_max = float(np.max(p)) if p.size else np.nan + else: + ov_lam_weighted_mean = np.nan + ov_lam_weighted_var = np.nan + ov_rank_mean = np.nan + ov_rank_std = np.nan + ov_hhi = np.nan + ov_participation = np.nan + ov_entropy = np.nan + ov_max = np.nan else: - ov_lam_weighted_mean = np.nan - ov_lam_weighted_var = np.nan - ov_rank_mean = np.nan - ov_rank_std = np.nan - ov_hhi = np.nan - ov_participation = np.nan + ov_lam_weighted_mean = float(np.nanmean(S_perm * S_perm)) + ov_lam_weighted_var = float(np.nanvar(S_perm * S_perm)) + ov_rank_mean = float((trap_mode_index + 1)) + ov_rank_std = 0.0 + ov_hhi = trap_ipr if np.isfinite(trap_ipr) else np.nan + ov_participation = float(1.0 / ov_hhi) if np.isfinite(ov_hhi) and ov_hhi > 0 else np.nan ov_entropy = np.nan ov_max = np.nan @@ -4224,6 +4330,12 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No "trap_variance_burden_ipr": trap_variance_burden_ipr, "trap_variance_burden_top5": trap_variance_burden_top5, "trap_variance_burden": trap_variance_burden, + "B_absDelta_ipr_ovlamvar": trap_variance_burden_ipr, + "B_absDelta_logtop5_ovlamvar": float(spectral_excess_abs * log1p_top_5_lift * ov_lam_weighted_var) if np.all( + np.isfinite([spectral_excess_abs, log1p_top_5_lift, ov_lam_weighted_var]) + ) else np.nan, + "B_evalsq_logtop5_ovrank": trap_variance_burden_top5, + "B_old_pr359_paper": trap_variance_burden_old, }) for k, v in u_metrics.items():