Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions notebooks/analyze_traps_iid_gaussian_sanity.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Analyze Traps sanity notebook (single injected localized singular mode)\n",
"\n",
"This notebook builds a very simple controlled matrix experiment for `analyze_traps`.\n",
"\n",
"Plan:\n",
"1. Build a `200 x 400` iid Gaussian matrix (MP-like random bulk).\n",
"2. Replace one left/right singular vector pair with a highly localized pair.\n",
"3. Verify the injected vector metrics directly (localization + top-5 mass).\n",
"4. Randomize with `randomize_model`.\n",
"5. Identify the strongest localized detected trap.\n",
"6. Compare trap metrics from `analyze_traps` against the injected expectations.\n",
"7. Run with `plot=True`.\n"
],
"id": "f97385d9"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import weightwatcher as ww\n"
],
"id": "7b10620d"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"seed = 123\n",
"np.random.seed(seed)\n",
"torch.manual_seed(seed)\n"
],
"id": "040428ef"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m, n = 200, 400\n",
"W0 = np.random.randn(m, n)\n",
"U0, S0, Vh0 = np.linalg.svd(W0, full_matrices=False)\n",
"print('rank =', len(S0))\n"
],
"id": "2ffb5e43"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"u_loc = np.zeros(m)\n",
"u_loc[:5] = 1.0\n",
"u_loc = u_loc / np.linalg.norm(u_loc)\n",
"\n",
"v_loc = np.zeros(n)\n",
"v_loc[:5] = 1.0\n",
"v_loc = v_loc / np.linalg.norm(v_loc)\n",
"\n",
"k = 0\n",
"S_new = S0.copy()\n",
"S_new[k] = S0.max() * 5.0\n",
"U_new = U0.copy()\n",
"Vh_new = Vh0.copy()\n",
"U_new[:, k] = u_loc\n",
"Vh_new[k, :] = v_loc\n",
"W_injected = U_new @ np.diag(S_new) @ Vh_new\n",
"\n",
"print('Injected mode index (0-based):', k)\n",
"print('Injected sigma:', S_new[k])\n"
],
"id": "38d9de01"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def top_k_mass(x, k=5):\n",
" p = np.abs(x) ** 2\n",
" idx = np.argsort(p)[::-1][:k]\n",
" return float(p[idx].sum())\n",
"\n",
"def localization_ratio(x):\n",
" return float(np.max(np.abs(x)) / np.linalg.norm(x))\n",
"\n",
"expected_left_top5 = top_k_mass(u_loc, 5)\n",
"expected_right_top5 = top_k_mass(v_loc, 5)\n",
"expected_top5_avg = 0.5 * (expected_left_top5 + expected_right_top5)\n",
"expected_left_loc = localization_ratio(u_loc)\n",
"expected_right_loc = localization_ratio(v_loc)\n",
"\n",
"print('Expected left top-5 mass :', expected_left_top5)\n",
"print('Expected right top-5 mass:', expected_right_top5)\n",
"print('Expected mean top-5 mass :', expected_top5_avg)\n",
"print('Expected left localization ratio :', expected_left_loc)\n",
"print('Expected right localization ratio:', expected_right_loc)\n"
],
"id": "18eae44b"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class OneLayerMatrix(nn.Module):\n",
" def __init__(self, W):\n",
" super().__init__()\n",
" self.fc = nn.Linear(W.shape[1], W.shape[0], bias=False)\n",
" with torch.no_grad():\n",
" self.fc.weight.copy_(torch.tensor(W, dtype=torch.float32))\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"model = OneLayerMatrix(W_injected)\n",
"watcher = ww.WeightWatcher(model=model)\n"
],
"id": "8654319c"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"randomized_model, trap_state = watcher.randomize_model(model=model, rng=seed, return_state=True)\n",
"print('Randomized model ready. permuted layer_ids =', sorted(trap_state['permuted_ids'].keys()))\n"
],
"id": "0790789f"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trap_df, trap_state_out = watcher.analyze_traps(\n",
" randomized_model=randomized_model,\n",
" trap_state=trap_state,\n",
" return_artifacts=True,\n",
" plot=False,\n",
")\n",
"\n",
"print('Number of detected traps:', len(trap_df))\n",
"show_cols = ['layer_id','trap_index','perm_mode_index','left_top_mass','right_top_mass','top_5_mass','sigma_perm']\n",
"trap_df[show_cols].sort_values('top_5_mass', ascending=False).head(10)\n"
],
"id": "d1e95282"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# We may see several traps from random fluctuations; choose the strongest localized trap.\n",
"row = trap_df.sort_values('top_5_mass', ascending=False).iloc[0]\n",
"print('Selected trap (max top_5_mass):')\n",
"print(row[['layer_id','trap_index','perm_mode_index','left_top_mass','right_top_mass','top_5_mass']])\n"
],
"id": "73c66dbf"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"observed_left_top5 = float(row['left_top_mass'])\n",
"observed_right_top5 = float(row['right_top_mass'])\n",
"observed_top5_avg = float(row['top_5_mass'])\n",
"\n",
"print('Observed left top-5 mass :', observed_left_top5)\n",
"print('Observed right top-5 mass:', observed_right_top5)\n",
"print('Observed mean top-5 mass :', observed_top5_avg)\n",
"print('Absolute error left :', abs(observed_left_top5 - expected_left_top5))\n",
"print('Absolute error right:', abs(observed_right_top5 - expected_right_top5))\n",
"print('Absolute error mean :', abs(observed_top5_avg - expected_top5_avg))\n",
"\n",
"# Strong agreement checks for the injected localized mode.\n",
"assert observed_left_top5 > 0.95\n",
"assert observed_right_top5 > 0.95\n",
"assert observed_top5_avg > 0.95\n",
"\n",
"# The selected trap should be close to our injected values.\n",
"assert abs(observed_left_top5 - expected_left_top5) < 0.05\n",
"assert abs(observed_right_top5 - expected_right_top5) < 0.05\n",
"assert abs(observed_top5_avg - expected_top5_avg) < 0.05\n"
],
"id": "9ca14646"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_ = watcher.analyze_traps(\n",
" randomized_model=randomized_model,\n",
" trap_state=trap_state,\n",
" return_artifacts=False,\n",
" plot=True,\n",
")\n",
"print('Done: analyze_traps(plot=True) executed.')\n"
],
"id": "d471148f"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.x"
}
},
"nbformat": 4,
"nbformat_minor": 5
}