Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions demo/notebooks/sklearn_wrappers.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Scikit-Learn Estimator Wrappers for BART\n",
"\n",
"`stochtree.BARTModel` is fundamentally a Bayesian interface in which users specify a prior, provide data, sample from the posterior, and manage / inspect the resulting posterior samples.\n",
"\n",
"However, the basic BART model \n",
"\n",
"$$y_i \\sim \\mathcal{N}\\left(f(X_i), \\sigma^2\\right)$$\n",
"\n",
"involves samples of a nonparametric function $f$ which estimates the expected value of $y$ given $X$. Averaging over these draws, the posterior mean $\\bar{f}$ alone may satisfy some supervised learning use cases. In order to serve this use case straightforwardly, we offer [scikit-learn-compatible estimator](https://scikit-learn.org/stable/developers/develop.html) wrappers around `BARTModel` which implement the familiar API of `sklearn` models.\n",
"\n",
"For continuous outcomes, the `stochtree.StochTreeBARTRegressor` class provides `fit`, `predict` and `score` methods.\n",
"\n",
"For binary outcomes (deployed via probit BART), the `stochtree.StochTreeBARTBinaryClassifier` class provides `fit`, `predict`, `predict_proba`, `decision_function`, and `score` methods.\n",
"\n",
"Users can fit multi-class classifiers by wrapping a [OneVsRestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) around `StochTreeBARTBinaryClassifier`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We begin by loading necessary libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.datasets import load_wine, load_breast_cancer\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.multiclass import OneVsRestClassifier\n",
"from stochtree import (\n",
" StochTreeBARTRegressor, \n",
" StochTreeBARTBinaryClassifier, \n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we seed a random number generator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"random_seed = 1234\n",
"rng = np.random.default_rng(random_seed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BART Regression via `sklearn` Estimator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We simulate some simple regression data to demonstrate the continuous outcome use case"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n = 100\n",
"p = 10\n",
"X = rng.normal(size=(n, p))\n",
"y = X[:, 0] * 3 + rng.normal(size=n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We fit a BART regression model by initializing a `StochTreeBARTRegressor` and calling its `fit()` method.\n",
"\n",
"\n",
"Since `stochtree.BARTModel` is configured primarily through parameter dictionaries, any downstream parameters that we wish to set are passed through as parameter dictionaries. In this case, we only specify the random seed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reg = StochTreeBARTRegressor(general_params={\"random_seed\": random_seed})\n",
"reg.fit(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can predict from this model and compare the (posterior mean) predictions to the true outcome"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred = reg.predict(X)\n",
"plt.scatter(pred, y)\n",
"plt.xlabel(\"Predicted\")\n",
"plt.ylabel(\"Actual\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also test the determinism of the model by running it again with the same seed and comparing predictions to the first model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reg2 = StochTreeBARTRegressor(general_params={\"random_seed\": random_seed})\n",
"reg2.fit(X, y)\n",
"pred2 = reg2.predict(X)\n",
"plt.scatter(pred, pred2)\n",
"plt.xlabel(\"First model\")\n",
"plt.ylabel(\"Second model\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cross-Validating a BART Model\n",
"\n",
"While the default hyperparameters of `stochtree.BARTModel` are designed to work well \"out of the box,\" we can use posterior mean prediction error to cross-validate the model's parameters.\n",
"\n",
"Below we use grid search to consider the effect of several BART parameters: \n",
"\n",
"1. Number of GFR iterations (`num_gfr`)\n",
"2. Number of MCMC iterations (`num_mcmc`)\n",
"3. `num_trees`, `alpha` and `beta` for the mean forest"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"param_grid = {\n",
" \"num_gfr\": [10, 40],\n",
" \"num_mcmc\": [0, 1000],\n",
" \"mean_forest_params\": [\n",
" {\"num_trees\": 50, \"alpha\": 0.95, \"beta\": 2.0},\n",
" {\"num_trees\": 100, \"alpha\": 0.90, \"beta\": 1.5},\n",
" {\"num_trees\": 200, \"alpha\": 0.85, \"beta\": 1.0},\n",
" ],\n",
"}\n",
"grid_search = GridSearchCV(\n",
" estimator=StochTreeBARTRegressor(),\n",
" param_grid=param_grid,\n",
" cv=5,\n",
" scoring=\"r2\",\n",
" n_jobs=-1,\n",
")\n",
"grid_search.fit(X, y)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cv_best_ind = np.argwhere(grid_search.cv_results_['rank_test_score'] == 1).item(0)\n",
"best_num_gfr = grid_search.cv_results_['param_num_gfr'][cv_best_ind].item(0)\n",
"best_num_mcmc = grid_search.cv_results_['param_num_mcmc'][cv_best_ind].item(0)\n",
"best_mean_forest_params = grid_search.cv_results_['param_mean_forest_params'][cv_best_ind]\n",
"best_num_trees = best_mean_forest_params['num_trees']\n",
"best_alpha = best_mean_forest_params['alpha']\n",
"best_beta = best_mean_forest_params['beta']\n",
"print_message = f\"\"\"\n",
"Hyperparameters chosen by grid search: \n",
" num_gfr: {best_num_gfr} \n",
" num_mcmc: {best_num_mcmc} \n",
" num_trees: {best_num_trees} \n",
" alpha: {best_alpha} \n",
" beta: {best_beta}\n",
"\"\"\"\n",
"print(print_message)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BART Classification via `sklearn` Estimator\n",
"\n",
"Now, we demonstrate the same functionality with binary and categorical outcomes, which require working with the `StochTreeBARTBinaryClassifier` class (and a wrapper for multi-class outcomes).\n",
"\n",
"First, we load a dataset from `sklearn` with a binary outcome."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = load_breast_cancer()\n",
"X = dataset.data\n",
"y = dataset.target"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we fit a binary classification model as follows"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"clf = StochTreeBARTBinaryClassifier(general_params={\"random_seed\": random_seed})\n",
"clf.fit(X=X, y=y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to class predictions, we can compute and visualize the predicted probability of each class via `predict_proba()`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"probs = clf.predict_proba(X)\n",
"plt.hist(probs[:, 1], bins=30)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we load a multi-class classification dataset from `sklearn`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = load_wine()\n",
"X = dataset.data\n",
"y = dataset.target"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And fit a multi-class classification model by wrapping a `OneVsRestClassifier` around `StochTreeBARTBinaryClassifier`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"clf = OneVsRestClassifier(\n",
" StochTreeBARTBinaryClassifier(general_params={\"random_seed\": random_seed})\n",
")\n",
"clf.fit(X=X, y=y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we visualize the histogram of predicted probabilities for each outcome category."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, (ax1, ax2, ax3) = plt.subplots(3, 1)\n",
"fig.tight_layout(pad=3.0)\n",
"probs = clf.predict_proba(X)\n",
"ax1.hist(probs[y == 0, 0], bins=30)\n",
"ax1.set_title(\"Predicted Probabilities for Class 0\")\n",
"ax1.set_xlim(0, 1)\n",
"ax2.hist(probs[y == 1, 1], bins=30)\n",
"ax2.set_title(\"Predicted Probabilities for Class 1\")\n",
"ax2.set_xlim(0, 1)\n",
"ax3.hist(probs[y == 2, 2], bins=30)\n",
"ax3.set_title(\"Predicted Probabilities for Class 2\")\n",
"ax3.set_xlim(0, 1)\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def sample(
# Raise a warning if the data have ties and only GFR is being run
if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0):
num_values, num_cov_orig = X_train.shape
max_grid_size = floor(num_values / cutpoint_grid_size)
max_grid_size = floor(num_values / cutpoint_grid_size) if num_values > cutpoint_grid_size else 1
x_is_df = isinstance(X_train, pd.DataFrame)
covs_warning_1 = []
covs_warning_2 = []
Expand Down
2 changes: 1 addition & 1 deletion stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def sample(
# Raise a warning if the data have ties and only GFR is being run
if (num_gfr > 0) and (num_burnin == 0) and (num_mcmc == 0):
num_values, num_cov_orig = X_train.shape
max_grid_size = floor(num_values / cutpoint_grid_size)
max_grid_size = floor(num_values / cutpoint_grid_size) if num_values > cutpoint_grid_size else 1
x_is_df = isinstance(X_train, pd.DataFrame)
covs_warning_1 = []
covs_warning_2 = []
Expand Down
Loading