Class that stores draws from an random ensemble of decision trees
ForestSamples.Rd
Wrapper around a C++ container of tree ensembles
Methods
Method new()
Create a new ForestContainer object.
Usage
ForestSamples$new(
num_trees,
output_dimension = 1,
is_leaf_constant = F,
is_exponentiated = F
)
Method load_from_json()
Create a new ForestContainer
object from a json object
Method append_from_json()
Append to a ForestContainer
object from a json object
Method load_from_json_string()
Create a new ForestContainer
object from a json object
Method append_from_json_string()
Append to a ForestContainer
object from a json object
Method predict()
Predict every tree ensemble on every sample in forest_dataset
Method predict_raw()
Predict "raw" leaf values (without being multiplied by basis) for every tree ensemble on every sample in forest_dataset
Returns
Array of predictions for each observation in forest_dataset
and
each sample in the ForestSamples
class with each prediction having the
dimensionality of the forests' leaf model. In the case of a constant leaf model
or univariate leaf regression, this array is two-dimensional (number of observations,
number of forest samples). In the case of a multivariate leaf regression,
this array is three-dimension (number of observations, leaf model dimension,
number of samples).
Method predict_raw_single_forest()
Predict "raw" leaf values (without being multiplied by basis) for a specific forest on every sample in forest_dataset
Method set_root_leaves()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Method prepare_for_sampler()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Usage
ForestSamples$prepare_for_sampler(
dataset,
outcome,
forest_model,
leaf_model_int,
leaf_value
)
Arguments
dataset
ForestDataset
Dataset class (covariates, basis, etc...)outcome
Outcome
Outcome class (residual / partial residual)forest_model
ForestModel
object storing tracking structures used in training / samplingleaf_model_int
Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
leaf_value
Constant leaf value(s) to be fixed for each tree in the ensemble indexed by
forest_num
. Can be either a single number or a vector, depending on the forest's leaf dimension.
Method adjust_residual()
Adjusts residual based on the predictions of a forest
This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Usage
ForestSamples$adjust_residual(
dataset,
outcome,
forest_model,
requires_basis,
forest_num,
add
)
Arguments
dataset
ForestDataset
object storing the covariates and bases for a given forestoutcome
Outcome
object storing the residuals to be updated based on forest predictionsforest_model
ForestModel
object storing tracking structures used in training / samplingrequires_basis
Whether or not a forest requires a basis for prediction
forest_num
Index of forest used to update residuals
add
Whether forest predictions should be added to or subtracted from residuals
Method update_residual()
Updates the residual used for training tree ensembles by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.
Arguments
dataset
ForestDataset
object storing the covariates and bases for a given forestoutcome
Outcome
object storing the residuals to be updated based on forest predictionsforest_model
ForestModel
object storing tracking structures used in training / samplingforest_num
Index of forest used to update residuals (starting at 1, in R style)
Method load_json()
Load trees and metadata for an ensemble from a json file. Note that
any trees and metadata already present in ForestDataset
class will
be overwritten.
Method add_forest_with_constant_leaves()
Add a new all-root ensemble to the container, with all of the leaves set to the value / vector provided
Method add_numeric_split_tree()
Add a numeric (i.e. X,i <= c) split to a given tree in the ensemble
Usage
ForestSamples$add_numeric_split_tree(
forest_num,
tree_num,
leaf_num,
feature_num,
split_threshold,
left_leaf_value,
right_leaf_value
)
Arguments
forest_num
Index of the forest which contains the tree to be split
tree_num
Index of the tree to be split
leaf_num
Leaf to be split
feature_num
Feature that defines the new split
split_threshold
Value that defines the cutoff of the new split
left_leaf_value
Value (or vector of values) to assign to the newly created left node
right_leaf_value
Value (or vector of values) to assign to the newly created right node
Method get_tree_leaves()
Retrieve a vector of indices of leaf nodes for a given tree in a given forest
Method get_tree_split_counts()
Retrieve a vector of split counts for every training set variable in a given tree in a given forest
Method get_forest_split_counts()
Retrieve a vector of split counts for every training set variable in a given forest
Method get_aggregate_split_counts()
Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees
Method get_granular_split_counts()
Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree
Method ensemble_tree_max_depth()
Maximum depth of a specific tree in a specific ensemble in a ForestContainer
object
Method average_ensemble_max_depth()
Average the maximum depth of each tree in a given ensemble in a ForestContainer
object