Skip to contents

This vignette demonstrates how to use the bcf() function for supervised learning. To begin, we load the stochtree package.

We also define several simple functions that configure the data generating processes used in this vignette.

g <- function(x) {ifelse(x[,5]==1,2,ifelse(x[,5]==2,-1,-4))}
mu1 <- function(x) {1+g(x)+x[,1]*x[,3]}
mu2 <- function(x) {1+g(x)+6*abs(x[,3]-1)}
tau1 <- function(x) {rep(3,nrow(x))}
tau2 <- function(x) {1+2*x[,2]*x[,4]}

Binary Treatment

Demo 1: Nonlinear Outcome Model, Heterogeneous Treatment Effect

We consider the following data generating process from Hahn, Murray, and Carvalho (2020):

y=μ(X)+τ(X)Z+ϵϵN(0,σ2)μ(X)=1+g(X)+6|X31|τ(X)=1+2X2X4g(X)=𝕀(X5=1)×2𝕀(X5=2)×1𝕀(X5=3)×4sμ=𝕍(μ(X))π(X)=0.8ϕ(3μ(X)sμ)X12+2U+120X1,X2,X3N(0,1)X4Bernoulli(1/2)X5Categorical(1/3,1/3,1/3)UUniform(0,1)ZBernoulli(π(X))\begin{equation*} \begin{aligned} y &= \mu(X) + \tau(X) Z + \epsilon\\ \epsilon &\sim N\left(0,\sigma^2\right)\\ \mu(X) &= 1 + g(X) + 6 \lvert X_3 - 1 \rvert\\ \tau(X) &= 1 + 2 X_2 X_4\\ g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\ s_{\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\ \pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_{\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\ X_1,X_2,X_3 &\sim N\left(0,1\right)\\ X_4 &\sim \text{Bernoulli}(1/2)\\ X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\ U &\sim \text{Uniform}\left(0,1\right)\\ Z &\sim \text{Bernoulli}\left(\pi(X)\right) \end{aligned} \end{equation*}

Simulation

We draw from the DGP defined above

n <- 500
snr <- 3
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n,1,0.5))
x5 <- as.numeric(sample(1:3,n,replace=TRUE))
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
mu_x <- mu1(X)
tau_x <- tau2(X)
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
X <- as.data.frame(X)
X$x4 <- factor(X$x4, ordered = TRUE)
X$x5 <- factor(X$x5, ordered = TRUE)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

Sampling and Analysis

Warmstart

We first simulate from an ensemble model of yXy \mid X using “warm-start” initialization samples (Krantsevich, He, and Hahn (2023)). This is the default in stochtree.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.93
BART MCMC without Warmstart

Next, we simulate from this ensemble model without any warm-start initialization.

num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_root <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples after burnin

plot(rowMeans(bcf_model_root$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_root$sigma2_samples, sigma_observed)))
plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.95

Demo 2: Linear Outcome Model, Heterogeneous Treatment Effect

We consider the following data generating process from Hahn, Murray, and Carvalho (2020):

y=μ(X)+τ(X)Z+ϵϵN(0,σ2)μ(X)=1+g(X)+6X1X3τ(X)=1+2X2X4g(X)=𝕀(X5=1)×2𝕀(X5=2)×1𝕀(X5=3)×4sμ=𝕍(μ(X))π(X)=0.8ϕ(3μ(X)sμ)X12+2U+120X1,X2,X3N(0,1)X4Bernoulli(1/2)X5Categorical(1/3,1/3,1/3)UUniform(0,1)ZBernoulli(π(X))\begin{equation*} \begin{aligned} y &= \mu(X) + \tau(X) Z + \epsilon\\ \epsilon &\sim N\left(0,\sigma^2\right)\\ \mu(X) &= 1 + g(X) + 6 X_1 X_3\\ \tau(X) &= 1 + 2 X_2 X_4\\ g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\ s_{\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\ \pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_{\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\ X_1,X_2,X_3 &\sim N\left(0,1\right)\\ X_4 &\sim \text{Bernoulli}(1/2)\\ X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\ U &\sim \text{Uniform}\left(0,1\right)\\ Z &\sim \text{Bernoulli}\left(\pi(X)\right) \end{aligned} \end{equation*}

Simulation

We draw from the DGP defined above

n <- 500
snr <- 3
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n,1,0.5))
x5 <- as.numeric(sample(1:3,n,replace=TRUE))
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
mu_x <- mu2(X)
tau_x <- tau2(X)
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
X <- as.data.frame(X)
X$x4 <- factor(X$x4, ordered = TRUE)
X$x5 <- factor(X$x5, ordered = TRUE)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

Sampling and Analysis

Warmstart

We first simulate from an ensemble model of yXy \mid X using “warm-start” initialization samples (Krantsevich, He, and Hahn (2023)). This is the default in stochtree.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.91
BART MCMC without Warmstart

Next, we simulate from this ensemble model without any warm-start initialization.

num_gfr <- 0
num_burnin <- 100
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_root <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples after burnin

plot(rowMeans(bcf_model_root$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_root$sigma2_samples, sigma_observed)))
plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.73

Demo 3: Linear Outcome Model, Homogeneous Treatment Effect

We consider the following data generating process from Hahn, Murray, and Carvalho (2020):

y=μ(X)+τ(X)Z+ϵϵN(0,σ2)μ(X)=1+g(X)+6X1X3τ(X)=3g(X)=𝕀(X5=1)×2𝕀(X5=2)×1𝕀(X5=3)×4sμ=𝕍(μ(X))π(X)=0.8ϕ(3μ(X)sμ)X12+2U+120X1,X2,X3N(0,1)X4Bernoulli(1/2)X5Categorical(1/3,1/3,1/3)UUniform(0,1)ZBernoulli(π(X))\begin{equation*} \begin{aligned} y &= \mu(X) + \tau(X) Z + \epsilon\\ \epsilon &\sim N\left(0,\sigma^2\right)\\ \mu(X) &= 1 + g(X) + 6 X_1 X_3\\ \tau(X) &= 3\\ g(X) &= \mathbb{I}(X_5=1) \times 2 - \mathbb{I}(X_5=2) \times 1 - \mathbb{I}(X_5=3) \times 4\\ s_{\mu} &= \sqrt{\mathbb{V}(\mu(X))}\\ \pi(X) &= 0.8 \phi\left(\frac{3\mu(X)}{s_{\mu}}\right) - \frac{X_1}{2} + \frac{2U+1}{20}\\ X_1,X_2,X_3 &\sim N\left(0,1\right)\\ X_4 &\sim \text{Bernoulli}(1/2)\\ X_5 &\sim \text{Categorical}(1/3,1/3,1/3)\\ U &\sim \text{Uniform}\left(0,1\right)\\ Z &\sim \text{Bernoulli}\left(\pi(X)\right) \end{aligned} \end{equation*}

Simulation

We draw from the DGP defined above

n <- 500
snr <- 3
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n,1,0.5))
x5 <- as.numeric(sample(1:3,n,replace=TRUE))
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
mu_x <- mu2(X)
tau_x <- tau1(X)
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
X <- as.data.frame(X)
X$x4 <- factor(X$x4, ordered = TRUE)
X$x5 <- factor(X$x5, ordered = TRUE)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

Sampling and Analysis

Warmstart

We first simulate from an ensemble model of yXy \mid X using “warm-start” initialization samples (Krantsevich, He, and Hahn (2023)). This is the default in stochtree.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 1
BART MCMC without Warmstart

Next, we simulate from this ensemble model without any warm-start initialization.

num_gfr <- 0
num_burnin <- 100
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_root <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples after burnin

plot(rowMeans(bcf_model_root$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_root$sigma2_samples, sigma_observed)))
plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 1

Demo 4: Nonlinear Outcome Model, Heterogeneous Treatment Effect

We consider the following data generating process:

y=μ(X)+τ(X)Z+ϵϵN(0,σ2)μ(X)={1.1 ifX1>X20.9 ifX1X2τ(X)=11+exp(X3)+X210π(X)=Φ(μ(X))ZBernoulli(π(X))X1,X2,X3N(0,1)X4N(X2,1)\begin{equation*} \begin{aligned} y &= \mu(X) + \tau(X) Z + \epsilon\\ \epsilon &\sim N\left(0,\sigma^2\right)\\ \mu(X) &= \begin{cases} -1.1 & \text{ if} X_1 > X_2\\ 0.9 & \text{ if} X_1 \leq X_2 \end{cases}\\ \tau(X) &= \frac{1}{1+\exp(-X_3)} + \frac{X_2}{10}\\ \pi(X) &= \Phi\left(\mu(X)\right)\\ Z &\sim \text{Bernoulli}\left(\pi(X)\right)\\ X_1,X_2,X_3 &\sim N\left(0,1\right)\\ X_4 &\sim N\left(X_2,1\right)\\ \end{aligned} \end{equation*}

Simulation

We draw from the DGP defined above

n <- 1000
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- rnorm(n,x2,1)
X <- cbind(x1,x2,x3,x4)
p <- ncol(X)
mu <- function(x) {-1*(x[,1]>(x[,2])) + 1*(x[,1]<(x[,2])) - 0.1}
tau <- function(x) {1/(1 + exp(-x[,3])) + x[,2]/10}
mu_x <- mu(X)
tau_x <- tau(X)
pi_x <- pnorm(mu_x)
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
sigma <- diff(range(mu_x + tau_x*pi))/8
y <- E_XZ + sigma*rnorm(n)
X <- as.data.frame(X)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

Sampling and Analysis

Warmstart

We first simulate from an ensemble model of yXy \mid X using “warm-start” initialization samples (Krantsevich, He, and Hahn (2023)). This is the default in stochtree.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.97
BART MCMC without Warmstart

Next, we simulate from this ensemble model without any warm-start initialization.

num_gfr <- 0
num_burnin <- 100
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_root <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples after burnin

plot(rowMeans(bcf_model_root$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_root$sigma2_samples, sigma_observed)))
plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.97

Demo 5: Nonlinear Outcome Model, Heterogeneous Treatment Effect with Additive Random Effects

We augment the simulated example in Demo 1 with an additive random effect structure and show that the bcf() function can estimate and incorporate these effects into its forest sampling procedure.

Simulation

We draw from the augmented “demo 1” DGP

n <- 500
snr <- 3
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n,1,0.5))
x5 <- as.numeric(sample(1:3,n,replace=TRUE))
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
mu_x <- mu1(X)
tau_x <- tau2(X)
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
group_ids <- rep(c(1,2), n %/% 2)
rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE)
rfx_basis <- cbind(1, runif(n, -1, 1))
rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis)
y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
X <- as.data.frame(X)
X$x4 <- factor(X$x4, ordered = TRUE)
X$x5 <- factor(X$x5, ordered = TRUE)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]
group_ids_test <- group_ids[test_inds]
group_ids_train <- group_ids[train_inds]
rfx_basis_test <- rfx_basis[test_inds,]
rfx_basis_train <- rfx_basis[train_inds,]
rfx_term_test <- rfx_term[test_inds]
rfx_term_train <- rfx_term[train_inds]

Sampling and Analysis

Warmstart

Here we simulate only from the “warm-start” model (running root-MCMC BART with random effects is simply a matter of modifying the below code snippet by setting num_gfr <- 0 and num_mcmc > 0).

num_gfr <- 100
num_burnin <- 0
num_mcmc <- 500
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
    rfx_basis_test = rfx_basis_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$y_hat_test), y_test, 
     xlab = "predicted", ylab = "actual", main = "Outcome")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$rfx_preds_test), rfx_term_test, 
     xlab = "predicted", ylab = "actual", main = "Random effects terms")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ-rfx_term)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.96

It is clear that causal inference is much more difficult in the presence of both strong covariate-dependent prognostic effects and strong group-level random effects. In this sense, proper prior calibration for all three of the μ\mu, τ\tau and random effects models is crucial.

Demo 6: Nonlinear Outcome Model, Heterogeneous Treatment Effect using Different Features in the Prognostic and Treatment Forests

Here, we consider the case in which we might prefer to use only a subset of covariates in the treatment effect forest. Why might we want to do that? Well, in many cases it is plausible that some covariates (for example age, income, etc…) influence the outcome of interest in a causal problem, but do not moderate the treatment effect. In this case, we’d need to include these variables in the prognostic forest for deconfounding but we don’t necessarily need to include them in the treatment effect forest.

Simulation

We draw from a modified “demo 1” DGP

mu <- function(x) {1+g(x)+x[,1]*x[,3]-x[,2]+3*x[,3]}
tau <- function(x) {1 - 2*x[,1] + 2*x[,2] + 1*x[,1]*x[,2]}
n <- 500
snr <- 4
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- as.numeric(rbinom(n,1,0.5))
x5 <- as.numeric(sample(1:3,n,replace=TRUE))
x6 <- rnorm(n)
x7 <- rnorm(n)
x8 <- rnorm(n)
x9 <- rnorm(n)
x10 <- rnorm(n)
X <- cbind(x1,x2,x3,x4,x5,x6,x7,x8,x9,x10)
p <- ncol(X)
mu_x <- mu(X)
tau_x <- tau(X)
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
Z <- rbinom(n,1,pi_x)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
X <- as.data.frame(X)
X$x4 <- factor(X$x4, ordered = TRUE)
X$x5 <- factor(X$x5, ordered = TRUE)

# Split data into test and train sets
test_set_pct <- 0.5
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

Sampling and Analysis

MCMC, full covariate set in τ(X)\tau(X)

Here we simulate from the model with the original MCMC sampler, using all of the covariates in both the prognostic (μ(X)\mu(X)) and treatment effect (τ(X)\tau(X)) forests.

num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_mcmc <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the burned-in samples

plot(rowMeans(bcf_model_mcmc$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_mcmc$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_mcmc$y_hat_test), y_test, 
     xlab = "predicted", ylab = "actual", main = "Outcome")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_mcmc$sigma2_samples, sigma_observed)))
plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.792

And test set RMSE

test_tau_mean <- rowMeans(bcf_model_mcmc$tau_hat_test)
sqrt(mean((test_tau_mean - tau_test)^2))
#> [1] 1.512766
test_outcome_mean <- rowMeans(bcf_model_mcmc$y_hat_test)
sqrt(mean((test_outcome_mean - y_test)^2))
#> [1] 1.912132
MCMC, covariate subset in τ(X)\tau(X)

Here we simulate from the model with the original MCMC sampler, using only covariate X1X_1 in the treatment effect forest.

num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_mcmc <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, 
    keep_vars_tau = c("x1","x2")
)

Inspect the BART samples

plot(rowMeans(bcf_model_mcmc$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_mcmc$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_mcmc$y_hat_test), y_test, 
     xlab = "predicted", ylab = "actual", main = "Outcome")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_mcmc$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_mcmc$sigma2_samples, sigma_observed)))
plot(bcf_model_mcmc$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_mcmc$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.836

And test set RMSE

test_mean <- rowMeans(bcf_model_mcmc$tau_hat_test)
sqrt(mean((test_mean - tau_test)^2))
#> [1] 1.330455
test_outcome_mean <- rowMeans(bcf_model_mcmc$y_hat_test)
sqrt(mean((test_outcome_mean - y_test)^2))
#> [1] 1.864776
Warmstart, full covariate set in τ(X)\tau(X)

Here we simulate from the model with the warm-start sampler, using all of the covariates in both the prognostic (μ(X)\mu(X)) and treatment effect (τ(X)\tau(X)) forests.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$y_hat_test), y_test, 
     xlab = "predicted", ylab = "actual", main = "Outcome")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.884

And test set RMSE

test_tau_mean <- rowMeans(bcf_model_warmstart$tau_hat_test)
sqrt(mean((tau_test - test_tau_mean)^2))
#> [1] 1.478661
test_outcome_mean <- rowMeans(bcf_model_warmstart$y_hat_test)
sqrt(mean((y_test - test_outcome_mean)^2))
#> [1] 1.866693
Warmstart, covariate subset in τ(X)\tau(X)

Here we simulate from the model with the warm-start sampler, using only covariate X1X_1 in the treatment effect forest.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F, 
    keep_vars_tau = c("x1", "x2")
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$y_hat_test), y_test, 
     xlab = "predicted", ylab = "actual", main = "Outcome")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.864

And test set RMSE

test_tau_mean <- rowMeans(bcf_model_warmstart$tau_hat_test)
sqrt(mean((tau_test - test_tau_mean)^2))
#> [1] 1.428011
test_outcome_mean <- rowMeans(bcf_model_warmstart$y_hat_test)
sqrt(mean((y_test - test_outcome_mean)^2))
#> [1] 1.917739

Continuous Treatment

Demo 1: Nonlinear Outcome Model, Heterogeneous Treatment Effect

We consider the following data generating process:

y=μ(X)+τ(X)Z+ϵϵN(0,σ2)μ(X)=1+2X1𝟙(X2<0)×4+𝟙(X20)×4+3(|X3|2π)τ(X)=1+2X4X1,X2,X3,X4,X5N(0,1)UUniform(0,1)π(X)=μ(X)12+4(U12)Z𝒩(π(X),1)\begin{equation*} \begin{aligned} y &= \mu(X) + \tau(X) Z + \epsilon\\ \epsilon &\sim N\left(0,\sigma^2\right)\\ \mu(X) &= 1 + 2 X_1 - \mathbb{1}\left(X_2 < 0\right) \times 4 + \mathbb{1}\left(X_2 \geq 0\right) \times 4 + 3 \left(\lvert X_3 \rvert - \sqrt{\frac{2}{\pi}} \right)\\ \tau(X) &= 1 + 2 X_4\\ X_1,X_2,X_3,X_4,X_5 &\sim N\left(0,1\right)\\ U &\sim \text{Uniform}\left(0,1\right)\\ \pi(X) &= \frac{\mu(X) - 1}{2} + 4 \left(U - \frac{1}{2}\right)\\ Z &\sim \mathcal{N}\left(\pi(X), 1\right) \end{aligned} \end{equation*}

Simulation

We draw from the DGP defined above

n <- 500
snr <- 3
x1 <- rnorm(n)
x2 <- rnorm(n)
x3 <- rnorm(n)
x4 <- rnorm(n)
x5 <- rnorm(n)
X <- cbind(x1,x2,x3,x4,x5)
p <- ncol(X)
mu_x <- 1 + 2*x1 - 4*(x2 < 0) + 4*(x2 >= 0) + 3*(abs(x3) - sqrt(2/pi))
tau_x <- 1 + 2*x4
u <- runif(n)
pi_x <- ((mu_x-1)/4) + 4*(u-0.5)
Z <- pi_x + rnorm(n,0,1)
E_XZ <- mu_x + Z*tau_x
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ)/snr)
X <- as.data.frame(X)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds]
pi_train <- pi_x[train_inds]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds]
tau_train <- tau_x[train_inds]

Sampling and Analysis

Warmstart

We first simulate from an ensemble model of yXy \mid X using “warm-start” initialization samples (Krantsevich, He, and Hahn (2023)). This is the default in stochtree.

num_gfr <- 10
num_burnin <- 0
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_warmstart <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples that were initialized with an XBART warm-start

plot(rowMeans(bcf_model_warmstart$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_warmstart$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_warmstart$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_warmstart$sigma2_samples, sigma_observed)))
plot(bcf_model_warmstart$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_warmstart$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 1
BART MCMC without Warmstart

Next, we simulate from this ensemble model without any warm-start initialization.

num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 1000
num_samples <- num_gfr + num_burnin + num_mcmc
bcf_model_root <- bcf(
    X_train = X_train, Z_train = Z_train, y_train = y_train, pi_train = pi_train, 
    X_test = X_test, Z_test = Z_test, pi_test = pi_test, 
    num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, 
    sample_sigma_leaf_mu = F, sample_sigma_leaf_tau = F
)

Inspect the BART samples after burnin

plot(rowMeans(bcf_model_root$mu_hat_test), mu_test, 
     xlab = "predicted", ylab = "actual", main = "Prognostic function")
abline(0,1,col="red",lty=3,lwd=3)

plot(rowMeans(bcf_model_root$tau_hat_test), tau_test, 
     xlab = "predicted", ylab = "actual", main = "Treatment effect")
abline(0,1,col="red",lty=3,lwd=3)

sigma_observed <- var(y-E_XZ)
plot_bounds <- c(min(c(bcf_model_root$sigma2_samples, sigma_observed)), 
                 max(c(bcf_model_root$sigma2_samples, sigma_observed)))
plot(bcf_model_root$sigma2_samples, ylim = plot_bounds, 
     ylab = "sigma^2", xlab = "Sample", main = "Global variance parameter")
abline(h = sigma_observed, lty=3, lwd = 3, col = "blue")

Examine test set interval coverage

test_lb <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.025)
test_ub <- apply(bcf_model_root$tau_hat_test, 1, quantile, 0.975)
cover <- (
    (test_lb <= tau_x[test_inds]) & 
    (test_ub >= tau_x[test_inds])
)
mean(cover)
#> [1] 0.96

References

Hahn, P Richard, Jared S Murray, and Carlos M Carvalho. 2020. “Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Effects (with Discussion).” Bayesian Analysis 15 (3): 965–1056.
Krantsevich, Nikolay, Jingyu He, and P Richard Hahn. 2023. “Stochastic Tree Ensembles for Estimating Heterogeneous Effects.” In International Conference on Artificial Intelligence and Statistics, 6120–31. PMLR.