Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ To build the R docs, first run a script that lays out the package as needed

```{bash}
cd stochtree_repo
Rscript cran-bootstrap.R 1
cp _pkgdown.yml stochtree_cran/_pkgdown.yml
cp R_README.md stochtree_cran/README.md
Rscript cran-bootstrap.R 1 1 1
cd ..
```

Expand Down
25 changes: 9 additions & 16 deletions dev/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@

# Clone stochtree repo
git clone --recursive [email protected]:StochasticTree/stochtree.git stochtree_repo
cd stochtree_repo
git checkout documentation-overhaul
cd ..

# Setup python virtual environment and the stochtree python package
# Set up python virtual environment and dependencies
python -m venv venv
source venv/bin/activate
cd stochtree_repo
pip install --upgrade pip
pip install numpy scipy pytest pandas scikit-learn pybind11
pip install -r requirements.txt

# Install python package
cd stochtree_repo
pip install .
cd ..

# Install python dependencies for the doc site
pip install mkdocs-material
pip install mkdocstrings-python
pip install mkdocs-jupyter

# Build the C++ doxygen output
sed -i '' 's|^OUTPUT_DIRECTORY *=.*|OUTPUT_DIRECTORY = ../docs/cpp_docs/|' stochtree_repo/Doxyfile
sed -i '' 's|^GENERATE_XML *=.*|GENERATE_XML = NO|' stochtree_repo/Doxyfile
Expand All @@ -34,12 +28,11 @@ Rscript -e 'install.packages(c("remotes", "devtools", "roxygen2", "ggplot2", "la

# Build the R package doc site
cd stochtree_repo
Rscript cran-bootstrap.R
cp _pkgdown.yml stochtree_cran/_pkgdown.yml
cp R_README.md stochtree_cran/README.md
Rscript cran-bootstrap.R 1 1 1
cd ..
mkdir -p docs/R_docs
Rscript -e 'pkgdown::build_site_github_pages("stochtree_repo/stochtree_cran", dest_dir = "../../docs/R_docs", install = TRUE)'
mkdir -p docs/R_docs/pkgdown
Rscript -e 'pkgdown::build_site_github_pages("stochtree_repo/stochtree_cran", dest_dir = "../../docs/R_docs/pkgdown", install = TRUE)'
rm -rf stochtree_repo/stochtree_cran

# Copy Jupyter notebook demos over to docs directory
cp stochtree_repo/demo/notebooks/supervised_learning.ipynb docs/python_docs/demo/supervised_learning.ipynb
Expand Down
103 changes: 103 additions & 0 deletions docs/about.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Overview of Stochastic Tree Models

Stochastic tree models are a powerful addition to your modeling toolkit.
As with many machine learning methods, understanding these models in depth is an involved task.

There are many excellent published papers on stochastic tree models
(to name a few, the [original BART paper](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-4/issue-1/BART-Bayesian-additive-regression-trees/10.1214/09-AOAS285.full),
[the XBART paper](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012),
and [the BCF paper](https://projecteuclid.org/journals/bayesian-analysis/volume-15/issue-3/Bayesian-Regression-Tree-Models-for-Causal-Inference--Regularization-Confounding/10.1214/19-BA1195.full)).
Here, we aim to build up an abbreviated intuition for these models from their conceptually-simple building blocks.

## Notation

We're going to introduce some notation to make these concepts precise.
In a traditional supervised learning setting, we hope to predict some **outcome** from **features** in a training dataset.
We'll call the outcome $y$ and the features $X$.
Our goal is to come up with a function $f$ that predicts the outcome $y$ as well as possible from $X$ alone.

## Decision Trees

[Decision tree learning](https://en.wikipedia.org/wiki/Decision_tree_learning) is a simple machine learning method that
constructs a function $f$ from a series of conditional statements. Consider the tree below.

```mermaid
stateDiagram-v2
state split_one <<choice>>
state split_two <<choice>>
split_one --> split_two: if x1 <= 1
split_one --> c : if x1 > 1
split_two --> a: if x2 <= -2
split_two --> b : if x2 > -2
```

We evaluate two conditional statments (`X[,1] > 1` and `X[,2] > -2`), arranged in a tree-like sequence of branches,
which determine whether the model predicts `a`, `b`, or `c`. We could similarly express this tree in math notation as

\begin{equation*}
f(X_i) = \begin{cases}
a & ; \;\;\; X_{i,1} \leq 1, \;\; X_{i,2} \leq -2\\
b & ; \;\;\; X_{i,1} \leq 1, \;\; X_{i,2} > -2\\
c & ; \;\;\; X_{i,1} > 1
\end{cases}
\end{equation*}

We won't belabor the discussion of trees as there are many good textbooks and online articles on the topic,
but we'll close by noting that training decision trees introduces a delicate balance between
[overfitting and underfitting](https://en.wikipedia.org/wiki/Overfitting).
Simple trees like the one above do not capture much complexity in a dataset and may potentially be underfit
while deep, complex trees are vulnerable to overfitting and tend to have high variance.

## Boosted Decision Tree Ensembles

One way to address the overfitting-underfitting tradeoff of decision trees is to build an "ensemble" of decision
trees, so that the function $f$ is defined by a sum of $k$ individual decision trees $g_i$

\begin{equation*}
f(X_i) = g_1(X_i) + \dots + g_k(X_i)
\end{equation*}

There are several ways to train an ensemble of decision trees (sometimes called "forests"), the most popular of which are [random forests](https://en.wikipedia.org/wiki/Random_forest) and
[gradient boosting](https://en.wikipedia.org/wiki/Gradient_boosting). Their main difference is that random forests train
all $m$ trees independently of one another, while boosting trains tree sequentially, so that tree $j$ depends on the result of training trees 1 through $j-1$.
Libraries like [xgboost](https://xgboost.readthedocs.io/en/stable/) and [LightGBM](https://lightgbm.readthedocs.io/en/latest/) are popular examples of boosted tree ensembles.

Tree ensembles often [outperform neural networks and other machine learning methods on tabular datasets](https://arxiv.org/abs/2207.08815),
but classic tree ensemble methods return a single estimated function $f$, without expressing uncertainty around its estimates.

## Stochastic Tree Ensembles

[Stochastic](https://en.wikipedia.org/wiki/Stochastic) tree ensembles differ from their classical counterparts in their use of randomness in learning a function.
Rather than returning a single "best" tree ensemble, stochastic tree ensembles return a range of tree ensembles that fit the data well.
Mechanically, it's useful to think of "sampling" -- rather than "fitting" -- a stochastic tree ensemble model.

Why is this useful? Suppose we've sampled $m$ forests. For each observation $i$, we obtain $m$ predictions: $[f_1(X_i), \dots, f_m(X_i)]$.
From this "dataset" of predictions, we can compute summary statistics, where a mean or median would give something akin to the predictions of an xgboost or lightgbm model,
and the $\alpha$ and $1-\alpha$ quantiles give a [credible interval](https://en.wikipedia.org/wiki/Credible_interval).

Rather than explain each of the models that `stochtree` supports in depth here, we provide a high-level overview, with pointers to the relevant literature.

### Supervised Learning

The [`bart`](R_docs/pkgdown/reference/bart.html) R function and the [`BARTModel`](python_docs/api/bart.md) python class are the primary interface for supervised
prediction tasks in `stochtree`. The primary references for these models are
[BART (Chipman, George, McCulloch 2010)](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-4/issue-1/BART-Bayesian-additive-regression-trees/10.1214/09-AOAS285.full) and
[XBART (He and Hahn 2021)](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012).

In addition to the standard BART / XBART models, in which each tree's leaves return a constant prediction, `stochtree` also supports
arbitrary leaf regression on a user-provided basis (i.e. an expanded version of [Chipman et al 2002](https://link.springer.com/article/10.1023/A:1013916107446) and [Gramacy and Lee 2012](https://www.tandfonline.com/doi/abs/10.1198/016214508000000689)).

### Causal Inference

The [`bcf`](R_docs/pkgdown/reference/bcf.html) R function and the [`BCFModel`](python_docs/api/bcf.md) python class are the primary interface for causal effect
estimation in `stochtree`. The primary references for these models are
[BCF (Hahn, Murray, Carvalho 2021)](https://projecteuclid.org/journals/bayesian-analysis/volume-15/issue-3/Bayesian-Regression-Tree-Models-for-Causal-Inference--Regularization-Confounding/10.1214/19-BA1195.full) and
[XBCF (Krantsevich, He, Hahn 2022)](https://arxiv.org/abs/2209.06998).

### Additional Modeling Features

Both the BART and BCF interfaces in `stochtree` support the following extensions:

* Accelerated / "warm-start" sampling of forests (i.e. [He and Hahn 2021](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012))
* Forest-based heteroskedasticity (i.e. [Murray 2021](https://www.tandfonline.com/doi/abs/10.1080/01621459.2020.1813587))
* Additive random effects (i.e. [Gelman et al 2008](https://www.tandfonline.com/doi/abs/10.1198/106186008X287337))
5 changes: 2 additions & 3 deletions docs/cpp_docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
This page documents the data structures and interfaces that constitute the `stochtree` C++ core.
It may be useful to researchers building novel tree algorithms or users seeking a deeper understanding of the algorithms implemented in `stochtree`. This resource is split into:

1. Technical documentation of the design / computational aspects of the C++ core
1. [Tree API](tree.md): decision tree class which underpins the ensembles that `stochtree` samples
2. [Tracker API](tracking.md): temporary data structures that synchronize a training dataset and the current state of a decision tree ensemble for faster sampling
1. Technical documentation of the design / computational aspects of the C++ core (in progress!)
1. [Tracker API](tracking.md): temporary data structures that synchronize a training dataset and the current state of a decision tree ensemble for faster sampling
2. [Doxygen site](doxygen/index.html) with auto-generated documentation of C++ classes and functions
7 changes: 1 addition & 6 deletions docs/cpp_docs/tracking.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Forest Sampling Tracker API
# Forest Sampling Tracker

A truly minimalist tree ensemble library only needs

Expand All @@ -21,8 +21,3 @@ These operations both perform unnecessary computation which can be avoided with

1. A mapping from dataset row index to leaf node id for every tree in an ensemble (so that we can skip the tree traversal during prediction)
2. A mapping from leaf node id to dataset row indices every tree in an ensemble (so that we can skip the full pass through the training data at split evaluation)

## Forest Tracker

The `ForestTracker` class is a wrapper around several implementations of the mappings discussed above.

9 changes: 0 additions & 9 deletions docs/cpp_docs/tree.md

This file was deleted.

5 changes: 5 additions & 0 deletions docs/development/contributing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Contributing

`stochtree` is hosted on [Github](https://github.com/StochasticTree/stochtree/).
Any feedback, requests, or bug reports can be submitted as [issues](https://github.com/StochasticTree/stochtree/issues).
Moreover, if you have ideas for how to improve stochtree, we welcome [pull requests](https://github.com/StochasticTree/stochtree/pulls) :)
1 change: 1 addition & 0 deletions docs/development/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

`stochtree` is in active development. Here, we detail some aspects of the development process

* [Contributing](contributing.md): how to get involved with stochtree, by contributing code, documentation, or helpful feedback
* [Roadmap](roadmap.md): timelines for new feature development and releases
28 changes: 17 additions & 11 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
# Getting Started

`stochtree` is composed of a C++ "core" and R / Python interfaces to that core.
Details on installation and use are available below:
Below, we detail how to install the R / Python packages, or work directly with the C++ codebase.

## R Package

The R package can be installed from CRAN via

```
install.packages("stochtree")
```

The development version of `stochtree` can be installed from Github via

```
remotes::install_github("StochasticTree/stochtree", ref="r-dev")
```

## Python Package

Expand Down Expand Up @@ -39,7 +53,7 @@ Then install the package from github via pip
pip install git+https://github.com/StochasticTree/stochtree.git
```

(*Note*: if you'd also like to run `stochtree`'s notebook examples, you will also need jupyterlab, seaborn, and matplotlib)
(*Note*: if you'd like to run `stochtree`'s notebook examples, you will also need `jupyterlab`, `seaborn`, and `matplotlib`)

```{bash}
conda install matplotlib seaborn
Expand Down Expand Up @@ -70,20 +84,12 @@ Then install stochtree via
pip install git+https://github.com/StochasticTree/stochtree.git
```

As above, if you'd like to run the notebook examples in the `demo/` subfolder, you will also need jupyterlab, seaborn, and matplotlib and you will have to [clone the repo](#cloning-the-repository)
As above, if you'd like to run the notebook examples in the `demo/` subfolder, you will also need `jupyterlab`, `seaborn`, and `matplotlib` and you will have to [clone the repo](#cloning-the-repository)

```{bash}
pip install matplotlib seaborn jupyterlab
```

## R Package

The package can be installed in R via

```
remotes::install_github("StochasticTree/stochtree", ref="r-dev")
```

## C++ Core

While the C++ core links to both R and Python for a performant, high-level interface,
Expand Down
18 changes: 17 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# StochTree

`stochtree` (stochastic tree) is software for building stochastic tree ensembles (i.e. BART, XBART) for supervised learning and causal inference.
`stochtree` (short for "stochastic tree") is a software suite for flexible decision tree modeling.
It has two primary interfaces:

1. "High-level": robust implementations of many popular stochastic tree algorithms (BART, XBART, BCF, XBCF), with support for serialization and parallelism.
2. "Low-level": access to the "inner loop" of a stochastic forest sampler, allowing custom tree algorithm development in <50 lines of code.

The "core" of the software is written in C++, but it provides R and Python APIs.
The R package is [available on CRAN](https://cran.r-project.org/web/packages/stochtree/index.html) and the python package will soon be on PyPI.

## Table of Contents

* [Getting Started](getting-started.md): Details on how to install and use `stochtree`
* [About](about.md): Overview of the models supported by stochtree and pointers to further reading
* [R Package](R_docs/index.md): Complete documentation of the R package
* [Python Package](python_docs/index.md): Complete documentation of the python package
* [C++ Core API and Architecture](cpp_docs/index.md): Overview and documentation of the C++ codebase that supports stochtree
* [Development](development/index.md): Roadmap and how to contribute
10 changes: 8 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ theme:
nav:
- Home: index.md
- 'Getting Started': getting-started.md
- 'About StochTree': about.md
- 'R Package':
- 'Overview': R_docs/index.md
- 'Pkgdown Site': 'R_docs/pkgdown/index.html'
Expand All @@ -61,11 +62,11 @@ nav:
- 'BCF': python_docs/demo/causal_inference.ipynb
- 'C++ Core API and Architecture':
- cpp_docs/index.md
- 'Tree Data Structure': cpp_docs/tree.md
- 'Tracking Data Structure': cpp_docs/tracking.md
- 'Tracking Data Structures': cpp_docs/tracking.md
- 'C++ Doxygen Site': 'cpp_docs/doxygen/index.html'
- 'Development':
- 'Overview': development/index.md
- 'Contributing': development/contributing.md
- 'Roadmap': development/roadmap.md
extra:
social:
Expand All @@ -89,6 +90,11 @@ markdown_extensions:
- pymdownx.snippets
- pymdownx.arithmatex:
generic: true
- pymdownx.superfences:
custom_fences:
- name: mermaid
class: mermaid
format: !!python/name:pymdownx.superfences.fence_code_format
plugins:
- offline
- search
Expand Down