From 36dfc7d84a7e09c6e7df6f02baf80fed8a608f1f Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Fri, 1 May 2026 23:08:55 -0700 Subject: [PATCH] Fix analyze_traps notebook assertions and add notebook cell ids --- .../analyze_traps_iid_gaussian_sanity.ipynb | 239 ++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 notebooks/analyze_traps_iid_gaussian_sanity.ipynb diff --git a/notebooks/analyze_traps_iid_gaussian_sanity.ipynb b/notebooks/analyze_traps_iid_gaussian_sanity.ipynb new file mode 100644 index 0000000..9f31514 --- /dev/null +++ b/notebooks/analyze_traps_iid_gaussian_sanity.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analyze Traps sanity notebook (single injected localized singular mode)\n", + "\n", + "This notebook builds a very simple controlled matrix experiment for `analyze_traps`.\n", + "\n", + "Plan:\n", + "1. Build a `200 x 400` iid Gaussian matrix (MP-like random bulk).\n", + "2. Replace one left/right singular vector pair with a highly localized pair.\n", + "3. Verify the injected vector metrics directly (localization + top-5 mass).\n", + "4. Randomize with `randomize_model`.\n", + "5. Identify the strongest localized detected trap.\n", + "6. Compare trap metrics from `analyze_traps` against the injected expectations.\n", + "7. Run with `plot=True`.\n" + ], + "id": "f97385d9" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import weightwatcher as ww\n" + ], + "id": "7b10620d" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seed = 123\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n" + ], + "id": "040428ef" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m, n = 200, 400\n", + "W0 = np.random.randn(m, n)\n", + "U0, S0, Vh0 = np.linalg.svd(W0, full_matrices=False)\n", + "print('rank =', len(S0))\n" + ], + "id": "2ffb5e43" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "u_loc = np.zeros(m)\n", + "u_loc[:5] = 1.0\n", + "u_loc = u_loc / np.linalg.norm(u_loc)\n", + "\n", + "v_loc = np.zeros(n)\n", + "v_loc[:5] = 1.0\n", + "v_loc = v_loc / np.linalg.norm(v_loc)\n", + "\n", + "k = 0\n", + "S_new = S0.copy()\n", + "S_new[k] = S0.max() * 5.0\n", + "U_new = U0.copy()\n", + "Vh_new = Vh0.copy()\n", + "U_new[:, k] = u_loc\n", + "Vh_new[k, :] = v_loc\n", + "W_injected = U_new @ np.diag(S_new) @ Vh_new\n", + "\n", + "print('Injected mode index (0-based):', k)\n", + "print('Injected sigma:', S_new[k])\n" + ], + "id": "38d9de01" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def top_k_mass(x, k=5):\n", + " p = np.abs(x) ** 2\n", + " idx = np.argsort(p)[::-1][:k]\n", + " return float(p[idx].sum())\n", + "\n", + "def localization_ratio(x):\n", + " return float(np.max(np.abs(x)) / np.linalg.norm(x))\n", + "\n", + "expected_left_top5 = top_k_mass(u_loc, 5)\n", + "expected_right_top5 = top_k_mass(v_loc, 5)\n", + "expected_top5_avg = 0.5 * (expected_left_top5 + expected_right_top5)\n", + "expected_left_loc = localization_ratio(u_loc)\n", + "expected_right_loc = localization_ratio(v_loc)\n", + "\n", + "print('Expected left top-5 mass :', expected_left_top5)\n", + "print('Expected right top-5 mass:', expected_right_top5)\n", + "print('Expected mean top-5 mass :', expected_top5_avg)\n", + "print('Expected left localization ratio :', expected_left_loc)\n", + "print('Expected right localization ratio:', expected_right_loc)\n" + ], + "id": "18eae44b" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class OneLayerMatrix(nn.Module):\n", + " def __init__(self, W):\n", + " super().__init__()\n", + " self.fc = nn.Linear(W.shape[1], W.shape[0], bias=False)\n", + " with torch.no_grad():\n", + " self.fc.weight.copy_(torch.tensor(W, dtype=torch.float32))\n", + "\n", + " def forward(self, x):\n", + " return self.fc(x)\n", + "\n", + "model = OneLayerMatrix(W_injected)\n", + "watcher = ww.WeightWatcher(model=model)\n" + ], + "id": "8654319c" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "randomized_model, trap_state = watcher.randomize_model(model=model, rng=seed, return_state=True)\n", + "print('Randomized model ready. permuted layer_ids =', sorted(trap_state['permuted_ids'].keys()))\n" + ], + "id": "0790789f" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trap_df, trap_state_out = watcher.analyze_traps(\n", + " randomized_model=randomized_model,\n", + " trap_state=trap_state,\n", + " return_artifacts=True,\n", + " plot=False,\n", + ")\n", + "\n", + "print('Number of detected traps:', len(trap_df))\n", + "show_cols = ['layer_id','trap_index','perm_mode_index','left_top_mass','right_top_mass','top_5_mass','sigma_perm']\n", + "trap_df[show_cols].sort_values('top_5_mass', ascending=False).head(10)\n" + ], + "id": "d1e95282" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We may see several traps from random fluctuations; choose the strongest localized trap.\n", + "row = trap_df.sort_values('top_5_mass', ascending=False).iloc[0]\n", + "print('Selected trap (max top_5_mass):')\n", + "print(row[['layer_id','trap_index','perm_mode_index','left_top_mass','right_top_mass','top_5_mass']])\n" + ], + "id": "73c66dbf" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "observed_left_top5 = float(row['left_top_mass'])\n", + "observed_right_top5 = float(row['right_top_mass'])\n", + "observed_top5_avg = float(row['top_5_mass'])\n", + "\n", + "print('Observed left top-5 mass :', observed_left_top5)\n", + "print('Observed right top-5 mass:', observed_right_top5)\n", + "print('Observed mean top-5 mass :', observed_top5_avg)\n", + "print('Absolute error left :', abs(observed_left_top5 - expected_left_top5))\n", + "print('Absolute error right:', abs(observed_right_top5 - expected_right_top5))\n", + "print('Absolute error mean :', abs(observed_top5_avg - expected_top5_avg))\n", + "\n", + "# Strong agreement checks for the injected localized mode.\n", + "assert observed_left_top5 > 0.95\n", + "assert observed_right_top5 > 0.95\n", + "assert observed_top5_avg > 0.95\n", + "\n", + "# The selected trap should be close to our injected values.\n", + "assert abs(observed_left_top5 - expected_left_top5) < 0.05\n", + "assert abs(observed_right_top5 - expected_right_top5) < 0.05\n", + "assert abs(observed_top5_avg - expected_top5_avg) < 0.05\n" + ], + "id": "9ca14646" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_ = watcher.analyze_traps(\n", + " randomized_model=randomized_model,\n", + " trap_state=trap_state,\n", + " return_artifacts=False,\n", + " plot=True,\n", + ")\n", + "print('Done: analyze_traps(plot=True) executed.')\n" + ], + "id": "d471148f" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.x" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file