Skip to contents

Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset. The vector has a "column-major" format that can be easily re-represented as as a CSC sparse matrix: elements are organized so that the first n elements correspond to leaf predictions for all n observations in a dataset for the first tree in an ensemble, the next n elements correspond to predictions for the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's leaf indices begin at 3, etc...). Users may pass a single dataset (which we refer to here as a "training set") or two datasets (which we refer to as "training and test sets"). This verbiage hints that one potential use-case for a matrix of leaf indices is to define a ensemble-based kernel for kriging.

Usage

computeForestLeafIndices(bart_model, X_train, X_test = NULL, forest_num = NULL)

Arguments

bart_model

Object of type bartmodel corresponding to a BART model with at least one sample

X_train

Matrix of "training" data. In a traditional Gaussian process kriging context, this corresponds to the observations for which outcomes are observed.

X_test

(Optional) Matrix of "test" data. In a traditional Gaussian process kriging context, this corresponds to the observations for which outcomes are unobserved and must be estimated based on the kernels k(X_test,X_test), k(X_test,X_train), and k(X_train,X_train). If not provided, this function will only compute k(X_train, X_train).

forest_num

(Option) Index of the forest sample to use for kernel computation. If not provided, this function will use the last forest.

Value

List of vectors. If X_test = NULL, the list contains one vector of length n_train * num_trees, where n_train = nrow(X_train) and num_trees is the number of trees in bart_model. If X_test is not NULL, the list contains another vector of length n_test * num_trees.