diff --git a/tests/test_remove_traps.py b/tests/test_remove_traps.py index 6aa2f8e..f4a7d65 100644 --- a/tests/test_remove_traps.py +++ b/tests/test_remove_traps.py @@ -451,3 +451,5 @@ def test_remove_traps_preserves_pytorch_layer_dtype_and_forward_pass(): x_mps = torch.randn(4, 96, dtype=torch.float32, device=mps_device) y_mps = model_mps(x_mps) assert y_mps.shape == (4, 10) + + diff --git a/tests/test_trap_ablation_workflow.py b/tests/test_trap_ablation_workflow.py new file mode 100644 index 0000000..1019159 --- /dev/null +++ b/tests/test_trap_ablation_workflow.py @@ -0,0 +1,217 @@ +import unittest +import numpy as np +import pandas as pd +import pytest + +import weightwatcher as ww +from weightwatcher.weightwatcher import WeightWatcher + +try: + import torch + import torch.nn as nn + TORCH_AVAILABLE = True +except Exception: + TORCH_AVAILABLE = False + + +class OneLayerTrapNet(nn.Module if TORCH_AVAILABLE else object): + def __init__(self, W): + if not TORCH_AVAILABLE: + return + super().__init__() + out_dim, in_dim = W.shape + self.fc = nn.Linear(in_dim, out_dim, bias=False) + with torch.no_grad(): + self.fc.weight.copy_(torch.tensor(W, dtype=torch.float32)) + + def forward(self, x): + return self.fc(x) + + +def _get_first_linear_weight_np(model): + for m in model.modules(): + if TORCH_AVAILABLE and isinstance(m, nn.Linear): + return m.weight.detach().cpu().numpy().copy() + raise AssertionError("No Linear layer found") + + +def _make_one_layer_trap_model(seed=123, out_dim=96, in_dim=96, spike_strength=100.0): + rng = np.random.RandomState(seed) + W_noise = 0.005 * rng.standard_normal((out_dim, in_dim)) + u = np.zeros(out_dim) + v = np.zeros(in_dim) + u[3] = 1.0 + u[4] = 0.5 + v[7] = 1.0 + v[8] = -0.5 + u = u / np.linalg.norm(u) + v = v / np.linalg.norm(v) + W = W_noise + spike_strength * np.outer(u, v) + return OneLayerTrapNet(W), W + + +def _select_trap_row(trap_df): + col = "B_absDelta_ipr_ovlamvar" + if col in trap_df.columns: + vals = pd.to_numeric(trap_df[col], errors="coerce") + if np.isfinite(vals).any(): + return int(vals.idxmax()) + return int(trap_df.index[0]) + + +@unittest.skipUnless(TORCH_AVAILABLE, "torch is required") +def test_randomized_model_analyze_then_remove_full_workflow(): + model, W_orig = _make_one_layer_trap_model() + watcher = ww.WeightWatcher(model=model) + + randomized_model, trap_state = watcher.randomize_model(model=model, layers=[], rng=123, return_state=True) + assert randomized_model is not None + assert isinstance(trap_state, dict) + assert "permuted_ids" in trap_state + assert "layers" in trap_state + + W_rand_before = _get_first_linear_weight_np(randomized_model) + assert not np.allclose(W_rand_before, W_orig) + + trap_df, trap_state = watcher.analyze_traps( + randomized_model=randomized_model, + trap_state=trap_state, + return_artifacts=True, + trap_burden=True, + plot=False, + savefig=False, + ) + + W_rand_after_analyze = _get_first_linear_weight_np(randomized_model) + assert np.allclose(W_rand_before, W_rand_after_analyze) + assert isinstance(trap_df, pd.DataFrame) + assert len(trap_df) > 0 + for col in ["layer_id", "trap_index", "B_absDelta_ipr_ovlamvar", "spectral_excess_abs", "ipr_lift_excess_pos", "ov_lam_weighted_var"]: + assert col in trap_df.columns + + assert isinstance(trap_state, dict) + assert "layers" in trap_state and len(trap_state["layers"]) > 0 + + idx = _select_trap_row(trap_df) + row = trap_df.loc[idx] + single_trap_df = trap_df.loc[[idx]].copy() + + selected_layer_id = int(row["layer_id"]) + selected_trap_index = int(row["trap_index"]) + layer_state = trap_state["layers"][selected_layer_id] + for key in ["artifacts", "U_perm", "S_perm", "Vh_perm", "permuted_ids", "permute_fingerprint"]: + assert key in layer_state + assert len(layer_state["artifacts"]) > 0 + + artifact = None + for a in layer_state["artifacts"]: + if int(a.get("trap_index", -1)) == selected_trap_index: + artifact = a + break + assert artifact is not None + if "trap_mode_index" in row and "trap_mode_index" in artifact: + assert int(row["trap_mode_index"]) == int(artifact["trap_mode_index"]) + if "sigma_perm" in row and "sigma_perm" in artifact: + assert np.isclose(float(row["sigma_perm"]), float(artifact["sigma_perm"]), rtol=1e-5, atol=1e-8) + + W_before_remove = _get_first_linear_weight_np(randomized_model) + ablated_model = watcher.remove_traps( + randomized_model=randomized_model, + traps=single_trap_df, + trap_state=trap_state, + plot=False, + ) + assert ablated_model is not None + W_after_remove = _get_first_linear_weight_np(randomized_model) + assert W_after_remove.shape == W_before_remove.shape + assert not np.allclose(W_after_remove, W_before_remove) + + with pytest.raises(ValueError, match="requires randomized_model"): + watcher.remove_traps(model=model, traps=single_trap_df, trap_state=trap_state, plot=False) + + +@unittest.skipUnless(TORCH_AVAILABLE, "torch is required") +def test_randomized_workflow_remove_traps_uses_cached_artifacts_no_svd(monkeypatch): + model, _ = _make_one_layer_trap_model() + watcher = WeightWatcher(model=model) + randomized_model, trap_state = watcher.randomize_model(model=model, rng=123, return_state=True) + trap_df, trap_state = watcher.analyze_traps( + randomized_model=randomized_model, + trap_state=trap_state, + return_artifacts=True, + trap_burden=True, + plot=False, + savefig=False, + ) + assert len(trap_df) > 0 + single_trap_df = trap_df.iloc[[0]].copy() + + from weightwatcher import remove_traps as remove_traps_ops + + def fail_collect(*args, **kwargs): + raise AssertionError("remove_traps should use cached trap_state artifacts, not recollect artifacts") + + monkeypatch.setattr(remove_traps_ops, "collect_trap_artifacts", fail_collect) + + out = watcher.remove_traps(randomized_model=randomized_model, traps=single_trap_df, trap_state=trap_state, plot=False) + assert out is not None + + +def test_randomized_model_workflow_public_api(monkeypatch): + from tests.test_remove_traps import _single_trap_setup, make_ww_layer + W, _, _, _ = _single_trap_setup(seed=111) + ww_layer = make_ww_layer(W) + watcher = WeightWatcher(model={"dummy_weight": np.array([1.0])}) + + monkeypatch.setattr( + watcher, + "make_layer_iterator", + lambda model=None, layers=None, params=None, base_model=None: [ww_layer], + ) + + randomized_model, trap_state = watcher.randomize_model(model={"dummy_weight": np.array([1.0])}, return_state=True, rng=123) + assert isinstance(trap_state, dict) + assert "permuted_ids" in trap_state + + trap_df, trap_state = watcher.analyze_traps( + randomized_model=randomized_model, + trap_state=trap_state, + return_artifacts=True, + trap_burden=True, + plot=False, + savefig=False, + ) + assert isinstance(trap_df, pd.DataFrame) + assert "B_absDelta_ipr_ovlamvar" in trap_df.columns + assert "layers" in trap_state + + +def test_remove_traps_requires_randomized_model_when_trap_state(): + watcher = WeightWatcher(model={"dummy_weight": np.array([1.0])}) + with pytest.raises(ValueError, match="requires randomized_model"): + watcher.remove_traps(model={"dummy_weight": np.array([1.0])}, trap_state={"layers": {}}, traps=pd.DataFrame([{"trap_index": 1}]), plot=False) + + +@unittest.skipUnless(TORCH_AVAILABLE, "torch is required") +def test_analyze_traps_return_artifacts_does_not_recollect_artifacts(monkeypatch): + model, _ = _make_one_layer_trap_model() + watcher = WeightWatcher(model=model) + randomized_model, trap_state = watcher.randomize_model(model=model, rng=123, return_state=True) + + from weightwatcher import remove_traps as remove_traps_ops + + def fail_collect(*args, **kwargs): + raise AssertionError("analyze_traps(return_artifacts=True) must not recollect artifacts or recompute SVD") + + monkeypatch.setattr(remove_traps_ops, "collect_trap_artifacts", fail_collect) + + trap_df, trap_state = watcher.analyze_traps( + randomized_model=randomized_model, + trap_state=trap_state, + return_artifacts=True, + trap_burden=True, + plot=False, + savefig=False, + ) + assert len(trap_df) > 0 + assert "layers" in trap_state and len(trap_state["layers"]) > 0 diff --git a/weightwatcher/__init__.py b/weightwatcher/__init__.py index a4f9739..d791e4e 100644 --- a/weightwatcher/__init__.py +++ b/weightwatcher/__init__.py @@ -19,7 +19,7 @@ __name__ = "weightwatcher" -__version__ = "0.8.4" +__version__ = "0.8.6" __license__ = "Apache License, Version 2.0" __description__ = "Diagnostic Tool for Deep Neural Networks" diff --git a/weightwatcher/remove_traps.py b/weightwatcher/remove_traps.py index 4900251..f77b166 100644 --- a/weightwatcher/remove_traps.py +++ b/weightwatcher/remove_traps.py @@ -1,6 +1,7 @@ import logging import numbers import hashlib +import copy import numpy as np import pandas as pd @@ -47,10 +48,12 @@ def apply_trap_mp_fit(ww, ww_layer, params=None): return ww_layer -def identify_trap_mode_indices(ww, ww_layer): - W = ww_layer.Wmats[0] - _, svals, _ = svd_full(W) - evals_desc = svals * svals +def identify_trap_mode_indices(ww, ww_layer, svals=None, evals_desc=None): + if evals_desc is None: + if svals is None: + W = ww_layer.Wmats[0] + _, svals, _ = svd_full(W) + evals_desc = np.asarray(svals, dtype=float) * np.asarray(svals, dtype=float) Q = ww_layer.N / ww_layer.M M = ww_layer.M @@ -66,7 +69,7 @@ def identify_trap_mode_indices(ww, ww_layer): return trap_mode_indices.tolist() -def analyze_single_trap(ww, ww_layer, trap_mode_index): +def analyze_single_trap(ww, ww_layer, trap_mode_index, U_perm=None, S_perm=None, Vh_perm=None): def _top_percent_abs_mass(mat, percent): flat = np.abs(np.asarray(mat, dtype=float)).ravel() if flat.size == 0: @@ -79,8 +82,9 @@ def _top_percent_abs_mass(mat, percent): top_sum = float(np.sum(np.partition(flat, -k)[-k:])) return top_sum / total - W_perm = ww_layer.Wmats[0] - U_perm, S_perm, Vh_perm = svd_full(W_perm) + if U_perm is None or S_perm is None or Vh_perm is None: + W_perm = ww_layer.Wmats[0] + U_perm, S_perm, Vh_perm = svd_full(W_perm) sigma_perm = S_perm[trap_mode_index] u_trap = U_perm[:, trap_mode_index] @@ -105,7 +109,7 @@ def _top_percent_abs_mass(mat, percent): } -def collect_trap_artifacts(ww, ww_layer, params=None, seed=None, rng=None): +def collect_trap_artifacts(ww, ww_layer, params=None, seed=None, rng=None, already_randomized=False, permuted_ids=None, return_state=False): if params is None: params = DEFAULT_PARAMS.copy() @@ -119,22 +123,43 @@ def collect_trap_artifacts(ww, ww_layer, params=None, seed=None, rng=None): analysis_layer.permute_ids = [] ww.apply_normalize_Wmats(analysis_layer, params) - ww.apply_permute_W(analysis_layer, params, rng=rng) + if already_randomized: + if permuted_ids is not None: + analysis_layer.permute_ids = [np.asarray(permuted_ids).astype(int)] + elif permuted_ids is not None: + pids = np.asarray(permuted_ids).astype(int) + analysis_layer.permute_ids = [pids] + analysis_layer.Wmats = [analysis_layer.Wmats[0].flatten()[pids].reshape(analysis_layer.Wmats[0].shape)] + else: + ww.apply_permute_W(analysis_layer, params, rng=rng) permute_fingerprint = None if len(analysis_layer.permute_ids) > 0: perm_arr = np.asarray(analysis_layer.permute_ids[0]) permute_fingerprint = hashlib.sha1(perm_arr.tobytes()).hexdigest() apply_trap_mp_fit(ww, analysis_layer, params) - trap_mode_indices = identify_trap_mode_indices(ww, analysis_layer) + W_perm = analysis_layer.Wmats[0] + U_perm, S_perm, Vh_perm = svd_full(W_perm) + trap_mode_indices = identify_trap_mode_indices(ww, analysis_layer, svals=S_perm) artifacts = [] for i, trap_mode_index in enumerate(trap_mode_indices, start=1): - artifact = analyze_single_trap(ww, analysis_layer, trap_mode_index) + artifact = analyze_single_trap(ww, analysis_layer, trap_mode_index, U_perm=U_perm, S_perm=S_perm, Vh_perm=Vh_perm) artifact["trap_index"] = i artifact["T_orig_raw"] = artifact["T_orig_norm"] / analysis_layer.w_norm artifact["permute_fingerprint"] = permute_fingerprint artifacts.append(artifact) + if return_state: + state = { + "layer_id": int(getattr(ww_layer, "layer_id", -1)), + "permuted_ids": np.asarray(analysis_layer.permute_ids[0]).copy() if analysis_layer.permute_ids else None, + "permute_fingerprint": permute_fingerprint, + "U_perm": U_perm, "S_perm": S_perm, "Vh_perm": Vh_perm, + "trap_mode_indices": list(trap_mode_indices), + "artifacts": artifacts, + "already_randomized": bool(already_randomized), + } + return artifacts, state return artifacts @@ -149,7 +174,7 @@ def make_stat_matched_random_matrix(T, rng): return np.mean(T) + np.std(T) * G -def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=None): +def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=None, trap_artifacts=None, trap_state=None): if params is None: params = DEFAULT_PARAMS.copy() if trap_indices is None or len(trap_indices) == 0: @@ -174,13 +199,19 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N replacement_seed = None if layer_seed is None else layer_seed + 1 replacement_rng = np.random.default_rng(replacement_seed) - artifacts = collect_trap_artifacts( - ww, - ww_layer, - params=params, - seed=None if permute_rng is not None else permute_seed, - rng=permute_rng, - ) + artifacts = trap_artifacts + if artifacts is None and isinstance(trap_state, dict): + artifacts = trap_state.get("artifacts") + if artifacts is None: + artifacts = collect_trap_artifacts( + ww, + ww_layer, + params=params, + seed=None if permute_rng is not None else permute_seed, + rng=permute_rng, + already_randomized=bool(isinstance(trap_state, dict) and trap_state.get("already_randomized", False)), + permuted_ids=(trap_state.get("permuted_ids") if isinstance(trap_state, dict) else None), + ) valid_indices = [idx for idx in requested if idx <= len(artifacts)] if len(valid_indices) < len(requested): logger.warning( @@ -251,7 +282,7 @@ def _trap_indices_from_traps_df(traps): def remove_traps(ww, model=None, layers=[], trap_indices=None, traps=None, seed=None, rng=None, pool=True, plot=True, verify_traps=False, return_analyze=False, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, - base_model=None, peft=DEFAULT_PEFT): + base_model=None, peft=DEFAULT_PEFT, trap_artifacts=None, trap_state=None, already_randomized=False): # PR359 compatibility path: passing traps= instead of trap_indices=[...] if trap_indices is None and traps is not None: trap_indices = _trap_indices_from_traps_df(traps) @@ -269,6 +300,7 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, traps=None, seed= params[PEFT] = peft params["seed"] = seed params["rng"] = _normalize_trap_rng(rng=rng, seed=seed) + params["already_randomized"] = bool(already_randomized) if not ww.__class__.valid_params(params): raise Exception(f"Error, params not valid: \n {params}") @@ -288,13 +320,28 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, traps=None, seed= else: raise ValueError("traps DataFrame must include trap_index") - pre_artifacts = collect_trap_artifacts( - ww, - ww_layer, - params=params, - seed=None if params["rng"] is not None else seed, - rng=params["rng"], - ) + pre_artifacts = trap_artifacts + if pre_artifacts is None and isinstance(trap_state, dict): + layer_state = trap_state.get("layers", {}).get(int(ww_layer.layer_id), {}) + pre_artifacts = layer_state.get("artifacts") + if pre_artifacts is None: + layer_perm = None + if isinstance(layer_state, dict) and layer_state.get("permuted_ids") is not None: + layer_perm = layer_state.get("permuted_ids") + elif isinstance(trap_state, dict): + pid_map = trap_state.get("permuted_ids", {}) + if isinstance(pid_map, dict): + layer_perm = pid_map.get(int(ww_layer.layer_id)) + pre_artifacts = collect_trap_artifacts( + ww, + ww_layer, + params=params, + seed=None if params["rng"] is not None else seed, + rng=params["rng"], + already_randomized=bool(already_randomized), + permuted_ids=layer_perm, + ) + pre_by_index = {int(a["trap_index"]): a for a in pre_artifacts} identity_ok = True identity_reason = "ok" @@ -320,7 +367,7 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, traps=None, seed= if not identity_ok: raise ValueError(f"Trap identity verification failed for layer {ww_layer.layer_id}: {identity_reason}") - apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"]) + apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"], trap_artifacts=pre_artifacts, trap_state=(trap_state.get("layers", {}).get(int(ww_layer.layer_id), {}) if isinstance(trap_state, dict) else None)) if verify_traps: remaining = collect_trap_artifacts( ww, @@ -344,3 +391,41 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, traps=None, seed= verify_df = pd.DataFrame.from_records(verify_rows) return model, verify_df return model + + +def randomize_model(ww, model=None, layers=None, pool=True, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, + base_model=None, peft=DEFAULT_PEFT, rng=None, return_state=False): + """Return a randomized copy of model and per-layer permutation ids.""" + if layers is None: + layers = [] + + randomized_model = copy.deepcopy(model if model is not None else ww.model) + ww.set_model_(randomized_model) + + params = DEFAULT_PARAMS.copy() + params[POOL] = pool + params[LAYERS] = layers + params[START_IDS] = start_ids + params[SVD_METHOD] = svd_method + params[PEFT] = peft + params["rng"] = _normalize_trap_rng(rng=rng) + + params = ww.normalize_params(params) + layer_iterator = ww.make_layer_iterator(model=ww.model, layers=layers, params=params, base_model=base_model) + + permuted_ids = {} + layer_states = {} + for ww_layer in layer_iterator: + if not ww_layer.skipped and ww_layer.has_weights: + ww.apply_normalize_Wmats(ww_layer, params) + ww.apply_permute_W(ww_layer, params, rng=params["rng"]) + if ww_layer.permute_ids: + pid = np.asarray(ww_layer.permute_ids[0]).copy() + lid = int(ww_layer.layer_id) + permuted_ids[lid] = pid + layer_states[lid] = {"layer_id": lid, "permuted_ids": pid, "permute_fingerprint": hashlib.sha1(pid.tobytes()).hexdigest(), "already_randomized": True} + ww.replace_layer_weights(ww_layer.layer_id, ww_layer.framework_layer, ww_layer.Wmats[0]) + + if return_state: + return randomized_model, {"permuted_ids": permuted_ids, "layers": layer_states, "already_randomized": True} + return randomized_model, permuted_ids diff --git a/weightwatcher/trap_analysis.py b/weightwatcher/trap_analysis.py index 8a6e5f4..ba31a9d 100644 --- a/weightwatcher/trap_analysis.py +++ b/weightwatcher/trap_analysis.py @@ -29,6 +29,10 @@ def analyze_traps( base_model=None, peft=wwcore.DEFAULT_PEFT, rng=None, + permuted_ids=None, + trap_state=None, + already_randomized=False, + return_artifacts=False, trap_burden=False, trap_burden_variant="top5", top_sector_l=1, @@ -79,6 +83,7 @@ def analyze_traps( params["trap_burden"] = bool(trap_burden) params["trap_burden_variant"] = trap_burden_variant params["top_sector_l"] = int(top_sector_l) + params["already_randomized"] = bool(already_randomized) wwcore.logger.debug("params {}".format(params)) if not watcher.valid_params(params): @@ -90,6 +95,7 @@ def analyze_traps( layer_iterator = watcher.make_layer_iterator(model=watcher.model, layers=layers, params=params, base_model=watcher.base_model) trap_rows = [] trap_component_rows = [] + state_layers = {} for ww_layer in layer_iterator: if not ww_layer.skipped and ww_layer.has_weights: @@ -99,8 +105,19 @@ def analyze_traps( watcher.apply_FFT(ww_layer, params) layer_params = dict(params) + if trap_state is not None and isinstance(trap_state, dict) and "permuted_ids" in trap_state: + layer_params["permuted_ids"] = trap_state.get("permuted_ids", {}) + elif permuted_ids is not None: + layer_params["permuted_ids"] = permuted_ids + layer_params["already_randomized"] = bool(already_randomized) layer_params["_keep_trap_matrix"] = bool(params.get(wwcore.PLOT, False)) - layer_rows = watcher.apply_analyze_traps(ww_layer, params=layer_params) + layer_params["return_artifacts"] = bool(return_artifacts) + result = watcher.apply_analyze_traps(ww_layer, params=layer_params) + if return_artifacts: + layer_rows, layer_state = result + state_layers[int(ww_layer.layer_id)] = layer_state + else: + layer_rows = result if layer_rows: if params.get(wwcore.PLOT, False): trap_infos = [] @@ -162,6 +179,15 @@ def analyze_traps( else: watcher.trap_component_summary = pd.DataFrame() + if return_artifacts: + out_state = trap_state.copy() if isinstance(trap_state, dict) else {} + out_state.setdefault("permuted_ids", permuted_ids if permuted_ids is not None else {}) + if trap_state is not None and isinstance(trap_state, dict) and "permuted_ids" in trap_state: + out_state["permuted_ids"] = trap_state.get("permuted_ids", {}) + out_state["layers"] = state_layers + out_state["details_rows"] = details.to_dict(orient="records") + out_state["already_randomized"] = bool(already_randomized) + return details, out_state return details diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index e29e40e..b24bb8f 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3694,7 +3694,7 @@ def get_details(self): return self.details - def analyze_traps(self, model=None, layers=[], + def analyze_traps(self, model=None, randomized_model=None, permuted_ids=None, trap_state=None, return_artifacts=False, layers=[], min_evals=DEFAULT_MIN_EVALS, max_evals=DEFAULT_MAX_EVALS, min_size=None, max_size=None, max_N=DEFAULT_MAX_N, glorot_fix=False, @@ -3730,10 +3730,13 @@ def analyze_traps(self, model=None, layers=[], Passing the same seed/object makes trap detection reproducible across runs. """ + if randomized_model is not None and model is not None: + raise ValueError("Pass either model or randomized_model, not both") + from . import trap_analysis return trap_analysis.analyze_traps( self, - model=model, + model=randomized_model if randomized_model is not None else model, layers=layers, min_evals=min_evals, max_evals=max_evals, @@ -3754,6 +3757,10 @@ def analyze_traps(self, model=None, layers=[], base_model=base_model, peft=peft, rng=rng, + permuted_ids=permuted_ids, + trap_state=trap_state, + already_randomized=(randomized_model is not None), + return_artifacts=return_artifacts, trap_burden=trap_burden, trap_burden_variant=trap_burden_variant, top_sector_l=top_sector_l, @@ -3808,12 +3815,24 @@ def apply_analyze_traps(self, ww_layer, params=None): self.apply_esd(ww_layer, params) original_basis_cache = self.compute_original_basis_for_traps(ww_layer, params=params) - self.apply_permute_W(ww_layer, params) + if params.get("already_randomized", False): + pids_map = params.get("permuted_ids", {}) + if int(ww_layer.layer_id) in pids_map: + ww_layer.permute_ids = [np.asarray(pids_map[int(ww_layer.layer_id)]).astype(int)] + elif params.get("permuted_ids") is not None and int(ww_layer.layer_id) in params.get("permuted_ids", {}): + pids = np.asarray(params["permuted_ids"][int(ww_layer.layer_id)]).astype(int) + ww_layer.permute_ids = [pids] + ww_layer.Wmats = [ww_layer.Wmats[0].flatten()[pids].reshape(ww_layer.Wmats[0].shape)] + else: + self.apply_permute_W(ww_layer, params) self.apply_trap_mp_fit(ww_layer, params=params) - trap_mode_indices = self.identify_trap_mode_indices(ww_layer, params=params) + W_perm = ww_layer.Wmats[0].astype(float) + U_perm, S_perm, Vh_perm = svd_full(W_perm, method=params[SVD_METHOD]) + trap_mode_indices = remove_traps_ops.identify_trap_mode_indices(self, ww_layer, svals=S_perm) bulk_stats = self.compute_bulk_trap_reference_metrics(ww_layer, trap_mode_indices, params=params) trap_rows = [] + trap_artifacts = [] for trap_index, mode_index in enumerate(trap_mode_indices): trap_row = self.analyze_single_trap( ww_layer, @@ -3825,12 +3844,36 @@ def apply_analyze_traps(self, ww_layer, params=None): ) trap_row.update(bulk_stats) trap_rows.append(trap_row) + artifact = remove_traps_ops.analyze_single_trap(self, ww_layer, mode_index, U_perm=U_perm, S_perm=S_perm, Vh_perm=Vh_perm) + artifact["trap_index"] = trap_index + 1 + artifact["layer_id"] = int(ww_layer.layer_id) + artifact["permute_fingerprint"] = trap_row.get("permute_fingerprint") + trap_artifacts.append(artifact) if trap_rows: layer_trap_variance_burden = float(np.nansum([row.get("trap_variance_burden_old", np.nan) for row in trap_rows])) for row in trap_rows: row["layer_trap_variance_burden"] = layer_trap_variance_burden + layer_state = { + "layer_id": int(ww_layer.layer_id), + "name": ww_layer.name, + "longname": ww_layer.longname, + "permuted_ids": np.asarray(ww_layer.permute_ids[0]).copy() if len(ww_layer.permute_ids) > 0 else None, + "permute_fingerprint": trap_rows[0].get("permute_fingerprint") if trap_rows else None, + "W_perm_shape": tuple(W_perm.shape), + "U_perm": U_perm, + "S_perm": S_perm, + "Vh_perm": Vh_perm, + "trap_mode_indices": [int(i) for i in trap_mode_indices], + "artifacts": trap_artifacts, + "trap_rows": trap_rows, + "mp_bulk_max": float(getattr(ww_layer, "bulk_max", np.nan)), + "already_randomized": bool(params.get("already_randomized", False)), + } + self.apply_unpermute_W(ww_layer, params) + if params.get("return_artifacts", False): + return trap_rows, layer_state return trap_rows def compute_trap_delta(self, eval_perm, mp_bulk_max): @@ -4224,6 +4267,10 @@ def analyze_single_trap(self, ww_layer, trap_mode_index, original_basis_cache=No "trap_variance_burden_ipr": trap_variance_burden_ipr, "trap_variance_burden_top5": trap_variance_burden_top5, "trap_variance_burden": trap_variance_burden, + "B_absDelta_ipr_ovlamvar": trap_variance_burden_ipr, + "B_absDelta_logtop5_ovlamvar": float(spectral_excess_abs * log1p_top_5_lift * ov_lam_weighted_var) if np.all(np.isfinite([spectral_excess_abs, log1p_top_5_lift, ov_lam_weighted_var])) else np.nan, + "B_evalsq_logtop5_ovrank": trap_variance_burden_top5, + "B_old_pr359_paper": trap_variance_burden_old, }) for k, v in u_metrics.items(): @@ -5978,14 +6025,28 @@ def apply_remove_traps(self, ww_layer, trap_indices, params=None, seed=None, rng """Remove selected traps from one dense WWLayer and replace with matched random matrices.""" return remove_traps_ops.apply_remove_traps(self, ww_layer, trap_indices, params=params, seed=seed, rng=rng) - def remove_traps(self, model=None, layers=[], trap_indices=None, traps=None, seed=None, rng=None, pool=True, plot=True, + def randomize_model(self, model=None, layers=[], pool=True, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, + base_model=None, peft=DEFAULT_PEFT, rng=None, return_state=False): + """Randomize model weights using reversible permutations and return permutation ids.""" + return remove_traps_ops.randomize_model( + self, model=model, layers=layers, pool=pool, start_ids=start_ids, + svd_method=svd_method, base_model=base_model, peft=peft, rng=rng, return_state=return_state + ) + + def remove_traps(self, model=None, randomized_model=None, layers=[], trap_indices=None, traps=None, seed=None, rng=None, pool=True, plot=True, verify_traps=False, return_analyze=False, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, - base_model=None, peft=DEFAULT_PEFT): + base_model=None, peft=DEFAULT_PEFT, trap_artifacts=None, trap_state=None): """Remove selected randomized MP/TW traps from dense layers.""" + if randomized_model is not None and model is not None: + raise ValueError("Pass either model or randomized_model, not both") + if trap_state is not None and randomized_model is None: + raise ValueError("trap_state-based remove_traps requires randomized_model") + active_model = randomized_model if randomized_model is not None else model return remove_traps_ops.remove_traps( - self, model=model, layers=layers, trap_indices=trap_indices, traps=traps, seed=seed, rng=rng, + self, model=active_model, layers=layers, trap_indices=trap_indices, traps=traps, seed=seed, rng=rng, pool=pool, plot=plot, verify_traps=verify_traps, return_analyze=return_analyze, - start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft + start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft, + trap_artifacts=trap_artifacts, trap_state=trap_state, already_randomized=(randomized_model is not None) )