Forest Sampling Tracker API¶
A truly minimalist tree ensemble library only needs
A representation of a decision tree
A container for grouping / storing ensembles of trees
In-memory access to / representation of training data
Routines / functions to construct the trees
Most algorithms for optimizing or sampling tree ensembles frequently perform the following operations
Determine which leaf a training observation falls into for a decision tree (to compute its prediction and update the residual / outcome)
Evaluate potential split candidates for a leaf of a decision
With only the “minimalist” tools above, these two tasks proceed largely as follows
For every observation in the dataset, traverse the tree (runtime depends on the tree topology but in a fully balanced tree with \(k\) nodes, this has time complexity \(O(\log (k))\)).
For every observation in the dataset, determine whether an observation falls into a given node and whether or not a proposed decision rule would be true
These operations both perform unnecessary computation which can be avoided with some additional real-time tracking. Essentially, we want
A mapping from dataset row index to leaf node id for every tree in an ensemble (so that we can skip the tree traversal during prediction)
A mapping from leaf node id to dataset row indices every tree in an ensemble (so that we can skip the full pass through the training data at split evaluation)
Forest Tracker¶
The ForestTracker
class is a wrapper around several implementations of the mappings discussed above.
-
class ForestTracker¶
“Superclass” wrapper around tracking data structures for forest sampling algorithms
Public Functions
-
ForestTracker(Eigen::MatrixXd &covariates, std::vector<FeatureType> &feature_types, int num_trees, int num_observations)¶
Construct a new
ForestTracker
object.- Parameters:
covariates – Matrix of covariate data
feature_types – Type of each feature (column) in
covariates
. This is represented by the enumStochTree::FeatureType
num_trees – Number of trees in an ensemble to be sampled
num_observations – Number of rows in
covariates
-
ForestTracker(Eigen::MatrixXd &covariates, std::vector<FeatureType> &feature_types, int num_trees, int num_observations)¶