Skip to contents

Wrapper around a C++ tree ensemble

Public fields

forest_ptr

External pointer to a C++ TreeEnsemble class

Methods


Method new()

Create a new Forest object.

Usage

Forest$new(
  num_trees,
  output_dimension = 1,
  is_leaf_constant = F,
  is_exponentiated = F
)

Arguments

num_trees

Number of trees in the forest

output_dimension

Dimensionality of the outcome model

is_leaf_constant

Whether leaf is constant

is_exponentiated

Whether forest predictions should be exponentiated before being returned

Returns

A new Forest object.


Method predict()

Predict forest on every sample in forest_dataset

Usage

Forest$predict(forest_dataset)

Arguments

forest_dataset

ForestDataset R class

Returns

vector of predictions with as many rows as in forest_dataset


Method predict_raw()

Predict "raw" leaf values (without being multiplied by basis) for every sample in forest_dataset

Usage

Forest$predict_raw(forest_dataset)

Arguments

forest_dataset

ForestDataset R class

Returns

Array of predictions for each observation in forest_dataset and each sample in the ForestSamples class with each prediction having the dimensionality of the forests' leaf model. In the case of a constant leaf model or univariate leaf regression, this array is a vector (length is the number of observations). In the case of a multivariate leaf regression, this array is a matrix (number of observations by leaf model dimension, number of samples).


Method set_root_leaves()

Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.

Usage

Forest$set_root_leaves(leaf_value)

Arguments

leaf_value

Constant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.


Method prepare_for_sampler()

Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.

Usage

Forest$prepare_for_sampler(
  dataset,
  outcome,
  forest_model,
  leaf_model_int,
  leaf_value
)

Arguments

dataset

ForestDataset Dataset class (covariates, basis, etc...)

outcome

Outcome Outcome class (residual / partial residual)

forest_model

ForestModel object storing tracking structures used in training / sampling

leaf_model_int

Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).

leaf_value

Constant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.


Method adjust_residual()

Adjusts residual based on the predictions of a forest

This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.

Usage

Forest$adjust_residual(dataset, outcome, forest_model, requires_basis, add)

Arguments

dataset

ForestDataset object storing the covariates and bases for a given forest

outcome

Outcome object storing the residuals to be updated based on forest predictions

forest_model

ForestModel object storing tracking structures used in training / sampling

requires_basis

Whether or not a forest requires a basis for prediction

add

Whether forest predictions should be added to or subtracted from residuals


Method num_trees()

Return number of trees in each ensemble of a Forest object

Usage

Forest$num_trees()

Returns

Tree count


Method output_dimension()

Return output dimension of trees in a Forest object

Usage

Forest$output_dimension()

Returns

Leaf node parameter size


Method is_constant_leaf()

Return constant leaf status of trees in a Forest object

Usage

Forest$is_constant_leaf()

Returns

T if leaves are constant, F otherwise


Method is_exponentiated()

Return exponentiation status of trees in a Forest object

Usage

Forest$is_exponentiated()

Returns

T if leaf predictions must be exponentiated, F otherwise


Method add_numeric_split_tree()

Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble

Usage

Forest$add_numeric_split_tree(
  tree_num,
  leaf_num,
  feature_num,
  split_threshold,
  left_leaf_value,
  right_leaf_value
)

Arguments

tree_num

Index of the tree to be split

leaf_num

Leaf to be split

feature_num

Feature that defines the new split

split_threshold

Value that defines the cutoff of the new split

left_leaf_value

Value (or vector of values) to assign to the newly created left node

right_leaf_value

Value (or vector of values) to assign to the newly created right node


Method get_tree_leaves()

Retrieve a vector of indices of leaf nodes for a given tree in a given forest

Usage

Forest$get_tree_leaves(tree_num)

Arguments

tree_num

Index of the tree for which leaf indices will be retrieved


Method get_tree_split_counts()

Retrieve a vector of split counts for every training set variable in a given tree in the forest

Usage

Forest$get_tree_split_counts(tree_num, num_features)

Arguments

tree_num

Index of the tree for which split counts will be retrieved

num_features

Total number of features in the training set


Method get_forest_split_counts()

Retrieve a vector of split counts for every training set variable in the forest

Usage

Forest$get_forest_split_counts(num_features)

Arguments

num_features

Total number of features in the training set


Method tree_max_depth()

Maximum depth of a specific tree in the forest

Usage

Forest$tree_max_depth(tree_num)

Arguments

tree_num

Tree index within forest

Returns

Maximum leaf depth


Method average_max_depth()

Average the maximum depth of each tree in the forest

Usage

Forest$average_max_depth()

Returns

Average maximum depth