Skip to contents

While out of sample evaluation and MCMC diagnostics on parametric BART components (i.e. σ2\sigma^2, the global error variance) are helpful, it’s important to be able to inspect the trees in a BART / BCF model (or a custom tree ensemble model). This vignette walks through some of the features stochtree provides to query and understand the forests / trees in a model.

To begin, we load the stochtree package.

Demo 1: Supervised Learning

Generate sample data where feature 10 is the only “important” feature.

# Generate the data
n <- 500
p_x <- 10
X <- matrix(runif(n*p_x), ncol = p_x)
f_XW <- (
    ((0 <= X[,10]) & (0.25 > X[,10])) * (-7.5) + 
    ((0.25 <= X[,10]) & (0.5 > X[,10])) * (-2.5) + 
    ((0.5 <= X[,10]) & (0.75 > X[,10])) * (2.5) + 
    ((0.75 <= X[,10]) & (1 > X[,10])) * (7.5)
)
noise_sd <- 1
y <- f_XW + rnorm(n, 0, 1)*noise_sd

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds,])
X_train <- as.data.frame(X[train_inds,])
W_test <- NULL
W_train <- NULL
y_test <- y[test_inds]
y_train <- y[train_inds]

Sampling and Analysis

Run BART.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
bart_params <- list(keep_gfr = T)
bart_model <- stochtree::bart(
    X_train = X_train, y_train = y_train, X_test = X_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    params = bart_params
)

Inspect the MCMC samples

plot(bart_model$sigma2_global_samples, ylab="sigma^2")
abline(h=noise_sd^2,col="red",lty=2,lwd=2.5)

plot(rowMeans(bart_model$y_hat_test), y_test, 
     pch=16, cex=0.75, xlab = "pred", ylab = "actual")
abline(0,1,col="red",lty=2,lwd=2.5)

Check the variable split count in the last GFR sample

bart_model$mean_forests$get_forest_split_counts(10, p_x)
#>  [1] 28 26 21 23 33 23 25 21 26 39
bart_model$mean_forests$get_aggregate_split_counts(p_x)
#>  [1] 2857 2099 2366 2566 2907 2279 2650 2450 2787 3543

The split counts appear relatively uniform across features, so let’s dig deeper and look at individual trees, starting with the first tree in the last “grow-from-root” sample.

splits = bart_model$mean_forests$get_granular_split_counts(p_x)
splits[10,1,]
#>  [1] 0 0 0 0 0 0 0 0 0 1

This tree has a single split on the only “important” feature. Now, let’s look at the second tree.

splits[10,2,]
#>  [1] 0 0 0 0 0 0 0 0 0 2

This tree also only splits on the important feature.

splits[10,20,]
#>  [1] 0 0 0 0 1 0 1 0 0 1
splits[10,30,]
#>  [1] 0 0 0 0 0 0 0 0 0 1

We see that “later” trees are splitting on other features, but we also note that these trees are fitting an outcome that is already residualized many “relevant splits” made by trees 1 and 2.

Now, let’s inspect the first tree for this last GFR sample in more depth, following this scikit-learn vignette

forest_num <- 9
tree_num <- 0
nodes <- sort(bart_model$mean_forests$nodes(forest_num, tree_num))
for (nid in nodes) {
    if (bart_model$mean_forests$is_leaf_node(forest_num, tree_num, nid)) {
        node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid)
        space_text <- rep("\t", node_depth)
        leaf_values <- bart_model$mean_forests$node_leaf_values(forest_num, tree_num, nid)
        cat(space_text, "node=", nid, " is a leaf node with value=", 
            format(leaf_values, digits = 3), "\n", sep = "")
    } else {
        node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid)
        space_text <- rep("\t", node_depth)
        left <- bart_model$mean_forests$left_child_node(forest_num, tree_num, nid)
        feature <- bart_model$mean_forests$node_split_index(forest_num, tree_num, nid)
        threshold <- bart_model$mean_forests$node_split_threshold(forest_num, tree_num, nid)
        right <- bart_model$mean_forests$right_child_node(forest_num, tree_num, nid)
        cat(space_text, "node=", nid, " is a split node, which tells us to go to node ", 
            left, " if X[:, ", feature, "] <= ", format(threshold, digits = 3), 
            " else to node ", right, "\n", sep = "")
    }
}
#> node=0 is a split node, which tells us to go to node 1 if X[:, 9] <= 0.496 else to node 2
#>  node=1 is a leaf node with value=-0.296
#>  node=2 is a leaf node with value=0.341