Lab: Causal Forests for Heterogeneous Treatment Effects
Estimate individualized treatment effects via the Wager-Athey causal forest: move beyond ATE to identify who benefits most and design targeting policies.
- Languages
- Python, R, Stata
- Dataset
- Simulated RCT with heterogeneous treatment effects
Overview
In this lab you will analyze a simulated randomized controlled trial where the treatment effect varies substantially across individuals. Rather than estimating a single average treatment effect (ATE), you will use causal forests to estimate the conditional average treatment effect (CATE) as a function of covariates, identify which variables drive treatment effect heterogeneity, and design an optimal treatment targeting policy.
What you will learn:
- How to estimate heterogeneous treatment effects with causal forests
- How to assess variable importance for treatment effect heterogeneity
- How to evaluate CATE estimates using calibration tests and RATE curves
- How to design and evaluate optimal targeting policies
- How to compare causal forests with simple subgroup analysis
Prerequisites: Familiarity with random forests and basic causal inference. Completion of the OLS and DML tutorial labs is recommended.
Step 1: Simulate an RCT with Heterogeneous Effects
We create an experiment where the treatment effect depends on age and baseline risk.
# First-time setup: install.packages(c("grf"))
library(grf)
set.seed(42)
n <- 4000
# Generate covariates: demographics and health characteristics
age <- runif(n, 25, 65)
income <- rlnorm(n, 10.5, 0.6)
educ <- pmin(pmax(rnorm(n, 14, 3), 8), 22) # Clip education to [8, 22]
health_score <- pmin(pmax(rnorm(n, 50, 15), 0), 100) # Clip health to [0, 100]
female <- rbinom(n, 1, 0.5)
risk <- rnorm(n) # Baseline risk factor
# Treatment assignment: randomized (RCT with 50% probability)
W <- rbinom(n, 1, 0.5)
# True CATE: effect varies by risk, age, and education
tau_true <- 4.5 + 3 * (risk > 0) - 0.1 * (age - 40) + 2 * (educ > 16)
# Baseline potential outcome under control
mu0 <- 50 + 0.5 * age + 0.3 * health_score - 2 * risk + rnorm(n, 0, 5)
# Observed outcome under treatment assignment
Y <- mu0 + W * tau_true
# Covariate matrix for the causal forest
X <- data.frame(age, income, educ, health_score, female, risk)
cat("True ATE:", mean(tau_true), "\n")
cat("CATE range:", range(tau_true), "\n")Expected output:
| Variable | Mean | Std Dev | Min | Max |
|---|---|---|---|---|
| Y | 84.5 | 12.3 | 45.2 | 125.6 |
| W | 0.50 | 0.50 | 0 | 1 |
| age | 45.0 | 11.5 | 25.0 | 65.0 |
| income | 44,300 | 30,100 | 5,800 | 245,000 |
| educ | 14.0 | 2.8 | 8.0 | 22.0 |
| health_score | 50.1 | 14.0 | 0.0 | 100.0 |
True ATE: ~6.00
True CATE range: [~2.5, ~11.5]
Treatment rate: ~50%
Covariates that drive heterogeneity: age, risk, educ
Step 2: Estimate the Average Treatment Effect
First, confirm the overall ATE before looking at heterogeneity.
# Simple difference in means (unbiased in an RCT)
ate_simple <- mean(Y[W == 1]) - mean(Y[W == 0])
cat("Diff in means:", ate_simple, "\n")
# Covariate-adjusted OLS (more precise due to reduced residual variance)
ols <- lm(Y ~ W + age + income + educ + health_score + female + risk)
cat("OLS-adjusted:", coef(ols)["W"], "\n")
cat("True ATE:", mean(tau_true), "\n")Expected output:
| Estimator | ATE Estimate | SE | True ATE |
|---|---|---|---|
| Difference in means | ~6.00 | ~0.36 | ~6.00 |
| OLS-adjusted | ~6.00 | ~0.16 | ~6.00 |
Difference in means: ~6.00
OLS-adjusted ATE: ~6.00 (SE: ~0.16)
True ATE: ~6.00
The covariate-adjusted estimator has a much smaller standard error because adjusting for strong predictors of the outcome (age, health_score, risk) reduces residual variance.
Step 3: Estimate CATEs with a Causal Forest
# Fit causal forest using the grf package
X_mat <- as.matrix(X)
cf <- causal_forest(X_mat, Y, W,
num.trees = 2000, # Number of trees in the forest
min.node.size = 5, # Minimum obs per leaf
seed = 42)
# Extract out-of-bag CATE predictions for each observation
tau_hat <- predict(cf)$predictions
# Evaluate accuracy by comparing estimated vs. true CATEs
cat("Correlation:", cor(tau_true, tau_hat), "\n")
cat("RMSE:", sqrt(mean((tau_true - tau_hat)^2)), "\n")
cat("Estimated ATE:", mean(tau_hat), "\n")
# Forest-based ATE with valid confidence interval (uses influence function)
ate_cf <- average_treatment_effect(cf)
cat("Forest ATE:", ate_cf[1], " SE:", ate_cf[2], "\n")Expected output:
Correlation(true CATE, estimated CATE): ~0.85
RMSE: ~1.2
Estimated ATE (mean of CATEs): ~6.00
True ATE: ~6.00
| Summary | Estimated CATE | True CATE |
|---|---|---|
| Mean | ~6.00 | ~6.00 |
| Std Dev | ~2.5 | ~2.8 |
| Min | ~1.5 | ~0.0 |
| Max | ~10.5 | ~12.0 |
The correlation of ~0.85 indicates the causal forest successfully identifies who benefits more vs. less from treatment, even though individual-level estimates are noisy.
A causal forest produces CATE estimates for each individual. How should you interpret these individual-level predictions?
Step 4: Variable Importance and Heterogeneity Drivers
# Variable importance: measures how often each covariate is used for splitting
vi <- variable_importance(cf)
rownames(vi) <- colnames(X)
vi_sorted <- sort(vi[,1], decreasing = TRUE)
barplot(vi_sorted, horiz = TRUE, main = "Variable Importance for Heterogeneity",
xlab = "Importance")
# Compare estimated vs. true CATEs by risk subgroup
cat("\nCATEs by risk group:\n")
cat("Low risk:", mean(tau_hat[risk <= 0]), "(true:", mean(tau_true[risk <= 0]), ")\n")
cat("High risk:", mean(tau_hat[risk > 0]), "(true:", mean(tau_true[risk > 0]), ")\n")Expected output:
Expected output: CATEs by age group
| Age Group | Estimated CATE | True CATE |
|---|---|---|
| 25–35 | ~7.5 | ~7.5 |
| 35–45 | ~6.5 | ~6.5 |
| 45–55 | ~5.5 | ~5.5 |
| 55–65 | ~4.5 | ~4.5 |
CATEs decline with age, consistent with the DGP: tau = 4.5 + 3*(risk > 0) - 0.1*(age - 40) + 2*(educ > 16).
Step 5: Calibration and the RATE Curve
Evaluate whether the CATE estimates actually predict treatment effect heterogeneity.
# Calibration test: checks if predicted CATEs predict actual effect heterogeneity
test_calibration <- test_calibration(cf)
print(test_calibration)
# RATE curve: evaluates targeting quality by ranking individuals by CATE
rate <- rank_average_treatment_effect(cf, tau_true)
plot(rate, main = "RATE Curve")
# Best linear projection: projects CATE onto covariates for interpretability
blp <- best_linear_projection(cf, X_mat)
print(blp)Expected output: CATE calibration by quintile
| Quintile | Predicted CATE | Observed Effect | True CATE |
|---|---|---|---|
| Q1 (lowest) | ~2.8 | ~3.0 | ~3.0 |
| Q2 | ~4.8 | ~5.0 | ~4.8 |
| Q3 | ~6.0 | ~5.8 | ~6.0 |
| Q4 | ~7.2 | ~7.0 | ~7.2 |
| Q5 (highest) | ~9.0 | ~9.2 | ~9.0 |
Calibration is good: predicted CATEs closely track both the observed and true effects across quintiles, confirming the forest identifies genuine heterogeneity.
Step 6: Optimal Targeting Policy
Use CATE estimates to design a policy that treats only those who benefit most.
# Policy: treat the top 50% by estimated CATE
threshold <- median(tau_hat)
policy <- as.integer(tau_hat >= threshold)
# Evaluate policy using true CATEs (only possible in simulation)
cat("Average effect, targeted:", mean(tau_true[policy == 1]), "\n")
cat("Average effect, not targeted:", mean(tau_true[policy == 0]), "\n")
cat("Average effect, treating all:", mean(tau_true), "\n")
# Compare forest-based targeting with a simple rule (risk > 0)
simple_policy <- as.integer(risk > 0)
cat("\nSimple rule benefit:", mean(tau_true[simple_policy == 1]), "\n")
cat("Forest targeting benefit:", mean(tau_true[policy == 1]), "\n")Expected output:
| Policy | Avg. Effect for Targeted | Avg. Effect for Not Targeted | Gain from Targeting |
|---|---|---|---|
| Treat top 50% (forest) | ~7.8 | ~4.2 | ~1.8 vs. treating all |
| Simple rule (risk > 0) | ~7.5 | ~4.5 | ~1.5 vs. treating all |
=== Optimal Targeting Policy (treat top 50%) ===
Average effect for those targeted: ~7.8
Average effect for those NOT targeted: ~4.2
Average effect if treating everyone: ~6.0
Gain from targeting: ~1.8 per treated individual
Fraction correctly identified (true top 50%): ~80%
Simple rule (risk > 0) benefit: ~7.5
Forest-based targeting benefit: ~7.8
The forest-based targeting outperforms the simple rule because it combines information from risk, age, and education simultaneously, while the simple rule uses only one variable.
You estimate CATEs using a causal forest on observational (non-experimental) data and find that older individuals have smaller treatment effects. A colleague suggests this finding could be driven by differential selection into treatment rather than true heterogeneity. How can you address this concern?
Exercises
-
Compare ML methods. Estimate CATEs using a causal forest, a T-learner (separate random forests for treated and control), and a linear interaction model. Which has the best RMSE for the true CATEs?
-
Out-of-sample validation. Split the data 50/50. Train the causal forest on the first half and evaluate CATE predictions on the second half using the calibration test.
-
Budget-constrained targeting. Suppose you can only treat 25% of the population. Design the optimal targeting policy and compute the expected average effect.
-
Nonlinear heterogeneity. Modify the DGP so that the treatment effect has a U-shape in age. Does the causal forest capture this pattern?
Summary
In this lab you learned:
- Causal forests estimate the conditional average treatment effect (CATE) as a function of covariates, moving beyond the ATE
- The method uses honest splitting (separate samples for determining splits and estimating effects) to produce valid confidence intervals
- Variable importance reveals which covariates drive treatment effect heterogeneity
- Calibration tests and RATE curves assess whether estimated CATEs are predictive of actual treatment effect heterogeneity
- Optimal targeting policies assign treatment to individuals with the highest predicted CATEs, potentially improving welfare
- CATE estimates from observational data should be validated before informing policy, as heterogeneity may reflect differential selection rather than true effect variation