Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/test_analyze_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,21 @@ def test_no_trap_fft_api_or_columns(self):
self.assertFalse(any(c.startswith("trap_variance_burden__") for c in df.columns))



def test_analyze_traps_public_trap_indices_are_1_based(self):
df, trap_state = self.watcher.analyze_traps(plot=False, savefig=False, return_artifacts=True)
if len(df) == 0:
self.skipTest("No traps detected in this environment")
self.assertGreaterEqual(int(df["trap_index"].min()), 1)
for _, g in df.groupby("layer_id"):
vals = sorted(g["trap_index"].astype(int).tolist())
self.assertEqual(vals, list(range(1, len(vals) + 1)))
for lid, layer_state in trap_state.get("layers", {}).items():
arts = layer_state.get("artifacts", [])
if not arts:
continue
self.assertEqual([int(a["trap_index"]) for a in arts], list(range(1, len(arts) + 1)))

def test_analyze_traps_fast_mode_skips_original_basis(self):
from unittest.mock import patch
with patch.object(ww.WeightWatcher, "compute_original_basis_for_traps", side_effect=AssertionError("should not call")):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_trap_ablation_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import pytest
import numpy as np
import pandas as pd
torch = pytest.importorskip("torch")
import weightwatcher as ww
from weightwatcher.RMT_Util import permute_matrix, unpermute_matrix
Expand Down Expand Up @@ -142,3 +143,17 @@ def test_rmt_util_unpermutes_randomized_layer_weight():
assert np.allclose(W_recon, original)
W_perm2, pids2 = permute_matrix(original, rng=123)
assert np.allclose(unpermute_matrix(W_perm2, pids2), original)


def test_remove_traps_rejects_zero_based_public_trap_index():
model = torch.nn.Sequential(torch.nn.Linear(16, 12, bias=False))
watcher = ww.WeightWatcher(model=model)
randomized_model, trap_state = watcher.randomize_model(model=model, rng=123, return_state=True, pool=False)
with pytest.raises(ValueError, match="trap_index values are 1-based; got 0"):
watcher.remove_traps(
randomized_model=randomized_model,
traps=pd.DataFrame([{"layer_id": int(sorted(trap_state["permuted_ids"].keys())[0]), "trap_index": 0}]),
trap_state=trap_state,
plot=False,
pool=False,
)
2 changes: 1 addition & 1 deletion tests/test_trap_component_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_top_trap_component_row_extracts_top_10_weight_and_coeff_pairs():
row = {
"layer_id": 7,
"name": "dense",
"trap_index": 0,
"trap_index": 1,
"trap_assessment": "mixed",
"trap_risk_score": 0.4,
"T_orig": trap,
Expand Down
5 changes: 1 addition & 4 deletions tests/test_trap_compute_efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ def count(self):


def _normalize_selected(df):
selected = df.copy()
if "trap_index" in selected.columns and selected["trap_index"].min() == 0:
selected["trap_index"] = selected["trap_index"].astype(int) + 1
return selected
return df.copy()


def _workflow(monkeypatch):
Expand Down
4 changes: 3 additions & 1 deletion weightwatcher/remove_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N

requested = sorted(set(trap_indices))
if any(idx < 1 for idx in requested):
raise ValueError("trap indices are 1-based and must be >= 1")
raise ValueError("trap_index values are 1-based; got 0")

layer_seed = seed
if layer_seed is None and isinstance(params, dict):
Expand Down Expand Up @@ -248,6 +248,8 @@ def _trap_indices_from_traps_df(traps):
if "trap_index" not in trap_df.columns:
raise ValueError("traps must include a 'trap_index' column")
indices = trap_df["trap_index"].dropna().astype(int).tolist()
if any(i < 1 for i in indices):
raise ValueError("trap_index values are 1-based; got 0")
indices = sorted(set(indices))
if len(indices) == 0:
raise ValueError("traps did not contain any valid trap_index values")
Expand Down
8 changes: 4 additions & 4 deletions weightwatcher/trap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@ def analyze_traps(
if params.get(wwcore.PLOT, False):
trap_infos = []
for row in layer_rows:
trap_idx_zero_based = int(row.get("trap_index", -1))
trap_idx = int(row.get("trap_index", -1))
trap_matrix = row.get("T_orig", None)
if trap_idx_zero_based < 0 or trap_matrix is None:
if trap_idx < 1 or trap_matrix is None:
continue

trap_infos.append(
{
"trap_index": trap_idx_zero_based + 1,
"trap_index": trap_idx,
"trap_matrix": trap_matrix,
"trap_assessment": row.get("trap_assessment", "mixed"),
"trap_risk_score": row.get("trap_risk_score", 0.0),
Expand Down Expand Up @@ -238,7 +238,7 @@ def _top_trap_component_row(row, weight_matrix, top_k=10):
out = {
"layer_id": row.get("layer_id"),
"name": row.get("name"),
"trap_index": int(row.get("trap_index", -1)) + 1,
"trap_index": int(row.get("trap_index", -1)),
"trap_assessment": row.get("trap_assessment", "mixed"),
"trap_risk_score": float(row.get("trap_risk_score", 0.0)),
}
Expand Down
4 changes: 2 additions & 2 deletions weightwatcher/weightwatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3904,11 +3904,11 @@ def apply_analyze_traps(self, ww_layer, params=None):
trap_rows=[]; artifacts=[]
pids = np.asarray(ww_layer.permute_ids[0]) if len(getattr(ww_layer,'permute_ids',[]))>0 else None
fp = hashlib.sha1(pids.tobytes()).hexdigest() if pids is not None else None
for trap_index, mode_index in enumerate(trap_mode_indices):
for trap_index, mode_index in enumerate(trap_mode_indices, start=1):
trap_row = self.analyze_single_trap(ww_layer, trap_mode_index=mode_index, original_basis_cache=original_basis_cache, params=params, trap_index=trap_index, bulk_stats=bulk_stats, precomputed_svd=(U_perm,S_perm,Vh_perm))
trap_row.update(bulk_stats); trap_row['permute_fingerprint']=fp
trap_rows.append(trap_row)
artifacts.append({"layer_id": int(ww_layer.layer_id), "trap_index": int(trap_index+1), "trap_mode_index": int(mode_index), "sigma_perm": float(S_perm[mode_index]), "permute_fingerprint": fp, "T_perm": trap_row.get("T_perm"), "T_orig_raw": trap_row.get("T_orig"), "u_trap_perm": U_perm[:, mode_index], "v_trap_perm": Vh_perm[mode_index, :]})
artifacts.append({"layer_id": int(ww_layer.layer_id), "trap_index": int(trap_index), "trap_mode_index": int(mode_index), "sigma_perm": float(S_perm[mode_index]), "permute_fingerprint": fp, "T_perm": trap_row.get("T_perm"), "T_orig_raw": trap_row.get("T_orig"), "u_trap_perm": U_perm[:, mode_index], "v_trap_perm": Vh_perm[mode_index, :]})
if trap_rows:
layer_trap_variance_burden = float(np.nansum([row.get("trap_variance_burden_old", np.nan) for row in trap_rows]))
for row in trap_rows: row["layer_trap_variance_burden"] = layer_trap_variance_burden
Expand Down