diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 207f098..cd3e51d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.7 + rev: v0.11.5 hooks: - id: ruff args: ["--fix"] diff --git a/examples/bayes_llama3/llama3/eval.py b/examples/bayes_llama3/llama3/eval.py index 139d1d2..5371f37 100644 --- a/examples/bayes_llama3/llama3/eval.py +++ b/examples/bayes_llama3/llama3/eval.py @@ -43,9 +43,9 @@ def __init__(self, config: FrozenConfigDict): config["pretrained_model_name_or_path"] ) else: - assert os.path.isdir( - config["checkpoints_folder"] - ), "Provided checkpoints is not a path to a folder" + assert os.path.isdir(config["checkpoints_folder"]), ( + "Provided checkpoints is not a path to a folder" + ) checkpoints = [ os.path.join(config["checkpoints_folder"], path) for path in os.listdir(config["checkpoints_folder"]) diff --git a/examples/continual_regression.ipynb b/examples/continual_regression.ipynb index 5e95560..ad605af 100644 --- a/examples/continual_regression.ipynb +++ b/examples/continual_regression.ipynb @@ -62,7 +62,14 @@ "outputs": [], "source": [ "episode_x_boundaries = torch.linspace(0, n_episodes, n_episodes + 1)\n", - "xs = torch.stack([torch.linspace(episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode) for i in range(n_episodes)])\n", + "xs = torch.stack(\n", + " [\n", + " torch.linspace(\n", + " episode_x_boundaries[i], episode_x_boundaries[i + 1], samps_per_episode\n", + " )\n", + " for i in range(n_episodes)\n", + " ]\n", + ")\n", "ys = torch.stack([true_f(x) + y_sd * torch.randn_like(x) for x in xs])" ] }, @@ -85,18 +92,20 @@ "source": [ "plt_linsp = torch.linspace(-1, episode_x_boundaries[-1] + 1, 1000)\n", "\n", + "\n", "def plot_data(ax, up_to_episode=None):\n", " if up_to_episode is None:\n", " up_to_episode = n_episodes\n", - " \n", - " ax.plot(xs.flatten(), ys.flatten(), 'o', color='gray', alpha=0.2)\n", + "\n", + " ax.plot(xs.flatten(), ys.flatten(), \"o\", color=\"gray\", alpha=0.2)\n", " for i in range(up_to_episode):\n", - " ax.plot(xs[i], ys[i], 'o', color='orange')\n", - " \n", + " ax.plot(xs[i], ys[i], \"o\", color=\"orange\")\n", + "\n", " for v in episode_x_boundaries:\n", - " ax.axvline(v, color='gray', linestyle='--', alpha=0.75)\n", - " ax.plot(plt_linsp, true_f(plt_linsp), color='green', zorder=10)\n", - " ax.set_ylim(-2., 2.5)\n", + " ax.axvline(v, color=\"gray\", linestyle=\"--\", alpha=0.75)\n", + " ax.plot(plt_linsp, true_f(plt_linsp), color=\"green\", zorder=10)\n", + " ax.set_ylim(-2.0, 2.5)\n", + "\n", "\n", "fig, ax = plt.subplots()\n", "plot_data(ax)" @@ -166,11 +175,21 @@ "outputs": [], "source": [ "def log_prior(p, prior_mean, prior_sd: float):\n", - " all_vals = tree_map(lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False).log_prob(p).sum(), p, prior_mean, prior_sd)\n", + " all_vals = tree_map(\n", + " lambda p, m, sd: torch.distributions.Normal(m, sd, validate_args=False)\n", + " .log_prob(p)\n", + " .sum(),\n", + " p,\n", + " prior_mean,\n", + " prior_sd,\n", + " )\n", " return tree_reduce(torch.add, all_vals)\n", - " \n", + "\n", + "\n", "def log_likelihood(y_pred, y):\n", - " return torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()" + " return (\n", + " torch.distributions.Normal(y_pred, y_sd, validate_args=False).log_prob(y).mean()\n", + " )" ] }, { @@ -182,7 +201,10 @@ "def log_posterior(params, batch, prior_mean, prior_sd):\n", " x, y = batch\n", " y_pred = mlp_functional(params, x)\n", - " log_post = log_likelihood(y_pred, y) + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n", + " log_post = (\n", + " log_likelihood(y_pred, y)\n", + " + log_prior(params, prior_mean, prior_sd) / samps_per_episode\n", + " )\n", " return log_post, y_pred" ] }, @@ -213,7 +235,13 @@ "outputs": [], "source": [ "batch_size = 3\n", - "dataloaders = [torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)), batch_size=batch_size) for x, y in zip(xs, ys)]" + "dataloaders = [\n", + " torch.utils.data.DataLoader(\n", + " torch.utils.data.TensorDataset(x.unsqueeze(-1), y.unsqueeze(-1)),\n", + " batch_size=batch_size,\n", + " )\n", + " for x, y in zip(xs, ys)\n", + "]" ] }, { @@ -227,7 +255,9 @@ " for _ in range(n_epochs):\n", " for batch in dataloader:\n", " opt.zero_grad()\n", - " loss = -log_posterior(dict(mlp.named_parameters()), batch, prior_mean, prior_sd)[0]\n", + " loss = -log_posterior(\n", + " dict(mlp.named_parameters()), batch, prior_mean, prior_sd\n", + " )[0]\n", " loss.backward()\n", " opt.step()" ] @@ -252,10 +282,10 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_predictions(params, ax, x, sd=y_sd, alpha=1.):\n", + "def plot_predictions(params, ax, x, sd=y_sd, alpha=1.0):\n", " preds = mlp_functional(params, x.unsqueeze(-1)).detach().numpy().squeeze()\n", - " ax.plot(x, preds, color='blue', alpha=alpha)\n", - " ax.fill_between(x, preds - sd, preds + sd, color='blue', alpha=0.2)" + " ax.plot(x, preds, color=\"blue\", alpha=alpha)\n", + " ax.fill_between(x, preds - sd, preds + sd, color=\"blue\", alpha=0.2)" ] }, { @@ -275,12 +305,14 @@ } ], "source": [ - "fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n", + "fig, axes = plt.subplots(\n", + " 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n", + ")\n", "\n", "for i, ax in enumerate(axes):\n", - " plot_data(ax, up_to_episode=i+1)\n", + " plot_data(ax, up_to_episode=i + 1)\n", " plot_predictions(trained_params[i], ax, plt_linsp)\n", - " ax.set_title(f\"After Episode {i+1}\")" + " ax.set_title(f\"After Episode {i + 1}\")" ] }, { @@ -318,7 +350,13 @@ "def train_for_vi(dataloader, prior_mean, prior_sd, n_epochs=200, init_log_sds=None):\n", " seq_log_post = partial(log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)\n", " optimizer = torchopt.adam(lr=2e-3)\n", - " transform = posteriors.vi.diag.build(seq_log_post, optimizer, temperature=1/samps_per_episode, init_log_sds=init_log_sds, stl=False)\n", + " transform = posteriors.vi.diag.build(\n", + " seq_log_post,\n", + " optimizer,\n", + " temperature=1 / samps_per_episode,\n", + " init_log_sds=init_log_sds,\n", + " stl=False,\n", + " )\n", " state = transform.init(dict(mlp.named_parameters()))\n", " nelbos = []\n", " for _ in range(n_epochs):\n", @@ -346,9 +384,18 @@ "nelbos = []\n", "for i in range(n_episodes):\n", " seq_prior_mean = prior_mean if i == 0 else vi_states[i - 1].params\n", - " seq_prior_sd = prior_sd if i == 0 else tree_map(lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd ** 2), vi_states[i - 1].log_sd_diag)\n", - " seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6., mlp.state_dict())\n", - " state, nelbos_i = train_for_vi(dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds)\n", + " seq_prior_sd = (\n", + " prior_sd\n", + " if i == 0\n", + " else tree_map(\n", + " lambda lsd: torch.sqrt(torch.exp(lsd) ** 2 + transition_sd**2),\n", + " vi_states[i - 1].log_sd_diag,\n", + " )\n", + " )\n", + " seq_log_sds = tree_map(lambda x: torch.zeros_like(x) - 6.0, mlp.state_dict())\n", + " state, nelbos_i = train_for_vi(\n", + " dataloaders[i], seq_prior_mean, seq_prior_sd, init_log_sds=seq_log_sds\n", + " )\n", " vi_states += [state]\n", " nelbos += [nelbos_i]\n", " mlp.load_state_dict(vi_states[i].params)" @@ -371,7 +418,9 @@ } ], "source": [ - "fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)\n", + "fig, axes = plt.subplots(\n", + " 1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True\n", + ")\n", "\n", "n_samples = 20\n", "\n", @@ -379,8 +428,8 @@ " for _ in range(n_samples):\n", " sample = posteriors.vi.diag.sample(vi_states[i])\n", " plot_predictions(sample, ax, plt_linsp, sd=y_sd, alpha=0.2)\n", - " plot_data(ax, up_to_episode=i+1)\n", - " ax.set_title(f\"After Episode {i+1}\")" + " plot_data(ax, up_to_episode=i + 1)\n", + " ax.set_title(f\"After Episode {i + 1}\")" ] }, { diff --git a/examples/pyro_pima_indians_sghmc.ipynb b/examples/pyro_pima_indians_sghmc.ipynb index e701472..5a9499f 100644 --- a/examples/pyro_pima_indians_sghmc.ipynb +++ b/examples/pyro_pima_indians_sghmc.ipynb @@ -365,7 +365,7 @@ " samples[:, i] = torch.stack([state.params for state in states])\n", " if i > N_warmup:\n", " j = i - N_warmup\n", - " gelman_rubin[j] = pyro.ops.stats.gelman_rubin(log_posts[:, N_warmup:i + 1])\n" + " gelman_rubin[j] = pyro.ops.stats.gelman_rubin(log_posts[:, N_warmup : i + 1])" ] }, { @@ -469,7 +469,7 @@ "for ind, ax in enumerate(axes.flatten()):\n", " ax.hist(samples[:, N_warmup:, ind].flatten(), bins=50, density=True)\n", " ax.set_title(column_names[ind])\n", - "fig.tight_layout()\n" + "fig.tight_layout()" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 3feacd4..db9ce1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "posteriors" -version = "0.1.0" +version = "0.1.1" description = "Uncertainty quantification with PyTorch" readme = "README.md" requires-python =">=3.9" diff --git a/tests/test_utils.py b/tests/test_utils.py index 0da3902..d039f83 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -110,7 +110,8 @@ def test_model_to_function(): func_output2 = func_lm(dict(lm.named_parameters()), input_ids, attention_mask) - assert type(output) == type(func_output1) == type(func_output2) + assert type(output) is type(func_output1) + assert type(output) is type(func_output2) assert torch.allclose(output["logits"], func_output1["logits"]) assert torch.allclose(output["logits"], func_output2["logits"])