diff --git a/tests/test_analyze_traps.py b/tests/test_analyze_traps.py index b64adb7..c6ab5fc 100644 --- a/tests/test_analyze_traps.py +++ b/tests/test_analyze_traps.py @@ -228,6 +228,21 @@ def test_no_trap_fft_api_or_columns(self): self.assertFalse(any(c.startswith("trap_variance_burden__") for c in df.columns)) + + def test_analyze_traps_public_trap_indices_are_1_based(self): + df, trap_state = self.watcher.analyze_traps(plot=False, savefig=False, return_artifacts=True) + if len(df) == 0: + self.skipTest("No traps detected in this environment") + self.assertGreaterEqual(int(df["trap_index"].min()), 1) + for _, g in df.groupby("layer_id"): + vals = sorted(g["trap_index"].astype(int).tolist()) + self.assertEqual(vals, list(range(1, len(vals) + 1))) + for lid, layer_state in trap_state.get("layers", {}).items(): + arts = layer_state.get("artifacts", []) + if not arts: + continue + self.assertEqual([int(a["trap_index"]) for a in arts], list(range(1, len(arts) + 1))) + def test_analyze_traps_fast_mode_skips_original_basis(self): from unittest.mock import patch with patch.object(ww.WeightWatcher, "compute_original_basis_for_traps", side_effect=AssertionError("should not call")): diff --git a/tests/test_trap_ablation_workflow.py b/tests/test_trap_ablation_workflow.py index c43c322..d679bf4 100644 --- a/tests/test_trap_ablation_workflow.py +++ b/tests/test_trap_ablation_workflow.py @@ -1,6 +1,7 @@ import inspect import pytest import numpy as np +import pandas as pd torch = pytest.importorskip("torch") import weightwatcher as ww from weightwatcher.RMT_Util import permute_matrix, unpermute_matrix @@ -142,3 +143,17 @@ def test_rmt_util_unpermutes_randomized_layer_weight(): assert np.allclose(W_recon, original) W_perm2, pids2 = permute_matrix(original, rng=123) assert np.allclose(unpermute_matrix(W_perm2, pids2), original) + + +def test_remove_traps_rejects_zero_based_public_trap_index(): + model = torch.nn.Sequential(torch.nn.Linear(16, 12, bias=False)) + watcher = ww.WeightWatcher(model=model) + randomized_model, trap_state = watcher.randomize_model(model=model, rng=123, return_state=True, pool=False) + with pytest.raises(ValueError, match="trap_index values are 1-based; got 0"): + watcher.remove_traps( + randomized_model=randomized_model, + traps=pd.DataFrame([{"layer_id": int(sorted(trap_state["permuted_ids"].keys())[0]), "trap_index": 0}]), + trap_state=trap_state, + plot=False, + pool=False, + ) diff --git a/tests/test_trap_component_summary.py b/tests/test_trap_component_summary.py index d46c001..cd9f563 100644 --- a/tests/test_trap_component_summary.py +++ b/tests/test_trap_component_summary.py @@ -19,7 +19,7 @@ def test_top_trap_component_row_extracts_top_10_weight_and_coeff_pairs(): row = { "layer_id": 7, "name": "dense", - "trap_index": 0, + "trap_index": 1, "trap_assessment": "mixed", "trap_risk_score": 0.4, "T_orig": trap, diff --git a/tests/test_trap_compute_efficiency.py b/tests/test_trap_compute_efficiency.py index 714cc42..c35ec25 100644 --- a/tests/test_trap_compute_efficiency.py +++ b/tests/test_trap_compute_efficiency.py @@ -45,10 +45,7 @@ def count(self): def _normalize_selected(df): - selected = df.copy() - if "trap_index" in selected.columns and selected["trap_index"].min() == 0: - selected["trap_index"] = selected["trap_index"].astype(int) + 1 - return selected + return df.copy() def _workflow(monkeypatch): diff --git a/weightwatcher/remove_traps.py b/weightwatcher/remove_traps.py index 9e6871b..2319c5c 100644 --- a/weightwatcher/remove_traps.py +++ b/weightwatcher/remove_traps.py @@ -165,7 +165,7 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N requested = sorted(set(trap_indices)) if any(idx < 1 for idx in requested): - raise ValueError("trap indices are 1-based and must be >= 1") + raise ValueError("trap_index values are 1-based; got 0") layer_seed = seed if layer_seed is None and isinstance(params, dict): @@ -248,6 +248,8 @@ def _trap_indices_from_traps_df(traps): if "trap_index" not in trap_df.columns: raise ValueError("traps must include a 'trap_index' column") indices = trap_df["trap_index"].dropna().astype(int).tolist() + if any(i < 1 for i in indices): + raise ValueError("trap_index values are 1-based; got 0") indices = sorted(set(indices)) if len(indices) == 0: raise ValueError("traps did not contain any valid trap_index values") diff --git a/weightwatcher/trap_analysis.py b/weightwatcher/trap_analysis.py index dc6cb3d..e08ab5a 100644 --- a/weightwatcher/trap_analysis.py +++ b/weightwatcher/trap_analysis.py @@ -171,14 +171,14 @@ def analyze_traps( if params.get(wwcore.PLOT, False): trap_infos = [] for row in layer_rows: - trap_idx_zero_based = int(row.get("trap_index", -1)) + trap_idx = int(row.get("trap_index", -1)) trap_matrix = row.get("T_orig", None) - if trap_idx_zero_based < 0 or trap_matrix is None: + if trap_idx < 1 or trap_matrix is None: continue trap_infos.append( { - "trap_index": trap_idx_zero_based + 1, + "trap_index": trap_idx, "trap_matrix": trap_matrix, "trap_assessment": row.get("trap_assessment", "mixed"), "trap_risk_score": row.get("trap_risk_score", 0.0), @@ -238,7 +238,7 @@ def _top_trap_component_row(row, weight_matrix, top_k=10): out = { "layer_id": row.get("layer_id"), "name": row.get("name"), - "trap_index": int(row.get("trap_index", -1)) + 1, + "trap_index": int(row.get("trap_index", -1)), "trap_assessment": row.get("trap_assessment", "mixed"), "trap_risk_score": float(row.get("trap_risk_score", 0.0)), } diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index a953948..7962716 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3904,11 +3904,11 @@ def apply_analyze_traps(self, ww_layer, params=None): trap_rows=[]; artifacts=[] pids = np.asarray(ww_layer.permute_ids[0]) if len(getattr(ww_layer,'permute_ids',[]))>0 else None fp = hashlib.sha1(pids.tobytes()).hexdigest() if pids is not None else None - for trap_index, mode_index in enumerate(trap_mode_indices): + for trap_index, mode_index in enumerate(trap_mode_indices, start=1): trap_row = self.analyze_single_trap(ww_layer, trap_mode_index=mode_index, original_basis_cache=original_basis_cache, params=params, trap_index=trap_index, bulk_stats=bulk_stats, precomputed_svd=(U_perm,S_perm,Vh_perm)) trap_row.update(bulk_stats); trap_row['permute_fingerprint']=fp trap_rows.append(trap_row) - artifacts.append({"layer_id": int(ww_layer.layer_id), "trap_index": int(trap_index+1), "trap_mode_index": int(mode_index), "sigma_perm": float(S_perm[mode_index]), "permute_fingerprint": fp, "T_perm": trap_row.get("T_perm"), "T_orig_raw": trap_row.get("T_orig"), "u_trap_perm": U_perm[:, mode_index], "v_trap_perm": Vh_perm[mode_index, :]}) + artifacts.append({"layer_id": int(ww_layer.layer_id), "trap_index": int(trap_index), "trap_mode_index": int(mode_index), "sigma_perm": float(S_perm[mode_index]), "permute_fingerprint": fp, "T_perm": trap_row.get("T_perm"), "T_orig_raw": trap_row.get("T_orig"), "u_trap_perm": U_perm[:, mode_index], "v_trap_perm": Vh_perm[mode_index, :]}) if trap_rows: layer_trap_variance_burden = float(np.nansum([row.get("trap_variance_burden_old", np.nan) for row in trap_rows])) for row in trap_rows: row["layer_trap_variance_burden"] = layer_trap_variance_burden