Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/test_bulk_mode_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,25 @@ def test_invalid_bulk_id_errors():
_, state = watcher.analyze_traps(randomized_model=randomized_model, trap_state=state, return_artifacts=True, return_bulk_ids=True, plot=False)
with pytest.raises(ValueError):
watcher.remove_modes(mode_ids_by_layer={999:[1]}, mode_type='bulk', randomized_model=randomized_model, trap_state=state, plot=False)


def test_bulk_only_has_mp_edges_and_nonempty():
model = OneLayer()
watcher = ww.WeightWatcher(model=model)
randomized_model, state = watcher.randomize_model(model=model, rng=123, return_state=True)
bulk_df, out_state = watcher.analyze_traps(
randomized_model=randomized_model,
trap_state=state,
return_artifacts=True,
return_bulk_ids=True,
bulk_only=True,
max_bulk_modes_per_layer=5,
bulk_sampling_seed=123,
plot=False,
)
assert len(bulk_df) > 0
assert set(bulk_df["mode_type"]) == {"bulk"}
for _lid, layer_state in out_state["layers"].items():
assert np.isfinite(layer_state["mp_bulk_min"])
assert np.isfinite(layer_state["mp_bulk_max"])
assert len(layer_state["bulk_svd_indices"]) > 0
87 changes: 76 additions & 11 deletions weightwatcher/trap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,73 @@ def _sample_bulk_modes(svd_indices, eigenvalues, max_modes=None, seed=None, stra
return sorted(int(x) for x in out)
raise ValueError("bulk_sampling_strategy must be one of: all, uniform, stratified")


def _build_trap_bulk_rows(layer_state, layer_rows, return_bulk_ids=False, bulk_only=False, trap_only=False, max_bulk_modes_per_layer=None, bulk_sampling_seed=None, bulk_sampling_strategy='all'):
def _extract_mp_bulk_edges(layer_result=None, details=None, bulk_stats=None):
candidates_min = ["mp_bulk_min", "bulk_min", "lambda_min", "mp_lambda_min", "lambda_minus", "xmin"]
candidates_max = ["mp_bulk_max", "bulk_max", "lambda_max", "mp_lambda_max", "lambda_plus", "xmax"]

def read(obj, keys):
if obj is None:
return None
if isinstance(obj, dict):
for k in keys:
if k in obj and obj[k] is not None:
return obj[k]
else:
for k in keys:
if hasattr(obj, k):
v = getattr(obj, k)
if v is not None:
return v
return None

mp_min = read(bulk_stats, candidates_min)
if mp_min is None:
mp_min = read(details, candidates_min)
if mp_min is None:
mp_min = read(layer_result, candidates_min)
mp_max = read(bulk_stats, candidates_max)
if mp_max is None:
mp_max = read(details, candidates_max)
if mp_max is None:
mp_max = read(layer_result, candidates_max)
mp_min = 0.0 if mp_min is None else float(mp_min)
mp_max = np.nan if mp_max is None else float(mp_max)
return mp_min, mp_max


def _build_trap_bulk_rows(layer_state, layer_rows, return_bulk_ids=False, bulk_only=False, trap_only=False, max_bulk_modes_per_layer=None, bulk_sampling_seed=None, bulk_sampling_strategy='all', allow_bulk_without_mp_edges=False):
trap_svd = [int(i) for i in layer_state.get('trap_mode_indices_0based', [])]
S = np.asarray(layer_state.get('S_perm', []), dtype=float)
evals = S*S
mp_max=float(layer_state.get('bulk_stats',{}).get('mp_bulk_max', np.nan))
mp_min=float(layer_state.get('bulk_stats',{}).get('mp_bulk_min', 0.0))
bulk_stats = layer_state.get('bulk_stats', {}) or {}
mp_min, mp_max = _extract_mp_bulk_edges(layer_result=layer_state, details=layer_rows[0] if layer_rows else None, bulk_stats=bulk_stats)
if not np.isfinite(mp_max):
if allow_bulk_without_mp_edges:
wwcore.logger.warning("Missing MP bulk upper edge for layer_id=%s; falling back to non-trap finite modes.", layer_state.get("layer_id"))
else:
raise ValueError("Cannot build bulk IDs because MP bulk upper edge is missing. Pass allow_bulk_without_mp_edges=True to fall back.")
bulk_stats["mp_bulk_min"] = float(mp_min)
bulk_stats["mp_bulk_max"] = float(mp_max) if np.isfinite(mp_max) else np.nan
layer_state["bulk_stats"] = bulk_stats
layer_state["mp_bulk_min"] = float(mp_min)
layer_state["mp_bulk_max"] = float(mp_max) if np.isfinite(mp_max) else np.nan
layer_state["evals"] = np.asarray(evals, dtype=float)
layer_state["singular_values"] = np.asarray(S, dtype=float)
trap_set=set(trap_svd)
inside=[i for i,e in enumerate(evals) if np.isfinite(mp_max) and e>=mp_min and e<=mp_max]
if np.isfinite(mp_max):
inside=[i for i,e in enumerate(evals) if np.isfinite(e) and e>=mp_min and e<=mp_max]
else:
inside=[i for i,e in enumerate(evals) if np.isfinite(e) and i not in trap_set]
bulk=[i for i in inside if i not in trap_set]
if return_bulk_ids and len(bulk) == 0:
raise ValueError(
f"No eligible bulk modes found for layer_id={layer_state.get('layer_id')} "
f"name={layer_state.get('name')} len_evals={len(evals)} mp_bulk_min={mp_min} "
f"mp_bulk_max={mp_max} min_eval={np.nanmin(evals) if len(evals) else np.nan} "
f"max_eval={np.nanmax(evals) if len(evals) else np.nan} n_traps={len(trap_svd)}. "
"This is an invalid state for bulk ID generation and indicates a broken MP fit "
"or a corrupted trap/bulk classification path."
)
bulk=_sample_bulk_modes(bulk, evals, max_bulk_modes_per_layer, bulk_sampling_seed, bulk_sampling_strategy)
layer_state['trap_svd_indices']=trap_svd
layer_state['bulk_svd_indices']=bulk
Expand All @@ -84,7 +141,7 @@ def _build_trap_bulk_rows(layer_state, layer_rows, return_bulk_ids=False, bulk_o
if return_bulk_ids:
for bi,svd_i in enumerate(bulk, start=1):
ev=float(evals[svd_i])
bulk_rows.append({'layer_id':int(layer_state['layer_id']),'name':layer_state.get('name'),'longname':layer_state.get('longname'),'mode_type':'bulk','ablation_type':'bulk','mode_id':bi,'trap_id':np.nan,'trap_index':np.nan,'bulk_id':bi,'bulk_index':bi,'is_trap':False,'is_bulk':True,'svd_mode_index':svd_i,'mode_index':svd_i,'singular_value':float(S[svd_i]),'eigenvalue':ev,'eval_perm':ev,'mp_lambda_min':mp_min,'mp_lambda_max':mp_max,'is_inside_mp_bulk':True,'is_above_mp_edge':False,'is_below_mp_edge':False})
bulk_rows.append({'layer_id':int(layer_state['layer_id']),'name':layer_state.get('name'),'longname':layer_state.get('longname'),'mode_type':'bulk','ablation_type':'bulk','mode_id':bi,'trap_id':np.nan,'trap_index':np.nan,'bulk_id':bi,'bulk_index':bi,'is_trap':False,'is_bulk':True,'svd_mode_index':svd_i,'mode_index':svd_i,'singular_value':float(S[svd_i]),'eigenvalue':ev,'eval_perm':ev,'mp_lambda_min':mp_min,'mp_lambda_max':mp_max,'mp_bulk_min':mp_min,'mp_bulk_max':mp_max,'bulk_quantile':float(bi)/float(max(1,len(bulk))),'is_inside_mp_bulk':True,'is_above_mp_edge':False,'is_below_mp_edge':False})
if bulk_only: return bulk_rows
if trap_only: return layer_rows
return layer_rows + bulk_rows
Expand Down Expand Up @@ -130,6 +187,7 @@ def analyze_traps(
max_bulk_modes_per_layer=None,
bulk_sampling_seed=None,
bulk_sampling_strategy="all",
allow_bulk_without_mp_edges=False,
):
"""Externalized implementation for WeightWatcher.analyze_traps()."""
if layers is None:
Expand Down Expand Up @@ -229,14 +287,19 @@ def analyze_traps(
layer_out = watcher.apply_analyze_traps(ww_layer, params=layer_params)
if return_artifacts:
layer_rows, layer_state = layer_out
if isinstance(layer_state, dict):
if layer_state.get("mp_bulk_max") is None and hasattr(ww_layer, "bulk_max"):
layer_state["mp_bulk_max"] = float(ww_layer.bulk_max) if ww_layer.bulk_max is not None else None
if layer_state.get("mp_bulk_min") is None and hasattr(ww_layer, "bulk_min"):
layer_state["mp_bulk_min"] = float(ww_layer.bulk_min) if ww_layer.bulk_min is not None else 0.0
if trap_state is None:
trap_state = {"already_randomized": bool(already_randomized), "permuted_ids": {}, "layers": {}}
trap_state.setdefault("layers", {})[int(ww_layer.layer_id)] = layer_state
trap_state.setdefault("permuted_ids", {})[int(ww_layer.layer_id)] = layer_state.get("permuted_ids")
else:
layer_rows = layer_out
if layer_rows or return_bulk_ids:
layer_rows = _build_trap_bulk_rows(layer_state if return_artifacts else {"layer_id": int(ww_layer.layer_id), "name": ww_layer.name, "longname": ww_layer.longname, "S_perm": np.array([]), "trap_mode_indices_0based": []}, layer_rows or [], return_bulk_ids=return_bulk_ids, bulk_only=bulk_only, trap_only=trap_only, max_bulk_modes_per_layer=max_bulk_modes_per_layer, bulk_sampling_seed=bulk_sampling_seed, bulk_sampling_strategy=bulk_sampling_strategy)
layer_rows = _build_trap_bulk_rows(layer_state if return_artifacts else {"layer_id": int(ww_layer.layer_id), "name": ww_layer.name, "longname": ww_layer.longname, "S_perm": np.array([]), "trap_mode_indices_0based": []}, layer_rows or [], return_bulk_ids=return_bulk_ids, bulk_only=bulk_only, trap_only=trap_only, max_bulk_modes_per_layer=max_bulk_modes_per_layer, bulk_sampling_seed=bulk_sampling_seed, bulk_sampling_strategy=bulk_sampling_strategy, allow_bulk_without_mp_edges=allow_bulk_without_mp_edges)
if params.get(wwcore.PLOT, False):
trap_infos = []
for row in layer_rows:
Expand Down Expand Up @@ -282,11 +345,13 @@ def analyze_traps(

if len(details) > 0:
if "perm_mode_index" in details.columns:
details["perm_mode_index_0based"] = details["perm_mode_index"].astype(int)
details["perm_mode_index"] = details["perm_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api)
mask = details["perm_mode_index"].notna()
details.loc[mask, "perm_mode_index_0based"] = details.loc[mask, "perm_mode_index"].astype(int)
details.loc[mask, "perm_mode_index"] = details.loc[mask, "perm_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api)
if "trap_mode_index" in details.columns:
details["trap_mode_index_0based"] = details["trap_mode_index"].astype(int)
details["trap_mode_index"] = details["trap_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api)
mask = details["trap_mode_index"].notna()
details.loc[mask, "trap_mode_index_0based"] = details.loc[mask, "trap_mode_index"].astype(int)
details.loc[mask, "trap_mode_index"] = details.loc[mask, "trap_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api)
lead_cols = ["layer_id", "name"]
details = details[lead_cols + [c for c in details.columns if c not in lead_cols]]

Expand Down
4 changes: 3 additions & 1 deletion weightwatcher/weightwatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3770,7 +3770,8 @@ def analyze_traps(self, model=None, layers=[],
trap_only=False,
max_bulk_modes_per_layer=None,
bulk_sampling_seed=None,
bulk_sampling_strategy="all"):
bulk_sampling_strategy="all",
allow_bulk_without_mp_edges=False):
"""Analyze randomized correlation traps and return one row per trap.

This method follows the randomized/permuted trap workflow:
Expand Down Expand Up @@ -3838,6 +3839,7 @@ def analyze_traps(self, model=None, layers=[],
max_bulk_modes_per_layer=max_bulk_modes_per_layer,
bulk_sampling_seed=bulk_sampling_seed,
bulk_sampling_strategy=bulk_sampling_strategy,
allow_bulk_without_mp_edges=allow_bulk_without_mp_edges,
)

def analyze_bulk_modes(self, bulk_ids_by_layer=None, layers=None, randomized_model=None, trap_state=None, **kwargs):
Expand Down