Skip to contents

Supervised learning

High-level functionality for training supervised Bayesian tree ensembles (BART, XBART)

bart()
Run the BART algorithm for supervised learning.
predict(<bartmodel>)
Predict from a sampled BART model on new data

Causal inference

High-level functionality for estimating causal effects using Bayesian tree ensembles (BCF, XBCF)

bcf()
Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation.
predict(<bcf>)
Predict from a sampled BCF model on new data

Low-level functionality

Serialization

Classes and functions for converting sampling artifacts to JSON and saving to disk

CppJson
Class that stores draws from an random ensemble of decision trees
createCppJson()
Create a new (empty) C++ Json object
createCppJsonFile()
Create a C++ Json object from a Json file
createCppJsonString()
Create a C++ Json object from a Json string
loadForestContainerJson()
Load a container of forest samples from json
loadForestContainerCombinedJson()
Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container
loadForestContainerCombinedJsonString()
Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container
loadRandomEffectSamplesJson()
Load a container of random effect samples from json
loadVectorJson()
Load a vector from json
loadScalarJson()
Load a scalar from json
convertBARTModelToJson()
Convert the persistent aspects of a BART model to (in-memory) JSON
convertBARTStateToJson()
Convert in-memory BART model objects (forests, random effects, vectors) to in-memory JSON. This function is primarily a convenience function for serialization / deserialization in a parallel BART sampler.
createBARTModelFromCombinedJson()
Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object which can be used for prediction, etc...
createBARTModelFromCombinedJsonString()
Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object which can be used for prediction, etc...
createBARTModelFromJson()
Convert an (in-memory) JSON representation of a BART model to a BART model object which can be used for prediction, etc...
createBARTModelFromJsonFile()
Convert a JSON file containing sample information on a trained BART model to a BART model object which can be used for prediction, etc...
createBARTModelFromJsonString()
Convert a JSON string containing sample information on a trained BART model to a BART model object which can be used for prediction, etc...
loadRandomEffectSamplesCombinedJson()
Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container
loadRandomEffectSamplesCombinedJsonString()
Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container
saveBARTModelToJsonFile()
Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file
saveBARTModelToJsonString()
Convert the persistent aspects of a BART model to (in-memory) JSON string
saveBCFModelToJsonFile()
Convert the persistent aspects of a BCF model to (in-memory) JSON and save to a file
saveBCFModelToJsonString()
Convert the persistent aspects of a BCF model to (in-memory) JSON string
createBCFModelFromJsonFile()
Convert a JSON file containing sample information on a trained BCF model to a BCF model object which can be used for prediction, etc...
createBCFModelFromJsonString()
Convert a JSON string containing sample information on a trained BCF model to a BCF model object which can be used for prediction, etc...
convertBCFModelToJson()
Convert the persistent aspects of a BCF model to (in-memory) JSON
createBCFModelFromJson()
Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc...
createBCFModelFromCombinedJsonString()
Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object which can be used for prediction, etc...

Data

Classes and functions for preparing data for sampling algorithms

ForestDataset
Dataset used to sample a forest
createForestDataset()
Create a forest dataset object
Outcome
Outcome / partial residual used to sample an additive model.
createOutcome()
Create an outcome object
RandomEffectsDataset
Dataset used to sample a random effects model
createRandomEffectsDataset()
Create a random effects dataset object
preprocessTrainData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
preprocessPredictionData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
preprocessTrainDataFrame()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
preprocessPredictionDataFrame()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be.
preprocessTrainMatrix()
Preprocess a matrix of covariate values, assuming all columns are numeric. Returns a list including a matrix of preprocessed covariate values and associated tracking.
preprocessPredictionMatrix()
Preprocess a matrix of covariate values, assuming all columns are numeric.
createForestCovariates()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
createForestCovariatesFromMetadata()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
oneHotEncode()
Convert a vector of unordered categorical data (either numeric or character labels) to a "one-hot" encoded matrix in which a 1 in a column indicates the presence of the relevant category.
oneHotInitializeAndEncode()
Convert a vector of unordered categorical data (either numeric or character labels) to a "one-hot" encoded matrix in which a 1 in a column indicates the presence of the relevant category.
orderedCatPreprocess()
Run some simple preprocessing of ordered categorical variables, converting ordered levels to integers if necessary, and storing the unique levels of a variable.
orderedCatInitializeAndPreprocess()
Run some simple preprocessing of ordered categorical variables, converting ordered levels to integers if necessary, and storing the unique levels of a variable.

Forest

Classes and functions for constructing and persisting forests

Forest
Class that stores a single ensemble of decision trees (often treated as the "active forest")
createForest()
Create a forest
ForestModel
Class that defines and samples a forest model
createForestModel()
Create a forest model object
ForestSamples
Class that stores draws from an random ensemble of decision trees
createForestContainer()
Create a container of forest samples
CppRNG
Class that wraps a C++ random number generator (for reproducibility)
createRNG()
Create an R class that wraps a C++ random number generator
calibrate_inverse_gamma_error_variance()
Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022)
preprocessBartParams()
Preprocess BART parameter list. Override defaults with any provided parameters.
preprocessBcfParams()
Preprocess BCF parameter list. Override defaults with any provided parameters.
computeMaxLeafIndex()
Compute and return the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.
computeForestLeafIndices()
Compute vector of forest leaf indices
computeForestLeafVariances()
Compute vector of forest leaf scale parameters
resetActiveForest()
Re-initialize an active forest from a specific forest in a ForestContainer
resetForestModel()
Re-initialize a forest model (tracking data structures) from a specific forest in a ForestContainer
rootResetActiveForest()
Reset an active forest to an ensemble of single-node (i.e. root) trees

Random Effects

Classes and functions for constructing and persisting random effects terms

RandomEffectSamples
Class that wraps the "persistent" aspects of a C++ random effects model (draws of the parameters and a map from the original label indices to the 0-indexed label numbers used to place group samples in memory (i.e. the first label is stored in column 0 of the sample matrix, the second label is store in column 1 of the sample matrix, etc...))
createRandomEffectSamples()
Create a RandomEffectSamples object
RandomEffectsModel
The core "model" class for sampling random effects.
createRandomEffectsModel()
Create a RandomEffectsModel object
RandomEffectsTracker
Class that defines a "tracker" for random effects models, most notably storing the data indices available in each group for quicker posterior computation and sampling of random effects terms.
createRandomEffectsTracker()
Create a RandomEffectsTracker object
getRandomEffectSamples()
Generic function for extracting random effect samples from a model object (BCF, BART, etc...)
getRandomEffectSamples(<bartmodel>)
Extract raw sample values for each of the random effect parameter terms.
getRandomEffectSamples(<bcf>)
Extract raw sample values for each of the random effect parameter terms.
sample_sigma2_one_iteration()
Sample one iteration of the (inverse gamma) global variance model
sample_tau_one_iteration()
Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)
resetRandomEffectsModel()
Reset a RandomEffectsModel object based on the parameters indexed by sample_num in a RandomEffectsSamples object
resetRandomEffectsTracker()
Reset a RandomEffectsTracker object based on the parameters indexed by sample_num in a RandomEffectsSamples object
rootResetRandomEffectsModel()
Reset a RandomEffectsModel object to its "default" state
rootResetRandomEffectsTracker()
Reset a RandomEffectsTracker object to its "default" state

Package info

High-level package details

stochtree stochtree-package
stochtree: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference