diff --git a/DESCRIPTION b/DESCRIPTION index 6dc68ff2..3aa06675 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -13,7 +13,7 @@ Description: Stochastic tree ensembles (XBART and BART) for supervised learning License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 LinkingTo: cpp11, BH Suggests: diff --git a/NAMESPACE b/NAMESPACE index 365e9dfd..aa71c7fc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -51,6 +51,7 @@ export(orderedCatInitializeAndPreprocess) export(orderedCatPreprocess) export(preprocessBartParams) export(preprocessBcfParams) +export(preprocessParams) export(preprocessPredictionData) export(preprocessPredictionDataFrame) export(preprocessPredictionMatrix) diff --git a/R/bart.R b/R/bart.R index 9098b8df..2f7e8333 100644 --- a/R/bart.R +++ b/R/bart.R @@ -29,63 +29,45 @@ #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. #' @param previous_model_json (Optional) JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`. #' @param warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting `warmstart_sample_num = 1`). Default: `NULL`. -#' @param params The list of model parameters, each of which has a default value. +#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' **1. Global Parameters** -#' -#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider. Default: `100`. -#' - `sigma2_init` Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set. -#' - `pct_var_sigma2_init` Percentage of standardized outcome variance used to initialize global error variance parameter. Default: `1`. Superseded by `sigma2_init`. -#' - `variance_scale` Variance after the data have been scaled. Default: `1`. -#' - `a_global` Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: `0`. -#' - `b_global` Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: `0`. -#' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. -#' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`. -#' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. -#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. +#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`. #' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`. +#' - `sample_sigma2_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(sigma2_global_shape, sigma2_global_scale)`. Default: `TRUE`. +#' - `sigma2_global_init` Starting value of global error variance parameter. Calibrated internally as `1.0*var(y_train)`, where `y_train` is the possibly standardized outcome, if not set. +#' - `sigma2_global_shape` Shape parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`. +#' - `sigma2_global_scale` Scale parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`. +#' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. +#' - `keep_burnin` Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`. +#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`. #' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' -#' **2. Mean Forest Parameters** -#' -#' - `num_trees_mean` Number of trees in the ensemble for the conditional mean model. Default: `200`. If `num_trees_mean = 0`, the conditional mean will not be modeled using a forest, and the function will only proceed if `num_trees_variance > 0`. -#' - `sample_sigma_leaf` Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: `FALSE`. -#' -#' **2.1. Tree Prior Parameters** -#' -#' - `alpha_mean` Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: `0.95`. -#' - `beta_mean` Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: `2`. -#' - `min_samples_leaf_mean` Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: `5`. -#' - `max_depth_mean` Maximum depth of any tree in the ensemble in the mean model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. +#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' **2.2. Leaf Model Parameters** -#' -#' - `variable_weights_mean` Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. -#' - `sigma_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees_mean` if not set here. -#' - `a_leaf` Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: `3`. -#' - `b_leaf` Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees_mean` if not set here. +#' - `num_trees` Number of trees in the ensemble for the conditional mean model. Default: `200`. If `num_trees = 0`, the conditional mean will not be modeled using a forest, and the function will only proceed if `num_trees > 0` for the variance forest. +#' - `alpha` Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`. +#' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`. +#' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: `5`. +#' - `max_depth` Maximum depth of any tree in the ensemble in the mean model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. +#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: `FALSE`. +#' - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`. +#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. #' -#' **3. Conditional Variance Forest Parameters** -#' -#' - `num_trees_variance` Number of trees in the ensemble for the conditional variance model. Default: `0`. Variance is only modeled using a tree / forest if `num_trees_variance > 0`. -#' - `variance_forest_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(pct_var_variance_forest_init*var((y-mean(y))/sd(y)))/num_trees_variance` if not set. -#' - `pct_var_variance_forest_init` Percentage of standardized outcome variance used to initialize global error variance parameter. Default: `1`. Superseded by `variance_forest_init`. -#' -#' **3.1. Tree Prior Parameters** -#' -#' - `alpha_variance` Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance`. Default: `0.95`. -#' - `beta_variance` Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance`. Default: `2`. -#' - `min_samples_leaf_variance` Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: `5`. -#' - `max_depth_variance` Maximum depth of any tree in the ensemble in the variance model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. +#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' **3.2. Leaf Model Parameters** -#' -#' - `variable_weights_variance` Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. -#' - `sigma_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees_mean` if not set here. -#' - `a_forest` Shape parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2 + 0.5` if not set. -#' - `b_forest` Scale parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2` if not set. +#' - `num_trees` Number of trees in the ensemble for the conditional variance model. Default: `0`. Variance is only modeled using a tree / forest if `num_trees > 0`. +#' - `alpha` Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`. +#' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`. +#' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: `5`. +#' - `max_depth` Maximum depth of any tree in the ensemble in the variance model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. +#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' - `var_forest_leaf_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(0.6*var(y_train))/num_trees`, where `y_train` is the possibly standardized outcome, if not set. +#' - `var_forest_prior_shape` Shape parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2 + 0.5` if not set. +#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set. #' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export @@ -119,43 +101,80 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, group_ids_test = NULL, rfx_basis_test = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, previous_model_json = NULL, warmstart_sample_num = NULL, - params = list()) { - # Extract BART parameters - bart_params <- preprocessBartParams(params) - cutpoint_grid_size <- bart_params$cutpoint_grid_size - sigma_leaf_init <- bart_params$sigma_leaf_init - alpha_mean <- bart_params$alpha_mean - beta_mean <- bart_params$beta_mean - min_samples_leaf_mean <- bart_params$min_samples_leaf_mean - max_depth_mean <- bart_params$max_depth_mean - alpha_variance <- bart_params$alpha_variance - beta_variance <- bart_params$beta_variance - min_samples_leaf_variance <- bart_params$min_samples_leaf_variance - max_depth_variance <- bart_params$max_depth_variance - a_global <- bart_params$a_global - b_global <- bart_params$b_global - a_leaf <- bart_params$a_leaf - b_leaf <- bart_params$b_leaf - a_forest <- bart_params$a_forest - b_forest <- bart_params$b_forest - variance_scale <- bart_params$variance_scale - sigma2_init <- bart_params$sigma2_init - variance_forest_init <- bart_params$variance_forest_init - pct_var_sigma2_init <- bart_params$pct_var_sigma2_init - pct_var_variance_forest_init <- bart_params$pct_var_variance_forest_init - variable_weights_mean <- bart_params$variable_weights_mean - variable_weights_variance <- bart_params$variable_weights_variance - num_trees_mean <- bart_params$num_trees_mean - num_trees_variance <- bart_params$num_trees_variance - sample_sigma_global <- bart_params$sample_sigma_global - sample_sigma_leaf <- bart_params$sample_sigma_leaf - random_seed <- bart_params$random_seed - keep_burnin <- bart_params$keep_burnin - keep_gfr <- bart_params$keep_gfr - standardize <- bart_params$standardize - keep_every <- bart_params$keep_every - num_chains <- bart_params$num_chains - verbose <- bart_params$verbose + general_params = list(), mean_forest_params = list(), + variance_forest_params = list()) { + # Update general BART parameters + general_params_default <- list( + cutpoint_grid_size = 100, standardize = T, + sample_sigma2_global = T, sigma2_global_init = NULL, + sigma2_global_shape = 0, sigma2_global_scale = 0, + random_seed = -1, keep_burnin = F, keep_gfr = F, + keep_every = 1, num_chains = 1, verbose = F + ) + general_params_updated <- preprocessParams( + general_params_default, general_params + ) + + # Update mean forest BART parameters + mean_forest_params_default <- list( + num_trees = 200, alpha = 0.95, beta = 2.0, + min_samples_leaf = 5, max_depth = 10, + variable_weights = NULL, + sample_sigma2_leaf = T, sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL + ) + mean_forest_params_updated <- preprocessParams( + mean_forest_params_default, mean_forest_params + ) + + # Update variance forest BART parameters + variance_forest_params_default <- list( + num_trees = 0, alpha = 0.95, beta = 2.0, + min_samples_leaf = 5, max_depth = 10, + variable_weights = NULL, var_forest_leaf_init = NULL, + var_forest_prior_shape = NULL, var_forest_prior_scale = NULL + ) + variance_forest_params_updated <- preprocessParams( + variance_forest_params_default, variance_forest_params + ) + + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size <- general_params_updated$cutpoint_grid_size + standardize <- general_params_updated$standardize + sample_sigma_global <- general_params_updated$sample_sigma2_global + sigma2_init <- general_params_updated$sigma2_global_init + a_global <- general_params_updated$sigma2_global_shape + b_global <- general_params_updated$sigma2_global_scale + random_seed <- general_params_updated$random_seed + keep_burnin <- general_params_updated$keep_burnin + keep_gfr <- general_params_updated$keep_gfr + keep_every <- general_params_updated$keep_every + num_chains <- general_params_updated$num_chains + verbose <- general_params_updated$verbose + + # 2. Mean forest parameters + num_trees_mean <- mean_forest_params_updated$num_trees + alpha_mean <- mean_forest_params_updated$alpha + beta_mean <- mean_forest_params_updated$beta + min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf + max_depth_mean <- mean_forest_params_updated$max_depth + variable_weights_mean <- mean_forest_params_updated$variable_weights + sample_sigma_leaf <- mean_forest_params_updated$sample_sigma2_leaf + sigma_leaf_init <- mean_forest_params_updated$sigma2_leaf_init + a_leaf <- mean_forest_params_updated$sigma2_leaf_shape + b_leaf <- mean_forest_params_updated$sigma2_leaf_scale + + # 3. Variance forest parameters + num_trees_variance <- variance_forest_params_updated$num_trees + alpha_variance <- variance_forest_params_updated$alpha + beta_variance <- variance_forest_params_updated$beta + min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf + max_depth_variance <- variance_forest_params_updated$max_depth + variable_weights_variance <- variance_forest_params_updated$variable_weights + variance_forest_init <- variance_forest_params_updated$var_forest_leaf_init + a_forest <- variance_forest_params_updated$var_forest_prior_shape + b_forest <- variance_forest_params_updated$var_forest_prior_scale # Check if there are enough GFR samples to seed num_chains samplers if (num_gfr > 0) { @@ -174,7 +193,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, previous_bart_model <- createBARTModelFromJsonString(previous_model_json) previous_y_bar <- previous_bart_model$model_params$outcome_mean previous_y_scale <- previous_bart_model$model_params$outcome_scale - previous_var_scale <- previous_bart_model$model_params$variance_scale if (previous_bart_model$model_params$include_mean_forest) { previous_forest_samples_mean <- previous_bart_model$mean_forests } else previous_forest_samples_mean <- NULL @@ -182,8 +200,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, previous_forest_samples_variance <- previous_bart_model$variance_forests } else previous_forest_samples_variance <- NULL if (previous_bart_model$model_params$sample_sigma_global) { - previous_global_var_samples <- previous_bart_model$sigma2_global_samples*( - previous_var_scale / (previous_y_scale*previous_y_scale) + previous_global_var_samples <- previous_bart_model$sigma2_global_samples / ( + previous_y_scale*previous_y_scale ) } else previous_global_var_samples <- NULL if (previous_bart_model$model_params$sample_sigma_leaf) { @@ -195,7 +213,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, } else { previous_y_bar <- NULL previous_y_scale <- NULL - previous_var_scale <- NULL previous_global_var_samples <- NULL previous_leaf_var_samples <- NULL previous_rfx_samples <- NULL @@ -372,14 +389,13 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, y_std_train <- 1 } resid_train <- (y_train-y_bar_train)/y_std_train - resid_train <- resid_train*sqrt(variance_scale) # Compute initial value of root nodes in mean forest init_val_mean <- mean(resid_train) # Calibrate priors for sigma^2 and tau - if (is.null(sigma2_init)) sigma2_init <- pct_var_sigma2_init*var(resid_train) - if (is.null(variance_forest_init)) variance_forest_init <- pct_var_variance_forest_init*var(resid_train) + if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) + if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) if (has_basis) { if (ncol(W_train) > 1) { @@ -702,8 +718,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, # Mean forest predictions if (include_mean_forest) { - y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train/sqrt(variance_scale) + y_bar_train - if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train/sqrt(variance_scale) + y_bar_train + y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train + if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train } # Variance forest predictions @@ -714,16 +730,16 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, # Random effects predictions if (has_rfx) { - rfx_preds_train <- rfx_samples$predict(group_ids_train, rfx_basis_train)*y_std_train/sqrt(variance_scale) + rfx_preds_train <- rfx_samples$predict(group_ids_train, rfx_basis_train)*y_std_train y_hat_train <- y_hat_train + rfx_preds_train } if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std_train/sqrt(variance_scale) + rfx_preds_test <- rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std_train y_hat_test <- y_hat_test + rfx_preds_test } # Global error variance - if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2)/variance_scale + if (sample_sigma_global) sigma2_samples <- global_var_samples*(y_std_train^2) # Leaf parameter variance if (sample_sigma_leaf) tau_samples <- leaf_scale_samples @@ -734,14 +750,12 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, sigma_x_hat_train <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i])) if (has_test) sigma_x_hat_test <- sapply(1:num_retained_samples, function(i) sqrt(sigma_x_hat_test[,i]*sigma2_samples[i])) } else { - sigma_x_hat_train <- sqrt(sigma_x_hat_train*sigma2_init)*y_std_train/sqrt(variance_scale) - if (has_test) sigma_x_hat_test <- sqrt(sigma_x_hat_test*sigma2_init)*y_std_train/sqrt(variance_scale) + sigma_x_hat_train <- sqrt(sigma_x_hat_train*sigma2_init)*y_std_train + if (has_test) sigma_x_hat_test <- sqrt(sigma_x_hat_test*sigma2_init)*y_std_train } } # Return results as a list - # TODO: store variance_scale and propagate through predict function - # TODO: refactor out the "num_retained_samples" variable now that we burn-in/thin correctly model_params <- list( "sigma2_init" = sigma2_init, "sigma_leaf_init" = sigma_leaf_init, @@ -773,8 +787,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, "sample_sigma_global" = sample_sigma_global, "sample_sigma_leaf" = sample_sigma_leaf, "include_mean_forest" = include_mean_forest, - "include_variance_forest" = include_variance_forest, - "variance_scale" = variance_scale + "include_variance_forest" = include_variance_forest ) result <- list( "model_params" = model_params, @@ -910,12 +923,11 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL # Compute mean forest predictions num_samples <- bart$model_params$num_samples - variance_scale <- bart$model_params$variance_scale y_std <- bart$model_params$outcome_scale y_bar <- bart$model_params$outcome_mean sigma2_init <- bart$model_params$sigma2_init if (bart$model_params$include_mean_forest) { - mean_forest_predictions <- bart$mean_forests$predict(prediction_dataset)*y_std/sqrt(variance_scale) + y_bar + mean_forest_predictions <- bart$mean_forests$predict(prediction_dataset)*y_std + y_bar } # Compute variance forest predictions @@ -925,7 +937,7 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL # Compute rfx predictions (if needed) if (bart$model_params$has_rfx) { - rfx_predictions <- bart$rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std/sqrt(variance_scale) + rfx_predictions <- bart$rfx_samples$predict(group_ids_test, rfx_basis_test)*y_std } # Scale variance forest predictions @@ -934,7 +946,7 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL sigma2_samples <- bart$sigma2_global_samples variance_forest_predictions <- sapply(1:num_samples, function(i) sqrt(s_x_raw[,i]*sigma2_samples[i])) } else { - variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std/sqrt(variance_scale) + variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std } } @@ -1090,7 +1102,6 @@ convertBARTModelToJson <- function(object){ } # Add global parameters - jsonobj$add_scalar("variance_scale", object$model_params$variance_scale) jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) @@ -1145,7 +1156,6 @@ convertBARTStateToJson <- function(param_list, mean_forest = NULL, variance_fore jsonobj <- createCppJson() # Add global parameters - jsonobj$add_scalar("variance_scale", param_list$variance_scale) jsonobj$add_scalar("outcome_scale", param_list$outcome_scale) jsonobj$add_scalar("outcome_mean", param_list$outcome_mean) jsonobj$add_boolean("standardize", param_list$standardize) @@ -1334,7 +1344,6 @@ createBARTModelFromJson <- function(json_object){ # Unpack model params model_params = list() - model_params[["variance_scale"]] <- json_object$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object$get_boolean("standardize") @@ -1680,7 +1689,6 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ # Unpack model params model_params = list() - model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object_default$get_boolean("standardize") diff --git a/R/bcf.R b/R/bcf.R index 4a645986..87f4359c 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -26,91 +26,69 @@ #' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100. #' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`. #' @param warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `warmstart_sample_num = 1`). Default: `NULL`. -#' @param params The list of model parameters, each of which has a default value. -#' -#' **1. Global Parameters** +#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider. Default: `100`. -#' - `a_global` Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: `0`. -#' - `b_global` Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: `0`. -#' - `sigma2_init` Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set. -#' - `pct_var_sigma2_init` Percentage of standardized outcome variance used to initialize global error variance parameter. Default: `1`. Superseded by `sigma2_init`. +#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`. +#' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`. +#' - `sample_sigma2_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(sigma2_global_shape, sigma2_global_scale)`. Default: `TRUE`. +#' - `sigma2_global_init` Starting value of global error variance parameter. Calibrated internally as `1.0*var((y_train-mean(y_train))/sd(y_train))` if not set. +#' - `sigma2_global_shape` Shape parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`. +#' - `sigma2_global_scale` Scale parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`. #' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in `X_train` and then set `propensity_covariate` to `'none'` adjust `keep_vars_mu`, `keep_vars_tau` and `keep_vars_variance` accordingly. #' - `propensity_covariate` Whether to include the propensity score as a covariate in either or both of the forests. Enter `"none"` for neither, `"mu"` for the prognostic forest, `"tau"` for the treatment forest, and `"both"` for both forests. If this is not `"none"` and a propensity score is not provided, it will be estimated from (`X_train`, `Z_train`) using `stochtree::bart()`. Default: `"mu"`. #' - `adaptive_coding` Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: `TRUE`. -#' - `b_0` Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: `-0.5`. -#' - `b_1` Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: `0.5`. +#' - `control_coding_init` Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: `-0.5`. +#' - `treated_coding_init` Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: `0.5`. +#' - `rfx_prior_var` Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length `ncol(rfx_basis_train)`. Default: `rep(1, ncol(rfx_basis_train))` #' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. -#' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. -#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. -#' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`. +#' - `keep_burnin` Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`. +#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`. #' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. -#' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`. -#' -#' **2. Prognostic Forest Parameters** -#' -#' - `num_trees_mu` Number of trees in the prognostic forest. Default: `200`. -#' - `sample_sigma_leaf_mu` Whether or not to update the `sigma_leaf_mu` leaf scale variance parameter in the prognostic forest based on `IG(a_leaf_mu, b_leaf_mu)`. Default: `TRUE`. -#' -#' **2.1. Tree Prior Parameters** -#' -#' - `alpha_mu` Prior probability of splitting for a tree of depth 0 for the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha_mu*(1+node_depth)^-beta_mu`. Default: `0.95`. -#' - `beta_mu` Exponent that decreases split probabilities for nodes of depth > 0 for the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha_mu*(1+node_depth)^-beta_mu`. Default: `2.0`. -#' - `min_samples_leaf_mu` Minimum allowable size of a leaf, in terms of training samples, for the prognostic forest. Default: `5`. -#' - `max_depth_mu` Maximum depth of any tree in the mu ensemble. Default: `10`. Can be overridden with `-1` which does not enforce any depth limits on trees. -#' -#' **2.2. Leaf Model Parameters** -#' -#' - `keep_vars_mu` Vector of variable names or column indices denoting variables that should be included in the prognostic (`mu(X)`) forest. Default: `NULL`. -#' - `drop_vars_mu` Vector of variable names or column indices denoting variables that should be excluded from the prognostic (`mu(X)`) forest. Default: `NULL`. If both `drop_vars_mu` and `keep_vars_mu` are set, `drop_vars_mu` will be ignored. -#' - `sigma_leaf_mu` Starting value of leaf node scale parameter for the prognostic forest. Calibrated internally as `1/num_trees_mu` if not set here. -#' - `a_leaf_mu` Shape parameter in the `IG(a_leaf_mu, b_leaf_mu)` leaf node parameter variance model for the prognostic forest. Default: `3`. -#' - `b_leaf_mu` Scale parameter in the `IG(a_leaf_mu, b_leaf_mu)` leaf node parameter variance model for the prognostic forest. Calibrated internally as `0.5/num_trees` if not set here. #' -#' **3. Treatment Effect Forest Parameters** -#' -#' - `num_trees_tau` Number of trees in the treatment effect forest. Default: `50`. -#' - `sample_sigma_leaf_tau` Whether or not to update the `sigma_leaf_tau` leaf scale variance parameter in the treatment effect forest based on `IG(a_leaf_tau, b_leaf_tau)`. Default: `TRUE`. +#' @param mu_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' **3.1. Tree Prior Parameters** -#' -#' - `alpha_tau` Prior probability of splitting for a tree of depth 0 for the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha_tau*(1+node_depth)^-beta_tau`. Default: `0.25`. -#' - `beta_tau` Exponent that decreases split probabilities for nodes of depth > 0 for the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha_tau*(1+node_depth)^-beta_tau`. Default: `3.0`. -#' - `min_samples_leaf_tau` Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Default: `5`. -#' - `max_depth_tau` Maximum depth of any tree in the tau ensemble. Default: `5`. Can be overridden with `-1` which does not enforce any depth limits on trees. +#' - `num_trees` Number of trees in the ensemble for the prognostic forest. Default: `250`. Must be a positive integer. +#' - `alpha` Prior probability of splitting for a tree of depth 0 in the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`. +#' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the prognostic forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`. +#' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the prognostic forest. Default: `5`. +#' - `max_depth` Maximum depth of any tree in the ensemble in the prognostic forest. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. +#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the prognostic forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: `FALSE`. +#' - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`. +#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. +#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. +#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' -#' **3.2. Leaf Model Parameters** -#' -#' - `a_leaf_tau` Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the treatment effect forest. Default: `3`. -#' - `b_leaf_tau` Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the treatment effect forest. Calibrated internally as `0.5/num_trees` if not set here. -#' - `keep_vars_tau` Vector of variable names or column indices denoting variables that should be included in the treatment effect (`tau(X)`) forest. Default: `NULL`. -#' - `drop_vars_tau` Vector of variable names or column indices denoting variables that should be excluded from the treatment effect (`tau(X)`) forest. Default: `NULL`. If both `drop_vars_tau` and `keep_vars_tau` are set, `drop_vars_tau` will be ignored. +#' @param tau_forest_params (Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' **4. Conditional Variance Forest Parameters** -#' -#' - `num_trees_variance` Number of trees in the (optional) conditional variance forest model. Default: `0`. -#' - `variance_forest_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(pct_var_variance_forest_init*var((y-mean(y))/sd(y)))/num_trees_variance` if not set. -#' - `pct_var_variance_forest_init` Percentage of standardized outcome variance used to initialize global error variance parameter. Default: `1`. Superseded by `variance_forest_init`. +#' - `num_trees` Number of trees in the ensemble for the treatment effect forest. Default: `50`. Must be a positive integer. +#' - `alpha` Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.25`. +#' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `3`. +#' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Default: `5`. +#' - `max_depth` Maximum depth of any tree in the ensemble in the treatment effect forest. Default: `5`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. +#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the treatment effect forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. +#' - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: `FALSE`. +#' - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. +#' - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`. +#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. +#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. +#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' -#' **4.1. Tree Prior Parameters** -#' -#' - `alpha_variance` Prior probability of splitting for a tree of depth 0 in the (optional) conditional variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance`. Default: `0.95`. -#' - `beta_variance` Exponent that decreases split probabilities for nodes of depth > 0 in the (optional) conditional variance model. Tree split prior combines `alpha_variance` and `beta_variance` via `alpha_variance*(1+node_depth)^-beta_variance`. Default: `2.0`. -#' - `min_samples_leaf_variance` Minimum allowable size of a leaf, in terms of training samples, in the (optional) conditional variance model. Default: `5`. -#' - `max_depth_variance` Maximum depth of any tree in the ensemble in the (optional) conditional variance model. Default: `10`. Can be overridden with `-1` which does not enforce any depth limits on trees. +#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' -#' **4.2. Leaf Model Parameters** -#' -#' - `a_forest` Shape parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2 + 0.5` if not set. -#' - `b_forest` Scale parameter in the `IG(a_forest, b_forest)` conditional error variance model (which is only sampled if `num_trees_variance > 0`). Calibrated internally as `num_trees_variance / 1.5^2` if not set. -#' - `keep_vars_variance` Vector of variable names or column indices denoting variables that should be included in the (optional) conditional variance forest. Default: `NULL`. -#' - `drop_vars_variance` Vector of variable names or column indices denoting variables that should be excluded from the (optional) conditional variance forest. Default: NULL. If both `drop_vars_variance` and `keep_vars_variance` are set, `drop_vars_variance` will be ignored. -#' -#' **5. Random Effects Parameters** -#' -#' - `rfx_prior_var` Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length `ncol(rfx_basis_train)`. Default: `rep(1, ncol(rfx_basis_train))` +#' - `num_trees` Number of trees in the ensemble for the conditional variance model. Default: `0`. Variance is only modeled using a tree / forest if `num_trees > 0`. +#' - `alpha` Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `0.95`. +#' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`. +#' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: `5`. +#' - `max_depth` Maximum depth of any tree in the ensemble in the variance model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees. +#' - `variance_forest_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(0.6*var((y_train-mean(y_train))/sd(y_train)))/num_trees` if not set. +#' - `var_forest_prior_shape` Shape parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2 + 0.5` if not set. +#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set. +#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. +#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export @@ -166,64 +144,121 @@ #' # abline(0,1,col="red",lty=3,lwd=3) bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, Z_test = NULL, pi_test = NULL, - group_ids_test = NULL, rfx_basis_test = NULL, num_gfr = 5, - num_burnin = 0, num_mcmc = 100, previous_model_json = NULL, - warmstart_sample_num = NULL, params = list()) { - # Extract BCF parameters - bcf_params <- preprocessBcfParams(params) - cutpoint_grid_size <- bcf_params$cutpoint_grid_size - sigma_leaf_mu <- bcf_params$sigma_leaf_mu - sigma_leaf_tau <- bcf_params$sigma_leaf_tau - alpha_mu <- bcf_params$alpha_mu - alpha_tau <- bcf_params$alpha_tau - alpha_variance <- bcf_params$alpha_variance - beta_mu <- bcf_params$beta_mu - beta_tau <- bcf_params$beta_tau - beta_variance <- bcf_params$beta_variance - min_samples_leaf_mu <- bcf_params$min_samples_leaf_mu - min_samples_leaf_tau <- bcf_params$min_samples_leaf_tau - min_samples_leaf_variance <- bcf_params$min_samples_leaf_variance - max_depth_mu <- bcf_params$max_depth_mu - max_depth_tau <- bcf_params$max_depth_tau - max_depth_variance <- bcf_params$max_depth_variance - a_global <- bcf_params$a_global - b_global <- bcf_params$b_global - a_leaf_mu <- bcf_params$a_leaf_mu - a_leaf_tau <- bcf_params$a_leaf_tau - b_leaf_mu <- bcf_params$b_leaf_mu - b_leaf_tau <- bcf_params$b_leaf_tau - a_forest <- bcf_params$a_forest - b_forest <- bcf_params$b_forest - sigma2_init <- bcf_params$sigma2_init - variance_forest_init <- bcf_params$variance_forest_init - pct_var_sigma2_init <- bcf_params$pct_var_sigma2_init - pct_var_variance_forest_init <- bcf_params$pct_var_variance_forest_init - variable_weights <- bcf_params$variable_weights - keep_vars_mu <- bcf_params$keep_vars_mu - drop_vars_mu <- bcf_params$drop_vars_mu - keep_vars_tau <- bcf_params$keep_vars_tau - drop_vars_tau <- bcf_params$drop_vars_tau - keep_vars_variance <- bcf_params$keep_vars_variance - drop_vars_variance <- bcf_params$drop_vars_variance - num_trees_mu <- bcf_params$num_trees_mu - num_trees_tau <- bcf_params$num_trees_tau - num_trees_variance <- bcf_params$num_trees_variance - sample_sigma_global <- bcf_params$sample_sigma_global - sample_sigma_leaf_mu <- bcf_params$sample_sigma_leaf_mu - sample_sigma_leaf_tau <- bcf_params$sample_sigma_leaf_tau - propensity_covariate <- bcf_params$propensity_covariate - adaptive_coding <- bcf_params$adaptive_coding - b_0 <- bcf_params$b_0 - b_1 <- bcf_params$b_1 - rfx_prior_var <- bcf_params$rfx_prior_var - random_seed <- bcf_params$random_seed - keep_burnin <- bcf_params$keep_burnin - keep_gfr <- bcf_params$keep_gfr - standardize <- bcf_params$standardize - keep_every <- bcf_params$keep_every - num_chains <- bcf_params$num_chains - verbose <- bcf_params$verbose + group_ids_test = NULL, rfx_basis_test = NULL, + num_gfr = 5, num_burnin = 0, num_mcmc = 100, + previous_model_json = NULL, warmstart_sample_num = NULL, + general_params = list(), mu_forest_params = list(), + tau_forest_params = list(), variance_forest_params = list()) { + # Update general BCF parameters + general_params_default <- list( + cutpoint_grid_size = 100, standardize = T, + sample_sigma2_global = T, sigma2_global_init = NULL, + sigma2_global_shape = 0, sigma2_global_scale = 0, + variable_weights = NULL, propensity_covariate = "mu", + adaptive_coding = T, control_coding_init = -0.5, + treated_coding_init = 0.5, rfx_prior_var = NULL, + random_seed = -1, keep_burnin = F, keep_gfr = F, + keep_every = 1, num_chains = 1, verbose = F + ) + general_params_updated <- preprocessParams( + general_params_default, general_params + ) + + # Update mu forest BCF parameters + mu_forest_params_default <- list( + num_trees = 250, alpha = 0.95, beta = 2.0, + min_samples_leaf = 5, max_depth = 10, + sample_sigma2_leaf = T, sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, + keep_vars = NULL, drop_vars = NULL + ) + mu_forest_params_updated <- preprocessParams( + mu_forest_params_default, mu_forest_params + ) + + # Update tau forest BCF parameters + tau_forest_params_default <- list( + num_trees = 50, alpha = 0.25, beta = 3.0, + min_samples_leaf = 5, max_depth = 5, + sample_sigma2_leaf = F, sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, + keep_vars = NULL, drop_vars = NULL + ) + tau_forest_params_updated <- preprocessParams( + tau_forest_params_default, tau_forest_params + ) + + # Update variance forest BCF parameters + variance_forest_params_default <- list( + num_trees = 0, alpha = 0.95, beta = 2.0, + min_samples_leaf = 5, max_depth = 10, + variance_forest_init = NULL, var_forest_prior_shape = NULL, + var_forest_prior_scale = NULL, + keep_vars = NULL, drop_vars = NULL + ) + variance_forest_params_updated <- preprocessParams( + variance_forest_params_default, variance_forest_params + ) + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size <- general_params_updated$cutpoint_grid_size + standardize <- general_params_updated$standardize + sample_sigma_global <- general_params_updated$sample_sigma2_global + sigma2_init <- general_params_updated$sigma2_global_init + a_global <- general_params_updated$sigma2_global_shape + b_global <- general_params_updated$sigma2_global_scale + variable_weights <- general_params_updated$variable_weights + propensity_covariate <- general_params_updated$propensity_covariate + adaptive_coding <- general_params_updated$adaptive_coding + b_0 <- general_params_updated$control_coding_init + b_1 <- general_params_updated$treated_coding_init + rfx_prior_var <- general_params_updated$rfx_prior_var + random_seed <- general_params_updated$random_seed + keep_burnin <- general_params_updated$keep_burnin + keep_gfr <- general_params_updated$keep_gfr + keep_every <- general_params_updated$keep_every + num_chains <- general_params_updated$num_chains + verbose <- general_params_updated$verbose + + # 2. Mu forest parameters + num_trees_mu <- mu_forest_params_updated$num_trees + alpha_mu <- mu_forest_params_updated$alpha + beta_mu <- mu_forest_params_updated$beta + min_samples_leaf_mu <- mu_forest_params_updated$min_samples_leaf + max_depth_mu <- mu_forest_params_updated$max_depth + sample_sigma_leaf_mu <- mu_forest_params_updated$sample_sigma2_leaf + sigma_leaf_mu <- mu_forest_params_updated$sigma2_leaf_init + a_leaf_mu <- mu_forest_params_updated$sigma2_leaf_shape + b_leaf_mu <- mu_forest_params_updated$sigma2_leaf_scale + keep_vars_mu <- mu_forest_params_updated$keep_vars + drop_vars_mu <- mu_forest_params_updated$drop_vars + + # 3. Tau forest parameters + num_trees_tau <- tau_forest_params_updated$num_trees + alpha_tau <- tau_forest_params_updated$alpha + beta_tau <- tau_forest_params_updated$beta + min_samples_leaf_tau <- tau_forest_params_updated$min_samples_leaf + max_depth_tau <- tau_forest_params_updated$max_depth + sample_sigma_leaf_tau <- tau_forest_params_updated$sample_sigma2_leaf + sigma_leaf_tau <- tau_forest_params_updated$sigma2_leaf_init + a_leaf_tau <- tau_forest_params_updated$sigma2_leaf_shape + b_leaf_tau <- tau_forest_params_updated$sigma2_leaf_scale + keep_vars_tau <- tau_forest_params_updated$keep_vars + drop_vars_tau <- tau_forest_params_updated$drop_vars + + # 4. Variance forest parameters + num_trees_variance <- variance_forest_params_updated$num_trees + alpha_variance <- variance_forest_params_updated$alpha + beta_variance <- variance_forest_params_updated$beta + min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf + max_depth_variance <- variance_forest_params_updated$max_depth + variance_forest_init <- variance_forest_params_updated$init_root_val + a_forest <- variance_forest_params_updated$var_forest_prior_shape + b_forest <- variance_forest_params_updated$var_forest_prior_scale + keep_vars_variance <- variance_forest_params_updated$keep_vars + drop_vars_variance <- variance_forest_params_updated$drop_vars + # Override keep_gfr if there are no MCMC samples if (num_mcmc == 0) keep_gfr <- T @@ -234,7 +269,6 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) previous_y_bar <- previous_bcf_model$model_params$outcome_mean previous_y_scale <- previous_bcf_model$model_params$outcome_scale - previous_var_scale <- previous_bcf_model$model_params$variance_scale previous_forest_samples_mu <- previous_bcf_model$forests_mu previous_forest_samples_tau <- previous_bcf_model$forests_tau if (previous_bcf_model$model_params$include_variance_forest) { @@ -634,8 +668,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU resid_train <- (y_train-y_bar_train)/y_std_train # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau - if (is.null(sigma2_init)) sigma2_init <- pct_var_sigma2_init*var(resid_train) - if (is.null(variance_forest_init)) variance_forest_init <- pct_var_variance_forest_init*var(resid_train) + if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) + if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) if (is.null(sigma_leaf_mu)) sigma_leaf_mu <- var(resid_train)/(num_trees_mu) @@ -1101,6 +1135,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU if (sample_sigma_leaf_tau) { leaf_scale_tau_samples <- leaf_scale_tau_samples[(num_gfr+1):length(leaf_scale_tau_samples)] } + if (adaptive_coding) { + b_1_samples <- b_1_samples[(num_gfr+1):length(b_1_samples)] + b_0_samples <- b_0_samples[(num_gfr+1):length(b_0_samples)] + } num_retained_samples <- num_retained_samples - num_gfr } diff --git a/R/utils.R b/R/utils.R index dd755c70..66c9b702 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1,35 +1,78 @@ +#' Preprocess a parameter list, overriding defaults with any provided parameters. +#' +#' @param default_params List of parameters with default values set. +#' @param user_params (Optional) User-supplied overrides to `default_params`. +#' +#' @return Parameter list with defaults overriden by values supplied in `user_params` +#' @export +preprocessParams <- function(default_params, user_params = NULL) { + # Override defaults from general_params + if (!is.null(user_params)) { + for (key in names(user_params)) { + if (key %in% names(default_params)) { + val <- user_params[[key]] + if (!is.null(val)) default_params[[key]] <- val + } + } + } + + # Return result + return(default_params) +} + #' Preprocess BART parameter list. Override defaults with any provided parameters. #' -#' @param params Parameter list +#' @param general_params List of any non-forest-specific parameters +#' @param mean_forest_params List of any mean forest parameters +#' @param variance_forest_params List of any variance forest parameters #' -#' @return Parameter list with defaults overriden by values supplied in `params` +#' @return Parameter list with defaults overriden by values supplied in parameter lists #' @export -preprocessBartParams <- function(params) { +preprocessBartParams <- function(general_params, mean_forest_params, variance_forest_params) { # Default parameter values processed_params <- list( - cutpoint_grid_size = 100, sigma_leaf_init = NULL, + cutpoint_grid_size = 100, alpha_mean = 0.95, beta_mean = 2.0, min_samples_leaf_mean = 5, max_depth_mean = 10, + variable_weights_mean = NULL, num_trees_mean = 200, alpha_variance = 0.95, beta_variance = 2.0, min_samples_leaf_variance = 5, max_depth_variance = 10, - a_global = 0, b_global = 0, a_leaf = 3, b_leaf = NULL, - a_forest = NULL, b_forest = NULL, variance_scale = 1, - sigma2_init = NULL, variance_forest_init = NULL, - pct_var_sigma2_init = 1, pct_var_variance_forest_init = 1, - variable_weights_mean = NULL, variable_weights_variance = NULL, - num_trees_mean = 200, num_trees_variance = 0, - sample_sigma_global = T, sample_sigma_leaf = F, + variable_weights_variance = NULL, num_trees_variance = 0, + sample_sigma2_global = T, sigma2_global_init = NULL, + sigma2_global_shape = 0, sigma2_global_scale = 0, + sample_sigma2_leaf = T, sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, + var_forest_prior_shape = NULL, var_forest_prior_scale = NULL, + variance_forest_init = NULL, + sample_sigma_global = T, sample_sigma2_leaf_mean = F, random_seed = -1, keep_burnin = F, keep_gfr = F, keep_every = 1, num_chains = 1, standardize = T, verbose = F ) - # Override defaults - for (key in names(params)) { - if (!key %in% names(processed_params)) { - stop("Variable ", key, " is not a valid BART model parameter") + # Override defaults from general_params + for (key in names(general_params)) { + if (key %in% names(processed_params)) { + val <- general_params[[key]] + if (!is.null(val)) processed_params[[key]] <- val + } + } + + # Override defaults from mean_forest_params + for (key in names(mean_forest_params)) { + modified_key <- paste0(key, "_mean") + if (modified_key %in% names(processed_params)) { + val <- general_params[[key]] + if (!is.null(val)) processed_params[[modified_key]] <- val + } + } + + # Override defaults from variance_forest_params + for (key in names(variance_forest_params)) { + modified_key <- paste0(key, "_variance") + if (modified_key %in% names(processed_params)) { + val <- general_params[[key]] + if (!is.null(val)) processed_params[[modified_key]] <- val } - val <- params[[key]] - if (!is.null(val)) processed_params[[key]] <- val } # Return result @@ -38,9 +81,12 @@ preprocessBartParams <- function(params) { #' Preprocess BCF parameter list. Override defaults with any provided parameters. #' -#' @param params Parameter list +#' @param general_params List of any non-forest-specific parameters +#' @param mu_forest_params List of any mu forest parameters +#' @param tau_forest_params List of any tau forest parameters +#' @param variance_forest_params List of any variance forest parameters #' -#' @return Parameter list with defaults overriden by values supplied in `params` +#' @return Parameter list with defaults overriden by values supplied in parameter lists #' @export preprocessBcfParams <- function(params) { # Default parameter values @@ -57,7 +103,7 @@ preprocessBcfParams <- function(params) { keep_vars_tau = NULL, drop_vars_tau = NULL, keep_vars_variance = NULL, drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50, num_trees_variance = 0, num_gfr = 5, num_burnin = 0, num_mcmc = 100, - sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F, + sample_sigma_global = T, sample_sigma2_leaf_mu = T, sample_sigma2_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, keep_every = 1, num_chains = 1, standardize = T, verbose = F @@ -65,11 +111,37 @@ preprocessBcfParams <- function(params) { # Override defaults for (key in names(params)) { - if (!key %in% names(processed_params)) { - stop("Variable ", key, " is not a valid BART model parameter") + if (key %in% names(processed_params)) { + val <- params[[key]] + if (!is.null(val)) processed_params[[key]] <- val + } + } + + # Override defaults from mu_forest_params + for (key in names(mu_forest_params)) { + modified_key <- paste0(key, "_mu") + if (modified_key %in% names(processed_params)) { + val <- general_params[[key]] + if (!is.null(val)) processed_params[[modified_key]] <- val + } + } + + # Override defaults from tau_forest_params + for (key in names(tau_forest_params)) { + modified_key <- paste0(key, "_tau") + if (modified_key %in% names(processed_params)) { + val <- general_params[[key]] + if (!is.null(val)) processed_params[[modified_key]] <- val + } + } + + # Override defaults from variance_forest_params + for (key in names(variance_forest_params)) { + modified_key <- paste0(key, "_variance") + if (modified_key %in% names(processed_params)) { + val <- general_params[[key]] + if (!is.null(val)) processed_params[[modified_key]] <- val } - val <- params[[key]] - if (!is.null(val)) processed_params[[key]] <- val } # Return result diff --git a/demo/notebooks/causal_inference.ipynb b/demo/notebooks/causal_inference.ipynb index c6d39642..4c5eb17c 100644 --- a/demo/notebooks/causal_inference.ipynb +++ b/demo/notebooks/causal_inference.ipynb @@ -103,7 +103,8 @@ "outputs": [], "source": [ "bcf_model = BCFModel()\n", - "bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, params={\"keep_every\": 5})" + "general_params = {\"keep_every\": 5}\n", + "bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, general_params=general_params)" ] }, { diff --git a/demo/notebooks/supervised_learning.ipynb b/demo/notebooks/supervised_learning.ipynb index 4fe6465d..9a49289a 100644 --- a/demo/notebooks/supervised_learning.ipynb +++ b/demo/notebooks/supervised_learning.ipynb @@ -119,8 +119,8 @@ "outputs": [], "source": [ "bart_model = BARTModel()\n", - "param_dict = {\"num_chains\": 3}\n", - "bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100, params=param_dict)" + "general_params = {\"num_chains\": 3}\n", + "bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100, general_params=general_params)" ] }, { diff --git a/man/bart.Rd b/man/bart.Rd index d5d7bacd..1e8b8dd5 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -19,7 +19,9 @@ bart( num_mcmc = 100, previous_model_json = NULL, warmstart_sample_num = NULL, - params = list() + general_params = list(), + mean_forest_params = list(), + variance_forest_params = list() ) } \arguments{ @@ -66,69 +68,47 @@ that were not in the training set.} \item{warmstart_sample_num}{(Optional) Sample number from \code{previous_model_json} that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting \code{warmstart_sample_num = 1}). Default: \code{NULL}.} -\item{params}{The list of model parameters, each of which has a default value. - -\strong{1. Global Parameters} +\item{general_params}{(Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{cutpoint_grid_size} Maximum size of the "grid" of potential cutpoints to consider. Default: \code{100}. -\item \code{sigma2_init} Starting value of global error variance parameter. Calibrated internally as \code{pct_var_sigma2_init*var((y-mean(y))/sd(y))} if not set. -\item \code{pct_var_sigma2_init} Percentage of standardized outcome variance used to initialize global error variance parameter. Default: \code{1}. Superseded by \code{sigma2_init}. -\item \code{variance_scale} Variance after the data have been scaled. Default: \code{1}. -\item \code{a_global} Shape parameter in the \code{IG(a_global, b_global)} global error variance model. Default: \code{0}. -\item \code{b_global} Scale parameter in the \code{IG(a_global, b_global)} global error variance model. Default: \code{0}. -\item \code{random_seed} Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}. -\item \code{sample_sigma_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(a_global, b_global)}. Default: \code{TRUE}. -\item \code{keep_burnin} Whether or not "burnin" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. -\item \code{keep_gfr} Whether or not "grow-from-root" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. +\item \code{cutpoint_grid_size} Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: \code{100}. \item \code{standardize} Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: \code{TRUE}. +\item \code{sample_sigma2_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(sigma2_global_shape, sigma2_global_scale)}. Default: \code{TRUE}. +\item \code{sigma2_init} Starting value of global error variance parameter. Calibrated internally as \code{1.0*var((y_train-mean(y_train))/sd(y_train))} if not set. +\item \code{sigma2_global_shape} Shape parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. +\item \code{sigma2_global_scale} Scale parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. +\item \code{random_seed} Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}. +\item \code{keep_burnin} Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. +\item \code{keep_gfr} Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. \item \code{keep_every} How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default \code{1}. Setting \code{keep_every <- k} for some \code{k > 1} will "thin" the MCMC samples by retaining every \code{k}-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. \item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. -} - -\strong{2. Mean Forest Parameters} -\itemize{ -\item \code{num_trees_mean} Number of trees in the ensemble for the conditional mean model. Default: \code{200}. If \code{num_trees_mean = 0}, the conditional mean will not be modeled using a forest, and the function will only proceed if \code{num_trees_variance > 0}. -\item \code{sample_sigma_leaf} Whether or not to update the \code{tau} leaf scale variance parameter based on \code{IG(a_leaf, b_leaf)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: \code{FALSE}. -} - -\strong{2.1. Tree Prior Parameters} -\itemize{ -\item \code{alpha_mean} Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines \code{alpha_mean} and \code{beta_mean} via \code{alpha_mean*(1+node_depth)^-beta_mean}. Default: \code{0.95}. -\item \code{beta_mean} Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines \code{alpha_mean} and \code{beta_mean} via \code{alpha_mean*(1+node_depth)^-beta_mean}. Default: \code{2}. -\item \code{min_samples_leaf_mean} Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: \code{5}. -\item \code{max_depth_mean} Maximum depth of any tree in the ensemble in the mean model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -} - -\strong{2.2. Leaf Model Parameters} -\itemize{ -\item \code{variable_weights_mean} Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. -\item \code{sigma_leaf_init} Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees_mean} if not set here. -\item \code{a_leaf} Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Default: \code{3}. -\item \code{b_leaf} Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees_mean} if not set here. -} - -\strong{3. Conditional Variance Forest Parameters} -\itemize{ -\item \code{num_trees_variance} Number of trees in the ensemble for the conditional variance model. Default: \code{0}. Variance is only modeled using a tree / forest if \code{num_trees_variance > 0}. -\item \code{variance_forest_init} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(pct_var_variance_forest_init*var((y-mean(y))/sd(y)))/num_trees_variance} if not set. -\item \code{pct_var_variance_forest_init} Percentage of standardized outcome variance used to initialize global error variance parameter. Default: \code{1}. Superseded by \code{variance_forest_init}. -} +}} -\strong{3.1. Tree Prior Parameters} +\item{mean_forest_params}{(Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{alpha_variance} Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines \code{alpha_variance} and \code{beta_variance} via \code{alpha_variance*(1+node_depth)^-beta_variance}. Default: \code{0.95}. -\item \code{beta_variance} Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines \code{alpha_variance} and \code{beta_variance} via \code{alpha_variance*(1+node_depth)^-beta_variance}. Default: \code{2}. -\item \code{min_samples_leaf_variance} Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: \code{5}. -\item \code{max_depth_variance} Maximum depth of any tree in the ensemble in the variance model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -} +\item \code{num_trees} Number of trees in the ensemble for the conditional mean model. Default: \code{200}. If \code{num_trees = 0}, the conditional mean will not be modeled using a forest, and the function will only proceed if \code{num_trees > 0} for the variance forest. +\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. +\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. +\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: \code{5}. +\item \code{max_depth} Maximum depth of any tree in the ensemble in the mean model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. +\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. +\item \code{sample_sigma2_leaf} Whether or not to update the leaf scale variance parameter based on \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: \code{FALSE}. +\item \code{sigma2_leaf_init} Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here. +\item \code{sigma2_leaf_shape} Shape parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Default: \code{3}. +\item \code{sigma2_leaf_scale} Scale parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here. +}} -\strong{3.2. Leaf Model Parameters} +\item{variance_forest_params}{(Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{variable_weights_variance} Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. -\item \code{sigma_leaf_init} Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees_mean} if not set here. -\item \code{a_forest} Shape parameter in the \code{IG(a_forest, b_forest)} conditional error variance model (which is only sampled if \code{num_trees_variance > 0}). Calibrated internally as \code{num_trees_variance / 1.5^2 + 0.5} if not set. -\item \code{b_forest} Scale parameter in the \code{IG(a_forest, b_forest)} conditional error variance model (which is only sampled if \code{num_trees_variance > 0}). Calibrated internally as \code{num_trees_variance / 1.5^2} if not set. +\item \code{num_trees} Number of trees in the ensemble for the conditional variance model. Default: \code{0}. Variance is only modeled using a tree / forest if \code{num_trees > 0}. +\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. +\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. +\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: \code{5}. +\item \code{max_depth} Maximum depth of any tree in the ensemble in the variance model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. +\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. +\item \code{init_root_val} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(0.6*var((y_train-mean(y_train))/sd(y_train)))/num_trees} if not set. +\item \code{var_forest_prior_shape} Shape parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2 + 0.5} if not set. +\item \code{var_forest_prior_scale} Scale parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2} if not set. }} } \value{ diff --git a/man/bcf.Rd b/man/bcf.Rd index d5743a6d..653ea440 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -21,7 +21,10 @@ bcf( num_mcmc = 100, previous_model_json = NULL, warmstart_sample_num = NULL, - params = list() + general_params = list(), + mu_forest_params = list(), + tau_forest_params = list(), + variance_forest_params = list() ) } \arguments{ @@ -67,101 +70,73 @@ that were not in the training set.} \item{warmstart_sample_num}{(Optional) Sample number from \code{previous_model_json} that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting \code{warmstart_sample_num = 1}). Default: \code{NULL}.} -\item{params}{The list of model parameters, each of which has a default value. - -\strong{1. Global Parameters} +\item{general_params}{(Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{cutpoint_grid_size} Maximum size of the "grid" of potential cutpoints to consider. Default: \code{100}. -\item \code{a_global} Shape parameter in the \code{IG(a_global, b_global)} global error variance model. Default: \code{0}. -\item \code{b_global} Scale parameter in the \code{IG(a_global, b_global)} global error variance model. Default: \code{0}. -\item \code{sigma2_init} Starting value of global error variance parameter. Calibrated internally as \code{pct_var_sigma2_init*var((y-mean(y))/sd(y))} if not set. -\item \code{pct_var_sigma2_init} Percentage of standardized outcome variance used to initialize global error variance parameter. Default: \code{1}. Superseded by \code{sigma2_init}. +\item \code{cutpoint_grid_size} Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: \code{100}. +\item \code{standardize} Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: \code{TRUE}. +\item \code{sample_sigma2_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(sigma2_global_shape, sigma2_global_scale)}. Default: \code{TRUE}. +\item \code{sigma2_init} Starting value of global error variance parameter. Calibrated internally as \code{1.0*var((y_train-mean(y_train))/sd(y_train))} if not set. +\item \code{sigma2_global_shape} Shape parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. +\item \code{sigma2_global_scale} Scale parameter in the \code{IG(sigma2_global_shape, sigma2_global_scale)} global error variance model. Default: \code{0}. \item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to \code{1/ncol(X_train)}. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in \code{X_train} and then set \code{propensity_covariate} to \code{'none'} adjust \code{keep_vars_mu}, \code{keep_vars_tau} and \code{keep_vars_variance} accordingly. \item \code{propensity_covariate} Whether to include the propensity score as a covariate in either or both of the forests. Enter \code{"none"} for neither, \code{"mu"} for the prognostic forest, \code{"tau"} for the treatment forest, and \code{"both"} for both forests. If this is not \code{"none"} and a propensity score is not provided, it will be estimated from (\code{X_train}, \code{Z_train}) using \code{stochtree::bart()}. Default: \code{"mu"}. \item \code{adaptive_coding} Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters \code{b_0} and \code{b_1} that attach to the outcome model \verb{[b_0 (1-Z) + b_1 Z] tau(X)}. This is ignored when Z is not binary. Default: \code{TRUE}. -\item \code{b_0} Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: \code{-0.5}. -\item \code{b_1} Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: \code{0.5}. +\item \code{control_coding_init} Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: \code{-0.5}. +\item \code{treated_coding_init} Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: \code{0.5}. +\item \code{rfx_prior_var} Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length \code{ncol(rfx_basis_train)}. Default: \code{rep(1, ncol(rfx_basis_train))} \item \code{random_seed} Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to \code{std::random_device}. -\item \code{keep_burnin} Whether or not "burnin" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. -\item \code{keep_gfr} Whether or not "grow-from-root" samples should be included in cached predictions. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. -\item \code{standardize} Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: \code{TRUE}. +\item \code{keep_burnin} Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. +\item \code{keep_gfr} Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default \code{FALSE}. Ignored if \code{num_mcmc = 0}. \item \code{keep_every} How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default \code{1}. Setting \code{keep_every <- k} for some \code{k > 1} will "thin" the MCMC samples by retaining every \code{k}-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. \item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. -\item \code{sample_sigma_global} Whether or not to update the \code{sigma^2} global error variance parameter based on \code{IG(a_global, b_global)}. Default: \code{TRUE}. -} - -\strong{2. Prognostic Forest Parameters} -\itemize{ -\item \code{num_trees_mu} Number of trees in the prognostic forest. Default: \code{200}. -\item \code{sample_sigma_leaf_mu} Whether or not to update the \code{sigma_leaf_mu} leaf scale variance parameter in the prognostic forest based on \code{IG(a_leaf_mu, b_leaf_mu)}. Default: \code{TRUE}. -} - -\strong{2.1. Tree Prior Parameters} -\itemize{ -\item \code{alpha_mu} Prior probability of splitting for a tree of depth 0 for the prognostic forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha_mu*(1+node_depth)^-beta_mu}. Default: \code{0.95}. -\item \code{beta_mu} Exponent that decreases split probabilities for nodes of depth > 0 for the prognostic forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha_mu*(1+node_depth)^-beta_mu}. Default: \code{2.0}. -\item \code{min_samples_leaf_mu} Minimum allowable size of a leaf, in terms of training samples, for the prognostic forest. Default: \code{5}. -\item \code{max_depth_mu} Maximum depth of any tree in the mu ensemble. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -} - -\strong{2.2. Leaf Model Parameters} -\itemize{ -\item \code{keep_vars_mu} Vector of variable names or column indices denoting variables that should be included in the prognostic (\code{mu(X)}) forest. Default: \code{NULL}. -\item \code{drop_vars_mu} Vector of variable names or column indices denoting variables that should be excluded from the prognostic (\code{mu(X)}) forest. Default: \code{NULL}. If both \code{drop_vars_mu} and \code{keep_vars_mu} are set, \code{drop_vars_mu} will be ignored. -\item \code{sigma_leaf_mu} Starting value of leaf node scale parameter for the prognostic forest. Calibrated internally as \code{1/num_trees_mu} if not set here. -\item \code{a_leaf_mu} Shape parameter in the \code{IG(a_leaf_mu, b_leaf_mu)} leaf node parameter variance model for the prognostic forest. Default: \code{3}. -\item \code{b_leaf_mu} Scale parameter in the \code{IG(a_leaf_mu, b_leaf_mu)} leaf node parameter variance model for the prognostic forest. Calibrated internally as \code{0.5/num_trees} if not set here. -} - -\strong{3. Treatment Effect Forest Parameters} -\itemize{ -\item \code{num_trees_tau} Number of trees in the treatment effect forest. Default: \code{50}. -\item \code{sample_sigma_leaf_tau} Whether or not to update the \code{sigma_leaf_tau} leaf scale variance parameter in the treatment effect forest based on \code{IG(a_leaf_tau, b_leaf_tau)}. Default: \code{TRUE}. -} - -\strong{3.1. Tree Prior Parameters} -\itemize{ -\item \code{alpha_tau} Prior probability of splitting for a tree of depth 0 for the treatment effect forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha_tau*(1+node_depth)^-beta_tau}. Default: \code{0.25}. -\item \code{beta_tau} Exponent that decreases split probabilities for nodes of depth > 0 for the treatment effect forest. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha_tau*(1+node_depth)^-beta_tau}. Default: \code{3.0}. -\item \code{min_samples_leaf_tau} Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Default: \code{5}. -\item \code{max_depth_tau} Maximum depth of any tree in the tau ensemble. Default: \code{5}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -} - -\strong{3.2. Leaf Model Parameters} -\itemize{ -\item \code{a_leaf_tau} Shape parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model for the treatment effect forest. Default: \code{3}. -\item \code{b_leaf_tau} Scale parameter in the \code{IG(a_leaf, b_leaf)} leaf node parameter variance model for the treatment effect forest. Calibrated internally as \code{0.5/num_trees} if not set here. -\item \code{keep_vars_tau} Vector of variable names or column indices denoting variables that should be included in the treatment effect (\code{tau(X)}) forest. Default: \code{NULL}. -\item \code{drop_vars_tau} Vector of variable names or column indices denoting variables that should be excluded from the treatment effect (\code{tau(X)}) forest. Default: \code{NULL}. If both \code{drop_vars_tau} and \code{keep_vars_tau} are set, \code{drop_vars_tau} will be ignored. -} - -\strong{4. Conditional Variance Forest Parameters} -\itemize{ -\item \code{num_trees_variance} Number of trees in the (optional) conditional variance forest model. Default: \code{0}. -\item \code{variance_forest_init} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(pct_var_variance_forest_init*var((y-mean(y))/sd(y)))/num_trees_variance} if not set. -\item \code{pct_var_variance_forest_init} Percentage of standardized outcome variance used to initialize global error variance parameter. Default: \code{1}. Superseded by \code{variance_forest_init}. -} +}} -\strong{4.1. Tree Prior Parameters} +\item{mu_forest_params}{(Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{alpha_variance} Prior probability of splitting for a tree of depth 0 in the (optional) conditional variance model. Tree split prior combines \code{alpha_variance} and \code{beta_variance} via \code{alpha_variance*(1+node_depth)^-beta_variance}. Default: \code{0.95}. -\item \code{beta_variance} Exponent that decreases split probabilities for nodes of depth > 0 in the (optional) conditional variance model. Tree split prior combines \code{alpha_variance} and \code{beta_variance} via \code{alpha_variance*(1+node_depth)^-beta_variance}. Default: \code{2.0}. -\item \code{min_samples_leaf_variance} Minimum allowable size of a leaf, in terms of training samples, in the (optional) conditional variance model. Default: \code{5}. -\item \code{max_depth_variance} Maximum depth of any tree in the ensemble in the (optional) conditional variance model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. -} +\item \code{num_trees} Number of trees in the ensemble for the conditional mean model. Default: \code{200}. If \code{num_trees = 0}, the conditional mean will not be modeled using a forest, and the function will only proceed if \code{num_trees > 0} for the variance forest. +\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. +\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. +\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: \code{5}. +\item \code{max_depth} Maximum depth of any tree in the ensemble in the mean model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. +\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. +\item \code{sample_sigma2_leaf} Whether or not to update the leaf scale variance parameter based on \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: \code{FALSE}. +\item \code{sigma2_leaf_init} Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here. +\item \code{sigma2_leaf_shape} Shape parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Default: \code{3}. +\item \code{sigma2_leaf_scale} Scale parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here. +\item \code{keep_vars} Vector of variable names or column indices denoting variables that should be included in the forest. Default: \code{NULL}. +\item \code{drop_vars} Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: \code{NULL}. If both \code{drop_vars} and \code{keep_vars} are set, \code{drop_vars} will be ignored. +}} -\strong{4.2. Leaf Model Parameters} +\item{tau_forest_params}{(Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{a_forest} Shape parameter in the \code{IG(a_forest, b_forest)} conditional error variance model (which is only sampled if \code{num_trees_variance > 0}). Calibrated internally as \code{num_trees_variance / 1.5^2 + 0.5} if not set. -\item \code{b_forest} Scale parameter in the \code{IG(a_forest, b_forest)} conditional error variance model (which is only sampled if \code{num_trees_variance > 0}). Calibrated internally as \code{num_trees_variance / 1.5^2} if not set. -\item \code{keep_vars_variance} Vector of variable names or column indices denoting variables that should be included in the (optional) conditional variance forest. Default: \code{NULL}. -\item \code{drop_vars_variance} Vector of variable names or column indices denoting variables that should be excluded from the (optional) conditional variance forest. Default: NULL. If both \code{drop_vars_variance} and \code{keep_vars_variance} are set, \code{drop_vars_variance} will be ignored. -} +\item \code{num_trees} Number of trees in the ensemble for the conditional mean model. Default: \code{200}. If \code{num_trees = 0}, the conditional mean will not be modeled using a forest, and the function will only proceed if \code{num_trees > 0} for the variance forest. +\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. +\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. +\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: \code{5}. +\item \code{max_depth} Maximum depth of any tree in the ensemble in the mean model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. +\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. +\item \code{sample_sigma2_leaf} Whether or not to update the leaf scale variance parameter based on \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)}. Cannot (currently) be set to true if \code{ncol(W_train)>1}. Default: \code{FALSE}. +\item \code{sigma2_leaf_init} Starting value of leaf node scale parameter. Calibrated internally as \code{1/num_trees} if not set here. +\item \code{sigma2_leaf_shape} Shape parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Default: \code{3}. +\item \code{sigma2_leaf_scale} Scale parameter in the \code{IG(sigma2_leaf_shape, sigma2_leaf_scale)} leaf node parameter variance model. Calibrated internally as \code{0.5/num_trees} if not set here. +\item \code{keep_vars} Vector of variable names or column indices denoting variables that should be included in the forest. Default: \code{NULL}. +\item \code{drop_vars} Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: \code{NULL}. If both \code{drop_vars} and \code{keep_vars} are set, \code{drop_vars} will be ignored. +}} -\strong{5. Random Effects Parameters} +\item{variance_forest_params}{(Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ -\item \code{rfx_prior_var} Prior on the (diagonals of the) covariance of the additive group-level random regression coefficients. Must be a vector of length \code{ncol(rfx_basis_train)}. Default: \code{rep(1, ncol(rfx_basis_train))} +\item \code{num_trees} Number of trees in the ensemble for the conditional variance model. Default: \code{0}. Variance is only modeled using a tree / forest if \code{num_trees > 0}. +\item \code{alpha} Prior probability of splitting for a tree of depth 0 in the variance model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{0.95}. +\item \code{beta} Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines \code{alpha} and \code{beta} via \code{alpha*(1+node_depth)^-beta}. Default: \code{2}. +\item \code{min_samples_leaf} Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: \code{5}. +\item \code{max_depth} Maximum depth of any tree in the ensemble in the variance model. Default: \code{10}. Can be overridden with \code{-1} which does not enforce any depth limits on trees. +\item \code{variable_weights} Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to \code{rep(1/ncol(X_train), ncol(X_train))} if not set here. +\item \code{init_root_val} Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as \code{log(0.6*var((y_train-mean(y_train))/sd(y_train)))/num_trees} if not set. +\item \code{var_forest_prior_shape} Shape parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2 + 0.5} if not set. +\item \code{var_forest_prior_scale} Scale parameter in the \code{IG(var_forest_prior_shape, var_forest_prior_scale)} conditional error variance model (which is only sampled if \code{num_trees > 0}). Calibrated internally as \code{num_trees / 1.5^2} if not set. +\item \code{keep_vars} Vector of variable names or column indices denoting variables that should be included in the forest. Default: \code{NULL}. +\item \code{drop_vars} Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: \code{NULL}. If both \code{drop_vars} and \code{keep_vars} are set, \code{drop_vars} will be ignored. }} } \value{ diff --git a/man/preprocessBartParams.Rd b/man/preprocessBartParams.Rd index f06ef30f..31d29c0b 100644 --- a/man/preprocessBartParams.Rd +++ b/man/preprocessBartParams.Rd @@ -4,13 +4,21 @@ \alias{preprocessBartParams} \title{Preprocess BART parameter list. Override defaults with any provided parameters.} \usage{ -preprocessBartParams(params) +preprocessBartParams( + general_params, + mean_forest_params, + variance_forest_params +) } \arguments{ -\item{params}{Parameter list} +\item{general_params}{List of any non-forest-specific parameters} + +\item{mean_forest_params}{List of any mean forest parameters} + +\item{variance_forest_params}{List of any variance forest parameters} } \value{ -Parameter list with defaults overriden by values supplied in \code{params} +Parameter list with defaults overriden by values supplied in parameter lists } \description{ Preprocess BART parameter list. Override defaults with any provided parameters. diff --git a/man/preprocessBcfParams.Rd b/man/preprocessBcfParams.Rd index fdc1d498..84b15f74 100644 --- a/man/preprocessBcfParams.Rd +++ b/man/preprocessBcfParams.Rd @@ -7,10 +7,16 @@ preprocessBcfParams(params) } \arguments{ -\item{params}{Parameter list} +\item{general_params}{List of any non-forest-specific parameters} + +\item{mu_forest_params}{List of any mu forest parameters} + +\item{tau_forest_params}{List of any tau forest parameters} + +\item{variance_forest_params}{List of any variance forest parameters} } \value{ -Parameter list with defaults overriden by values supplied in \code{params} +Parameter list with defaults overriden by values supplied in parameter lists } \description{ Preprocess BCF parameter list. Override defaults with any provided parameters. diff --git a/man/preprocessParams.Rd b/man/preprocessParams.Rd new file mode 100644 index 00000000..93a808c0 --- /dev/null +++ b/man/preprocessParams.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{preprocessParams} +\alias{preprocessParams} +\title{Preprocess a parameter list, overriding defaults with any provided parameters.} +\usage{ +preprocessParams(user_params, default_params) +} +\arguments{ +\item{user_params}{User-supplied overrides to \code{default_params}.} + +\item{default_params}{List of parameters with default values set.} +} +\value{ +Parameter list with defaults overriden by values supplied in \code{user_params} +} +\description{ +Preprocess a parameter list, overriding defaults with any provided parameters. +} diff --git a/stochtree/bart.py b/stochtree/bart.py index 7f852142..0dd5f340 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -8,7 +8,7 @@ from typing import Optional, Dict, Any from .data import Dataset, Residual from .forest import ForestContainer, Forest -from .preprocessing import CovariateTransformer, _preprocess_bart_params +from .preprocessing import CovariateTransformer, _preprocess_params from .sampler import ForestSampler, RNG, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer from .utils import NotSampledError @@ -26,7 +26,8 @@ def is_sampled(self) -> bool: return self.sampled def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = None, X_test: np.array = None, basis_test: np.array = None, - num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, params: Optional[Dict[str, Any]] = None) -> None: + num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, general_params: Optional[Dict[str, Any]] = None, + mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None) -> None: """Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis. @@ -49,81 +50,140 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N Number of "burn-in" iterations of the MCMC sampler. Defaults to ``0``. Ignored if ``num_gfr > 0``. num_mcmc : :obj:`int`, optional Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained. - params : :obj:`dict`, optional - Dictionary of model parameters, each of which has a default value. + general_params : :obj:`dict`, optional + Dictionary of general model parameters, each of which has a default value processed internally, so this argument is optional. * ``cutpoint_grid_size`` (``int``): Maximum number of cutpoints to consider for each feature. Defaults to ``100``. - * ``sigma_leaf`` (``float``): Scale parameter on the (conditional mean) leaf node regression model. - * ``alpha_mean`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional mean model. Tree split prior combines ``alpha_mean`` and ``beta_mean`` via ``alpha_mean*(1+node_depth)^-beta_mean``. - * ``beta_mean`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional mean model. Tree split prior combines ``alpha_mean`` and ``beta_mean`` via ``alpha_mean*(1+node_depth)^-beta_mean``. - * ``min_samples_leaf_mean`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the conditional mean model. Defaults to ``5``. - * ``max_depth_mean`` (``int``): Maximum depth of any tree in the ensemble in the conditional mean model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``alpha_variance`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines ``alpha_variance`` and ``beta_variance`` via ``alpha_variance*(1+node_depth)^-beta_variance``. - * ``beta_variance`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines ``alpha_variance`` and ``beta_variance`` via ``alpha_variance*(1+node_depth)^-beta_variance``. - * ``min_samples_leaf_variance`` (``int``): Minimum allowable size of a leaf, in terms of training samples in the conditional variance model. Defaults to ``5``. - * ``max_depth_variance`` (``int``): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``a_global`` (``float``): Shape parameter in the ``IG(a_global, b_global)`` global error variance model. Defaults to ``0``. - * ``b_global`` (``float``): Scale parameter in the ``IG(a_global, b_global)`` global error variance prior. Defaults to ``0``. - * ``a_leaf`` (``float``): Shape parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model. Defaults to ``3``. - * ``b_leaf`` (``float``): Scale parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees_mean`` if not set here. - * ``a_forest`` (``float``): Shape parameter in the [optional] ``IG(a_forest, b_forest)`` conditional error variance forest (which is only sampled if ``num_trees_variance > 0``). Calibrated internally as ``num_trees_variance / 1.5^2 + 0.5`` if not set here. - * ``b_forest`` (``float``): Scale parameter in the [optional] ``IG(a_forest, b_forest)`` conditional error variance forest (which is only sampled if ``num_trees_variance > 0``). Calibrated internally as ``num_trees_variance / 1.5^2`` if not set here. - * ``sigma2_init`` (``float``): Starting value of global variance parameter. Set internally as a percentage of the standardized outcome variance if not set here. - * ``variance_forest_leaf_init`` (``float``): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as ``np.log(pct_var_variance_forest_init*np.var((y-np.mean(y))/np.std(y)))/num_trees_variance`` if not set. - * ``pct_var_sigma2_init`` (``float``): Percentage of standardized outcome variance used to initialize global error variance parameter. Superseded by ``sigma2``. Defaults to ``1``. - * ``pct_var_variance_forest_init`` (``float``): Percentage of standardized outcome variance used to initialize global error variance parameter. Default: ``1``. Superseded by ``variance_forest_init``. - * ``variance_scale`` (``float``): Variance after the data have been scaled. Default: ``1``. - * ``variable_weights_mean`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``variable_weights_variance`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``num_trees_mean`` (``int``): Number of trees in the ensemble for the conditional mean model. Defaults to ``200``. If ``num_trees_mean = 0``, the conditional mean will not be modeled using a forest and the function will only proceed if ``num_trees_variance > 0``. - * ``num_trees_variance`` (``int``): Number of trees in the ensemble for the conditional variance model. Defaults to ``0``. Variance is only modeled using a tree / forest if ``num_trees_variance > 0``. - * ``sample_sigma_global`` (``bool``): Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(a_global, b_global)``. Defaults to ``True``. - * ``sample_sigma_leaf`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(a_leaf, b_leaf)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``False``. + * ``standardize`` (``bool``): Whether or not to standardize the outcome (and store the offset / scale in the model object). Defaults to ``True``. + * ``sample_sigma2_global`` (``bool``): Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(sigma2_global_shape, sigma2_global_scale)``. Defaults to ``True``. + * ``sigma2_init`` (``float``): Starting value of global variance parameter. Set internally to the outcome variance (standardized if `standardize = True`) if not set here. + * ``sigma2_global_shape`` (``float``): Shape parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. + * ``sigma2_global_scale`` (``float``): Scale parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. * ``random_seed`` (``int``): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to ``std::random_device``. * ``keep_burnin`` (``bool``): Whether or not "burnin" samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. * ``keep_gfr`` (``bool``): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. * ``keep_every`` (``int``): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to ``1``. Setting ``keep_every = k`` for some ``k > 1`` will "thin" the MCMC samples by retaining every ``k``-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. + * ``num_chains`` (``int``): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. + + mean_forest_params : :obj:`dict`, optional + Dictionary of mean forest model parameters, each of which has a default value processed internally, so this argument is optional. + + * ``num_trees`` (``int``): Number of trees in the conditional mean model. Defaults to ``200``. If ``num_trees = 0``, the conditional mean will not be modeled using a forest and sampling will only proceed if ``num_trees > 0`` for the variance forest. + * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional mean model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. + * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional mean model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. + * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the conditional mean model. Defaults to ``5``. + * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the conditional mean model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. + * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. + * ``sample_sigma2_leaf`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(sigma2_leaf_shape, sigma2_leaf_scale)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``False``. + * ``sigma2_leaf_init`` (``float``): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. + * ``sigma2_leaf_shape`` (``float``): Shape parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Defaults to ``3``. + * ``sigma2_leaf_scale`` (``float``): Scale parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees`` if not set here. + + variance_forest_params : :obj:`dict`, optional + Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. + + * ``num_trees`` (``int``): Number of trees in the conditional variance model. Defaults to ``0``. Variance is only modeled using a tree / forest if ``num_trees > 0``. + * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. + * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. + * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the conditional variance model. Defaults to ``5``. + * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. + * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. + * ``var_forest_leaf_init`` (``float``): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as ``np.log(0.6*np.var(y_train))/num_trees_variance``, where `y_train` is the possibly standardized outcome, if not set. + * ``var_forest_prior_shape`` (``float``): Shape parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2 + 0.5`` if not set here. + * ``var_forest_prior_scale`` (``float``): Scale parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2`` if not set here. Returns ------- self : BARTModel Sampled BART Model. """ - # Unpack parameters - bart_params = _preprocess_bart_params(params) - cutpoint_grid_size = bart_params['cutpoint_grid_size'] - sigma_leaf = bart_params['sigma_leaf'] - alpha_mean = bart_params['alpha_mean'] - beta_mean = bart_params['beta_mean'] - min_samples_leaf_mean = bart_params['min_samples_leaf_mean'] - max_depth_mean = bart_params['max_depth_mean'] - alpha_variance = bart_params['alpha_variance'] - beta_variance = bart_params['beta_variance'] - min_samples_leaf_variance = bart_params['min_samples_leaf_variance'] - max_depth_variance = bart_params['max_depth_variance'] - a_global = bart_params['a_global'] - b_global = bart_params['b_global'] - a_leaf = bart_params['a_leaf'] - b_leaf = bart_params['b_leaf'] - a_forest = bart_params['a_forest'] - b_forest = bart_params['b_forest'] - sigma2_init = bart_params['sigma2_init'] - variance_forest_leaf_init = bart_params['variance_forest_leaf_init'] - pct_var_sigma2_init = bart_params['pct_var_sigma2_init'] - pct_var_variance_forest_init = bart_params['pct_var_variance_forest_init'] - variance_scale = bart_params['variance_scale'] - variable_weights_mean = bart_params['variable_weights_mean'] - variable_weights_variance = bart_params['variable_weights_variance'] - num_trees_mean = bart_params['num_trees_mean'] - num_trees_variance = bart_params['num_trees_variance'] - sample_sigma_global = bart_params['sample_sigma_global'] - sample_sigma_leaf = bart_params['sample_sigma_leaf'] - random_seed = bart_params['random_seed'] - keep_burnin = bart_params['keep_burnin'] - keep_gfr = bart_params['keep_gfr'] - self.standardize = bart_params['standardize'] - num_chains = bart_params['num_chains'] - keep_every = bart_params['keep_every'] + # Update general BART parameters + general_params_default = { + 'cutpoint_grid_size' : 100, + 'standardize' : True, + 'sample_sigma2_global' : True, + 'sigma2_init' : None, + 'sigma2_global_shape' : 0, + 'sigma2_global_scale' : 0, + 'random_seed' : -1, + 'keep_burnin' : False, + 'keep_gfr' : False, + 'keep_every' : 1, + 'num_chains' : 1 + } + general_params_updated = _preprocess_params( + general_params_default, general_params + ) + + # Update mean forest BART parameters + mean_forest_params_default = { + 'num_trees' : 200, + 'alpha' : 0.95, + 'beta' : 2.0, + 'min_samples_leaf' : 5, + 'max_depth' : 10, + 'variable_weights' : None, + 'sample_sigma2_leaf' : True, + 'sigma2_leaf_init' : None, + 'sigma2_leaf_shape' : 3, + 'sigma2_leaf_scale' : None + } + mean_forest_params_updated = _preprocess_params( + mean_forest_params_default, mean_forest_params + ) + + # Update variance forest BART parameters + variance_forest_params_default = { + 'num_trees' : 0, + 'alpha' : 0.95, + 'beta' : 2.0, + 'min_samples_leaf' : 5, + 'max_depth' : 10, + 'variable_weights' : None, + 'var_forest_leaf_init' : None, + 'var_forest_prior_shape' : None, + 'var_forest_prior_scale' : None + } + variance_forest_params_updated = _preprocess_params( + variance_forest_params_default, variance_forest_params + ) + + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size = general_params_updated['cutpoint_grid_size'] + self.standardize = general_params_updated['standardize'] + sample_sigma_global = general_params_updated['sample_sigma2_global'] + sigma2_init = general_params_updated['sigma2_init'] + a_global = general_params_updated['sigma2_global_shape'] + b_global = general_params_updated['sigma2_global_scale'] + random_seed = general_params_updated['random_seed'] + keep_burnin = general_params_updated['keep_burnin'] + keep_gfr = general_params_updated['keep_gfr'] + keep_every = general_params_updated['keep_every'] + num_chains = general_params_updated['num_chains'] + + # 2. Mean forest parameters + num_trees_mean = mean_forest_params_updated['num_trees'] + alpha_mean = mean_forest_params_updated['alpha'] + beta_mean = mean_forest_params_updated['beta'] + min_samples_leaf_mean = mean_forest_params_updated['min_samples_leaf'] + max_depth_mean = mean_forest_params_updated['max_depth'] + variable_weights_mean = mean_forest_params_updated['variable_weights'] + sample_sigma_leaf = mean_forest_params_updated['sample_sigma2_leaf'] + sigma_leaf = mean_forest_params_updated['sigma2_leaf_init'] + a_leaf = mean_forest_params_updated['sigma2_leaf_shape'] + b_leaf = mean_forest_params_updated['sigma2_leaf_scale'] + + # 3. Variance forest parameters + num_trees_variance = variance_forest_params_updated['num_trees'] + alpha_variance = variance_forest_params_updated['alpha'] + beta_variance = variance_forest_params_updated['beta'] + min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf'] + max_depth_variance = variance_forest_params_updated['max_depth'] + variable_weights_variance = variance_forest_params_updated['variable_weights'] + variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init'] + a_forest = variance_forest_params_updated['var_forest_prior_shape'] + b_forest = variance_forest_params_updated['var_forest_prior_scale'] # Check that num_chains >= 1 if not isinstance(num_chains, Integral) or num_chains < 1: @@ -234,18 +294,13 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N else: self.y_bar = 0 self.y_std = 1 - if variance_scale > 0: - self.variance_scale = variance_scale - else: - raise ValueError("variance_scale must be positive") resid_train = (y_train-self.y_bar)/self.y_std - resid_train = resid_train*np.sqrt(self.variance_scale) # Calibrate priors for global sigma^2 and sigma_leaf (don't use regression initializer for warm-start or XBART) if not sigma2_init: - sigma2_init = pct_var_sigma2_init*np.var(resid_train) + sigma2_init = 1.0*np.var(resid_train) if not variance_forest_leaf_init: - variance_forest_leaf_init = pct_var_variance_forest_init*np.var(resid_train) + variance_forest_leaf_init = 0.6*np.var(resid_train) current_sigma2 = sigma2_init self.sigma2_init = sigma2_init if self.include_mean_forest: @@ -469,17 +524,17 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N # Store predictions if self.sample_sigma_global: - self.global_var_samples = self.global_var_samples*self.y_std*self.y_std/self.variance_scale + self.global_var_samples = self.global_var_samples*self.y_std*self.y_std if self.sample_sigma_leaf: self.leaf_scale_samples = self.leaf_scale_samples if self.include_mean_forest: yhat_train_raw = self.forest_container_mean.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp) - self.y_hat_train = yhat_train_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar + self.y_hat_train = yhat_train_raw*self.y_std + self.y_bar if self.has_test: yhat_test_raw = self.forest_container_mean.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp) - self.y_hat_test = yhat_test_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar + self.y_hat_test = yhat_test_raw*self.y_std + self.y_bar if self.include_variance_forest: sigma_x_train_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp) @@ -488,7 +543,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N for i in range(self.num_samples): self.sigma2_x_train[:,i] = sigma_x_train_raw[:,i]*self.global_var_samples[i] else: - self.sigma2_x_train = sigma_x_train_raw*self.sigma2_init*self.y_std*self.y_std/self.variance_scale + self.sigma2_x_train = sigma_x_train_raw*self.sigma2_init*self.y_std*self.y_std if self.has_test: sigma_x_test_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp) if self.sample_sigma_global: @@ -496,7 +551,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N for i in range(self.num_samples): self.sigma2_x_test[:,i] = sigma_x_test_raw[:,i]*self.global_var_samples[i] else: - self.sigma2_x_test = sigma_x_test_raw*self.sigma2_init*self.y_std*self.y_std/self.variance_scale + self.sigma2_x_test = sigma_x_test_raw*self.sigma2_init*self.y_std*self.y_std def predict(self, covariates: np.array, basis: np.array = None) -> np.array: """Return predictions from every forest sampled (either / both of mean and variance) @@ -538,7 +593,7 @@ def predict(self, covariates: np.array, basis: np.array = None) -> np.array: pred_dataset.add_basis(basis) if self.include_mean_forest: mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(pred_dataset.dataset_cpp) - mean_pred = mean_pred_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar + mean_pred = mean_pred_raw*self.y_std + self.y_bar if self.include_variance_forest: variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(pred_dataset.dataset_cpp) if self.sample_sigma_global: @@ -546,7 +601,7 @@ def predict(self, covariates: np.array, basis: np.array = None) -> np.array: for i in range(self.num_samples): variance_pred[:,i] = np.sqrt(variance_pred_raw[:,i]*self.global_var_samples[i]) else: - variance_pred = np.sqrt(variance_pred_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale) + variance_pred = np.sqrt(variance_pred_raw*self.sigma2_init)*self.y_std if self.include_mean_forest and self.include_variance_forest: return (mean_pred, variance_pred) @@ -601,7 +656,7 @@ def predict_mean(self, covariates: np.array, basis: np.array = None) -> np.array if basis is not None: pred_dataset.add_basis(basis) mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict(pred_dataset.dataset_cpp) - mean_pred = mean_pred_raw*self.y_std/np.sqrt(self.variance_scale) + self.y_bar + mean_pred = mean_pred_raw*self.y_std + self.y_bar return mean_pred @@ -644,7 +699,7 @@ def predict_variance(self, covariates: np.array) -> np.array: for i in range(self.num_samples): variance_pred[:,i] = variance_pred_raw[:,i]*self.global_var_samples[i] else: - variance_pred = variance_pred_raw*self.sigma2_init*self.y_std*self.y_std/self.variance_scale + variance_pred = variance_pred_raw*self.sigma2_init*self.y_std*self.y_std return variance_pred @@ -675,7 +730,6 @@ def to_json(self) -> str: bart_json.add_forest(self.forest_container_variance) # Add global parameters - bart_json.add_scalar("variance_scale", self.variance_scale) bart_json.add_scalar("outcome_scale", self.y_std) bart_json.add_scalar("outcome_mean", self.y_bar) bart_json.add_boolean("standardize", self.standardize) @@ -729,7 +783,6 @@ def from_json(self, json_string: str) -> None: self.forest_container_variance.forest_container_cpp.LoadFromJson(bart_json.json_cpp, "forest_0") # Unpack global parameters - self.variance_scale = bart_json.get_scalar("variance_scale") self.y_std = bart_json.get_scalar("outcome_scale") self.y_bar = bart_json.get_scalar("outcome_mean") self.standardize = bart_json.get_boolean("standardize") diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 9e045e1d..fc16e7dc 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -8,7 +8,7 @@ from .bart import BARTModel from .data import Dataset, Residual from .forest import ForestContainer, Forest -from .preprocessing import CovariateTransformer, _preprocess_bcf_params +from .preprocessing import CovariateTransformer, _preprocess_params from .sampler import ForestSampler, RNG, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer from .utils import NotSampledError @@ -27,7 +27,9 @@ def is_sampled(self) -> bool: def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_train: np.array, pi_train: np.array = None, X_test: Union[pd.DataFrame, np.array] = None, Z_test: np.array = None, pi_test: np.array = None, - num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, params: Optional[Dict[str, Any]] = None) -> None: + num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, general_params: Optional[Dict[str, Any]] = None, + mu_forest_params: Optional[Dict[str, Any]] = None, tau_forest_params: Optional[Dict[str, Any]] = None, + variance_forest_params: Optional[Dict[str, Any]] = None) -> None: """Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set. @@ -54,152 +56,219 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr Number of "burn-in" iterations of the MCMC sampler. Defaults to ``0``. Ignored if ``num_gfr > 0``. num_mcmc : :obj:`int`, optional Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained. - params : :obj:`dict`, optional - Dictionary of model parameters, each of which has a default value. + general_params : :obj:`dict`, optional + Dictionary of general model parameters, each of which has a default value processed internally, so this argument is optional. * ``cutpoint_grid_size`` (``int``): Maximum number of cutpoints to consider for each feature. Defaults to ``100``. - * ``sigma_leaf_mu`` (``float``): Starting value of leaf node scale parameter for the prognostic forest. Calibrated internally as ``2/num_trees_mu`` if not set here. - * ``sigma_leaf_tau`` (``float`` or ``np.array``): Starting value of leaf node scale parameter for the treatment effect forest. - When treatment (``Z_train``) is multivariate, this can be either a ``float`` or a square 2-dimensional ``np.array`` - with ``sigma_leaf_tau.shape[0] == Z_train.shape[1]`` and ``sigma_leaf_tau.shape[1] == Z_train.shape[1]``. - If ``sigma_leaf_tau`` is provided as a float for multivariate treatment, the leaf scale term will be set as a - diagonal matrix with ``sigma_leaf_tau`` on every diagonal. If not passed as an argument, this parameter is - calibrated internally as ``1/num_trees_tau`` (and propagated to a diagonal matrix if necessary). - * ``alpha_mu`` (``float``): Prior probability of splitting for a tree of depth 0 for the prognostic forest. - Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. - * ``alpha_tau`` (``float``): Prior probability of splitting for a tree of depth 0 for the treatment effect forest. - Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. - * ``alpha_variance`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines ``alpha_variance`` and ``beta_variance`` via ``alpha_variance*(1+node_depth)^-beta_variance``. - * ``beta_mu`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 for the prognostic forest. - Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. - * ``beta_tau`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 for the treatment effect forest. - Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. - * ``beta_variance`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines ``alpha_variance`` and ``beta_variance`` via ``alpha_variance*(1+node_depth)^-beta_variance``. - * ``min_samples_leaf_mu`` (``int``): Minimum allowable size of a leaf, in terms of training samples, for the prognostic forest. Defaults to ``5``. - * ``min_samples_leaf_tau`` (``int``): Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Defaults to ``5``. - * ``min_samples_leaf_variance`` (``int``): Minimum allowable size of a leaf, in terms of training samples in the conditional variance model. Defaults to ``5``. - * ``max_depth_mu`` (``int``): Maximum depth of any tree in the mu ensemble. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``max_depth_tau`` (``int``): Maximum depth of any tree in the tau ensemble. Defaults to ``5``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``max_depth_variance`` (``int``): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. - * ``a_global`` (``float``): Shape parameter in the ``IG(a_global, b_global)`` global error variance model. Defaults to ``0``. - * ``b_global`` (``float``): Component of the scale parameter in the ``IG(a_global, b_global)`` global error variance prior. Defaults to ``0``. - * ``a_leaf_mu`` (``float``): Shape parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model for the prognostic forest. Defaults to ``3``. - * ``a_leaf_tau`` (``float``): Shape parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model for the treatment effect forest. Defaults to ``3``. - * ``b_leaf_mu`` (``float``): Scale parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model for the prognostic forest. Calibrated internally as ``0.5/num_trees`` if not set here. - * ``b_leaf_tau`` (``float``): Scale parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model for the treatment effect forest. Calibrated internally as ``0.5/num_trees`` if not set here. - * ``sigma2_init`` (``float``): Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here. - * ``variance_forest_leaf_init`` (``float``): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as ``np.log(pct_var_variance_forest_init*np.var((y-np.mean(y))/np.std(y)))/num_trees_variance`` if not set. - * ``pct_var_sigma2_init`` (``float``): Percentage of standardized outcome variance used to initialize global error variance parameter. Superseded by ``sigma2``. Defaults to ``0.25``. - * ``pct_var_variance_forest_init`` (``float``): Percentage of standardized outcome variance used to initialize global error variance parameter. Default: ``1``. Superseded by ``variance_forest_init``. - * ``variable_weights_mean`` (`np.`array``): Numeric weights reflecting the relative probability of splitting on each variable in the prognostic and treatment effect forests. Does not need to sum to 1 but cannot be negative. Defaults to ``np.repeat(1/X_train.shape[1], X_train.shape[1])`` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to ``1/X_train.shape[1]``. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in ``X_train`` and then set ``propensity_covariate`` to ``'none'`` and adjust ``keep_vars_mu`` and ``keep_vars_tau`` accordingly. - * ``variable_weights_variance`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. - * ``keep_vars_mu`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be included in the prognostic (``mu(X)``) forest. Defaults to ``None``. - * ``drop_vars_mu`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be excluded from the prognostic (``mu(X)``) forest. Defaults to ``None``. If both ``drop_vars_mu`` and ``keep_vars_mu`` are set, ``drop_vars_mu`` will be ignored. - * ``drop_vars_variance`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be excluded from the variance (``sigma^2(X)``) forest. Defaults to ``None``. If both ``drop_vars_variance`` and ``keep_vars_variance`` are set, ``drop_vars_variance`` will be ignored. - * ``keep_vars_tau`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be included in the treatment effect (``tau(X)``) forest. Defaults to ``None``. - * ``drop_vars_tau`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be excluded from the treatment effect (``tau(X)``) forest. Defaults to ``None``. If both ``drop_vars_tau`` and ``keep_vars_tau`` are set, ``drop_vars_tau`` will be ignored. - * ``drop_vars_variance`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be excluded from the variance (``sigma^2(X)``) forest. Defaults to ``None``. If both ``drop_vars_variance`` and ``keep_vars_variance`` are set, ``drop_vars_variance`` will be ignored. - * ``keep_vars_variance`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be included in the variance (``sigma^2(X)``) forest. Defaults to ``None``. - * ``num_trees_mu`` (``int``): Number of trees in the prognostic forest. Defaults to ``200``. - * ``num_trees_tau`` (``int``): Number of trees in the treatment effect forest. Defaults to ``50``. - * ``num_trees_variance`` (``int``): Number of trees in the ensemble for the conditional variance model. Defaults to ``0``. Variance is only modeled using a tree / forest if ``num_trees_variance > 0``. - * ``sample_sigma_global`` (``bool``): Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(a_global, b_global)``. Defaults to ``True``. - * ``sample_sigma_leaf_mu`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(a_leaf, b_leaf)`` for the prognostic forest. - Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``True``. - * ``sample_sigma_leaf_tau`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(a_leaf, b_leaf)`` for the treatment effect forest. - Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``True``. + * ``standardize`` (``bool``): Whether or not to standardize the outcome (and store the offset / scale in the model object). Defaults to ``True``. + * ``sample_sigma2_global`` (``bool``): Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(sigma2_global_shape, sigma2_global_scale)``. Defaults to ``True``. + * ``sigma2_global_init`` (``float``): Starting value of global variance parameter. Set internally to the outcome variance (standardized if `standardize = True`) if not set here. + * ``sigma2_global_shape`` (``float``): Shape parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. + * ``sigma2_global_scale`` (``float``): Scale parameter in the ``IG(sigma2_global_shape, b_glsigma2_global_scaleobal)`` global error variance model. Defaults to ``0``. + * ``variable_weights`` (`np.`array``): Numeric weights reflecting the relative probability of splitting on each variable in each of the forests. Does not need to sum to 1 but cannot be negative. Defaults to ``np.repeat(1/X_train.shape[1], X_train.shape[1])`` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to ``1/X_train.shape[1]``. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in ``X_train`` and then set ``propensity_covariate`` to ``'none'`` and adjust ``keep_vars`` accordingly for the mu or tau forests. * ``propensity_covariate`` (``str``): Whether to include the propensity score as a covariate in either or both of the forests. Enter ``"none"`` for neither, ``"mu"`` for the prognostic forest, ``"tau"`` for the treatment forest, and ``"both"`` for both forests. If this is not ``"none"`` and a propensity score is not provided, it will be estimated from (``X_train``, ``Z_train``) using ``BARTModel``. Defaults to ``"mu"``. * ``adaptive_coding`` (``bool``): Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters ``b_0`` and ``b_1`` that attach to the outcome model ``[b_0 (1-Z) + b_1 Z] tau(X)``. This is ignored when Z is not binary. Defaults to True. - * ``b_0`` (``float``): Initial value of the "control" group coding parameter. This is ignored when ``Z`` is not binary. Default: ``-0.5``. - * ``b_1`` (``float``): Initial value of the "treated" group coding parameter. This is ignored when ``Z`` is not binary. Default: ``0.5``. + * ``control_coding_init`` (``float``): Initial value of the "control" group coding parameter. This is ignored when ``Z`` is not binary. Default: ``-0.5``. + * ``treated_coding_init`` (``float``): Initial value of the "treated" group coding parameter. This is ignored when ``Z`` is not binary. Default: ``0.5``. * ``random_seed`` (``int``): Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to ``std::random_device``. * ``keep_burnin`` (``bool``): Whether or not "burnin" samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. * ``keep_gfr`` (``bool``): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to ``False``. Ignored if ``num_mcmc == 0``. * ``keep_every`` (``int``): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to ``1``. Setting ``keep_every = k`` for some ``k > 1`` will "thin" the MCMC samples by retaining every ``k``-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. + * ``num_chains`` (``int``): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. + + mu_forest_params : :obj:`dict`, optional + Dictionary of prognostic forest model parameters, each of which has a default value processed internally, so this argument is optional. + + * ``num_trees`` (``int``): Number of trees in the prognostic forest. Defaults to ``250``. Must be a positive integer. + * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the prognostic forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. + * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the prognostic forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. + * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the prognostic forest. Defaults to ``5``. + * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the prognostic forest. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. + * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the prognostic forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. + * ``sample_sigma2_leaf`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(sigma2_leaf_shape, sigma2_leaf_scale)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``False``. + * ``sigma2_leaf_init`` (``float``): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. + * ``sigma2_leaf_shape`` (``float``): Shape parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Defaults to ``3``. + * ``sigma2_leaf_scale`` (``float``): Scale parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees`` if not set here. + * ``keep_vars`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be included in the prognostic (``mu(X)``) forest. Defaults to ``None``. + * ``drop_vars`` (``list`` or ``np.array``): Vector of variable names or column indices denoting variables that should be excluded from the prognostic (``mu(X)``) forest. Defaults to ``None``. If both ``drop_vars_mu`` and ``keep_vars_mu`` are set, ``drop_vars_mu`` will be ignored. + + tau_forest_params : :obj:`dict`, optional + Dictionary of treatment effect forest model parameters, each of which has a default value processed internally, so this argument is optional. + + * ``num_trees`` (``int``): Number of trees in the treatment effect forest. Defaults to ``50``. Must be a positive integer. + * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the treatment effect forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.25``. + * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the treatment effect forest. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``3``. + * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the treatment effect forest. Defaults to ``5``. + * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the treatment effect forest. Defaults to ``5``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. + * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the treatment effect forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. + * ``sample_sigma2_leaf`` (``bool``): Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(sigma2_leaf_shape, sigma2_leaf_scale)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``False``. + * ``sigma2_leaf_init`` (``float``): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. + * ``sigma2_leaf_shape`` (``float``): Shape parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Defaults to ``3``. + * ``sigma2_leaf_scale`` (``float``): Scale parameter in the ``IG(sigma2_leaf_shape, sigma2_leaf_scale)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees`` if not set here. + + variance_forest_params : :obj:`dict`, optional + Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. + + * ``num_trees`` (``int``): Number of trees in the conditional variance model. Defaults to ``0``. Variance is only modeled using a tree / forest if ``num_trees > 0``. + * ``alpha`` (``float``): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``0.95``. + * ``beta`` (``float``): Exponent that decreases split probabilities for nodes of depth > 0 in the conditional variance model. Tree split prior combines ``alpha`` and ``beta`` via ``alpha*(1+node_depth)^-beta``. Defaults to ``2``. + * ``min_samples_leaf`` (``int``): Minimum allowable size of a leaf, in terms of training samples, in the conditional variance model. Defaults to ``5``. + * ``max_depth`` (``int``): Maximum depth of any tree in the ensemble in the conditional variance model. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees. + * ``variable_weights`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided. + * ``var_forest_leaf_init`` (``float``): Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as ``np.log(0.6*np.var(y_train))/num_trees_variance``, where `y_train` is the possibly standardized outcome, if not set. + * ``var_forest_prior_shape`` (``float``): Shape parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2 + 0.5`` if not set here. + * ``var_forest_prior_scale`` (``float``): Scale parameter in the [optional] ``IG(var_forest_prior_shape, var_forest_prior_scale)`` conditional error variance forest (which is only sampled if ``num_trees > 0``). Calibrated internally as ``num_trees / 1.5^2`` if not set here. Returns ------- self : BCFModel Sampled BCF Model. """ - # Unpack parameters - bcf_params = _preprocess_bcf_params(params) - cutpoint_grid_size = bcf_params['cutpoint_grid_size'] - sigma_leaf_mu = bcf_params['sigma_leaf_mu'] - sigma_leaf_tau = bcf_params['sigma_leaf_tau'] - alpha_mu = bcf_params['alpha_mu'] - alpha_tau = bcf_params['alpha_tau'] - alpha_variance = bcf_params['alpha_variance'] - beta_mu = bcf_params['beta_mu'] - beta_tau = bcf_params['beta_tau'] - beta_variance = bcf_params['beta_variance'] - min_samples_leaf_mu = bcf_params['min_samples_leaf_mu'] - min_samples_leaf_tau = bcf_params['min_samples_leaf_tau'] - min_samples_leaf_variance = bcf_params['min_samples_leaf_variance'] - max_depth_mu = bcf_params['max_depth_mu'] - max_depth_tau = bcf_params['max_depth_tau'] - max_depth_variance = bcf_params['max_depth_variance'] - a_global = bcf_params['a_global'] - b_global = bcf_params['b_global'] - a_forest = bcf_params['a_forest'] - b_forest = bcf_params['b_forest'] - a_leaf_mu = bcf_params['a_leaf_mu'] - a_leaf_tau = bcf_params['a_leaf_tau'] - b_leaf_mu = bcf_params['b_leaf_mu'] - b_leaf_tau = bcf_params['b_leaf_tau'] - sigma2_init = bcf_params['sigma2_init'] - variance_forest_leaf_init = bcf_params['variance_forest_leaf_init'] - pct_var_sigma2_init = bcf_params['pct_var_sigma2_init'] - pct_var_variance_forest_init = bcf_params['pct_var_variance_forest_init'] - variable_weights_mu = bcf_params['variable_weights_mu'] - variable_weights_tau = bcf_params['variable_weights_tau'] - variable_weights_variance = bcf_params['variable_weights_variance'] - keep_vars_mu = bcf_params['keep_vars_mu'] - drop_vars_mu = bcf_params['drop_vars_mu'] - keep_vars_tau = bcf_params['keep_vars_tau'] - drop_vars_tau = bcf_params['drop_vars_tau'] - keep_vars_variance = bcf_params['keep_vars_variance'] - drop_vars_variance = bcf_params['drop_vars_variance'] - num_trees_mu = bcf_params['num_trees_mu'] - num_trees_tau = bcf_params['num_trees_tau'] - num_trees_variance = bcf_params['num_trees_variance'] - sample_sigma_global = bcf_params['sample_sigma_global'] - sample_sigma_leaf_mu = bcf_params['sample_sigma_leaf_mu'] - sample_sigma_leaf_tau = bcf_params['sample_sigma_leaf_tau'] - propensity_covariate = bcf_params['propensity_covariate'] - adaptive_coding = bcf_params['adaptive_coding'] - b_0 = bcf_params['b_0'] - b_1 = bcf_params['b_1'] - random_seed = bcf_params['random_seed'] - keep_burnin = bcf_params['keep_burnin'] - keep_gfr = bcf_params['keep_gfr'] - self.standardize = bcf_params['standardize'] - keep_every = bcf_params['keep_every'] + # Update general BART parameters + general_params_default = { + 'cutpoint_grid_size' : 100, + 'standardize' : True, + 'sample_sigma2_global' : True, + 'sigma2_global_init' : None, + 'sigma2_global_shape' : 0, + 'sigma2_global_scale' : 0, + 'variable_weights' : None, + 'propensity_covariate' : "mu", + 'adaptive_coding' : True, + 'control_coding_init' : -0.5, + 'treated_coding_init' : 0.5, + 'random_seed' : -1, + 'keep_burnin' : False, + 'keep_gfr' : False, + 'keep_every' : 1, + 'num_chains' : 1 + } + general_params_updated = _preprocess_params( + general_params_default, general_params + ) + + # Update mu forest BART parameters + mu_forest_params_default = { + 'num_trees' : 250, + 'alpha' : 0.95, + 'beta' : 2.0, + 'min_samples_leaf' : 5, + 'max_depth' : 10, + 'sample_sigma2_leaf' : True, + 'sigma2_leaf_init' : None, + 'sigma2_leaf_shape' : 3, + 'sigma2_leaf_scale' : None, + 'keep_vars' : None, + 'drop_vars' : None + } + mu_forest_params_updated = _preprocess_params( + mu_forest_params_default, mu_forest_params + ) + + # Update tau forest BART parameters + tau_forest_params_default = { + 'num_trees' : 50, + 'alpha' : 0.25, + 'beta' : 3.0, + 'min_samples_leaf' : 5, + 'max_depth' : 5, + 'sample_sigma2_leaf' : False, + 'sigma2_leaf_init' : None, + 'sigma2_leaf_shape' : 3, + 'sigma2_leaf_scale' : None, + 'keep_vars' : None, + 'drop_vars' : None + } + tau_forest_params_updated = _preprocess_params( + tau_forest_params_default, tau_forest_params + ) + # Update variance forest BART parameters + variance_forest_params_default = { + 'num_trees' : 0, + 'alpha' : 0.95, + 'beta' : 2.0, + 'min_samples_leaf' : 5, + 'max_depth' : 10, + 'var_forest_leaf_init' : None, + 'var_forest_prior_shape' : None, + 'var_forest_prior_scale' : None, + 'keep_vars' : None, + 'drop_vars' : None + } + variance_forest_params_updated = _preprocess_params( + variance_forest_params_default, variance_forest_params + ) + + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size = general_params_updated['cutpoint_grid_size'] + self.standardize = general_params_updated['standardize'] + sample_sigma_global = general_params_updated['sample_sigma2_global'] + sigma2_init = general_params_updated['sigma2_global_init'] + a_global = general_params_updated['sigma2_global_shape'] + b_global = general_params_updated['sigma2_global_scale'] + variable_weights = general_params_updated['variable_weights'] + propensity_covariate = general_params_updated['propensity_covariate'] + adaptive_coding = general_params_updated['adaptive_coding'] + b_0 = general_params_updated['control_coding_init'] + b_1 = general_params_updated['treated_coding_init'] + random_seed = general_params_updated['random_seed'] + keep_burnin = general_params_updated['keep_burnin'] + keep_gfr = general_params_updated['keep_gfr'] + keep_every = general_params_updated['keep_every'] + + # 2. Mu forest parameters + num_trees_mu = mu_forest_params_updated['num_trees'] + alpha_mu = mu_forest_params_updated['alpha'] + beta_mu = mu_forest_params_updated['beta'] + min_samples_leaf_mu = mu_forest_params_updated['min_samples_leaf'] + max_depth_mu = mu_forest_params_updated['max_depth'] + sample_sigma_leaf_mu = mu_forest_params_updated['sample_sigma2_leaf'] + sigma_leaf_mu = mu_forest_params_updated['sigma2_leaf_init'] + a_leaf_mu = mu_forest_params_updated['sigma2_leaf_shape'] + b_leaf_mu = mu_forest_params_updated['sigma2_leaf_scale'] + keep_vars_mu = mu_forest_params_updated['keep_vars'] + drop_vars_mu = mu_forest_params_updated['drop_vars'] + + # 3. Tau forest parameters + num_trees_tau = tau_forest_params_updated['num_trees'] + alpha_tau = tau_forest_params_updated['alpha'] + beta_tau = tau_forest_params_updated['beta'] + min_samples_leaf_tau = tau_forest_params_updated['min_samples_leaf'] + max_depth_tau = tau_forest_params_updated['max_depth'] + sample_sigma_leaf_tau = tau_forest_params_updated['sample_sigma2_leaf'] + sigma_leaf_tau = tau_forest_params_updated['sigma2_leaf_init'] + a_leaf_tau = tau_forest_params_updated['sigma2_leaf_shape'] + b_leaf_tau = tau_forest_params_updated['sigma2_leaf_scale'] + keep_vars_tau = tau_forest_params_updated['keep_vars'] + drop_vars_tau = tau_forest_params_updated['drop_vars'] + + # 4. Variance forest parameters + num_trees_variance = variance_forest_params_updated['num_trees'] + alpha_variance = variance_forest_params_updated['alpha'] + beta_variance = variance_forest_params_updated['beta'] + min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf'] + max_depth_variance = variance_forest_params_updated['max_depth'] + variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init'] + a_forest = variance_forest_params_updated['var_forest_prior_shape'] + b_forest = variance_forest_params_updated['var_forest_prior_scale'] + keep_vars_variance = variance_forest_params_updated['keep_vars'] + drop_vars_variance = variance_forest_params_updated['drop_vars'] + # Variable weight preprocessing (and initialization if necessary) - if variable_weights_mu is None: + if variable_weights is None: if X_train.ndim > 1: - variable_weights_mu = np.repeat(1/X_train.shape[1], X_train.shape[1]) + variable_weights = np.repeat(1/X_train.shape[1], X_train.shape[1]) else: - variable_weights_mu = np.repeat(1., 1) - if np.any(variable_weights_mu < 0): - raise ValueError("variable_weights_mu cannot have any negative weights") - if variable_weights_tau is None: - if X_train.ndim > 1: - variable_weights_tau = np.repeat(1/X_train.shape[1], X_train.shape[1]) - else: - variable_weights_tau = np.repeat(1., 1) - if np.any(variable_weights_tau < 0): - raise ValueError("variable_weights_tau cannot have any negative weights") - if variable_weights_variance is None: - if X_train.ndim > 1: - variable_weights_variance = np.repeat(1/X_train.shape[1], X_train.shape[1]) - else: - variable_weights_variance = np.repeat(1., 1) - if np.any(variable_weights_variance < 0): - raise ValueError("variable_weights_variance cannot have any negative weights") + variable_weights = np.repeat(1., 1) + if np.any(variable_weights < 0): + raise ValueError("variable_weights cannot have any negative weights") + variable_weights_mu = variable_weights + variable_weights_tau = variable_weights + variable_weights_variance = variable_weights # Determine whether conditional variance model will be fit self.include_variance_forest = True if num_trees_variance > 0 else False @@ -611,9 +680,9 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART) if not sigma2_init: - sigma2_init = pct_var_sigma2_init*np.var(resid_train) + sigma2_init = 1.0*np.var(resid_train) if not variance_forest_leaf_init: - variance_forest_leaf_init = pct_var_variance_forest_init*np.var(resid_train) + variance_forest_leaf_init = 0.6*np.var(resid_train) b_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if b_leaf_mu is None else b_leaf_mu b_leaf_tau = np.squeeze(np.var(resid_train)) / (2*num_trees_tau) if b_leaf_tau is None else b_leaf_tau sigma_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if sigma_leaf_mu is None else sigma_leaf_mu @@ -639,7 +708,6 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr if not b_forest: b_forest = 1. - # Update variable weights variable_counts = [original_var_indices.count(i) for i in original_var_indices] variable_weights_mu_adj = [1/i for i in variable_counts] @@ -649,7 +717,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr variable_weights_tau = variable_weights_tau[original_var_indices]*variable_weights_tau_adj variable_weights_variance = variable_weights_variance[original_var_indices]*variable_weights_variance_adj - # Create mu, tau, and variance specific variable weights with weights zeroed out for excluded variables + # Zero out weights for excluded variables variable_weights_mu[[variable_subset_mu.count(i) == 0 for i in original_var_indices]] = 0 variable_weights_tau[[variable_subset_tau.count(i) == 0 for i in original_var_indices]] = 0 variable_weights_variance[[variable_subset_variance.count(i) == 0 for i in original_var_indices]] = 0 diff --git a/stochtree/preprocessing.py b/stochtree/preprocessing.py index c913cab7..7a796198 100644 --- a/stochtree/preprocessing.py +++ b/stochtree/preprocessing.py @@ -10,6 +10,15 @@ import pandas as pd import warnings +def _preprocess_params(default_params: Dict[str, Any], user_params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if user_params: + for key, value in user_params.items(): + if key in default_params: + default_params[key] = value + + return default_params + + def _preprocess_bart_params(params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: processed_params = { 'cutpoint_grid_size' : 100, diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index d7a5f36f..b0372236 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -24,35 +24,35 @@ test_that("MCMC BART", { y_train <- y[train_inds] # 1 chain, no thinning - param_list <- list(num_chains = 1, keep_every = 1) + general_param_list <- list(num_chains = 1, keep_every = 1) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 0, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # 3 chains, no thinning - param_list <- list(num_chains = 3, keep_every = 1) + general_param_list <- list(num_chains = 3, keep_every = 1) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 0, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # 1 chain, thinning - param_list <- list(num_chains = 1, keep_every = 5) + general_param_list <- list(num_chains = 1, keep_every = 5) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 0, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # 3 chains, thinning - param_list <- list(num_chains = 3, keep_every = 5) + general_param_list <- list(num_chains = 3, keep_every = 5) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 0, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) }) @@ -82,50 +82,50 @@ test_that("GFR BART", { y_train <- y[train_inds] # 1 chain, no thinning - param_list <- list(num_chains = 1, keep_every = 1) + general_param_list <- list(num_chains = 1, keep_every = 1) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # 3 chains, no thinning - param_list <- list(num_chains = 3, keep_every = 1) + general_param_list <- list(num_chains = 3, keep_every = 1) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # 1 chain, thinning - param_list <- list(num_chains = 1, keep_every = 5) + general_param_list <- list(num_chains = 1, keep_every = 5) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # 3 chains, thinning - param_list <- list(num_chains = 3, keep_every = 5) + general_param_list <- list(num_chains = 3, keep_every = 5) expect_no_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # Check for error when more chains than GFR forests - param_list <- list(num_chains = 11, keep_every = 1) + general_param_list <- list(num_chains = 11, keep_every = 1) expect_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) # Check for error when more chains than GFR forests - param_list <- list(num_chains = 11, keep_every = 5) + general_param_list <- list(num_chains = 11, keep_every = 5) expect_error( bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 10, num_mcmc = 10, - params = param_list) + general_params = general_param_list) ) }) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 26d1fb91..e49931ab 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -224,9 +224,11 @@ def conditional_stddev(X): # Run BCF with test set and propensity score bart_model = BARTModel() - bart_params = {'num_trees_variance': 50, 'sample_sigma_global': True} - bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, params=bart_params, - num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc) + general_params = {'sample_sigma2_global': True} + variance_forest_params = {'num_trees': 50} + bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, general_params=general_params, + variance_forest_params=variance_forest_params, num_gfr=num_gfr, + num_burnin=num_burnin, num_mcmc=num_mcmc) # Assertions assert (bart_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -293,10 +295,12 @@ def conditional_stddev(X): # Run BCF with test set and propensity score bart_model = BARTModel() - bart_params = {'num_trees_variance': 50, 'sample_sigma_global': True} + general_params = {'sample_sigma2_global': True} + variance_forest_params = {'num_trees': 50} bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bart_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, general_params=general_params, + variance_forest_params=variance_forest_params) # Assertions assert (bart_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -363,10 +367,12 @@ def conditional_stddev(X): # Run BCF with test set and propensity score bart_model = BARTModel() - bart_params = {'num_trees_variance': 50, 'sample_sigma_global': True} + general_params = {'sample_sigma2_global': True} + variance_forest_params = {'num_trees': 50} bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bart_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, general_params=general_params, + variance_forest_params=variance_forest_params) # Assertions assert (bart_model.y_hat_train.shape == (n_train, num_mcmc)) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index 01b733da..6b558753 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -49,10 +49,10 @@ def test_binary_bcf(self): # Run BCF with test set and propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, X_test=X_test, Z_test=Z_test, pi_test=pi_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -74,9 +74,10 @@ def test_binary_bcf(self): # Run BCF without test set and with propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, - num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -95,10 +96,11 @@ def test_binary_bcf(self): # Run BCF with test set and without propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, X_test=X_test, Z_test=Z_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -122,9 +124,10 @@ def test_binary_bcf(self): # Run BCF without test set and without propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, - num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -186,10 +189,11 @@ def test_continuous_univariate_bcf(self): # Run BCF with test set and propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, X_test=X_test, Z_test=Z_test, pi_test=pi_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -211,9 +215,10 @@ def test_continuous_univariate_bcf(self): # Run BCF without test set and with propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, - num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -232,10 +237,11 @@ def test_continuous_univariate_bcf(self): # Run BCF with test set and without propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, X_test=X_test, Z_test=Z_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -259,9 +265,10 @@ def test_continuous_univariate_bcf(self): # Run BCF without test set and without propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, - num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -325,10 +332,11 @@ def test_multivariate_bcf(self): # Run BCF with test set and propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, X_test=X_test, Z_test=Z_test, pi_test=pi_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -350,9 +358,10 @@ def test_multivariate_bcf(self): # Run BCF without test set and with propensity score bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, - num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Assertions assert (bcf_model.y_hat_train.shape == (n_train, num_mcmc)) @@ -372,14 +381,16 @@ def test_multivariate_bcf(self): # Run BCF with test set and without propensity score with pytest.raises(ValueError): bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, X_test=X_test, Z_test=Z_test, num_gfr=num_gfr, - num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) # Run BCF without test set and without propensity score with pytest.raises(ValueError): bcf_model = BCFModel() - bcf_params = {'num_trees_variance': 0} + variance_forest_params = {'num_trees': 0} bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, - num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, params=bcf_params) + num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params) diff --git a/vignettes/BayesianSupervisedLearning.Rmd b/vignettes/BayesianSupervisedLearning.Rmd index e1af506d..9d98d2a2 100644 --- a/vignettes/BayesianSupervisedLearning.Rmd +++ b/vignettes/BayesianSupervisedLearning.Rmd @@ -73,11 +73,12 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = 100) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = 100) bart_model_warmstart <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params ) ``` @@ -100,11 +101,12 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = 100) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = 100) bart_model_root <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params ) ``` @@ -168,11 +170,12 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = 100) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = 100) bart_model_warmstart <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params ) ``` @@ -195,11 +198,12 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = 100) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = 100) bart_model_root <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params ) ``` @@ -272,12 +276,13 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = 100) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = 100) bart_model_warmstart <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, W_test = W_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params ) ``` @@ -300,12 +305,13 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = 100) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = 100) bart_model_root <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, W_test = W_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params ) ``` diff --git a/vignettes/CausalInference.Rmd b/vignettes/CausalInference.Rmd index fb461708..9086b759 100644 --- a/vignettes/CausalInference.Rmd +++ b/vignettes/CausalInference.Rmd @@ -112,13 +112,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -160,13 +162,15 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -277,13 +281,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -325,13 +331,15 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -442,13 +450,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -490,13 +500,15 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -605,13 +617,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -653,13 +667,15 @@ num_gfr <- 0 num_burnin <- 100 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -761,14 +777,16 @@ num_gfr <- 100 num_burnin <- 0 num_mcmc <- 500 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -876,13 +894,15 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_mcmc <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -936,13 +956,15 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_vars_tau = c("x1","x2"), keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F, keep_vars = c("x1","x2")) bcf_model_mcmc <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -996,13 +1018,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -1056,13 +1080,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_vars_tau = c("x1","x2"), keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F, keep_vars = c("x1","x2")) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -1182,13 +1208,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_warmstart <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` @@ -1230,13 +1258,15 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 1000 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, - keep_every = 5) +general_params <- list(keep_every = 5) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model_root <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + general_params = general_params, mu_forest_params = mu_forest_params, + tau_forest_params = tau_forest_params ) ``` diff --git a/vignettes/EnsembleKernel.Rmd b/vignettes/EnsembleKernel.Rmd index e808163d..2863efe3 100644 --- a/vignettes/EnsembleKernel.Rmd +++ b/vignettes/EnsembleKernel.Rmd @@ -97,8 +97,8 @@ sigma_leaf <- 1/num_trees X_train <- as.data.frame(X_train) X_test <- as.data.frame(X_test) colnames(X_train) <- colnames(X_test) <- "x1" -bart_params <- list(num_trees_mean=num_trees, sigma_leaf_init=sigma_leaf) -bart_model <- bart(X_train=X_train, y_train=y_train, X_test=X_test, params = bart_params) +mean_forest_params <- list(num_trees=num_trees, sigma2_leaf_init=sigma_leaf) +bart_model <- bart(X_train=X_train, y_train=y_train, X_test=X_test, mean_forest_params = mean_forest_params) # Extract kernels needed for kriging leaf_mat_train <- computeForestLeafIndices(bart_model, X_train, forest_type = "mean", @@ -174,8 +174,8 @@ num_trees <- 200 sigma_leaf <- 1/num_trees X_train <- as.data.frame(X_train) X_test <- as.data.frame(X_test) -bart_params <- list(num_trees_mean=num_trees) -bart_model <- bart(X_train=X_train, y_train=y_train, X_test=X_test, params = bart_params) +mean_forest_params <- list(num_trees=num_trees, sigma2_leaf_init=sigma_leaf) +bart_model <- bart(X_train=X_train, y_train=y_train, X_test=X_test, mean_forest_params = mean_forest_params) # Extract kernels needed for kriging leaf_mat_train <- computeForestLeafIndices(bart_model, X_train, forest_type = "mean", diff --git a/vignettes/Heteroskedasticity.Rmd b/vignettes/Heteroskedasticity.Rmd index d29008b0..0ffe8f72 100644 --- a/vignettes/Heteroskedasticity.Rmd +++ b/vignettes/Heteroskedasticity.Rmd @@ -93,12 +93,14 @@ num_mcmc <- 100 num_trees <- 20 a_0 <- 1.5 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 0, num_trees_variance = num_trees, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0) +variance_forest_params <- list(num_trees = num_trees) bart_model_warmstart <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` @@ -120,12 +122,14 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 0, num_trees_variance = 50, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0) +variance_forest_params <- list(num_trees = num_trees) bart_model_mcmc <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` @@ -198,18 +202,21 @@ initialization (@he2023stochastic). This is the default in `stochtree`. ```{r} +num_trees <- 20 num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 0, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, min_samples_leaf_variance = 1, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = num_trees, alpha = 0.95, + beta = 1.25, min_samples_leaf = 1) bart_model_warmstart <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` @@ -231,14 +238,16 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 0, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, min_samples_leaf_variance = 5, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = num_trees, alpha = 0.95, + beta = 1.25, min_samples_leaf = 1) bart_model_mcmc <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` @@ -326,14 +335,16 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 50, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, min_samples_leaf_variance = 5, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = 50, alpha = 0.95, + beta = 1.25, min_samples_leaf = 5) bart_model_warmstart <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` @@ -358,14 +369,16 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 50, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, min_samples_leaf_variance = 5, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = 50, alpha = 0.95, + beta = 1.25, min_samples_leaf = 5) bart_model_mcmc <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` @@ -457,14 +470,16 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 50, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, min_samples_leaf_variance = 5, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = 50, alpha = 0.95, + beta = 1.25, min_samples_leaf = 5) bart_model_warmstart <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` @@ -489,14 +504,16 @@ num_gfr <- 0 num_burnin <- 1000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 50, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, min_samples_leaf_variance = 5, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = 50, alpha = 0.95, + beta = 1.25, min_samples_leaf = 5) bart_model_mcmc <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` diff --git a/vignettes/ModelSerialization.Rmd b/vignettes/ModelSerialization.Rmd index e22b199b..c36ced79 100644 --- a/vignettes/ModelSerialization.Rmd +++ b/vignettes/ModelSerialization.Rmd @@ -100,14 +100,15 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bcf_params <- list(sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F) +mu_forest_params <- list(sample_sigma2_leaf = F) +tau_forest_params <- list(sample_sigma2_leaf = F) bcf_model <- bcf( X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bcf_params + mu_forest_params = mu_forest_params, tau_forest_params = tau_forest_params ) ``` @@ -190,15 +191,16 @@ num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc -bart_params <- list(num_trees_mean = 100, num_trees_variance = 50, - alpha_mean = 0.95, beta_mean = 2, min_samples_leaf_mean = 5, - alpha_variance = 0.95, beta_variance = 1.25, - min_samples_leaf_variance = 1, - sample_sigma_global = F, sample_sigma_leaf = F) +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 100, + alpha = 0.95, beta = 2, min_samples_leaf = 5) +variance_forest_params <- list(num_trees = 50, alpha = 0.95, + beta = 1.25, min_samples_leaf = 1) bart_model <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params, + variance_forest_params = variance_forest_params ) ``` diff --git a/vignettes/MultiChain.Rmd b/vignettes/MultiChain.Rmd index c9d69b50..425f771f 100644 --- a/vignettes/MultiChain.Rmd +++ b/vignettes/MultiChain.Rmd @@ -97,12 +97,13 @@ Run the sampler, storing the resulting BART objects in a list ```{r} bart_models <- list() -bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = num_trees) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = num_trees) for (i in 1:num_chains) { bart_models[[i]] <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params, mean_forest_params = mean_forest_params ) } ``` @@ -177,11 +178,12 @@ storing the resulting BART JSON strings in a list. ```{r} bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { random_seed <- i - bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees, random_seed = random_seed) + general_params <- list(sample_sigma2_global = T, random_seed = random_seed) + mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = num_trees) bart_model <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, - num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bart_params + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params ) bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) y_hat_test <- bart_model$y_hat_test @@ -267,11 +269,12 @@ First, we sample this model using the grow-from-root algorithm in the main R ses for several iterations (we will use these forests to see independent parallel chains in a moment). ```{r} -xbart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, - num_trees_mean = num_trees) +general_params <- list(sample_sigma2_global = T) +mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = num_trees) xbart_model <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, - num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, params = xbart_params + num_gfr = num_gfr, num_burnin = 0, num_mcmc = 0, + general_params = general_params, mean_forest_params = mean_forest_params ) xbart_model_string <- stochtree::saveBARTModelToJsonString(xbart_model) ``` @@ -303,9 +306,12 @@ bart_model_outputs <- foreach (i = 1:num_chains) %dopar% { random_seed <- i bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = T, num_trees_mean = num_trees, random_seed = random_seed) + general_params <- list(sample_sigma2_global = T, random_seed = random_seed) + mean_forest_params <- list(sample_sigma2_leaf = T, num_trees = num_trees) bart_model <- stochtree::bart( X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, - num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, params = bart_params, + num_gfr = 0, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params, previous_model_json = xbart_model_string, warmstart_sample_num = num_gfr - i + 1, ) bart_model_string <- stochtree::saveBARTModelToJsonString(bart_model) diff --git a/vignettes/PriorCalibration.Rmd b/vignettes/PriorCalibration.Rmd index 5e30202b..87a09e46 100644 --- a/vignettes/PriorCalibration.Rmd +++ b/vignettes/PriorCalibration.Rmd @@ -87,10 +87,10 @@ lambda <- calibrate_inverse_gamma_error_variance(y_train, X_train, nu = nu) Now we run a BART model with this variance parameterization ```{r} -bart_params <- list(a_global = nu/2, b_global = (nu*lambda)/2) +general_params <- list(sigma2_global_shape = nu/2, sigma2_global_scale = (nu*lambda)/2) bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 0, num_burnin = 1000, num_mcmc = 100, - params = bart_params) + general_params = general_params) ``` Inspect the out-of-sample predictions of the model diff --git a/vignettes/TreeInspection.Rmd b/vignettes/TreeInspection.Rmd index fd7b41d6..66140d67 100644 --- a/vignettes/TreeInspection.Rmd +++ b/vignettes/TreeInspection.Rmd @@ -66,11 +66,11 @@ Run BART. num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 -bart_params <- list(keep_gfr = T) +general_params <- list(keep_gfr = T) bart_model <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, - params = bart_params + general_params = general_params ) ```