Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,24 @@ trap_df = watcher.analyze_traps(layers=[3, 5], plot=True, savefig="trap_images")

For a complete walkthrough (including `remove_traps`), see: [Correlation Trap Workflow (`analyze_traps` + `remove_traps`)](./docs_trap_features.md)

#### Trap variance burden

`analyze_traps` can optionally compute trap burden metrics with `trap_burden=True`.

IPR version:

`trap_variance_burden_ipr = spectral_excess_abs * ipr_lift_excess_pos * ov_lam_weighted_var`

Top5 version:

`trap_variance_burden_top5 = spectral_excess_abs * log1p_top_5_lift * ov_rank_mean`

- `spectral_excess_abs` measures trap strength above MP bulk.
- `ipr_lift_excess_pos` measures localization relative to bulk.
- `top_5_lift` measures concentration of the trap matrix relative to bulk modes.
- `ov_lam_weighted_var` measures how broadly the trap overlaps eigenvalue scales of `X = W^T W`.
- `ov_rank_mean` measures where the trap lives in the eigenbasis of `X`.

Fig (a) is well trained; Fig (b) may be over-fit.

That orange spike on the far right is the tell-tale clue; it's caled a **Correlation Trap**.
Expand Down
39 changes: 39 additions & 0 deletions tests/test_analyze_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,45 @@ def test_order_invariant_stats_are_finite(self):
]:
self.assertTrue(np.isfinite(row[col]))

def test_trap_burden_backward_compat(self):
df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=False)
self.assertIsInstance(df, pd.DataFrame)
self.assertIn("top_5_mass", df.columns)
self.assertIn("bulk_top_5_mass_mean", df.columns)

def test_trap_burden_columns_appear(self):
df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True)
required = {
"spectral_excess_abs", "spectral_excess_rel", "trap_ipr", "bulk_ipr_mean",
"ipr_lift_excess_pos", "top_5_lift", "log1p_top_5_lift", "ov_lam_weighted_var",
"ov_rank_mean", "trap_variance_burden_ipr", "trap_variance_burden_top5",
"trap_variance_burden",
}
self.assertTrue(required.issubset(set(df.columns)))

def test_trap_burden_finite_values_if_traps(self):
df = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True)
if len(df) == 0:
self.skipTest("No traps detected in this environment")
cols = ["spectral_excess_abs", "ov_lam_weighted_var", "ov_rank_mean", "trap_variance_burden"]
finite_mask = np.isfinite(df[cols]).all(axis=1)
self.assertTrue(bool(finite_mask.any()))

def test_trap_burden_variant_selection(self):
df_ipr = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_variant="ipr")
mask_ipr = np.isfinite(df_ipr["trap_variance_burden"]) & np.isfinite(df_ipr["trap_variance_burden_ipr"])
if mask_ipr.any():
self.assertTrue(np.allclose(df_ipr.loc[mask_ipr, "trap_variance_burden"], df_ipr.loc[mask_ipr, "trap_variance_burden_ipr"]))

df_top5 = self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_variant="top5")
mask_top5 = np.isfinite(df_top5["trap_variance_burden"]) & np.isfinite(df_top5["trap_variance_burden_top5"])
if mask_top5.any():
self.assertTrue(np.allclose(df_top5.loc[mask_top5, "trap_variance_burden"], df_top5.loc[mask_top5, "trap_variance_burden_top5"]))

def test_trap_burden_variant_invalid(self):
with self.assertRaises(ValueError):
self.watcher.analyze_traps(plot=False, savefig=False, trap_burden=True, trap_burden_variant="bad")


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions weightwatcher/trap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def analyze_traps(
base_model=None,
peft=wwcore.DEFAULT_PEFT,
rng=None,
trap_burden=False,
trap_burden_variant="top5",
):
"""Externalized implementation for WeightWatcher.analyze_traps()."""
if layers is None:
Expand Down Expand Up @@ -73,6 +75,8 @@ def analyze_traps(
params[wwcore.PEFT] = peft
params[wwcore.INVERSE] = False
params["rng"] = remove_traps_ops._normalize_trap_rng(rng=rng)
params["trap_burden"] = bool(trap_burden)
params["trap_burden_variant"] = trap_burden_variant

wwcore.logger.debug("params {}".format(params))
if not watcher.valid_params(params):
Expand Down
128 changes: 126 additions & 2 deletions weightwatcher/weightwatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3705,7 +3705,9 @@ def analyze_traps(self, model=None, layers=[],
start_ids=DEFAULT_START_ID,
base_model=None,
peft=DEFAULT_PEFT,
rng=None):
rng=None,
trap_burden=False,
trap_burden_variant="top5"):
"""Analyze randomized correlation traps and return one row per trap.

This method follows the randomized/permuted trap workflow:
Expand Down Expand Up @@ -3750,6 +3752,8 @@ def analyze_traps(self, model=None, layers=[],
base_model=base_model,
peft=peft,
rng=rng,
trap_burden=trap_burden,
trap_burden_variant=trap_burden_variant,
)

def _trap_result_columns(self):
Expand All @@ -3768,6 +3772,13 @@ def _trap_result_columns(self):
"v_top1_mass", "v_top5_mass", "v_top10_mass", "v_squared_amp_entropy", "v_stable_rank_surrogate",
"trap_balance_ratio", "trap_detected", "trap_eval_minus_bulk",
"trap_diffuseness_score", "trap_risk_score", "trap_assessment",
"top_5_mass", "bulk_top_5_mass_mean", "bulk_top_5_mass_std",
"spectral_excess_abs", "spectral_excess_rel", "trap_ipr",
"bulk_ipr_mean", "bulk_ipr_std", "ipr_lift", "ipr_lift_excess_pos",
"top_5_lift", "top_5_lift_excess_pos", "log1p_top_5_lift",
"ov_lam_weighted_mean", "ov_lam_weighted_var", "ov_rank_mean", "ov_rank_std",
"ov_hhi", "ov_participation", "ov_entropy", "ov_max",
"trap_variance_burden_ipr", "trap_variance_burden_top5", "trap_variance_burden",
]


Expand All @@ -3785,6 +3796,9 @@ def apply_analyze_traps(self, ww_layer, params=None):
self.apply_permute_W(ww_layer, params)
self.apply_trap_mp_fit(ww_layer, params=params)
trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params)
bulk_reference_metrics = self.compute_bulk_trap_reference_metrics(
ww_layer, trap_mode_indices=trap_mode_indices, params=params
)

trap_rows = []
for trap_index, mode_index in enumerate(trap_mode_indices):
Expand All @@ -3794,12 +3808,40 @@ def apply_analyze_traps(self, ww_layer, params=None):
original_basis_cache=original_basis_cache,
params=params,
trap_index=trap_index,
bulk_reference_metrics=bulk_reference_metrics,
)
trap_rows.append(trap_row)

self.apply_unpermute_W(ww_layer, params)
return trap_rows

def compute_bulk_trap_reference_metrics(self, ww_layer, trap_mode_indices=None, params=None):
if params is None: params = DEFAULT_PARAMS.copy()
if trap_mode_indices is None:
trap_mode_indices = []
W_perm = ww_layer.Wmats[0].astype(float)
p_ids = ww_layer.permute_ids[0]
U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD])
V_perm = Vh_perm.T
trap_set = set(int(i) for i in trap_mode_indices)
total_modes = len(S_perm)
bulk_modes = [k for k in range(total_modes) if k not in trap_set]
top5_vals, ipr_vals = [], []
for k in bulk_modes:
sigma_k = float(S_perm[k])
T_perm = sigma_k * np.outer(U_perm[:, k], V_perm[:, k])
T_orig = unpermute_matrix(T_perm, p_ids)
top5_vals.append(float(self._top_mass_fraction(T_orig, frac=0.05)))
Ut, _, Vht = svd_full(T_orig, method=params[SVD_METHOD])
u = Ut[:, 0]; v = Vht.T[:, 0]
ipr_vals.append(float(0.5 * (np.sum(u ** 4) + np.sum(v ** 4))))
return {
"bulk_top_5_mass_mean": float(np.mean(top5_vals)) if len(top5_vals) else np.nan,
"bulk_top_5_mass_std": float(np.std(top5_vals)) if len(top5_vals) else np.nan,
"bulk_ipr_mean": float(np.mean(ipr_vals)) if len(ipr_vals) else np.nan,
"bulk_ipr_std": float(np.std(ipr_vals)) if len(ipr_vals) else np.nan,
}


def apply_trap_mp_fit(self, ww_layer, params=None):
if params is None: params = DEFAULT_PARAMS.copy()
Expand Down Expand Up @@ -3831,7 +3873,7 @@ def compute_original_basis_for_traps(self, ww_layer, params=None):
}


def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0):
def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=None, params=None, trap_index=0, bulk_reference_metrics=None):
if params is None: params = DEFAULT_PARAMS.copy()
if original_basis_cache is None:
original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params)
Expand Down Expand Up @@ -3923,6 +3965,43 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No
trap_result["u_effective_support"] / (trap_result["v_effective_support"] + 1e-12)
)
trap_result.update(self.assess_trap_diffuseness(trap_result))
if bulk_reference_metrics is None:
bulk_reference_metrics = {}
trap_result.update(bulk_reference_metrics)
trap_result["top_5_mass"] = float(self._top_mass_fraction(T_orig, frac=0.05))

if params.get("trap_burden", False):
variant = params.get("trap_burden_variant", "top5")
if variant not in {"ipr", "top5", "both"}:
raise ValueError("trap_burden_variant must be one of: 'ipr', 'top5', 'both'")
mp_bulk_max = float(trap_result.get("mp_bulk_max", np.nan))
spectral_excess_abs = float(max(eval_perm - mp_bulk_max, 0.0)) if np.isfinite(mp_bulk_max) else np.nan
spectral_excess_rel = float(spectral_excess_abs / mp_bulk_max) if np.isfinite(mp_bulk_max) and mp_bulk_max != 0 else np.nan
trap_ipr = float(0.5 * (np.sum(u_trap ** 4) + np.sum(v_trap ** 4)))
bulk_ipr_mean = float(trap_result.get("bulk_ipr_mean", np.nan))
ipr_lift = float(trap_ipr / bulk_ipr_mean) if np.isfinite(bulk_ipr_mean) and bulk_ipr_mean != 0 else np.nan
ipr_lift_excess_pos = float(max(ipr_lift - 1.0, 0.0)) if np.isfinite(ipr_lift) else np.nan
bulk_top_5_mass_mean = float(trap_result.get("bulk_top_5_mass_mean", np.nan))
top_5_lift = float(trap_result["top_5_mass"] / bulk_top_5_mass_mean) if np.isfinite(bulk_top_5_mass_mean) and bulk_top_5_mass_mean != 0 else np.nan
top_5_lift_excess_pos = float(max(top_5_lift - 1.0, 0.0)) if np.isfinite(top_5_lift) else np.nan
log1p_top_5_lift = float(np.log1p(top_5_lift)) if np.isfinite(top_5_lift) else np.nan
ov = self._compute_overlap_metrics(original_basis_cache["W_true"], v_trap)
burden_ipr = float(spectral_excess_abs * ipr_lift_excess_pos * ov["ov_lam_weighted_var"]) if np.isfinite(spectral_excess_abs) and np.isfinite(ipr_lift_excess_pos) and np.isfinite(ov["ov_lam_weighted_var"]) else np.nan
burden_top5 = float(spectral_excess_abs * log1p_top_5_lift * ov["ov_rank_mean"]) if np.isfinite(spectral_excess_abs) and np.isfinite(log1p_top_5_lift) and np.isfinite(ov["ov_rank_mean"]) else np.nan
trap_result.update({
"spectral_excess_abs": spectral_excess_abs,
"spectral_excess_rel": spectral_excess_rel,
"trap_ipr": trap_ipr,
"ipr_lift": ipr_lift,
"ipr_lift_excess_pos": ipr_lift_excess_pos,
"top_5_lift": top_5_lift,
"top_5_lift_excess_pos": top_5_lift_excess_pos,
"log1p_top_5_lift": log1p_top_5_lift,
"trap_variance_burden_ipr": burden_ipr,
"trap_variance_burden_top5": burden_top5,
})
trap_result.update(ov)
trap_result["trap_variance_burden"] = burden_ipr if variant == "ipr" else burden_top5

trap_result["left_overlaps"] = left_overlaps
trap_result["right_overlaps"] = right_overlaps
Expand Down Expand Up @@ -3997,6 +4076,51 @@ def assess_trap_diffuseness(self, trap_result):
"trap_assessment": assessment,
}

def _top_mass_fraction(self, matrix, frac=0.05):
arr = np.abs(np.asarray(matrix, dtype=float).ravel())
if arr.size == 0:
return np.nan
total = np.sum(arr)
if not np.isfinite(total) or total <= 0:
return np.nan
k = max(1, int(np.ceil(frac * arr.size)))
idx = np.argpartition(arr, -k)[-k:]
return float(np.sum(arr[idx]) / total)

def _compute_overlap_metrics(self, W_orig, v_trap):
X = W_orig.T @ W_orig
lam, phi = np.linalg.eigh(X)
order = np.argsort(lam)[::-1]
lam = lam[order]
phi = phi[:, order]
v = np.asarray(v_trap, dtype=float)
nv = np.linalg.norm(v)
if nv == 0 or not np.isfinite(nv):
return {k: np.nan for k in ["ov_lam_weighted_mean", "ov_lam_weighted_var", "ov_rank_mean", "ov_rank_std", "ov_hhi", "ov_participation", "ov_entropy", "ov_max"]}
v = v / nv
p = np.abs(phi.T @ v) ** 2
ps = np.sum(p)
if ps <= 0 or not np.isfinite(ps):
return {k: np.nan for k in ["ov_lam_weighted_mean", "ov_lam_weighted_var", "ov_rank_mean", "ov_rank_std", "ov_hhi", "ov_participation", "ov_entropy", "ov_max"]}
p = p / ps
ov_mean = float(np.sum(p * lam))
ov_var = float(np.sum(p * (lam - ov_mean) ** 2))
ranks = np.arange(1, len(lam) + 1, dtype=float)
rk_mean = float(np.sum(p * ranks))
rk_std = float(np.sqrt(np.sum(p * (ranks - rk_mean) ** 2)))
eps = 1e-12
hhi = float(np.sum(p ** 2))
return {
"ov_lam_weighted_mean": ov_mean,
"ov_lam_weighted_var": ov_var,
"ov_rank_mean": rk_mean,
"ov_rank_std": rk_std,
"ov_hhi": hhi,
"ov_participation": float(1.0 / hhi) if hhi > 0 else np.nan,
"ov_entropy": float(-np.sum(p * np.log(p + eps))),
"ov_max": float(np.max(p)),
}


def plot_trap_analysis(self, ww_layer, trap_result, params=None):
if params is None: params = DEFAULT_PARAMS.copy()
Expand Down