From 3ac784de4f9ba08efce1b507ae763fc9c9c120bb Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Tue, 28 Apr 2026 22:51:10 -0700 Subject: [PATCH] Add optional trap burden metrics and variants to analyze_traps --- README.md | 18 +++++ tests/test_analyze_traps.py | 39 ++++++++++ weightwatcher/trap_analysis.py | 4 ++ weightwatcher/weightwatcher.py | 128 ++++++++++++++++++++++++++++++++- 4 files changed, 187 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 248f568..38f4e4b 100644 --- a/README.md +++ b/README.md @@ -291,6 +291,24 @@ trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images") For a complete walkthrough (including `remove_traps`), see: [Correlation Trap Workflow (`analyze_traps` + `remove_traps`)](./docs_trap_features.md) +#### Trap variance burden + +`analyze_traps` can optionally compute trap burden metrics with `trap_burden=True`. + +IPR version: + +`trap_variance_burden_ipr = spectral_excess_abs * ipr_lift_excess_pos * ov_lam_weighted_var` + +Top5 version: + +`trap_variance_burden_top5 = spectral_excess_abs * log1p_top_5_lift * ov_rank_mean` + +- `spectral_excess_abs` measures trap strength above MP bulk. +- `ipr_lift_excess_pos` measures localization relative to bulk. +- `top_5_lift` measures concentration of the trap matrix relative to bulk modes. +- `ov_lam_weighted_var` measures how broadly the trap overlaps eigenvalue scales of `X = W^T W`. +- `ov_rank_mean` measures where the trap lives in the eigenbasis of `X`. + Fig (a) is well trained; Fig (b) may be over-fit. That orange spike on the far right is the tell-tale clue; it's caled a **Correlation Trap**. diff --git a/tests/test_analyze_traps.py b/tests/test_analyze_traps.py index 4ce08f1..7ab2b03 100644 --- a/tests/test_analyze_traps.py +++ b/tests/test_analyze_traps.py @@ -137,6 +137,45 @@ def test_order_invariant_stats_are_finite(self): ]: self.assertTrue(np.isfinite(row[col])) + def test_trap_burden_backward_compat(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=False) + self.assertIsInstance(df, pd.DataFrame) + self.assertIn("top_5_mass", df.columns) + self.assertIn("bulk_top_5_mass_mean", df.columns) + + def test_trap_burden_columns_appear(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True) + required = { + "spectral_excess_abs", "spectral_excess_rel", "trap_ipr", "bulk_ipr_mean", + "ipr_lift_excess_pos", "top_5_lift", "log1p_top_5_lift", "ov_lam_weighted_var", + "ov_rank_mean", "trap_variance_burden_ipr", "trap_variance_burden_top5", + "trap_variance_burden", + } + self.assertTrue(required.issubset(set(df.columns))) + + def test_trap_burden_finite_values_if_traps(self): + df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True) + if len(df) == 0: + self.skipTest("No traps detected in this environment") + cols = ["spectral_excess_abs", "ov_lam_weighted_var", "ov_rank_mean", "trap_variance_burden"] + finite_mask = np.isfinite(df[cols]).all(axis=1) + self.assertTrue(bool(finite_mask.any())) + + def test_trap_burden_variant_selection(self): + df_ipr = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_variant="ipr") + mask_ipr = np.isfinite(df_ipr["trap_variance_burden"]) & np.isfinite(df_ipr["trap_variance_burden_ipr"]) + if mask_ipr.any(): + self.assertTrue(np.allclose(df_ipr.loc[mask_ipr, "trap_variance_burden"], df_ipr.loc[mask_ipr, "trap_variance_burden_ipr"])) + + df_top5 = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_variant="top5") + mask_top5 = np.isfinite(df_top5["trap_variance_burden"]) & np.isfinite(df_top5["trap_variance_burden_top5"]) + if mask_top5.any(): + self.assertTrue(np.allclose(df_top5.loc[mask_top5, "trap_variance_burden"], df_top5.loc[mask_top5, "trap_variance_burden_top5"])) + + def test_trap_burden_variant_invalid(self): + with self.assertRaises(ValueError): + self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_variant="bad") + if __name__ == "__main__": unittest.main() diff --git a/weightwatcher/trap_analysis.py b/weightwatcher/trap_analysis.py index c35bdd6..5ff8ebb 100644 --- a/weightwatcher/trap_analysis.py +++ b/weightwatcher/trap_analysis.py @@ -29,6 +29,8 @@ def analyze_traps( base_model=None, peft=wwcore.DEFAULT_PEFT, rng=None, + trap_burden=False, + trap_burden_variant="top5", ): """Externalized implementation for WeightWatcher.analyze_traps().""" if layers is None: @@ -73,6 +75,8 @@ def analyze_traps( params[wwcore.PEFT] = peft params[wwcore.INVERSE] = False params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng) + params["trap_burden"] = bool(trap_burden) + params["trap_burden_variant"] = trap_burden_variant wwcore.logger.debug("params {}".format(params)) if not watcher.valid_params(params): diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 0da873b..b935366 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3705,7 +3705,9 @@ def analyze_traps(self, model=None, layers=[], start_ids=DEFAULT_START_ID, base_model=None, peft=DEFAULT_PEFT, - rng=None): + rng=None, + trap_burden=False, + trap_burden_variant="top5"): """Analyze randomized correlation traps and return one row per trap. This method follows the randomized/permuted trap workflow: @@ -3750,6 +3752,8 @@ def analyze_traps(self, model=None, layers=[], base_model=base_model, peft=peft, rng=rng, + trap_burden=trap_burden, + trap_burden_variant=trap_burden_variant, ) def _trap_result_columns(self): @@ -3768,6 +3772,13 @@ def _trap_result_columns(self): "v_top1_mass", "v_top5_mass", "v_top10_mass", "v_squared_amp_entropy", "v_stable_rank_surrogate", "trap_balance_ratio", "trap_detected", "trap_eval_minus_bulk", "trap_diffuseness_score", "trap_risk_score", "trap_assessment", + "top_5_mass", "bulk_top_5_mass_mean", "bulk_top_5_mass_std", + "spectral_excess_abs", "spectral_excess_rel", "trap_ipr", + "bulk_ipr_mean", "bulk_ipr_std", "ipr_lift", "ipr_lift_excess_pos", + "top_5_lift", "top_5_lift_excess_pos", "log1p_top_5_lift", + "ov_lam_weighted_mean", "ov_lam_weighted_var", "ov_rank_mean", "ov_rank_std", + "ov_hhi", "ov_participation", "ov_entropy", "ov_max", + "trap_variance_burden_ipr", "trap_variance_burden_top5", "trap_variance_burden", ] @@ -3785,6 +3796,9 @@ def apply_analyze_traps(self, ww_layer, params=None): self.apply_permute_W(ww_layer, params) self.apply_trap_mp_fit(ww_layer, params=params) trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params) + bulk_reference_metrics = self.compute_bulk_trap_reference_metrics( + ww_layer, trap_mode_indices=trap_mode_indices, params=params + ) trap_rows = [] for trap_index, mode_index in enumerate(trap_mode_indices): @@ -3794,12 +3808,40 @@ def apply_analyze_traps(self, ww_layer, params=None): original_basis_cache=original_basis_cache, params=params, trap_index=trap_index, + bulk_reference_metrics=bulk_reference_metrics, ) trap_rows.append(trap_row) self.apply_unpermute_W(ww_layer, params) return trap_rows + def compute_bulk_trap_reference_metrics(self, ww_layer, trap_mode_indices=None, params=None): + if params is None: params = DEFAULT_PARAMS.copy() + if trap_mode_indices is None: + trap_mode_indices = [] + W_perm = ww_layer.Wmats[0].astype(float) + p_ids = ww_layer.permute_ids[0] + U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) + V_perm = Vh_perm.T + trap_set = set(int(i) for i in trap_mode_indices) + total_modes = len(S_perm) + bulk_modes = [k for k in range(total_modes) if k not in trap_set] + top5_vals, ipr_vals = [], [] + for k in bulk_modes: + sigma_k = float(S_perm[k]) + T_perm = sigma_k * np.outer(U_perm[:, k], V_perm[:, k]) + T_orig = unpermute_matrix(T_perm, p_ids) + top5_vals.append(float(self._top_mass_fraction(T_orig, frac=0.05))) + Ut, _, Vht = svd_full(T_orig, method=params[SVD_METHOD]) + u = Ut[:, 0]; v = Vht.T[:, 0] + ipr_vals.append(float(0.5 * (np.sum(u ** 4) + np.sum(v ** 4)))) + return { + "bulk_top_5_mass_mean": float(np.mean(top5_vals)) if len(top5_vals) else np.nan, + "bulk_top_5_mass_std": float(np.std(top5_vals)) if len(top5_vals) else np.nan, + "bulk_ipr_mean": float(np.mean(ipr_vals)) if len(ipr_vals) else np.nan, + "bulk_ipr_std": float(np.std(ipr_vals)) if len(ipr_vals) else np.nan, + } + def apply_trap_mp_fit(self, ww_layer, params=None): if params is None: params = DEFAULT_PARAMS.copy() @@ -3831,7 +3873,7 @@ def compute_original_basis_for_traps(self, ww_layer, params=None): } - def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0): + def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0, bulk_reference_metrics=None): if params is None: params = DEFAULT_PARAMS.copy() if original_basis_cache is None: original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) @@ -3923,6 +3965,43 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No trap_result["u_effective_support"] / (trap_result["v_effective_support"] + 1e-12) ) trap_result.update(self.assess_trap_diffuseness(trap_result)) + if bulk_reference_metrics is None: + bulk_reference_metrics = {} + trap_result.update(bulk_reference_metrics) + trap_result["top_5_mass"] = float(self._top_mass_fraction(T_orig, frac=0.05)) + + if params.get("trap_burden", False): + variant = params.get("trap_burden_variant", "top5") + if variant not in {"ipr", "top5", "both"}: + raise ValueError("trap_burden_variant must be one of: 'ipr', 'top5', 'both'") + mp_bulk_max = float(trap_result.get("mp_bulk_max", np.nan)) + spectral_excess_abs = float(max(eval_perm - mp_bulk_max, 0.0)) if np.isfinite(mp_bulk_max) else np.nan + spectral_excess_rel = float(spectral_excess_abs / mp_bulk_max) if np.isfinite(mp_bulk_max) and mp_bulk_max != 0 else np.nan + trap_ipr = float(0.5 * (np.sum(u_trap ** 4) + np.sum(v_trap ** 4))) + bulk_ipr_mean = float(trap_result.get("bulk_ipr_mean", np.nan)) + ipr_lift = float(trap_ipr / bulk_ipr_mean) if np.isfinite(bulk_ipr_mean) and bulk_ipr_mean != 0 else np.nan + ipr_lift_excess_pos = float(max(ipr_lift - 1.0, 0.0)) if np.isfinite(ipr_lift) else np.nan + bulk_top_5_mass_mean = float(trap_result.get("bulk_top_5_mass_mean", np.nan)) + top_5_lift = float(trap_result["top_5_mass"] / bulk_top_5_mass_mean) if np.isfinite(bulk_top_5_mass_mean) and bulk_top_5_mass_mean != 0 else np.nan + top_5_lift_excess_pos = float(max(top_5_lift - 1.0, 0.0)) if np.isfinite(top_5_lift) else np.nan + log1p_top_5_lift = float(np.log1p(top_5_lift)) if np.isfinite(top_5_lift) else np.nan + ov = self._compute_overlap_metrics(original_basis_cache["W_true"], v_trap) + burden_ipr = float(spectral_excess_abs * ipr_lift_excess_pos * ov["ov_lam_weighted_var"]) if np.isfinite(spectral_excess_abs) and np.isfinite(ipr_lift_excess_pos) and np.isfinite(ov["ov_lam_weighted_var"]) else np.nan + burden_top5 = float(spectral_excess_abs * log1p_top_5_lift * ov["ov_rank_mean"]) if np.isfinite(spectral_excess_abs) and np.isfinite(log1p_top_5_lift) and np.isfinite(ov["ov_rank_mean"]) else np.nan + trap_result.update({ + "spectral_excess_abs": spectral_excess_abs, + "spectral_excess_rel": spectral_excess_rel, + "trap_ipr": trap_ipr, + "ipr_lift": ipr_lift, + "ipr_lift_excess_pos": ipr_lift_excess_pos, + "top_5_lift": top_5_lift, + "top_5_lift_excess_pos": top_5_lift_excess_pos, + "log1p_top_5_lift": log1p_top_5_lift, + "trap_variance_burden_ipr": burden_ipr, + "trap_variance_burden_top5": burden_top5, + }) + trap_result.update(ov) + trap_result["trap_variance_burden"] = burden_ipr if variant == "ipr" else burden_top5 trap_result["left_overlaps"] = left_overlaps trap_result["right_overlaps"] = right_overlaps @@ -3997,6 +4076,51 @@ def assess_trap_diffuseness(self, trap_result): "trap_assessment": assessment, } + def _top_mass_fraction(self, matrix, frac=0.05): + arr = np.abs(np.asarray(matrix, dtype=float).ravel()) + if arr.size == 0: + return np.nan + total = np.sum(arr) + if not np.isfinite(total) or total <= 0: + return np.nan + k = max(1, int(np.ceil(frac * arr.size))) + idx = np.argpartition(arr, -k)[-k:] + return float(np.sum(arr[idx]) / total) + + def _compute_overlap_metrics(self, W_orig, v_trap): + X = W_orig.T @ W_orig + lam, phi = np.linalg.eigh(X) + order = np.argsort(lam)[::-1] + lam = lam[order] + phi = phi[:, order] + v = np.asarray(v_trap, dtype=float) + nv = np.linalg.norm(v) + if nv == 0 or not np.isfinite(nv): + return {k: np.nan for k in ["ov_lam_weighted_mean", "ov_lam_weighted_var", "ov_rank_mean", "ov_rank_std", "ov_hhi", "ov_participation", "ov_entropy", "ov_max"]} + v = v / nv + p = np.abs(phi.T @ v) ** 2 + ps = np.sum(p) + if ps <= 0 or not np.isfinite(ps): + return {k: np.nan for k in ["ov_lam_weighted_mean", "ov_lam_weighted_var", "ov_rank_mean", "ov_rank_std", "ov_hhi", "ov_participation", "ov_entropy", "ov_max"]} + p = p / ps + ov_mean = float(np.sum(p * lam)) + ov_var = float(np.sum(p * (lam - ov_mean) ** 2)) + ranks = np.arange(1, len(lam) + 1, dtype=float) + rk_mean = float(np.sum(p * ranks)) + rk_std = float(np.sqrt(np.sum(p * (ranks - rk_mean) ** 2))) + eps = 1e-12 + hhi = float(np.sum(p ** 2)) + return { + "ov_lam_weighted_mean": ov_mean, + "ov_lam_weighted_var": ov_var, + "ov_rank_mean": rk_mean, + "ov_rank_std": rk_std, + "ov_hhi": hhi, + "ov_participation": float(1.0 / hhi) if hhi > 0 else np.nan, + "ov_entropy": float(-np.sum(p * np.log(p + eps))), + "ov_max": float(np.max(p)), + } + def plot_trap_analysis(self, ww_layer, trap_result, params=None): if params is None: params = DEFAULT_PARAMS.copy()