Skip to contents

Supervised learning

High-level functionality for training supervised Bayesian tree ensembles (BART, XBART)

bart()
Run the BART algorithm for supervised learning.
predict(<bartmodel>)
Predict from a sampled BART model on new data

Causal inference

High-level functionality for estimating causal effects using Bayesian tree ensembles (BCF, XBCF)

bcf()
Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation.
predict(<bcf>)
Predict from a sampled BCF model on new data
saveBCFModelToJsonFile()
Convert the persistent aspects of a BCF model to (in-memory) JSON and save to a file
createBCFModelFromJsonFile()
Convert a JSON file containing sample information on a trained BCF model to a BCF model object which can be used for prediction, etc...
convertBCFModelToJson()
Convert the persistent aspects of a BCF model to (in-memory) JSON
createBCFModelFromJson()
Convert an (in-memory) JSON representation of a BCF model to a BCF model object which can be used for prediction, etc...

Low-level functionality

Serialization

Classes and functions for converting sampling artifacts to JSON and saving to disk

CppJson
Class that stores draws from an random ensemble of decision trees
createCppJson()
Create a new (empty) C++ Json object
loadForestContainerJson()
Load a container of forest samples from json
loadRandomEffectSamplesJson()
Load a container of random effect samples from json
loadVectorJson()
Load a vector from json
loadScalarJson()
Load a scalar from json
createCppJsonFile()
Create a C++ Json object from a Json file

Data

Classes and functions for preparing data for sampling algorithms

ForestDataset
Dataset used to sample a forest
createForestDataset()
Create a forest dataset object
Outcome
Outcome / partial residual used to sample an additive model.
createOutcome()
Create an outcome object
RandomEffectsDataset
Dataset used to sample a random effects model
createRandomEffectsDataset()
Create a random effects dataset object
preprocessTrainData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
preprocessPredictionData()
Preprocess covariates. DataFrames will be preprocessed based on their column types. Matrices will be passed through assuming all columns are numeric.
preprocessTrainDataFrame()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
preprocessPredictionDataFrame()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be.
preprocessTrainMatrix()
Preprocess a matrix of covariate values, assuming all columns are numeric. Returns a list including a matrix of preprocessed covariate values and associated tracking.
preprocessPredictionMatrix()
Preprocess a matrix of covariate values, assuming all columns are numeric.
createForestCovariates()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
createForestCovariatesFromMetadata()
Preprocess a dataframe of covariate values, converting categorical variables to integers and one-hot encoding if need be. Returns a list including a matrix of preprocessed covariate values and associated tracking.
oneHotEncode()
Convert a vector of unordered categorical data (either numeric or character labels) to a "one-hot" encoded matrix in which a 1 in a column indicates the presence of the relevant category.
oneHotInitializeAndEncode()
Convert a vector of unordered categorical data (either numeric or character labels) to a "one-hot" encoded matrix in which a 1 in a column indicates the presence of the relevant category.
orderedCatPreprocess()
Run some simple preprocessing of ordered categorical variables, converting ordered levels to integers if necessary, and storing the unique levels of a variable.
orderedCatInitializeAndPreprocess()
Run some simple preprocessing of ordered categorical variables, converting ordered levels to integers if necessary, and storing the unique levels of a variable.

Forest

Classes and functions for constructing and persisting forests

ForestModel
Class that defines and samples a forest model
createForestModel()
Create a forest model object
ForestSamples
Class that stores draws from an random ensemble of decision trees
createForestContainer()
Create a container of forest samples
ForestKernel
Class that provides functionality for statistical kernel definition and computation based on shared leaf membership of observations in a tree ensemble.
createForestKernel()
Create a ForestKernel object
CppRNG
Class that wraps a C++ random number generator (for reproducibility)
createRNG()
Create an R class that wraps a C++ random number generator
calibrate_inverse_gamma_error_variance()
Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) 1

Random Effects

Classes and functions for constructing and persisting random effects terms

RandomEffectSamples
Class that wraps the "persistent" aspects of a C++ random effects model (draws of the parameters and a map from the original label indices to the 0-indexed label numbers used to place group samples in memory (i.e. the first label is stored in column 0 of the sample matrix, the second label is store in column 1 of the sample matrix, etc...))
createRandomEffectSamples()
Create a RandomEffectSamples object
RandomEffectsModel
The core "model" class for sampling random effects.
createRandomEffectsModel()
Create a RandomEffectsModel object
RandomEffectsTracker
Class that defines a "tracker" for random effects models, most notably storing the data indices available in each group for quicker posterior computation and sampling of random effects terms.
createRandomEffectsTracker()
Create a RandomEffectsTracker object
getRandomEffectSamples()
Generic function for extracting random effect samples from a model object (BCF, BART, etc...)
getRandomEffectSamples(<bartmodel>)
Extract raw sample values for each of the random effect parameter terms.
getRandomEffectSamples(<bcf>)
Extract raw sample values for each of the random effect parameter terms.
sample_sigma2_one_iteration()
Sample one iteration of the (inverse gamma) global variance model
sample_tau_one_iteration()
Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)
computeForestKernels()
Compute a kernel from a tree ensemble, defined by the fraction of trees of an ensemble in which two observations fall into the same leaf.
computeForestLeafIndices()
Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset. The vector has a "column-major" format that can be easily re-represented as as a CSC sparse matrix: elements are organized so that the first n elements correspond to leaf predictions for all n observations in a dataset for the first tree in an ensemble, the next n elements correspond to predictions for the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's leaf indices begin at 3, etc...). Users may pass a single dataset (which we refer to here as a "training set") or two datasets (which we refer to as "training and test sets"). This verbiage hints that one potential use-case for a matrix of leaf indices is to define a ensemble-based kernel for kriging.

Package info

High-level package details

stochtree stochtree-package
stochtree: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference