MethodAtlas
Lab·tutorial·7 min read
tutorial120 minutes

Lab: Causal Forests for Heterogeneous Treatment Effects

Estimate individualized treatment effects via the Wager-Athey causal forest: move beyond ATE to identify who benefits most and design targeting policies.

Languages
Python, R, Stata
Dataset
Simulated RCT with heterogeneous treatment effects

Overview

In this lab you will analyze a simulated randomized controlled trial where the treatment effect varies substantially across individuals. Rather than estimating a single average treatment effect (ATE), you will use causal forests to estimate the conditional average treatment effect (CATE) as a function of covariates, identify which variables drive treatment effect heterogeneity, and design an optimal treatment targeting policy.

What you will learn:

  • How to estimate heterogeneous treatment effects with causal forests
  • How to assess variable importance for treatment effect heterogeneity
  • How to evaluate CATE estimates using calibration tests and RATE curves
  • How to design and evaluate optimal targeting policies
  • How to compare causal forests with simple subgroup analysis

Prerequisites: Familiarity with random forests and basic causal inference. Completion of the OLS and DML tutorial labs is recommended.


Step 1: Simulate an RCT with Heterogeneous Effects

We create an experiment where the treatment effect depends on age and baseline risk.

# First-time setup: install.packages(c("grf"))
library(grf)

set.seed(42)
n <- 4000

# Generate covariates: demographics and health characteristics
age <- runif(n, 25, 65)
income <- rlnorm(n, 10.5, 0.6)
educ <- pmin(pmax(rnorm(n, 14, 3), 8), 22)       # Clip education to [8, 22]
health_score <- pmin(pmax(rnorm(n, 50, 15), 0), 100) # Clip health to [0, 100]
female <- rbinom(n, 1, 0.5)
risk <- rnorm(n)  # Baseline risk factor

# Treatment assignment: randomized (RCT with 50% probability)
W <- rbinom(n, 1, 0.5)

# True CATE: effect varies by risk, age, and education
tau_true <- 4.5 + 3 * (risk > 0) - 0.1 * (age - 40) + 2 * (educ > 16)
# Baseline potential outcome under control
mu0 <- 50 + 0.5 * age + 0.3 * health_score - 2 * risk + rnorm(n, 0, 5)
# Observed outcome under treatment assignment
Y <- mu0 + W * tau_true

# Covariate matrix for the causal forest
X <- data.frame(age, income, educ, health_score, female, risk)

cat("True ATE:", mean(tau_true), "\n")
cat("CATE range:", range(tau_true), "\n")
Requiresgrf

Expected output:

VariableMeanStd DevMinMax
Y84.512.345.2125.6
W0.500.5001
age45.011.525.065.0
income44,30030,1005,800245,000
educ14.02.88.022.0
health_score50.114.00.0100.0
True ATE: ~6.00
True CATE range: [~2.5, ~11.5]
Treatment rate: ~50%
Covariates that drive heterogeneity: age, risk, educ

Step 2: Estimate the Average Treatment Effect

First, confirm the overall ATE before looking at heterogeneity.

# Simple difference in means (unbiased in an RCT)
ate_simple <- mean(Y[W == 1]) - mean(Y[W == 0])

cat("Diff in means:", ate_simple, "\n")

# Covariate-adjusted OLS (more precise due to reduced residual variance)
ols <- lm(Y ~ W + age + income + educ + health_score + female + risk)
cat("OLS-adjusted:", coef(ols)["W"], "\n")
cat("True ATE:", mean(tau_true), "\n")

Expected output:

EstimatorATE EstimateSETrue ATE
Difference in means~6.00~0.36~6.00
OLS-adjusted~6.00~0.16~6.00
Difference in means: ~6.00
OLS-adjusted ATE:    ~6.00 (SE: ~0.16)
True ATE:            ~6.00

The covariate-adjusted estimator has a much smaller standard error because adjusting for strong predictors of the outcome (age, health_score, risk) reduces residual variance.


Step 3: Estimate CATEs with a Causal Forest

# Fit causal forest using the grf package
X_mat <- as.matrix(X)

cf <- causal_forest(X_mat, Y, W,
                   num.trees = 2000,    # Number of trees in the forest
                   min.node.size = 5,   # Minimum obs per leaf
                   seed = 42)

# Extract out-of-bag CATE predictions for each observation
tau_hat <- predict(cf)$predictions

# Evaluate accuracy by comparing estimated vs. true CATEs
cat("Correlation:", cor(tau_true, tau_hat), "\n")
cat("RMSE:", sqrt(mean((tau_true - tau_hat)^2)), "\n")
cat("Estimated ATE:", mean(tau_hat), "\n")

# Forest-based ATE with valid confidence interval (uses influence function)
ate_cf <- average_treatment_effect(cf)
cat("Forest ATE:", ate_cf[1], " SE:", ate_cf[2], "\n")
Requiresgrf

Expected output:

Correlation(true CATE, estimated CATE): ~0.85
RMSE:                                   ~1.2
Estimated ATE (mean of CATEs):          ~6.00
True ATE:                                ~6.00
SummaryEstimated CATETrue CATE
Mean~6.00~6.00
Std Dev~2.5~2.8
Min~1.5~0.0
Max~10.5~12.0

The correlation of ~0.85 indicates the causal forest successfully identifies who benefits more vs. less from treatment, even though individual-level estimates are noisy.

Concept Check

A causal forest produces CATE estimates for each individual. How should you interpret these individual-level predictions?


Step 4: Variable Importance and Heterogeneity Drivers

# Variable importance: measures how often each covariate is used for splitting
vi <- variable_importance(cf)
rownames(vi) <- colnames(X)
vi_sorted <- sort(vi[,1], decreasing = TRUE)
barplot(vi_sorted, horiz = TRUE, main = "Variable Importance for Heterogeneity",
      xlab = "Importance")

# Compare estimated vs. true CATEs by risk subgroup
cat("\nCATEs by risk group:\n")
cat("Low risk:", mean(tau_hat[risk <= 0]), "(true:", mean(tau_true[risk <= 0]), ")\n")
cat("High risk:", mean(tau_hat[risk > 0]), "(true:", mean(tau_true[risk > 0]), ")\n")

Expected output:

Expected output: CATEs by age group

Age GroupEstimated CATETrue CATE
25–35~7.5~7.5
35–45~6.5~6.5
45–55~5.5~5.5
55–65~4.5~4.5

CATEs decline with age, consistent with the DGP: tau = 4.5 + 3*(risk > 0) - 0.1*(age - 40) + 2*(educ > 16).


Step 5: Calibration and the RATE Curve

Evaluate whether the CATE estimates actually predict treatment effect heterogeneity.

# Calibration test: checks if predicted CATEs predict actual effect heterogeneity
test_calibration <- test_calibration(cf)
print(test_calibration)

# RATE curve: evaluates targeting quality by ranking individuals by CATE
rate <- rank_average_treatment_effect(cf, tau_true)
plot(rate, main = "RATE Curve")

# Best linear projection: projects CATE onto covariates for interpretability
blp <- best_linear_projection(cf, X_mat)
print(blp)

Expected output: CATE calibration by quintile

QuintilePredicted CATEObserved EffectTrue CATE
Q1 (lowest)~2.8~3.0~3.0
Q2~4.8~5.0~4.8
Q3~6.0~5.8~6.0
Q4~7.2~7.0~7.2
Q5 (highest)~9.0~9.2~9.0

Calibration is good: predicted CATEs closely track both the observed and true effects across quintiles, confirming the forest identifies genuine heterogeneity.


Step 6: Optimal Targeting Policy

Use CATE estimates to design a policy that treats only those who benefit most.

# Policy: treat the top 50% by estimated CATE
threshold <- median(tau_hat)
policy <- as.integer(tau_hat >= threshold)

# Evaluate policy using true CATEs (only possible in simulation)
cat("Average effect, targeted:", mean(tau_true[policy == 1]), "\n")
cat("Average effect, not targeted:", mean(tau_true[policy == 0]), "\n")
cat("Average effect, treating all:", mean(tau_true), "\n")

# Compare forest-based targeting with a simple rule (risk > 0)
simple_policy <- as.integer(risk > 0)
cat("\nSimple rule benefit:", mean(tau_true[simple_policy == 1]), "\n")
cat("Forest targeting benefit:", mean(tau_true[policy == 1]), "\n")

Expected output:

PolicyAvg. Effect for TargetedAvg. Effect for Not TargetedGain from Targeting
Treat top 50% (forest)~7.8~4.2~1.8 vs. treating all
Simple rule (risk > 0)~7.5~4.5~1.5 vs. treating all
=== Optimal Targeting Policy (treat top 50%) ===
Average effect for those targeted:      ~7.8
Average effect for those NOT targeted:  ~4.2
Average effect if treating everyone:    ~6.0

Gain from targeting: ~1.8 per treated individual
Fraction correctly identified (true top 50%): ~80%

Simple rule (risk > 0) benefit:         ~7.5
Forest-based targeting benefit:         ~7.8

The forest-based targeting outperforms the simple rule because it combines information from risk, age, and education simultaneously, while the simple rule uses only one variable.

Concept Check

You estimate CATEs using a causal forest on observational (non-experimental) data and find that older individuals have smaller treatment effects. A colleague suggests this finding could be driven by differential selection into treatment rather than true heterogeneity. How can you address this concern?


Exercises

  1. Compare ML methods. Estimate CATEs using a causal forest, a T-learner (separate random forests for treated and control), and a linear interaction model. Which has the best RMSE for the true CATEs?

  2. Out-of-sample validation. Split the data 50/50. Train the causal forest on the first half and evaluate CATE predictions on the second half using the calibration test.

  3. Budget-constrained targeting. Suppose you can only treat 25% of the population. Design the optimal targeting policy and compute the expected average effect.

  4. Nonlinear heterogeneity. Modify the DGP so that the treatment effect has a U-shape in age. Does the causal forest capture this pattern?


Summary

In this lab you learned:

  • Causal forests estimate the conditional average treatment effect (CATE) as a function of covariates, moving beyond the ATE
  • The method uses honest splitting (separate samples for determining splits and estimating effects) to produce valid confidence intervals
  • Variable importance reveals which covariates drive treatment effect heterogeneity
  • Calibration tests and RATE curves assess whether estimated CATEs are predictive of actual treatment effect heterogeneity
  • Optimal targeting policies assign treatment to individuals with the highest predicted CATEs, potentially improving welfare
  • CATE estimates from observational data should be validated before informing policy, as heterogeneity may reflect differential selection rather than true effect variation