Run the BART algorithm for supervised learning.
BART.Rd
Run the BART algorithm for supervised learning.
Usage
bart(
X_train,
y_train,
W_train = NULL,
group_ids_train = NULL,
rfx_basis_train = NULL,
X_test = NULL,
W_test = NULL,
group_ids_test = NULL,
rfx_basis_test = NULL,
cutpoint_grid_size = 100,
tau_init = NULL,
alpha = 0.95,
beta = 2,
min_samples_leaf = 5,
leaf_model = 0,
nu = 3,
lambda = NULL,
a_leaf = 3,
b_leaf = NULL,
q = 0.9,
sigma2_init = NULL,
num_trees = 200,
num_gfr = 5,
num_burnin = 0,
num_mcmc = 100,
sample_sigma = T,
sample_tau = T,
random_seed = -1,
keep_burnin = F,
keep_gfr = F,
verbose = F
)
Arguments
- X_train
Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata that the column is ordered categorical).
- y_train
Outcome to be modeled by the ensemble.
- W_train
(Optional) Bases used to define a regression model
y ~ W
in each leaf of each regression tree. By default, BART assumes constant leaf node parameters, implicitly regressing on a constant basis of ones (i.e.y ~ 1
).- group_ids_train
(Optional) Group labels used for an additive random effects model.
- rfx_basis_train
(Optional) Basis for "random-slope" regression in an additive random effects model. If
group_ids_train
is provided with a regression basis, an intercept-only random effects model will be estimated.- X_test
(Optional) Test set of covariates used to define "out of sample" evaluation data. May be provided either as a dataframe or a matrix, but the format of
X_test
must be consistent with that ofX_train
.- W_test
(Optional) Test set of bases used to define "out of sample" evaluation data. While a test set is optional, the structure of any provided test set must match that of the training set (i.e. if both X_train and W_train are provided, then a test set must consist of X_test and W_test with the same number of columns).
- group_ids_test
(Optional) Test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set.
- rfx_basis_test
(Optional) Test set basis for "random-slope" regression in additive random effects model.
- cutpoint_grid_size
Maximum size of the "grid" of potential cutpoints to consider. Default: 100.
- tau_init
Starting value of leaf node scale parameter. Calibrated internally as
1/num_trees
if not set here.- alpha
Prior probability of splitting for a tree of depth 0. Tree split prior combines
alpha
andbeta
viaalpha*(1+node_depth)^-beta
.- beta
Exponent that decreases split probabilities for nodes of depth > 0. Tree split prior combines
alpha
andbeta
viaalpha*(1+node_depth)^-beta
.- min_samples_leaf
Minimum allowable size of a leaf, in terms of training samples. Default: 5.
- nu
Shape parameter in the
IG(nu, nu*lambda)
global error variance model. Default: 3.- lambda
Component of the scale parameter in the
IG(nu, nu*lambda)
global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).- a_leaf
Shape parameter in the
IG(a_leaf, b_leaf)
leaf node parameter variance model. Default: 3.- b_leaf
Scale parameter in the
IG(a_leaf, b_leaf)
leaf node parameter variance model. Calibrated internally as 0.5/num_trees if not set here.- q
Quantile used to calibrated
lambda
as in Sparapani et al (2021). Default: 0.9.- sigma2_init
Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
- num_trees
Number of trees in the ensemble. Default: 200.
- num_gfr
Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
- num_burnin
Number of "burn-in" iterations of the MCMC sampler. Default: 0.
- num_mcmc
Number of "retained" iterations of the MCMC sampler. Default: 100.
- sample_sigma
Whether or not to update the
sigma^2
global error variance parameter based onIG(nu, nu*lambda)
. Default: T.- sample_tau
Whether or not to update the
tau
leaf scale variance parameter based onIG(a_leaf, b_leaf)
. Cannot (currently) be set to true ifncol(W_train)>1
. Default: T.- random_seed
Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to
std::random_device
.- keep_burnin
Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
- keep_gfr
Whether or not "grow-from-root" samples should be included in cached predictions. Default TRUE. Ignored if num_mcmc = 0.
- verbose
Whether or not to print progress during the sampling loops. Default: FALSE.
- variable_weights
Vector of length
ncol(X_train)
indicating a "weight" placed on each variable for sampling purposes. Default:rep(1/ncol(X_train),ncol(X_train))
.
Value
List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
Examples
n <- 100
p <- 5
X <- matrix(runif(n*p), ncol = p)
f_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
noise_sd <- 1
y <- f_XW + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = F))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test)
# plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
# abline(0,1,col="red",lty=3,lwd=3)