diff --git a/README.md b/README.md index d3301dfcd..a880817ac 100644 --- a/README.md +++ b/README.md @@ -58,9 +58,7 @@ To build the R docs, first run a script that lays out the package as needed ```{bash} cd stochtree_repo -Rscript cran-bootstrap.R 1 -cp _pkgdown.yml stochtree_cran/_pkgdown.yml -cp R_README.md stochtree_cran/README.md +Rscript cran-bootstrap.R 1 1 1 cd .. ``` diff --git a/dev/build.sh b/dev/build.sh index a9f25007a..87a77f44c 100644 --- a/dev/build.sh +++ b/dev/build.sh @@ -2,24 +2,18 @@ # Clone stochtree repo git clone --recursive git@github.com:StochasticTree/stochtree.git stochtree_repo -cd stochtree_repo -git checkout documentation-overhaul -cd .. -# Setup python virtual environment and the stochtree python package +# Set up python virtual environment and dependencies python -m venv venv source venv/bin/activate -cd stochtree_repo pip install --upgrade pip -pip install numpy scipy pytest pandas scikit-learn pybind11 +pip install -r requirements.txt + +# Install python package +cd stochtree_repo pip install . cd .. -# Install python dependencies for the doc site -pip install mkdocs-material -pip install mkdocstrings-python -pip install mkdocs-jupyter - # Build the C++ doxygen output sed -i '' 's|^OUTPUT_DIRECTORY *=.*|OUTPUT_DIRECTORY = ../docs/cpp_docs/|' stochtree_repo/Doxyfile sed -i '' 's|^GENERATE_XML *=.*|GENERATE_XML = NO|' stochtree_repo/Doxyfile @@ -34,12 +28,11 @@ Rscript -e 'install.packages(c("remotes", "devtools", "roxygen2", "ggplot2", "la # Build the R package doc site cd stochtree_repo -Rscript cran-bootstrap.R -cp _pkgdown.yml stochtree_cran/_pkgdown.yml -cp R_README.md stochtree_cran/README.md +Rscript cran-bootstrap.R 1 1 1 cd .. -mkdir -p docs/R_docs -Rscript -e 'pkgdown::build_site_github_pages("stochtree_repo/stochtree_cran", dest_dir = "../../docs/R_docs", install = TRUE)' +mkdir -p docs/R_docs/pkgdown +Rscript -e 'pkgdown::build_site_github_pages("stochtree_repo/stochtree_cran", dest_dir = "../../docs/R_docs/pkgdown", install = TRUE)' +rm -rf stochtree_repo/stochtree_cran # Copy Jupyter notebook demos over to docs directory cp stochtree_repo/demo/notebooks/supervised_learning.ipynb docs/python_docs/demo/supervised_learning.ipynb diff --git a/docs/about.md b/docs/about.md new file mode 100644 index 000000000..6672deb7a --- /dev/null +++ b/docs/about.md @@ -0,0 +1,103 @@ +# Overview of Stochastic Tree Models + +Stochastic tree models are a powerful addition to your modeling toolkit. +As with many machine learning methods, understanding these models in depth is an involved task. + +There are many excellent published papers on stochastic tree models +(to name a few, the [original BART paper](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-4/issue-1/BART-Bayesian-additive-regression-trees/10.1214/09-AOAS285.full), +[the XBART paper](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012), +and [the BCF paper](https://projecteuclid.org/journals/bayesian-analysis/volume-15/issue-3/Bayesian-Regression-Tree-Models-for-Causal-Inference--Regularization-Confounding/10.1214/19-BA1195.full)). +Here, we aim to build up an abbreviated intuition for these models from their conceptually-simple building blocks. + +## Notation + +We're going to introduce some notation to make these concepts precise. +In a traditional supervised learning setting, we hope to predict some **outcome** from **features** in a training dataset. +We'll call the outcome $y$ and the features $X$. +Our goal is to come up with a function $f$ that predicts the outcome $y$ as well as possible from $X$ alone. + +## Decision Trees + +[Decision tree learning](https://en.wikipedia.org/wiki/Decision_tree_learning) is a simple machine learning method that +constructs a function $f$ from a series of conditional statements. Consider the tree below. + +```mermaid +stateDiagram-v2 + state split_one <> + state split_two <> + split_one --> split_two: if x1 <= 1 + split_one --> c : if x1 > 1 + split_two --> a: if x2 <= -2 + split_two --> b : if x2 > -2 +``` + +We evaluate two conditional statments (`X[,1] > 1` and `X[,2] > -2`), arranged in a tree-like sequence of branches, +which determine whether the model predicts `a`, `b`, or `c`. We could similarly express this tree in math notation as + +\begin{equation*} +f(X_i) = \begin{cases} +a & ; \;\;\; X_{i,1} \leq 1, \;\; X_{i,2} \leq -2\\ +b & ; \;\;\; X_{i,1} \leq 1, \;\; X_{i,2} > -2\\ +c & ; \;\;\; X_{i,1} > 1 +\end{cases} +\end{equation*} + +We won't belabor the discussion of trees as there are many good textbooks and online articles on the topic, +but we'll close by noting that training decision trees introduces a delicate balance between +[overfitting and underfitting](https://en.wikipedia.org/wiki/Overfitting). +Simple trees like the one above do not capture much complexity in a dataset and may potentially be underfit +while deep, complex trees are vulnerable to overfitting and tend to have high variance. + +## Boosted Decision Tree Ensembles + +One way to address the overfitting-underfitting tradeoff of decision trees is to build an "ensemble" of decision +trees, so that the function $f$ is defined by a sum of $k$ individual decision trees $g_i$ + +\begin{equation*} +f(X_i) = g_1(X_i) + \dots + g_k(X_i) +\end{equation*} + +There are several ways to train an ensemble of decision trees (sometimes called "forests"), the most popular of which are [random forests](https://en.wikipedia.org/wiki/Random_forest) and +[gradient boosting](https://en.wikipedia.org/wiki/Gradient_boosting). Their main difference is that random forests train +all $m$ trees independently of one another, while boosting trains tree sequentially, so that tree $j$ depends on the result of training trees 1 through $j-1$. +Libraries like [xgboost](https://xgboost.readthedocs.io/en/stable/) and [LightGBM](https://lightgbm.readthedocs.io/en/latest/) are popular examples of boosted tree ensembles. + +Tree ensembles often [outperform neural networks and other machine learning methods on tabular datasets](https://arxiv.org/abs/2207.08815), +but classic tree ensemble methods return a single estimated function $f$, without expressing uncertainty around its estimates. + +## Stochastic Tree Ensembles + +[Stochastic](https://en.wikipedia.org/wiki/Stochastic) tree ensembles differ from their classical counterparts in their use of randomness in learning a function. +Rather than returning a single "best" tree ensemble, stochastic tree ensembles return a range of tree ensembles that fit the data well. +Mechanically, it's useful to think of "sampling" -- rather than "fitting" -- a stochastic tree ensemble model. + +Why is this useful? Suppose we've sampled $m$ forests. For each observation $i$, we obtain $m$ predictions: $[f_1(X_i), \dots, f_m(X_i)]$. +From this "dataset" of predictions, we can compute summary statistics, where a mean or median would give something akin to the predictions of an xgboost or lightgbm model, +and the $\alpha$ and $1-\alpha$ quantiles give a [credible interval](https://en.wikipedia.org/wiki/Credible_interval). + +Rather than explain each of the models that `stochtree` supports in depth here, we provide a high-level overview, with pointers to the relevant literature. + +### Supervised Learning + +The [`bart`](R_docs/pkgdown/reference/bart.html) R function and the [`BARTModel`](python_docs/api/bart.md) python class are the primary interface for supervised +prediction tasks in `stochtree`. The primary references for these models are +[BART (Chipman, George, McCulloch 2010)](https://projecteuclid.org/journals/annals-of-applied-statistics/volume-4/issue-1/BART-Bayesian-additive-regression-trees/10.1214/09-AOAS285.full) and +[XBART (He and Hahn 2021)](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012). + +In addition to the standard BART / XBART models, in which each tree's leaves return a constant prediction, `stochtree` also supports +arbitrary leaf regression on a user-provided basis (i.e. an expanded version of [Chipman et al 2002](https://link.springer.com/article/10.1023/A:1013916107446) and [Gramacy and Lee 2012](https://www.tandfonline.com/doi/abs/10.1198/016214508000000689)). + +### Causal Inference + +The [`bcf`](R_docs/pkgdown/reference/bcf.html) R function and the [`BCFModel`](python_docs/api/bcf.md) python class are the primary interface for causal effect +estimation in `stochtree`. The primary references for these models are +[BCF (Hahn, Murray, Carvalho 2021)](https://projecteuclid.org/journals/bayesian-analysis/volume-15/issue-3/Bayesian-Regression-Tree-Models-for-Causal-Inference--Regularization-Confounding/10.1214/19-BA1195.full) and +[XBCF (Krantsevich, He, Hahn 2022)](https://arxiv.org/abs/2209.06998). + +### Additional Modeling Features + +Both the BART and BCF interfaces in `stochtree` support the following extensions: + +* Accelerated / "warm-start" sampling of forests (i.e. [He and Hahn 2021](https://www.tandfonline.com/doi/full/10.1080/01621459.2021.1942012)) +* Forest-based heteroskedasticity (i.e. [Murray 2021](https://www.tandfonline.com/doi/abs/10.1080/01621459.2020.1813587)) +* Additive random effects (i.e. [Gelman et al 2008](https://www.tandfonline.com/doi/abs/10.1198/106186008X287337)) diff --git a/docs/cpp_docs/index.md b/docs/cpp_docs/index.md index 4dcf64279..46794927e 100644 --- a/docs/cpp_docs/index.md +++ b/docs/cpp_docs/index.md @@ -3,7 +3,6 @@ This page documents the data structures and interfaces that constitute the `stochtree` C++ core. It may be useful to researchers building novel tree algorithms or users seeking a deeper understanding of the algorithms implemented in `stochtree`. This resource is split into: -1. Technical documentation of the design / computational aspects of the C++ core - 1. [Tree API](tree.md): decision tree class which underpins the ensembles that `stochtree` samples - 2. [Tracker API](tracking.md): temporary data structures that synchronize a training dataset and the current state of a decision tree ensemble for faster sampling +1. Technical documentation of the design / computational aspects of the C++ core (in progress!) + 1. [Tracker API](tracking.md): temporary data structures that synchronize a training dataset and the current state of a decision tree ensemble for faster sampling 2. [Doxygen site](doxygen/index.html) with auto-generated documentation of C++ classes and functions diff --git a/docs/cpp_docs/tracking.md b/docs/cpp_docs/tracking.md index ff54dd62d..8d0611abf 100644 --- a/docs/cpp_docs/tracking.md +++ b/docs/cpp_docs/tracking.md @@ -1,4 +1,4 @@ -# Forest Sampling Tracker API +# Forest Sampling Tracker A truly minimalist tree ensemble library only needs @@ -21,8 +21,3 @@ These operations both perform unnecessary computation which can be avoided with 1. A mapping from dataset row index to leaf node id for every tree in an ensemble (so that we can skip the tree traversal during prediction) 2. A mapping from leaf node id to dataset row indices every tree in an ensemble (so that we can skip the full pass through the training data at split evaluation) - -## Forest Tracker - -The `ForestTracker` class is a wrapper around several implementations of the mappings discussed above. - diff --git a/docs/cpp_docs/tree.md b/docs/cpp_docs/tree.md deleted file mode 100644 index 116c6422a..000000000 --- a/docs/cpp_docs/tree.md +++ /dev/null @@ -1,9 +0,0 @@ -# Decision Tree API - -## Tree - -The fundamental building block of the C++ tree interface is the `Tree` class. - -## Tree Split - -Numeric and categorical splits are represented by a `TreeSplit` class. diff --git a/docs/development/contributing.md b/docs/development/contributing.md new file mode 100644 index 000000000..dff02d051 --- /dev/null +++ b/docs/development/contributing.md @@ -0,0 +1,5 @@ +# Contributing + +`stochtree` is hosted on [Github](https://github.com/StochasticTree/stochtree/). +Any feedback, requests, or bug reports can be submitted as [issues](https://github.com/StochasticTree/stochtree/issues). +Moreover, if you have ideas for how to improve stochtree, we welcome [pull requests](https://github.com/StochasticTree/stochtree/pulls) :) diff --git a/docs/development/index.md b/docs/development/index.md index 4243219c8..617379d84 100644 --- a/docs/development/index.md +++ b/docs/development/index.md @@ -2,4 +2,5 @@ `stochtree` is in active development. Here, we detail some aspects of the development process +* [Contributing](contributing.md): how to get involved with stochtree, by contributing code, documentation, or helpful feedback * [Roadmap](roadmap.md): timelines for new feature development and releases diff --git a/docs/getting-started.md b/docs/getting-started.md index 7a06d5cf3..679a65fc7 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -1,7 +1,21 @@ # Getting Started `stochtree` is composed of a C++ "core" and R / Python interfaces to that core. -Details on installation and use are available below: +Below, we detail how to install the R / Python packages, or work directly with the C++ codebase. + +## R Package + +The R package can be installed from CRAN via + +``` +install.packages("stochtree") +``` + +The development version of `stochtree` can be installed from Github via + +``` +remotes::install_github("StochasticTree/stochtree", ref="r-dev") +``` ## Python Package @@ -39,7 +53,7 @@ Then install the package from github via pip pip install git+https://github.com/StochasticTree/stochtree.git ``` -(*Note*: if you'd also like to run `stochtree`'s notebook examples, you will also need jupyterlab, seaborn, and matplotlib) +(*Note*: if you'd like to run `stochtree`'s notebook examples, you will also need `jupyterlab`, `seaborn`, and `matplotlib`) ```{bash} conda install matplotlib seaborn @@ -70,20 +84,12 @@ Then install stochtree via pip install git+https://github.com/StochasticTree/stochtree.git ``` -As above, if you'd like to run the notebook examples in the `demo/` subfolder, you will also need jupyterlab, seaborn, and matplotlib and you will have to [clone the repo](#cloning-the-repository) +As above, if you'd like to run the notebook examples in the `demo/` subfolder, you will also need `jupyterlab`, `seaborn`, and `matplotlib` and you will have to [clone the repo](#cloning-the-repository) ```{bash} pip install matplotlib seaborn jupyterlab ``` -## R Package - -The package can be installed in R via - -``` -remotes::install_github("StochasticTree/stochtree", ref="r-dev") -``` - ## C++ Core While the C++ core links to both R and Python for a performant, high-level interface, diff --git a/docs/index.md b/docs/index.md index 55a3da1f3..c4c360038 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,3 +1,19 @@ # StochTree -`stochtree` (stochastic tree) is software for building stochastic tree ensembles (i.e. BART, XBART) for supervised learning and causal inference. +`stochtree` (short for "stochastic tree") is a software suite for flexible decision tree modeling. +It has two primary interfaces: + +1. "High-level": robust implementations of many popular stochastic tree algorithms (BART, XBART, BCF, XBCF), with support for serialization and parallelism. +2. "Low-level": access to the "inner loop" of a stochastic forest sampler, allowing custom tree algorithm development in <50 lines of code. + +The "core" of the software is written in C++, but it provides R and Python APIs. +The R package is [available on CRAN](https://cran.r-project.org/web/packages/stochtree/index.html) and the python package will soon be on PyPI. + +## Table of Contents + +* [Getting Started](getting-started.md): Details on how to install and use `stochtree` +* [About](about.md): Overview of the models supported by stochtree and pointers to further reading +* [R Package](R_docs/index.md): Complete documentation of the R package +* [Python Package](python_docs/index.md): Complete documentation of the python package +* [C++ Core API and Architecture](cpp_docs/index.md): Overview and documentation of the C++ codebase that supports stochtree +* [Development](development/index.md): Roadmap and how to contribute diff --git a/mkdocs.yml b/mkdocs.yml index a191cafe6..dff193b79 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -40,6 +40,7 @@ theme: nav: - Home: index.md - 'Getting Started': getting-started.md + - 'About StochTree': about.md - 'R Package': - 'Overview': R_docs/index.md - 'Pkgdown Site': 'R_docs/pkgdown/index.html' @@ -61,11 +62,11 @@ nav: - 'BCF': python_docs/demo/causal_inference.ipynb - 'C++ Core API and Architecture': - cpp_docs/index.md - - 'Tree Data Structure': cpp_docs/tree.md - - 'Tracking Data Structure': cpp_docs/tracking.md + - 'Tracking Data Structures': cpp_docs/tracking.md - 'C++ Doxygen Site': 'cpp_docs/doxygen/index.html' - 'Development': - 'Overview': development/index.md + - 'Contributing': development/contributing.md - 'Roadmap': development/roadmap.md extra: social: @@ -89,6 +90,11 @@ markdown_extensions: - pymdownx.snippets - pymdownx.arithmatex: generic: true + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format plugins: - offline - search