Skip to contents

Supervised learning

High-level functionality for training supervised Bayesian tree ensembles (BART, XBART)

bart()
Run the BART algorithm for supervised learning.
predict(<bartmodel>)
Predict from a sampled BART model on new data

Causal inference

High-level functionality for estimating causal effects using Bayesian tree ensembles (BCF, XBCF)

bcf()
Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation.
predict(<bcf>)
Predict from a sampled BCF model on new data
saveBCFModelToJsonFile()
Convert the persistent aspects of a BCF model to (in-memory) JSON and save to a file
createBCFModelFromJsonFile()
Convert a JSON file containing sample information on a trained BCF model to a BCF model object which can be used for prediction, etc...
createBCFModelFromJsonString()
Convert a JSON string containing sample information on a trained BCF model to a BCF model object which can be used for prediction, etc...
convertBCFModelToJson()
Convert the persistent aspects of a BCF model to (in-memory) JSON
createBCFModelFromJson()
Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc...

Low-level functionality

Serialization

Classes and functions for converting sampling artifacts to JSON and saving to disk

CppJson
Class that stores draws from an random ensemble of decision trees
createCppJson()
Create a new (empty) C++ Json object
createCppJsonFile()
Create a C++ Json object from a Json file
createCppJsonString()
Create a C++ Json object from a Json string
loadForestContainerJson()
Load a container of forest samples from json
loadForestContainerCombinedJson()
Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container
loadForestContainerCombinedJsonString()
Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container
loadRandomEffectSamplesJson()
Load a container of random effect samples from json
loadVectorJson()
Load a vector from json
loadScalarJson()
Load a scalar from json
convertBARTModelToJson()
Convert the persistent aspects of a BART model to (in-memory) JSON
createBARTModelFromCombinedJson()
Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object which can be used for prediction, etc...
createBARTModelFromCombinedJsonString()
Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object which can be used for prediction, etc...
createBARTModelFromJson()
Convert an (in-memory) JSON representation of a BART model to a BART model object which can be used for prediction, etc...
createBARTModelFromJsonFile()
Convert a JSON file containing sample information on a trained BART model to a BART model object which can be used for prediction, etc...
createBARTModelFromJsonString()
Convert a JSON string containing sample information on a trained BART model to a BART model object which can be used for prediction, etc...
loadRandomEffectSamplesCombinedJson()
Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container
loadRandomEffectSamplesCombinedJsonString()
Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container
saveBARTModelToJsonFile()
Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file
saveBARTModelToJsonString()
Convert the persistent aspects of a BART model to (in-memory) JSON string
saveBCFModelToJsonString()
Convert the persistent aspects of a BCF model to (in-memory) JSON string

Data

Classes and functions for preparing data for sampling algorithms

ForestDataset
Dataset used to sample a forest
createForestDataset()
Create a forest dataset object
Outcome
Outcome / partial residual used to sample an additive model.
createOutcome()
Create an outcome object
RandomEffectsDataset
Dataset used to sample a random effects model
createRandomEffectsDataset()
Create a random effects dataset object
preprocessTrainData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
preprocessPredictionData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
preprocessTrainDataFrame()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
preprocessPredictionDataFrame()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be.
preprocessTrainMatrix()
Preprocess a matrix of covariate values, assuming all columns are numeric. Returns a list including a matrix of preprocessed covariate values and associated tracking.
preprocessPredictionMatrix()
Preprocess a matrix of covariate values, assuming all columns are numeric.
createForestCovariates()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
createForestCovariatesFromMetadata()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
oneHotEncode()
Convert a vector of unordered categorical data (either numeric or character labels) to a "one-hot" encoded matrix in which a 1 in a column indicates the presence of the relevant category.
oneHotInitializeAndEncode()
Convert a vector of unordered categorical data (either numeric or character labels) to a "one-hot" encoded matrix in which a 1 in a column indicates the presence of the relevant category.
orderedCatPreprocess()
Run some simple preprocessing of ordered categorical variables, converting ordered levels to integers if necessary, and storing the unique levels of a variable.
orderedCatInitializeAndPreprocess()
Run some simple preprocessing of ordered categorical variables, converting ordered levels to integers if necessary, and storing the unique levels of a variable.

Forest

Classes and functions for constructing and persisting forests

ForestModel
Class that defines and samples a forest model
createForestModel()
Create a forest model object
ForestSamples
Class that stores draws from an random ensemble of decision trees
createForestContainer()
Create a container of forest samples
CppRNG
Class that wraps a C++ random number generator (for reproducibility)
createRNG()
Create an R class that wraps a C++ random number generator
calibrate_inverse_gamma_error_variance()
Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) 1
preprocessBartParams()
Preprocess BART parameter list. Override defaults with any provided parameters.
preprocessBcfParams()
Preprocess BCF parameter list. Override defaults with any provided parameters.
computeForestLeafIndices()
Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset. The vector has a "column-major" format that can be easily re-represented as as a CSC sparse matrix: elements are organized so that the first n elements correspond to leaf predictions for all n observations in a dataset for the first tree in an ensemble, the next n elements correspond to predictions for the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's leaf indices begin at 3, etc...).
computeMaxLeafIndex()
Compute and return the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.

Random Effects

Classes and functions for constructing and persisting random effects terms

RandomEffectSamples
Class that wraps the "persistent" aspects of a C++ random effects model (draws of the parameters and a map from the original label indices to the 0-indexed label numbers used to place group samples in memory (i.e. the first label is stored in column 0 of the sample matrix, the second label is store in column 1 of the sample matrix, etc...))
createRandomEffectSamples()
Create a RandomEffectSamples object
RandomEffectsModel
The core "model" class for sampling random effects.
createRandomEffectsModel()
Create a RandomEffectsModel object
RandomEffectsTracker
Class that defines a "tracker" for random effects models, most notably storing the data indices available in each group for quicker posterior computation and sampling of random effects terms.
createRandomEffectsTracker()
Create a RandomEffectsTracker object
getRandomEffectSamples()
Generic function for extracting random effect samples from a model object (BCF, BART, etc...)
getRandomEffectSamples(<bartmodel>)
Extract raw sample values for each of the random effect parameter terms.
getRandomEffectSamples(<bcf>)
Extract raw sample values for each of the random effect parameter terms.
sample_sigma2_one_iteration()
Sample one iteration of the (inverse gamma) global variance model
sample_tau_one_iteration()
Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)

Package info

High-level package details

stochtree stochtree-package
stochtree: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference