Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/test_remove_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,63 @@ def test_remove_traps_invalid_indices_warns_and_skips(monkeypatch, caplog):
assert len(post_artifacts) == 0


def test_remove_traps_can_remove_explicit_bulk_vector():
watcher = WeightWatcher(model=None)
rng = np.random.default_rng(123)
W = rng.normal(0.0, 0.1, size=(64, 64))
ww_layer = make_ww_layer(W)

# For a pure random matrix there should be no traps and many bulk modes.
trap_artifacts = watcher._collect_trap_artifacts(ww_layer, params=make_test_params(), seed=17)
assert len(trap_artifacts) == 0

W_before = ww_layer.framework_layer._W.copy()
watcher.apply_remove_traps(
ww_layer,
trap_indices=[],
bulk_indices=[1],
params=make_test_params(),
seed=17,
)
W_after = ww_layer.framework_layer._W
assert not np.allclose(W_before, W_after)
assert np.isclose(np.linalg.norm(W_after, "fro"), np.linalg.norm(W_before, "fro"), rtol=0.15)


def test_remove_traps_can_remove_random_bulk_vector():
watcher = WeightWatcher(model=None)
rng = np.random.default_rng(321)
W = rng.normal(0.0, 0.1, size=(64, 64))
ww_a = make_ww_layer(W)
ww_b = make_ww_layer(W)
ww_c = make_ww_layer(W)

watcher.apply_remove_traps(
ww_a,
trap_indices=[],
num_random_bulk_vectors=1,
params=make_test_params(),
seed=55,
)
watcher.apply_remove_traps(
ww_b,
trap_indices=[],
num_random_bulk_vectors=1,
params=make_test_params(),
seed=55,
)
watcher.apply_remove_traps(
ww_c,
trap_indices=[],
num_random_bulk_vectors=1,
params=make_test_params(),
seed=56,
)

assert np.allclose(ww_a.framework_layer._W, ww_b.framework_layer._W)
assert not np.allclose(ww_a.framework_layer._W, ww_c.framework_layer._W)


@pytest.mark.skipif(torch is None, reason="PyTorch not installed")
def test_trap_rng_consistency_analyze_vs_collect_single_and_multi_layer():
model = torch.nn.Sequential(
Expand Down
142 changes: 125 additions & 17 deletions weightwatcher/remove_traps.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def identify_trap_mode_indices(ww, ww_layer):
return trap_mode_indices.tolist()


def identify_bulk_mode_indices(ww, ww_layer):
W = ww_layer.Wmats[0]
_, svals, _ = svd_full(W)
evals_desc = svals * svals

Q = ww_layer.N / ww_layer.M
M = ww_layer.M
sigma_mp = ww_layer.sigma_mp
Wscale = ww_layer.W_scale

bulk_max = (sigma_mp * (1 + 1 / np.sqrt(Q))) ** 2
TW = 1 / np.sqrt(Q) * np.power(bulk_max, 2 / 3) * np.power(M, -2 / 3)
bulk_max_TW = bulk_max + np.sqrt(TW)
threshold = bulk_max_TW / (Wscale * Wscale)

bulk_mode_indices = np.where(evals_desc <= threshold)[0]
return bulk_mode_indices.tolist()


def analyze_single_trap(ww, ww_layer, trap_mode_index):
W_perm = ww_layer.Wmats[0]
U_perm, S_perm, Vh_perm = svd_full(W_perm)
Expand Down Expand Up @@ -115,6 +134,34 @@ def collect_trap_artifacts(ww, ww_layer, params=None, seed=None, rng=None):
return artifacts


def collect_bulk_artifacts(ww, ww_layer, params=None, seed=None, rng=None):
if params is None:
params = DEFAULT_PARAMS.copy()

if rng is None and seed is None and isinstance(params, dict):
seed = params.get("seed", None)
rng = _normalize_trap_rng(rng=rng, seed=seed)

analysis_layer = ww_layer.copy()
analysis_layer.Wmats = [ww_layer.Wmats[0].copy()]
analysis_layer.w_norm = 1.0
analysis_layer.permute_ids = []

ww.apply_normalize_Wmats(analysis_layer, params)
ww.apply_permute_W(analysis_layer, params, rng=rng)
apply_trap_mp_fit(ww, analysis_layer, params)
bulk_mode_indices = identify_bulk_mode_indices(ww, analysis_layer)

artifacts = []
for i, bulk_mode_index in enumerate(bulk_mode_indices, start=1):
artifact = analyze_single_trap(ww, analysis_layer, bulk_mode_index)
artifact["bulk_index"] = i
artifact["T_orig_raw"] = artifact["T_orig_norm"] / analysis_layer.w_norm
artifacts.append(artifact)

return artifacts


def make_stat_matched_random_matrix(T, rng):
G = rng.standard_normal(T.shape)
G = G - np.mean(G)
Expand All @@ -126,18 +173,32 @@ def make_stat_matched_random_matrix(T, rng):
return np.mean(T) + np.std(T) * G


def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=None):
def apply_remove_traps(
ww,
ww_layer,
trap_indices=None,
bulk_indices=None,
num_random_bulk_vectors=0,
params=None,
seed=None,
rng=None,
):
if params is None:
params = DEFAULT_PARAMS.copy()
if trap_indices is None or len(trap_indices) == 0:
raise ValueError("trap_indices must be a non-empty list of 1-based indices")
if (trap_indices is None or len(trap_indices) == 0) and (bulk_indices is None or len(bulk_indices) == 0) and int(num_random_bulk_vectors) <= 0:
raise ValueError("Specify at least one of trap_indices, bulk_indices, or num_random_bulk_vectors > 0")

if ww_layer.the_type != LAYER_TYPE.DENSE or len(ww_layer.Wmats) != 1 or ww_layer.Wmats[0].ndim != 2:
raise NotImplementedError("remove_traps currently supports single 2D dense matrices only")

requested = sorted(set(trap_indices))
requested = sorted(set(trap_indices or []))
if any(idx < 1 for idx in requested):
raise ValueError("trap indices are 1-based and must be >= 1")
requested_bulk = sorted(set(bulk_indices or []))
if any(idx < 1 for idx in requested_bulk):
raise ValueError("bulk indices are 1-based and must be >= 1")
if int(num_random_bulk_vectors) < 0:
raise ValueError("num_random_bulk_vectors must be >= 0")

layer_seed = seed
if layer_seed is None and isinstance(params, dict):
Expand All @@ -151,29 +212,61 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N
replacement_seed = None if layer_seed is None else layer_seed + 1
replacement_rng = np.random.default_rng(replacement_seed)

artifacts = collect_trap_artifacts(
trap_artifacts = collect_trap_artifacts(
ww,
ww_layer,
params=params,
seed=None if permute_rng is not None else permute_seed,
rng=permute_rng,
)
valid_indices = [idx for idx in requested if idx <= len(artifacts)]
valid_indices = [idx for idx in requested if idx <= len(trap_artifacts)]
if len(valid_indices) < len(requested):
logger.warning(
f"Skipping invalid trap indices {set(requested) - set(valid_indices)}; "
f"only {len(artifacts)} traps detected"
f"only {len(trap_artifacts)} traps detected"
)
if len(valid_indices) == 0:
logger.warning("No valid traps to remove for this layer; skipping")
return ww_layer
requested = valid_indices

bulk_artifacts = collect_bulk_artifacts(
ww,
ww_layer,
params=params,
seed=None if permute_rng is not None else permute_seed,
rng=permute_rng,
)
valid_bulk_indices = [idx for idx in requested_bulk if idx <= len(bulk_artifacts)]
if len(valid_bulk_indices) < len(requested_bulk):
logger.warning(
f"Skipping invalid bulk indices {set(requested_bulk) - set(valid_bulk_indices)}; "
f"only {len(bulk_artifacts)} bulk vectors available"
)
requested_bulk = valid_bulk_indices

random_bulk_count = int(num_random_bulk_vectors)
if random_bulk_count > 0:
available_bulk = [idx for idx in range(1, len(bulk_artifacts) + 1) if idx not in requested_bulk]
if len(available_bulk) == 0:
logger.warning("Requested random bulk vector removal, but no bulk vectors are available")
else:
if random_bulk_count > len(available_bulk):
logger.warning(
"Requested %d random bulk vectors but only %d are available; selecting all available",
random_bulk_count,
len(available_bulk),
)
random_bulk_count = len(available_bulk)
random_choices = replacement_rng.choice(available_bulk, size=random_bulk_count, replace=False).tolist()
requested_bulk = sorted(set(requested_bulk + random_choices))

if len(requested) == 0 and len(requested_bulk) == 0:
logger.warning("No valid trap or bulk vectors to remove for this layer; skipping")
return ww_layer

if params.get(PLOT, False):
max_sigma = max(float(a.get("sigma_perm", 0.0)) for a in artifacts) if len(artifacts) > 0 else 0.0
max_sigma = max(float(a.get("sigma_perm", 0.0)) for a in trap_artifacts) if len(trap_artifacts) > 0 else 0.0
trap_infos = []
for idx in requested:
artifact = artifacts[idx - 1]
artifact = trap_artifacts[idx - 1]
rel_sigma = float(artifact.get("sigma_perm", 0.0)) / (max_sigma + 1e-12)
if rel_sigma >= 0.8:
assessment = "localized_risky"
Expand All @@ -200,7 +293,12 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N
old_W = ww_layer.Wmats[0]
new_W = old_W.copy()
for idx in requested:
T_orig_raw = artifacts[idx - 1]["T_orig_raw"]
T_orig_raw = trap_artifacts[idx - 1]["T_orig_raw"]
R_orig = make_stat_matched_random_matrix(T_orig_raw, replacement_rng)
new_W = new_W - T_orig_raw + R_orig

for idx in requested_bulk:
T_orig_raw = bulk_artifacts[idx - 1]["T_orig_raw"]
R_orig = make_stat_matched_random_matrix(T_orig_raw, replacement_rng)
new_W = new_W - T_orig_raw + R_orig

Expand All @@ -209,10 +307,11 @@ def apply_remove_traps(ww, ww_layer, trap_indices, params=None, seed=None, rng=N
return ww_layer


def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True,
def remove_traps(ww, model=None, layers=[], trap_indices=None, bulk_indices=None, num_random_bulk_vectors=0,
seed=None, rng=None, pool=True, plot=True,
start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT):
if trap_indices is None or len(trap_indices) == 0:
raise ValueError("trap_indices must be provided and non-empty")
if (trap_indices is None or len(trap_indices) == 0) and (bulk_indices is None or len(bulk_indices) == 0) and int(num_random_bulk_vectors) <= 0:
raise ValueError("Specify at least one of trap_indices, bulk_indices, or num_random_bulk_vectors > 0")

ww.set_model_(model)
params = DEFAULT_PARAMS.copy()
Expand All @@ -232,6 +331,15 @@ def remove_traps(ww, model=None, layers=[], trap_indices=None, seed=None, rng=No
layer_iterator = ww.make_layer_iterator(model=ww.model, layers=layers, params=params, base_model=base_model)
for ww_layer in layer_iterator:
if not ww_layer.skipped and ww_layer.has_weights:
apply_remove_traps(ww, ww_layer, trap_indices=trap_indices, params=params, seed=seed, rng=params["rng"])
apply_remove_traps(
ww,
ww_layer,
trap_indices=trap_indices,
bulk_indices=bulk_indices,
num_random_bulk_vectors=num_random_bulk_vectors,
params=params,
seed=seed,
rng=params["rng"],
)

return model
23 changes: 17 additions & 6 deletions weightwatcher/weightwatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5654,15 +5654,26 @@ def _make_stat_matched_random_matrix(self, T, rng):
"""Build a random matrix with matched shape, mean, and variance."""
return remove_traps_ops.make_stat_matched_random_matrix(T, rng)

def apply_remove_traps(self, ww_layer, trap_indices, params=None, seed=None, rng=None):
"""Remove selected traps from one dense WWLayer and replace with matched random matrices."""
return remove_traps_ops.apply_remove_traps(self, ww_layer, trap_indices, params=params, seed=seed, rng=rng)
def apply_remove_traps(self, ww_layer, trap_indices=None, bulk_indices=None, num_random_bulk_vectors=0, params=None, seed=None, rng=None):
"""Remove selected trap and/or bulk rank-1 vectors from one dense WWLayer and replace with matched random matrices."""
return remove_traps_ops.apply_remove_traps(
self,
ww_layer,
trap_indices=trap_indices,
bulk_indices=bulk_indices,
num_random_bulk_vectors=num_random_bulk_vectors,
params=params,
seed=seed,
rng=rng,
)

def remove_traps(self, model=None, layers=[], trap_indices=None, seed=None, rng=None, pool=True, plot=True,
def remove_traps(self, model=None, layers=[], trap_indices=None, bulk_indices=None, num_random_bulk_vectors=0,
seed=None, rng=None, pool=True, plot=True,
start_ids=DEFAULT_START_ID, svd_method=FAST_SVD, base_model=None, peft=DEFAULT_PEFT):
"""Remove selected randomized MP/TW traps from dense layers."""
"""Remove selected randomized MP/TW traps and/or bulk vectors from dense layers."""
return remove_traps_ops.remove_traps(
self, model=model, layers=layers, trap_indices=trap_indices, seed=seed, rng=rng,
self, model=model, layers=layers, trap_indices=trap_indices, bulk_indices=bulk_indices,
num_random_bulk_vectors=num_random_bulk_vectors, seed=seed, rng=rng,
pool=pool, plot=plot, start_ids=start_ids, svd_method=svd_method, base_model=base_model, peft=peft
)

Expand Down