Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f641698
Updated partition_tracker to track auxiliary data for CLogLog Ordinal…
Entejar Sep 17, 2025
e99791b
Added leaf model for CLogLog Ordinal BART
Entejar Sep 17, 2025
8f77e15
Added ordinal_sampler
Entejar Sep 17, 2025
8547425
Updated tree_sampler.h
Entejar Sep 17, 2025
6c1d3ce
Updated sampler.cpp
Entejar Sep 17, 2025
955a211
Merge branch 'StochasticTree:main' into main
Entejar Sep 28, 2025
084be88
Added cloglog_ordinal_bart.R function
Entejar Sep 28, 2025
c8492fb
Tested CLogLog Ordinal BART — running successfully!
Entejar Sep 28, 2025
444c067
Added vignette for CLogLog Ordinal Bart
Entejar Sep 29, 2025
132071e
Update leaf_model.h
Entejar Sep 30, 2025
74cff51
Merge branch 'StochasticTree:main' into main
Entejar Oct 8, 2025
7bc3eb3
Merge branch 'main' into pr/196
andrewherren Oct 24, 2025
18c9e15
Migrated auxiliary data to ForestDataset from ForestTracker
andrewherren Oct 27, 2025
36b6a98
Removed call to deprectated cpp function
andrewherren Oct 27, 2025
2de707c
Fixed indexing bug
andrewherren Oct 27, 2025
2d19399
Refactored and fixed bugs
andrewherren Oct 27, 2025
f9a0b5a
Updated multinomial cloglog vignette
andrewherren Oct 27, 2025
66164f5
Added binary outcome cloglog model demo
andrewherren Oct 27, 2025
d5de763
Reworking sampler implementation to match current stochtree::main API
andrewherren Oct 27, 2025
0302459
Reflecting num_threads further through the interface
andrewherren Oct 27, 2025
a7c79d4
Refactoring out unused slice sampler for leaf scale parameter
andrewherren Oct 27, 2025
6ffdef7
Adding num_threads (back) to GFR interface
andrewherren Oct 27, 2025
853b129
Continue building in multithreading support to cloglog branch
andrewherren Oct 27, 2025
9edad36
Update tree_sampler.h
andrewherren Oct 27, 2025
04de102
Updating GFR to reflect multithreading capabilities in the main branch
andrewherren Oct 27, 2025
cdca915
Reflecting num_threads through the MCMC and GFR interface
andrewherren Oct 27, 2025
bf21447
Set up cloglog to work with GFR and updated examples
andrewherren Oct 27, 2025
a5cee2b
Updating vignettes
andrewherren Oct 28, 2025
815c538
WIP fix for data augmentation in the binary case
andrewherren Oct 28, 2025
2106f32
Updating sampler
andrewherren Oct 28, 2025
02dc2ac
Remove unused slice sampler code
andrewherren Oct 28, 2025
8f92425
Cleaned up PR
andrewherren Oct 28, 2025
95a0ce9
Including variant in leaf model header file
andrewherren Oct 28, 2025
42f9ac4
Updated vignettes and function defaults
andrewherren Oct 28, 2025
4f576a6
Added a release candidate readme
andrewherren Oct 28, 2025
dc25a1d
Updated demo scripts
andrewherren Oct 28, 2025
5089faf
Prototype of outcome model interface
andrewherren Jan 15, 2026
3d8533a
Merge branch 'main' into pr-temp-rc
andrewherren Jan 27, 2026
db538f1
Formatted R code
andrewherren Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ S3method(summary,bcfmodel)
export(bart)
export(bcf)
export(calibrateInverseGammaErrorVariance)
export(cloglog_ordinal_bart)
export(computeForestLeafIndices)
export(computeForestLeafVariances)
export(computeForestMaxLeafIndex)
Expand Down
2 changes: 2 additions & 0 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. Note that if `num_chains > 1`, the returned model object will contain samples from all chains, stored consecutively. That is, if there are 4 chains with 100 samples each, the first 100 samples will be from chain 1, the next 100 samples will be from chain 2, etc... For more detail on working with multi-chain BART models, see [the multi chain vignette](https://stochtree.ai/R_docs/pkgdown/articles/MultiChain.html).
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
#' - `outcome_model` An object of class `outcome_model` specifying the outcome model to be used. Default: `outcome_model(outcome = "continuous", link = "identity")`.
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
#' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
#'
Expand Down Expand Up @@ -151,6 +152,7 @@ bart <- function(
keep_every = 1,
num_chains = 1,
verbose = FALSE,
outcome_model = outcome_model(outcome = "continuous", link = "identity"),
probit_outcome_model = FALSE,
num_threads = -1
)
Expand Down
285 changes: 285 additions & 0 deletions R/cloglog_ordinal_bart.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
#' Run the BART algorithm for ordinal outcomes using a complementary log-log link
#'
#' @param X A numeric matrix of predictors (training data).
#' @param y A numeric vector of ordinal outcomes (positive integers starting from 1).
#' @param X_test An optional numeric matrix of predictors (test data).
#' @param n_trees Number of trees in the BART ensemble. Default: `50`.
#' @param num_gfr Number of GFR samples to draw at the beginning of the sampler. Default: `0`.
#' @param num_burnin Number of burn-in MCMC samples to discard. Default: `1000`.
#' @param num_mcmc Total number of MCMC samples to draw. Default: `500`.
#' @param n_thin Thinning interval for MCMC samples. Default: `1`.
#' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`.
#' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`.
#' @param variable_weights (Optional) vector of variable weights for splitting (default: equal weights).
#' @param feature_types (Optional) vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).
#' @param seed (Optional) random seed for reproducibility.
#' @param num_threads (Optional) Number of threads to use in split evaluations and other compute-intensive operations. Default: 1.
#' @export
cloglog_ordinal_bart <- function(
X,
y,
X_test = NULL,
n_trees = 50,
num_gfr = 0,
num_burnin = 1000,
num_mcmc = 500,
n_thin = 1,
alpha_gamma = 2.0,
beta_gamma = 2.0,
variable_weights = NULL,
feature_types = NULL,
seed = NULL,
num_threads = 1
) {
# BART parameters
alpha_bart <- 0.95
beta_bart <- 2
min_samples_in_leaf <- 5
max_depth <- 10
scale_leaf <- 2 / sqrt(n_trees)
cutpoint_grid_size <- 100 # Needed for stochtree::sample_gfr_one_iteration_cpp, not used in MCMC BART

# Fixed for identifiability (can be pass as argument later if desired)
gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0

# Determine whether a test dataset is provided
has_test <- !is.null(X_test)

# Data checks
if (!is.matrix(X)) {
X <- as.matrix(X)
}
if (!is.numeric(y)) {
y <- as.numeric(y)
}
if (has_test && !is.matrix(X_test)) {
X_test <- as.matrix(X_test)
}

n_samples <- nrow(X)
n_features <- ncol(X)

if (any(y < 1) || any(y != round(y))) {
stop("Ordinal outcome y must contain positive integers starting from 1")
}

# Convert from 1-based (R) to 0-based (C++) indexing
ordinal_outcome <- as.integer(y - 1)
n_levels <- max(y) # Number of ordinal categories

if (n_levels < 2) {
stop("Ordinal outcome must have at least 2 categories")
}

if (is.null(variable_weights)) {
variable_weights <- rep(1.0, n_features)
}

if (is.null(feature_types)) {
feature_types <- rep(0L, n_features)
}

if (!is.null(seed)) {
set.seed(seed)
}

# Indices of MCMC samples to keep after GFR, burn-in, and thinning
keep_idx <- seq(
num_gfr + num_burnin + 1,
num_gfr + num_burnin + num_mcmc,
by = n_thin
)
n_keep <- length(keep_idx)

# Storage for MCMC samples
forest_pred_train <- matrix(0, n_samples, n_keep)
if (has_test) {
n_samples_test <- nrow(X_test)
forest_pred_test <- matrix(0, n_samples_test, n_keep)
}
gamma_samples <- matrix(0, n_levels - 1, n_keep)
latent_samples <- matrix(0, n_samples, n_keep)

# Initialize samplers
ordinal_sampler <- stochtree:::ordinal_sampler_cpp()
rng <- stochtree::createCppRNG(
if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed
)

# Initialize other model structures as before
dataX <- stochtree::createForestDataset(X)
if (has_test) {
dataXtest <- stochtree::createForestDataset(X_test)
}
outcome_data <- stochtree::createOutcome(as.numeric(ordinal_outcome))
active_forest <- stochtree::createForest(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves
active_forest$set_root_leaves(0.0)
split_prior <- stochtree:::tree_prior_cpp(
alpha_bart,
beta_bart,
min_samples_in_leaf,
max_depth
)
forest_samples <- stochtree::createForestSamples(
as.integer(n_trees),
1L,
TRUE,
FALSE
) # Use constant leaves
forest_tracker <- stochtree:::forest_tracker_cpp(
dataX$data_ptr,
as.integer(feature_types),
as.integer(n_trees),
as.integer(n_samples)
)

# Latent variable (Z in Alam et al (2025) notation)
dataX$add_auxiliary_dimension(nrow(X))
# Forest predictions (eta in Alam et al (2025) notation)
dataX$add_auxiliary_dimension(nrow(X))
# Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation)
dataX$add_auxiliary_dimension(n_levels - 1)
# Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation)
# This auxiliary series is designed so that the element stored at position `i`
# corresponds to the sum of all exponentiated gamma_j values for j < i.
# It has n_levels elements instead of n_levels - 1 because even the largest
# categorical index has a valid value of sum_{j < i} exp(gamma_j)
dataX$add_auxiliary_dimension(n_levels)

# Initialize gamma parameters to zero (3rd auxiliary data series, mapped to `dim_idx = 2` with 0-indexing)
initial_gamma <- rep(0.0, n_levels - 1)
for (i in seq_along(initial_gamma)) {
dataX$set_auxiliary_data_value(2, i - 1, initial_gamma[i])
}

# Convert the log-scale parameters into cumulative exponentiated parameters.
# This is done under the hood in a C++ function for efficiency.
stochtree:::ordinal_sampler_update_cumsum_exp_cpp(
ordinal_sampler,
dataX$data_ptr
)

# Initialize forest predictions to zero (slot 1)
for (i in 1:n_samples) {
dataX$set_auxiliary_data_value(1, i - 1, 0.0)
}

# Initialize latent variables to zero (slot 0)
for (i in 1:n_samples) {
dataX$set_auxiliary_data_value(0, i - 1, 0.0)
}

# Set up sweep indices for tree updates (sample all trees each iteration)
sweep_indices <- as.integer(seq(0, n_trees - 1))

sample_counter <- 0
for (i in 1:(num_mcmc + num_burnin + num_gfr)) {
keep_sample <- i %in% keep_idx
if (keep_sample) {
sample_counter <- sample_counter + 1
}

# 1. Sample forest using MCMC
if (i > num_gfr) {
stochtree:::sample_mcmc_one_iteration_cpp(
dataX$data_ptr,
outcome_data$data_ptr,
forest_samples$forest_container_ptr,
active_forest$forest_ptr,
forest_tracker,
split_prior,
rng$rng_ptr,
sweep_indices,
as.integer(feature_types),
as.integer(cutpoint_grid_size),
scale_leaf,
variable_weights,
alpha_gamma,
beta_gamma,
1.0,
4L,
keep_sample,
num_threads
)
} else {
stochtree:::sample_gfr_one_iteration_cpp(
dataX$data_ptr,
outcome_data$data_ptr,
forest_samples$forest_container_ptr,
active_forest$forest_ptr,
forest_tracker,
split_prior,
rng$rng_ptr,
sweep_indices,
as.integer(feature_types),
as.integer(cutpoint_grid_size),
scale_leaf,
variable_weights,
alpha_gamma,
beta_gamma,
1.0,
4L,
keep_sample,
ncol(X),
num_threads
)
}

# Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions
# This is needed for updating gamma parameters, latent z_i's
forest_pred_current <- active_forest$predict(dataX)
for (i in 1:n_samples) {
dataX$set_auxiliary_data_value(1, i - 1, forest_pred_current[i])
}

# 2. Sample latent z_i's using truncated exponential
stochtree:::ordinal_sampler_update_latent_variables_cpp(
ordinal_sampler,
dataX$data_ptr,
outcome_data$data_ptr,
rng$rng_ptr
)

# 3. Sample gamma parameters
stochtree:::ordinal_sampler_update_gamma_params_cpp(
ordinal_sampler,
dataX$data_ptr,
outcome_data$data_ptr,
alpha_gamma,
beta_gamma,
gamma_0,
rng$rng_ptr
)

# 4. Update cumulative sum of exp(gamma) values
stochtree:::ordinal_sampler_update_cumsum_exp_cpp(
ordinal_sampler,
dataX$data_ptr
)

if (keep_sample) {
forest_pred_train[, sample_counter] <- active_forest$predict(dataX)
if (has_test) {
forest_pred_test[, sample_counter] <- active_forest$predict(dataXtest)
}
gamma_current <- dataX$get_auxiliary_data_vector(2)
gamma_samples[, sample_counter] <- gamma_current
latent_current <- dataX$get_auxiliary_data_vector(0)
latent_samples[, sample_counter] <- latent_current
}
}

result <- list(
forest_predictions_train = forest_pred_train,
forest_predictions_test = if (has_test) forest_pred_test else NULL,
gamma_samples = gamma_samples,
latent_samples = latent_samples,
scale_leaf = scale_leaf,
ordinal_outcome = ordinal_outcome,
n_trees = n_trees,
n_keep = n_keep
)

class(result) <- "cloglog_ordinal_bart"
return(result)
}
40 changes: 40 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,30 @@ forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) {
.Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr)
}

forest_dataset_has_auxiliary_dimension_cpp <- function(dataset_ptr, dim_idx) {
.Call(`_stochtree_forest_dataset_has_auxiliary_dimension_cpp`, dataset_ptr, dim_idx)
}

forest_dataset_add_auxiliary_dimension_cpp <- function(dataset_ptr, dim_size) {
invisible(.Call(`_stochtree_forest_dataset_add_auxiliary_dimension_cpp`, dataset_ptr, dim_size))
}

forest_dataset_get_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx) {
.Call(`_stochtree_forest_dataset_get_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx)
}

forest_dataset_set_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx, value) {
invisible(.Call(`_stochtree_forest_dataset_set_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx, value))
}

forest_dataset_get_auxiliary_data_vector_cpp <- function(dataset_ptr, dim_idx) {
.Call(`_stochtree_forest_dataset_get_auxiliary_data_vector_cpp`, dataset_ptr, dim_idx)
}

forest_dataset_store_auxiliary_data_vector_as_column_cpp <- function(dataset_ptr, output_matrix, dim_idx, matrix_col_idx) {
.Call(`_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp`, dataset_ptr, output_matrix, dim_idx, matrix_col_idx)
}

create_column_vector_cpp <- function(outcome) {
.Call(`_stochtree_create_column_vector_cpp`, outcome)
}
Expand Down Expand Up @@ -692,6 +716,22 @@ sample_without_replacement_integer_cpp <- function(population_vector, sampling_p
.Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size)
}

ordinal_sampler_cpp <- function() {
.Call(`_stochtree_ordinal_sampler_cpp`)
}

ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, rng_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, rng_ptr))
}

ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr))
}

ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, data_ptr) {
invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, data_ptr))
}

init_json_cpp <- function() {
.Call(`_stochtree_init_json_cpp`)
}
Expand Down
Loading