diff --git a/demo/notebooks/sklearn_wrappers.ipynb b/demo/notebooks/sklearn_wrappers.ipynb new file mode 100644 index 00000000..3d636543 --- /dev/null +++ b/demo/notebooks/sklearn_wrappers.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scikit-Learn Estimator Wrappers for BART\n", + "\n", + "`stochtree.BARTModel` is fundamentally a Bayesian interface in which users specify a prior, provide data, sample from the posterior, and manage / inspect the resulting posterior samples.\n", + "\n", + "However, the basic BART model \n", + "\n", + "$$y_i \\sim \\mathcal{N}\\left(f(X_i), \\sigma^2\\right)$$\n", + "\n", + "involves samples of a nonparametric function $f$ which estimates the expected value of $y$ given $X$. Averaging over these draws, the posterior mean $\\bar{f}$ alone may satisfy some supervised learning use cases. In order to serve this use case straightforwardly, we offer [scikit-learn-compatible estimator](https://scikit-learn.org/stable/developers/develop.html) wrappers around `BARTModel` which implement the familiar API of `sklearn` models.\n", + "\n", + "For continuous outcomes, the `stochtree.StochTreeBARTRegressor` class provides `fit`, `predict` and `score` methods.\n", + "\n", + "For binary outcomes (deployed via probit BART), the `stochtree.StochTreeBARTBinaryClassifier` class provides `fit`, `predict`, `predict_proba`, `decision_function`, and `score` methods.\n", + "\n", + "Users can fit multi-class classifiers by wrapping a [OneVsRestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) around `StochTreeBARTBinaryClassifier`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We begin by loading necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.datasets import load_wine, load_breast_cancer\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.multiclass import OneVsRestClassifier\n", + "from stochtree import (\n", + " StochTreeBARTRegressor, \n", + " StochTreeBARTBinaryClassifier, \n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we seed a random number generator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random_seed = 1234\n", + "rng = np.random.default_rng(random_seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## BART Regression via `sklearn` Estimator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We simulate some simple regression data to demonstrate the continuous outcome use case" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n = 100\n", + "p = 10\n", + "X = rng.normal(size=(n, p))\n", + "y = X[:, 0] * 3 + rng.normal(size=n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We fit a BART regression model by initializing a `StochTreeBARTRegressor` and calling its `fit()` method.\n", + "\n", + "\n", + "Since `stochtree.BARTModel` is configured primarily through parameter dictionaries, any downstream parameters that we wish to set are passed through as parameter dictionaries. In this case, we only specify the random seed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reg = StochTreeBARTRegressor(general_params={\"random_seed\": random_seed})\n", + "reg.fit(X, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we can predict from this model and compare the (posterior mean) predictions to the true outcome" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred = reg.predict(X)\n", + "plt.scatter(pred, y)\n", + "plt.xlabel(\"Predicted\")\n", + "plt.ylabel(\"Actual\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also test the determinism of the model by running it again with the same seed and comparing predictions to the first model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reg2 = StochTreeBARTRegressor(general_params={\"random_seed\": random_seed})\n", + "reg2.fit(X, y)\n", + "pred2 = reg2.predict(X)\n", + "plt.scatter(pred, pred2)\n", + "plt.xlabel(\"First model\")\n", + "plt.ylabel(\"Second model\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cross-Validating a BART Model\n", + "\n", + "While the default hyperparameters of `stochtree.BARTModel` are designed to work well \"out of the box,\" we can use posterior mean prediction error to cross-validate the model's parameters.\n", + "\n", + "Below we use grid search to consider the effect of several BART parameters: \n", + "\n", + "1. Number of GFR iterations (`num_gfr`)\n", + "2. Number of MCMC iterations (`num_mcmc`)\n", + "3. `num_trees`, `alpha` and `beta` for the mean forest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "param_grid = {\n", + " \"num_gfr\": [10, 40],\n", + " \"num_mcmc\": [0, 1000],\n", + " \"mean_forest_params\": [\n", + " {\"num_trees\": 50, \"alpha\": 0.95, \"beta\": 2.0},\n", + " {\"num_trees\": 100, \"alpha\": 0.90, \"beta\": 1.5},\n", + " {\"num_trees\": 200, \"alpha\": 0.85, \"beta\": 1.0},\n", + " ],\n", + "}\n", + "grid_search = GridSearchCV(\n", + " estimator=StochTreeBARTRegressor(),\n", + " param_grid=param_grid,\n", + " cv=5,\n", + " scoring=\"r2\",\n", + " n_jobs=-1,\n", + ")\n", + "grid_search.fit(X, y)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cv_best_ind = np.argwhere(grid_search.cv_results_['rank_test_score'] == 1).item(0)\n", + "best_num_gfr = grid_search.cv_results_['param_num_gfr'][cv_best_ind].item(0)\n", + "best_num_mcmc = grid_search.cv_results_['param_num_mcmc'][cv_best_ind].item(0)\n", + "best_mean_forest_params = grid_search.cv_results_['param_mean_forest_params'][cv_best_ind]\n", + "best_num_trees = best_mean_forest_params['num_trees']\n", + "best_alpha = best_mean_forest_params['alpha']\n", + "best_beta = best_mean_forest_params['beta']\n", + "print_message = f\"\"\"\n", + "Hyperparameters chosen by grid search: \n", + " num_gfr: {best_num_gfr} \n", + " num_mcmc: {best_num_mcmc} \n", + " num_trees: {best_num_trees} \n", + " alpha: {best_alpha} \n", + " beta: {best_beta}\n", + "\"\"\"\n", + "print(print_message)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## BART Classification via `sklearn` Estimator\n", + "\n", + "Now, we demonstrate the same functionality with binary and categorical outcomes, which require working with the `StochTreeBARTBinaryClassifier` class (and a wrapper for multi-class outcomes).\n", + "\n", + "First, we load a dataset from `sklearn` with a binary outcome." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_breast_cancer()\n", + "X = dataset.data\n", + "y = dataset.target" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And we fit a binary classification model as follows" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clf = StochTreeBARTBinaryClassifier(general_params={\"random_seed\": random_seed})\n", + "clf.fit(X=X, y=y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In addition to class predictions, we can compute and visualize the predicted probability of each class via `predict_proba()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "probs = clf.predict_proba(X)\n", + "plt.hist(probs[:, 1], bins=30)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we load a multi-class classification dataset from `sklearn`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_wine()\n", + "X = dataset.data\n", + "y = dataset.target" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And fit a multi-class classification model by wrapping a `OneVsRestClassifier` around `StochTreeBARTBinaryClassifier`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clf = OneVsRestClassifier(\n", + " StochTreeBARTBinaryClassifier(general_params={\"random_seed\": random_seed})\n", + ")\n", + "clf.fit(X=X, y=y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And we visualize the histogram of predicted probabilities for each outcome category." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, (ax1, ax2, ax3) = plt.subplots(3, 1)\n", + "fig.tight_layout(pad=3.0)\n", + "probs = clf.predict_proba(X)\n", + "ax1.hist(probs[y == 0, 0], bins=30)\n", + "ax1.set_title(\"Predicted Probabilities for Class 0\")\n", + "ax1.set_xlim(0, 1)\n", + "ax2.hist(probs[y == 1, 1], bins=30)\n", + "ax2.set_title(\"Predicted Probabilities for Class 1\")\n", + "ax2.set_xlim(0, 1)\n", + "ax3.hist(probs[y == 2, 2], bins=30)\n", + "ax3.set_title(\"Predicted Probabilities for Class 2\")\n", + "ax3.set_xlim(0, 1)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/stochtree/bart.py b/stochtree/bart.py index cdd3a28c..17328f1e 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -468,7 +468,7 @@ def sample( # Raise a warning if the data have ties and only GFR is being run if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0): num_values, num_cov_orig = X_train.shape - max_grid_size = floor(num_values / cutpoint_grid_size) + max_grid_size = floor(num_values / cutpoint_grid_size) if num_values > cutpoint_grid_size else 1 x_is_df = isinstance(X_train, pd.DataFrame) covs_warning_1 = [] covs_warning_2 = [] diff --git a/stochtree/bcf.py b/stochtree/bcf.py index b232a46d..627b1c07 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -650,7 +650,7 @@ def sample( # Raise a warning if the data have ties and only GFR is being run if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0): num_values, num_cov_orig = X_train.shape - max_grid_size = floor(num_values / cutpoint_grid_size) + max_grid_size = floor(num_values / cutpoint_grid_size) if num_values > cutpoint_grid_size else 1 x_is_df = isinstance(X_train, pd.DataFrame) covs_warning_1 = [] covs_warning_2 = []