Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,9 @@ bart <- function(
# Set sigma2_init to 1, ignoring default provided
sigma2_init <- 1.0
# Skip variance_forest_init, since variance forests are not supported with probit link
b_leaf <- 1 / (num_trees_mean)
if (is.null(b_leaf)) {
b_leaf <- 1 / (num_trees_mean)
}
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma2_leaf_init)) {
Expand Down Expand Up @@ -1225,9 +1227,21 @@ bart <- function(
# Initialize the leaves of each tree in the mean forest
if (include_mean_forest) {
if (requires_basis) {
init_values_mean_forest <- rep(0., ncol(leaf_basis_train))
# Handle the case in which we must initialize root values in a leaf basis regression
# when init_val_mean != 0. To do this, we regress rep(init_val_mean, nrow(y_train))
# on leaf_basis_train and use (coefs / num_trees_mean) as initial values
if (abs(init_val_mean) > 0.00001) {
init_val_y <- rep(init_val_mean, nrow(y_train))
init_val_model <- lm(init_val_y ~ 0 + leaf_basis_train)
init_values_mean_forest <- coef(init_val_model)
if (any(is.na(init_values_mean_forest))) {
init_values_mean_forest[which(is.na(init_values_mean_forest))] <- 0.
}
} else {
init_values_mean_forest <- rep(init_val_mean, ncol(leaf_basis_train))
}
} else {
init_values_mean_forest <- 0.
init_values_mean_forest <- init_val_mean
}
active_forest_mean$prepare_for_sampler(
forest_dataset_train,
Expand All @@ -1236,13 +1250,6 @@ bart <- function(
leaf_model_mean_forest,
init_values_mean_forest
)
active_forest_mean$adjust_residual(
forest_dataset_train,
outcome_train,
forest_model_mean,
requires_basis,
FALSE
)
}

# Initialize the leaves of each tree in the variance forest
Expand Down
4 changes: 4 additions & 0 deletions notes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# StochTree Notes

This is a directory of lightweight "developer" notes / docs, which may in time make it to the official documentation.

Loading
Loading