From 6bd8cd26c65c8dc43a14d4b224322cb424061ee1 Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Tue, 28 Apr 2026 22:36:14 -0700 Subject: [PATCH] Allow remove_traps to replace selected/random bulk vectors --- tests/test_remove_traps.py | 57 +++++++++++++ weightwatcher/remove_traps.py | 142 +++++++++++++++++++++++++++++---- weightwatcher/weightwatcher.py | 23 ++++-- 3 files changed, 199 insertions(+), 23 deletions(-) diff --git a/tests/test_remove_traps.py b/tests/test_remove_traps.py index e17ee1d..257abcf 100644 --- a/tests/test_remove_traps.py +++ b/tests/test_remove_traps.py @@ -244,6 +244,63 @@ def test_remove_traps_invalid_indices_warns_and_skips(monkeypatch, caplog): assert len(post_artifacts) == 0 +def test_remove_traps_can_remove_explicit_bulk_vector(): + watcher = WeightWatcher(model=None) + rng = np.random.default_rng(123) + W = rng.normal(0.0, 0.1, size=(64, 64)) + ww_layer = make_ww_layer(W) + + # For a pure random matrix there should be no traps and many bulk modes. + trap_artifacts = watcher._collect_trap_artifacts(ww_layer, params=make_test_params(), seed=17) + assert len(trap_artifacts) == 0 + + W_before = ww_layer.framework_layer._W.copy() + watcher.apply_remove_traps( + ww_layer, + trap_indices=[], + bulk_indices=[1], + params=make_test_params(), + seed=17, + ) + W_after = ww_layer.framework_layer._W + assert not np.allclose(W_before, W_after) + assert np.isclose(np.linalg.norm(W_after, "fro"), np.linalg.norm(W_before, "fro"), rtol=0.15) + + +def test_remove_traps_can_remove_random_bulk_vector(): + watcher = WeightWatcher(model=None) + rng = np.random.default_rng(321) + W = rng.normal(0.0, 0.1, size=(64, 64)) + ww_a = make_ww_layer(W) + ww_b = make_ww_layer(W) + ww_c = make_ww_layer(W) + + watcher.apply_remove_traps( + ww_a, + trap_indices=[], + num_random_bulk_vectors=1, + params=make_test_params(), + seed=55, + ) + watcher.apply_remove_traps( + ww_b, + trap_indices=[], + num_random_bulk_vectors=1, + params=make_test_params(), + seed=55, + ) + watcher.apply_remove_traps( + ww_c, + trap_indices=[], + num_random_bulk_vectors=1, + params=make_test_params(), + seed=56, + ) + + assert np.allclose(ww_a.framework_layer._W, ww_b.framework_layer._W) + assert not np.allclose(ww_a.framework_layer._W, ww_c.framework_layer._W) + + @pytest.mark.skipif(torch is None, reason="PyTorch not installed") def test_trap_rng_consistency_analyze_vs_collect_single_and_multi_layer(): model = torch.nn.Sequential( diff --git a/weightwatcher/remove_traps.py b/weightwatcher/remove_traps.py index cba8ce0..b8d3e77 100644 --- a/weightwatcher/remove_traps.py +++ b/weightwatcher/remove_traps.py @@ -64,6 +64,25 @@ def identify_trap_mode_indices(ww, ww_layer): return trap_mode_indices.tolist() +def identify_bulk_mode_indices(ww, ww_layer): + W = ww_layer.Wmats[0] + _, svals, _ = svd_full(W) + evals_desc = svals * svals + + Q = ww_layer.N / ww_layer.M + M = ww_layer.M + sigma_mp = ww_layer.sigma_mp + Wscale = ww_layer.W_scale + + bulk_max = (sigma_mp * (1 + 1 / np.sqrt(Q))) ** 2 + TW = 1 / np.sqrt(Q) * np.power(bulk_max, 2 / 3) * np.power(M, -2 / 3) + bulk_max_TW = bulk_max + np.sqrt(TW) + threshold = bulk_max_TW / (Wscale * Wscale) + + bulk_mode_indices = np.where(evals_desc <= threshold)[0] + return bulk_mode_indices.tolist() + + def analyze_single_trap(ww, ww_layer, trap_mode_index): W_perm = ww_layer.Wmats[0] U_perm, S_perm, Vh_perm = svd_full(W_perm) @@ -115,6 +134,34 @@ def collect_trap_artifacts(ww, ww_layer, params=None, seed=None, rng=None): return artifacts +def collect_bulk_artifacts(ww, ww_layer, params=None, seed=None, rng=None): + if params is None: + params = DEFAULT_PARAMS.copy() + + if rng is None and seed is None and isinstance(params, dict): + seed = params.get("seed", None) + rng = _normalize_trap_rng(rng=rng, seed=seed) + + analysis_layer = ww_layer.copy() + analysis_layer.Wmats = [ww_layer.Wmats[0].copy()] + analysis_layer.w_norm = 1.0 + analysis_layer.permute_ids = [] + + ww.apply_normalize_Wmats(analysis_layer, params) + ww.apply_permute_W(analysis_layer, params, rng=rng) + apply_trap_mp_fit(ww, analysis_layer, params) + bulk_mode_indices = identify_bulk_mode_indices(ww, analysis_layer) + + artifacts = [] + for i, bulk_mode_index in enumerate(bulk_mode_indices, start=1): + artifact = analyze_single_trap(ww, analysis_layer, bulk_mode_index) + artifact["bulk_index"] = i + artifact["T_orig_raw"] = artifact["T_orig_norm"] / analysis_layer.w_norm + artifacts.append(artifact) + + return artifacts + + def make_stat_matched_random_matrix(T, rng): G = rng.standard_normal(T.shape) G = G - np.mean(G) @@ -126,18 +173,32 @@ def make_stat_matched_random_matrix(T, rng): return np.mean(T) + np.std(T) * G -def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=None): +def apply_remove_traps( + ww, + ww_layer, + trap_indices=None, + bulk_indices=None, + num_random_bulk_vectors=0, + params=None, + seed=None, + rng=None, +): if params is None: params = DEFAULT_PARAMS.copy() - if trap_indices is None or len(trap_indices) == 0: - raise ValueError("trap_indices must be a non-empty list of 1-based indices") + if (trap_indices is None or len(trap_indices) == 0) and (bulk_indices is None or len(bulk_indices) == 0) and int(num_random_bulk_vectors) <= 0: + raise ValueError("Specify at least one of trap_indices, bulk_indices, or num_random_bulk_vectors > 0") if ww_layer.the_type != LAYER_TYPE.DENSE or len(ww_layer.Wmats) != 1 or ww_layer.Wmats[0].ndim != 2: raise NotImplementedError("remove_traps currently supports single 2D dense matrices only") - requested = sorted(set(trap_indices)) + requested = sorted(set(trap_indices or [])) if any(idx < 1 for idx in requested): raise ValueError("trap indices are 1-based and must be >= 1") + requested_bulk = sorted(set(bulk_indices or [])) + if any(idx < 1 for idx in requested_bulk): + raise ValueError("bulk indices are 1-based and must be >= 1") + if int(num_random_bulk_vectors) < 0: + raise ValueError("num_random_bulk_vectors must be >= 0") layer_seed = seed if layer_seed is None and isinstance(params, dict): @@ -151,29 +212,61 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N replacement_seed = None if layer_seed is None else layer_seed + 1 replacement_rng = np.random.default_rng(replacement_seed) - artifacts = collect_trap_artifacts( + trap_artifacts = collect_trap_artifacts( ww, ww_layer, params=params, seed=None if permute_rng is not None else permute_seed, rng=permute_rng, ) - valid_indices = [idx for idx in requested if idx <= len(artifacts)] + valid_indices = [idx for idx in requested if idx <= len(trap_artifacts)] if len(valid_indices) < len(requested): logger.warning( f"Skipping invalid trap indices {set(requested) - set(valid_indices)}; " - f"only {len(artifacts)} traps detected" + f"only {len(trap_artifacts)} traps detected" ) - if len(valid_indices) == 0: - logger.warning("No valid traps to remove for this layer; skipping") - return ww_layer requested = valid_indices + bulk_artifacts = collect_bulk_artifacts( + ww, + ww_layer, + params=params, + seed=None if permute_rng is not None else permute_seed, + rng=permute_rng, + ) + valid_bulk_indices = [idx for idx in requested_bulk if idx <= len(bulk_artifacts)] + if len(valid_bulk_indices) < len(requested_bulk): + logger.warning( + f"Skipping invalid bulk indices {set(requested_bulk) - set(valid_bulk_indices)}; " + f"only {len(bulk_artifacts)} bulk vectors available" + ) + requested_bulk = valid_bulk_indices + + random_bulk_count = int(num_random_bulk_vectors) + if random_bulk_count > 0: + available_bulk = [idx for idx in range(1, len(bulk_artifacts) + 1) if idx not in requested_bulk] + if len(available_bulk) == 0: + logger.warning("Requested random bulk vector removal, but no bulk vectors are available") + else: + if random_bulk_count > len(available_bulk): + logger.warning( + "Requested %d random bulk vectors but only %d are available; selecting all available", + random_bulk_count, + len(available_bulk), + ) + random_bulk_count = len(available_bulk) + random_choices = replacement_rng.choice(available_bulk, size=random_bulk_count, replace=False).tolist() + requested_bulk = sorted(set(requested_bulk + random_choices)) + + if len(requested) == 0 and len(requested_bulk) == 0: + logger.warning("No valid trap or bulk vectors to remove for this layer; skipping") + return ww_layer + if params.get(PLOT, False): - max_sigma = max(float(a.get("sigma_perm", 0.0)) for a in artifacts) if len(artifacts) > 0 else 0.0 + max_sigma = max(float(a.get("sigma_perm", 0.0)) for a in trap_artifacts) if len(trap_artifacts) > 0 else 0.0 trap_infos = [] for idx in requested: - artifact = artifacts[idx - 1] + artifact = trap_artifacts[idx - 1] rel_sigma = float(artifact.get("sigma_perm", 0.0)) / (max_sigma + 1e-12) if rel_sigma >= 0.8: assessment = "localized_risky" @@ -200,7 +293,12 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N old_W = ww_layer.Wmats[0] new_W = old_W.copy() for idx in requested: - T_orig_raw = artifacts[idx - 1]["T_orig_raw"] + T_orig_raw = trap_artifacts[idx - 1]["T_orig_raw"] + R_orig = make_stat_matched_random_matrix(T_orig_raw, replacement_rng) + new_W = new_W - T_orig_raw + R_orig + + for idx in requested_bulk: + T_orig_raw = bulk_artifacts[idx - 1]["T_orig_raw"] R_orig = make_stat_matched_random_matrix(T_orig_raw, replacement_rng) new_W = new_W - T_orig_raw + R_orig @@ -209,10 +307,11 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N return ww_layer -def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True, +def remove_traps(ww, model=None, layers=[], trap_indices=None, bulk_indices=None, num_random_bulk_vectors=0, + seed=None, rng=None, pool=True, plot=True, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT): - if trap_indices is None or len(trap_indices) == 0: - raise ValueError("trap_indices must be provided and non-empty") + if (trap_indices is None or len(trap_indices) == 0) and (bulk_indices is None or len(bulk_indices) == 0) and int(num_random_bulk_vectors) <= 0: + raise ValueError("Specify at least one of trap_indices, bulk_indices, or num_random_bulk_vectors > 0") ww.set_model_(model) params = DEFAULT_PARAMS.copy() @@ -232,6 +331,15 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=No layer_iterator = ww.make_layer_iterator(model=ww.model, layers=layers, params=params, base_model=base_model) for ww_layer in layer_iterator: if not ww_layer.skipped and ww_layer.has_weights: - apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"]) + apply_remove_traps( + ww, + ww_layer, + trap_indices=trap_indices, + bulk_indices=bulk_indices, + num_random_bulk_vectors=num_random_bulk_vectors, + params=params, + seed=seed, + rng=params["rng"], + ) return model diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index 0da873b..6b8b63c 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -5654,15 +5654,26 @@ def _make_stat_matched_random_matrix(self, T, rng): """Build a random matrix with matched shape, mean, and variance.""" return remove_traps_ops.make_stat_matched_random_matrix(T, rng) - def apply_remove_traps(self, ww_layer, trap_indices, params=None, seed=None, rng=None): - """Remove selected traps from one dense WWLayer and replace with matched random matrices.""" - return remove_traps_ops.apply_remove_traps(self, ww_layer, trap_indices, params=params, seed=seed, rng=rng) + def apply_remove_traps(self, ww_layer, trap_indices=None, bulk_indices=None, num_random_bulk_vectors=0, params=None, seed=None, rng=None): + """Remove selected trap and/or bulk rank-1 vectors from one dense WWLayer and replace with matched random matrices.""" + return remove_traps_ops.apply_remove_traps( + self, + ww_layer, + trap_indices=trap_indices, + bulk_indices=bulk_indices, + num_random_bulk_vectors=num_random_bulk_vectors, + params=params, + seed=seed, + rng=rng, + ) - def remove_traps(self, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True, + def remove_traps(self, model=None, layers=[], trap_indices=None, bulk_indices=None, num_random_bulk_vectors=0, + seed=None, rng=None, pool=True, plot=True, start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT): - """Remove selected randomized MP/TW traps from dense layers.""" + """Remove selected randomized MP/TW traps and/or bulk vectors from dense layers.""" return remove_traps_ops.remove_traps( - self, model=model, layers=layers, trap_indices=trap_indices, seed=seed, rng=rng, + self, model=model, layers=layers, trap_indices=trap_indices, bulk_indices=bulk_indices, + num_random_bulk_vectors=num_random_bulk_vectors, seed=seed, rng=rng, pool=pool, plot=plot, start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft )