| Title: | Balancing Confounder Distributions with Forest Energy Balancing |
|---|---|
| Description: | Estimates average treatment effects using kernel energy balancing with random forest similarity kernels. A multivariate random forest jointly models covariates, outcome, and treatment to build a similarity kernel between observations. This kernel is then used for energy balancing to create weights that control for confounding. The method is described in De and Huling (2025) <doi:10.48550/arXiv.2512.18069>. |
| Authors: | Jared Huling [aut, cre], Simion De [aut] |
| Maintainer: | Jared Huling <[email protected]> |
| License: | GPL (>= 3) |
| Version: | 0.1.1 |
| Built: | 2026-05-08 06:30:55 UTC |
| Source: | https://github.com/jaredhuling/forestbalance |
Estimates average treatment effects (ATE) using kernel energy balancing with random forest similarity kernels. A multivariate random forest jointly models covariates, treatment, and outcome to build a proximity kernel that captures confounding structure. Balancing weights are obtained via a closed-form kernel energy distance solution.
forest_balance is the primary interface. It fits the forest,
constructs the kernel, computes balancing weights, and estimates the ATE.
By default it uses K-fold cross-fitting and an adaptive leaf size to minimize
overfitting bias.
Adaptive min.node.size that scales with and
K-fold cross-fitting to reduce kernel overfitting bias
Rcpp-accelerated leaf node extraction
Sparse kernel construction via single tcrossprod
Conjugate gradient solver for large (avoids forming the
kernel matrix entirely)
For more control, the pipeline can be run step by step:
get_leaf_node_matrix, leaf_node_kernel,
kernel_balance.
Maintainer: Jared Huling [email protected]
Authors:
Simion De
De, S. and Huling, J.D. (2025). Data adaptive covariate balancing for causal effect estimation for high dimensional data. arXiv preprint arXiv:2512.18069.
Useful links:
Report bugs at https://github.com/jaredhuling/forestBalance/issues
Computes standardized mean differences (SMD), effective sample sizes (ESS), and optionally the weighted energy distance for a given set of balancing weights. Can also assess balance on user-supplied nonlinear transformations of the covariates.
compute_balance(X, trt, weights, X.trans = NULL, energy.dist = TRUE) ## S3 method for class 'forest_balance_diag' print(x, threshold = 0.1, ...)compute_balance(X, trt, weights, X.trans = NULL, energy.dist = TRUE) ## S3 method for class 'forest_balance_diag' print(x, threshold = 0.1, ...)
X |
An |
trt |
A binary (0/1) vector of treatment assignments of length |
weights |
A numeric weight vector of length |
X.trans |
An optional matrix of nonlinear or transformed covariates
( |
energy.dist |
Logical; if |
x |
A |
threshold |
SMD threshold for flagging imbalanced covariates. Default is 0.1, a standard threshold in the causal inference literature. |
... |
Ignored. |
The standardized mean difference for covariate is defined as
where is the weighted mean of covariate in
group and is the pooled (unweighted) standard deviation.
The effective sample size for a group is
reported as a fraction of the group size.
The weighted energy distance is
where and are the normalized treated and control weights.
An object of class "forest_balance_diag" containing:
Named vector of |SMD| for each covariate.
Maximum |SMD| across covariates.
Named vector of |SMD| for transformed covariates (if
X.trans was supplied), otherwise NULL.
Maximum |SMD| for transformed covariates, or
NA.
Weighted energy distance, or NA if not
computed.
Effective sample size for the treated group as a
fraction of .
Effective sample size for the control group as a
fraction of .
Total sample size.
Number of treated units.
Number of control units.
The input x, invisibly. Called for its side effect of
printing balance diagnostics to the console.
n <- 500; p <- 10 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(0.5 * X[, 1])) Y <- X[, 1] + rnorm(n) fit <- forest_balance(X, A, Y) bal <- compute_balance(X, A, fit$weights) bal # With nonlinear features X.nl <- cbind(X[,1]^2, X[,1]*X[,2]) colnames(X.nl) <- c("X1^2", "X1*X2") bal2 <- compute_balance(X, A, fit$weights, X.trans = X.nl) bal2n <- 500; p <- 10 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(0.5 * X[, 1])) Y <- X[, 1] + rnorm(n) fit <- forest_balance(X, A, Y) bal <- compute_balance(X, A, fit$weights) bal # With nonlinear features X.nl <- cbind(X[,1]^2, X[,1]*X[,2]) colnames(X.nl) <- c("X1^2", "X1*X2") bal2 <- compute_balance(X, A, fit$weights, X.trans = X.nl) bal2
Fits a multivariate random forest that jointly models the relationship between covariates, treatment, and outcome, computes a random forest proximity kernel, and then uses kernel energy balancing to produce weights for estimating the average treatment effect (ATE). By default, K-fold cross-fitting is used to avoid overfitting bias from estimating the kernel on the same data used for treatment effect estimation.
forest_balance( X, A, Y, num.trees = 1000, min.node.size = NULL, estimand = c("ATE", "ATT", "ATC"), cross.fitting = TRUE, num.folds = 2, augmented = FALSE, mu.hat = NULL, scale.outcomes = TRUE, solver = c("auto", "direct", "cg", "bj"), tol = 1e-08, parallel = FALSE, ... )forest_balance( X, A, Y, num.trees = 1000, min.node.size = NULL, estimand = c("ATE", "ATT", "ATC"), cross.fitting = TRUE, num.folds = 2, augmented = FALSE, mu.hat = NULL, scale.outcomes = TRUE, solver = c("auto", "direct", "cg", "bj"), tol = 1e-08, parallel = FALSE, ... )
X |
A numeric matrix or data frame of covariates ( |
A |
A binary (0/1) vector of treatment assignments. |
Y |
A numeric vector of outcomes. |
num.trees |
Number of trees to grow in the forest. Default is 1000. |
min.node.size |
Minimum number of observations per leaf node. If
|
cross.fitting |
Logical; if |
num.folds |
Number of cross-fitting folds. Default is 2. Only used when
|
augmented |
Logical; if |
mu.hat |
Optional list with components |
scale.outcomes |
If |
solver |
Which linear solver to use for the balancing weights.
|
tol |
Convergence tolerance for the CG solver. Default is |
parallel |
Logical or integer. If |
... |
Additional arguments passed to
|
The method proceeds in three steps:
A multi_regression_forest is fit on covariates
X with a bivariate response (A, Y). This jointly models
the relationship between covariates, treatment assignment, and outcome.
The forest's leaf co-membership structure defines a proximity kernel:
is the proportion of trees where and share
a leaf. Because the forest splits on both and , this
kernel captures confounding structure.
kernel_balance computes balancing weights via the
closed-form kernel energy distance solution. The ATE is then estimated
using the Hajek (ratio) estimator with these weights.
Cross-fitting (default): For each fold , the forest is
trained on all data except fold , and the kernel for fold
is built from that held-out forest's leaf predictions. This breaks
the dependence between the kernel and the outcomes, reducing overfitting
bias. The final ATE is the average of the per-fold Hajek estimates (DML1).
Augmented estimator: When augmented = TRUE, two
group-specific outcome models and
are fit, and the ATE is estimated via
the doubly-robust formula:
The first term is the regression-based estimate of the ATE; the remaining terms are weighted bias corrections. This is consistent if either the kernel (balancing weights) or the outcome models are correctly specified. When combined with cross-fitting, the outcome models are automatically cross-fitted in lockstep with the kernel.
Adaptive leaf size: The default min.node.size is set
adaptively via max(20, min(floor(n/200) + p, floor(n/50))). Larger
leaves produce smoother kernels that generalize better, while the cap at
n/50 prevents kernel degeneracy.
An object of class "forest_balance" (a list) with the
following elements:
The estimated average treatment effect. When cross-fitting is used, this is the average of per-fold Hajek estimates (DML1).
The balancing weight vector (length ). When
cross-fitting is used, these are the concatenated per-fold weights.
Predictions of (length ), or
NULL if augmented = FALSE.
Predictions of (length ), or
NULL if augmented = FALSE.
The forest proximity kernel (sparse matrix),
or NULL when cross-fitting or the CG solver is used.
The trained forest object. When cross-fitting is used, this is the last fold's forest.
The input data.
Total, treated, and control sample sizes.
The solver that was used ("direct" or "cg").
Logical indicating whether cross-fitting was used.
Logical indicating whether augmentation was used.
Number of folds (if cross-fitting was used).
Per-fold ATE estimates (if cross-fitting was used).
Fold assignments (if cross-fitting was used).
The object has print and summary methods. Use
summary.forest_balance for covariate balance diagnostics.
Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W. and Robins, J. (2018). Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21(1), C1–C68.
De, S. and Huling, J.D. (2025). Data adaptive covariate balancing for causal effect estimation for high dimensional data. arXiv preprint arXiv:2512.18069.
n <- 500 p <- 10 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(0.5 * X[, 1])) Y <- X[, 1] + rnorm(n) # true ATE = 0 # Default: cross-fitting with adaptive leaf size result <- forest_balance(X, A, Y) result # Augmented (doubly-robust) estimator result_aug <- forest_balance(X, A, Y, augmented = TRUE) # Without cross-fitting result_nocf <- forest_balance(X, A, Y, cross.fitting = FALSE)n <- 500 p <- 10 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(0.5 * X[, 1])) Y <- X[, 1] + rnorm(n) # true ATE = 0 # Default: cross-fitting with adaptive leaf size result <- forest_balance(X, A, Y) result # Augmented (doubly-robust) estimator result_aug <- forest_balance(X, A, Y, augmented = TRUE) # Without cross-fitting result_nocf <- forest_balance(X, A, Y, cross.fitting = FALSE)
Extracts the leaf node assignments for each observation across all trees
in a trained GRF forest, then computes the proximity kernel
matrix where entry is the proportion of trees in which
observations and share a leaf node.
forest_kernel(forest, newdata = NULL)forest_kernel(forest, newdata = NULL)
forest |
A trained forest object from the grf package. |
newdata |
A numeric matrix of observations. If |
This is a convenience function that calls get_leaf_node_matrix
followed by leaf_node_kernel. If you need both the leaf matrix
and the kernel, it is more efficient to call them separately.
A symmetric numeric matrix of dimension .
library(grf) n <- 100 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 50) K <- forest_kernel(forest)library(grf) n <- 100 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 50) K <- forest_kernel(forest)
For each observation and each tree in a trained GRF forest, determines which
leaf node the observation falls into. The tree traversal is implemented in
C++ for speed, directly reading the forest's internal tree structure and
avoiding the overhead of get_tree and
get_leaf_node.
get_leaf_node_matrix(forest, newdata = NULL)get_leaf_node_matrix(forest, newdata = NULL)
forest |
A trained forest object from the grf package
(e.g., from |
newdata |
A numeric matrix of observations to predict leaf membership
for. Must have the same number of columns as the training data. If
|
This function reads the internal tree vectors stored in a GRF forest object
(_child_nodes, _split_vars, _split_values,
_root_nodes, _send_missing_left) and traverses each tree in
compiled C++ code. This is dramatically faster than the per-observation
R-level loop in grf::get_leaf_node.
An integer matrix of dimension nrow(newdata) by
num.trees, where entry [i, b] is the internal node index
(1-based) of the leaf that observation i falls into in tree
b.
library(grf) n <- 200 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 100) # Leaf membership for training data leaf_mat <- get_leaf_node_matrix(forest) # Leaf membership for new data X.test <- matrix(rnorm(50 * p), 50, p) leaf_mat_test <- get_leaf_node_matrix(forest, X.test)library(grf) n <- 200 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 100) # Leaf membership for training data leaf_mat <- get_leaf_node_matrix(forest) # Leaf membership for new data X.test <- matrix(rnorm(50 * p), 50, p) leaf_mat_test <- get_leaf_node_matrix(forest, X.test)
Computes balancing weights that minimize a kernelized energy distance between the weighted treated and control distributions and the overall sample. The weights are obtained via a closed-form solution to a linear system derived from the kernel energy distance objective.
kernel_balance( trt, kern = NULL, Z = NULL, leaf_matrix = NULL, num.trees = NULL, estimand = c("ATE", "ATT", "ATC"), solver = c("auto", "direct", "cg", "bj"), tol = 1e-08, maxiter = 2000L )kernel_balance( trt, kern = NULL, Z = NULL, leaf_matrix = NULL, num.trees = NULL, estimand = c("ATE", "ATT", "ATC"), solver = c("auto", "direct", "cg", "bj"), tol = 1e-08, maxiter = 2000L )
trt |
A binary (0/1) integer or numeric vector indicating treatment
assignment ( |
kern |
A symmetric |
Z |
Optional sparse indicator matrix from
|
leaf_matrix |
Optional integer matrix of leaf node assignments
(observations x trees), as returned by |
num.trees |
Number of trees |
solver |
Which linear solver to use. |
tol |
Convergence tolerance for iterative solvers. Default is
|
maxiter |
Maximum iterations for iterative solvers. Default is 2000. |
The modified kernel used in the optimization is block-diagonal:
the treated–control cross-blocks are zero because
whenever . All solvers exploit this
structure by working on the treated and control blocks independently.
Solver requirements:
| Solver | Required inputs | Optional inputs |
"direct" |
kern (or Z + num.trees) |
|
"cg" |
Z + num.trees |
|
"bj" |
Z + num.trees + leaf_matrix
|
(falls back to "cg" if leaf_matrix is missing)
|
The direct solver extracts sub-blocks of the kernel and solves via
sparse Cholesky. If only Z is provided, the kernel is formed as
, which requires time and memory.
The CG solver uses the factored representation
to perform matrix–vector products without forming any kernel matrix.
The Block Jacobi solver ("bj") uses the first tree's leaf
partition (from leaf_matrix) to define a block-diagonal
preconditioner for CG. Each leaf block is a small dense system that is
cheap to factor.
Only 2 linear solves per block are needed (not 3) because the third right-hand side is a linear combination of the first two.
A list with the following elements:
A numeric vector of length containing the balancing
weights. Treated weights sum to and control weights sum to
.
The solver that was used.
De, S. and Huling, J.D. (2025). Data adaptive covariate balancing for causal effect estimation for high dimensional data. arXiv preprint arXiv:2512.18069.
library(grf) n <- 200 p <- 5 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(X[, 1])) Y <- X[, 1] + rnorm(n) # --- Direct solver (using the kernel matrix) --- forest <- multi_regression_forest(X, cbind(A, Y), num.trees = 500) K <- forest_kernel(forest) bal_direct <- kernel_balance(A, kern = K, solver = "direct") # --- CG solver (using the Z matrix, avoids forming K) --- # Step 1: extract leaf node assignments (n x B matrix) leaf_mat <- get_leaf_node_matrix(forest, X) # Step 2: build sparse indicator matrix Z such that K = Z Z' / B Z <- leaf_node_kernel_Z(leaf_mat) # Step 3: solve with CG (matrix-free, no kernel formed) bal_cg <- kernel_balance(A, Z = Z, num.trees = 500, solver = "cg") # Both solvers give the same weights max(abs(bal_direct$weights - bal_cg$weights)) # Weighted ATE estimate w <- bal_cg$weights ate <- weighted.mean(Y[A == 1], w[A == 1]) - weighted.mean(Y[A == 0], w[A == 0])library(grf) n <- 200 p <- 5 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(X[, 1])) Y <- X[, 1] + rnorm(n) # --- Direct solver (using the kernel matrix) --- forest <- multi_regression_forest(X, cbind(A, Y), num.trees = 500) K <- forest_kernel(forest) bal_direct <- kernel_balance(A, kern = K, solver = "direct") # --- CG solver (using the Z matrix, avoids forming K) --- # Step 1: extract leaf node assignments (n x B matrix) leaf_mat <- get_leaf_node_matrix(forest, X) # Step 2: build sparse indicator matrix Z such that K = Z Z' / B Z <- leaf_node_kernel_Z(leaf_mat) # Step 3: solve with CG (matrix-free, no kernel formed) bal_cg <- kernel_balance(A, Z = Z, num.trees = 500, solver = "cg") # Both solvers give the same weights max(abs(bal_direct$weights - bal_cg$weights)) # Weighted ATE estimate w <- bal_cg$weights ate <- weighted.mean(Y[A == 1], w[A == 1]) - weighted.mean(Y[A == 0], w[A == 0])
Given a matrix of leaf node assignments (observations x trees), computes the
kernel matrix where entry is the proportion of
trees in which observations and fall into the same leaf node.
leaf_node_kernel(leaf_matrix, sparse = TRUE)leaf_node_kernel(leaf_matrix, sparse = TRUE)
leaf_matrix |
An integer matrix of dimension |
sparse |
Logical; if |
For each tree , the leaf assignment defines a sparse indicator matrix . This function stacks all indicator
matrices column-wise into a single sparse matrix
of dimension
. The kernel is then obtained via a single sparse
cross-product: . Leaf ID remapping is done in C++ for
speed.
A symmetric matrix (sparse or dense depending on
sparse) where entry equals
i.e., the fraction of trees where and share a leaf.
library(grf) n <- 100 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 50) leaf_mat <- get_leaf_node_matrix(forest) K <- leaf_node_kernel(leaf_mat) # sparse K_dense <- leaf_node_kernel(leaf_mat, sparse = FALSE) # denselibrary(grf) n <- 100 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 50) leaf_mat <- get_leaf_node_matrix(forest) K <- leaf_node_kernel(leaf_mat) # sparse K_dense <- leaf_node_kernel(leaf_mat, sparse = FALSE) # dense
Returns the sparse indicator matrix such that
the proximity kernel is . This factored
representation can be passed to kernel_balance to enable the
CG solver, which avoids forming the full kernel.
leaf_node_kernel_Z(leaf_matrix)leaf_node_kernel_Z(leaf_matrix)
leaf_matrix |
An integer matrix of dimension |
A sparse dgCMatrix of dimension , where
is the total number of leaves across all trees.
Each row has exactly nonzero entries (one per tree).
library(grf) n <- 100 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 50) leaf_mat <- get_leaf_node_matrix(forest) Z <- leaf_node_kernel_Z(leaf_mat)library(grf) n <- 100 p <- 5 X <- matrix(rnorm(n * p), n, p) Y <- cbind(X[, 1] + rnorm(n), X[, 2] + rnorm(n)) forest <- multi_regression_forest(X, Y, num.trees = 50) leaf_mat <- get_leaf_node_matrix(forest) Z <- leaf_node_kernel_Z(leaf_mat)
Displays a concise summary of the forest balance fit, including the ATE estimate, sample sizes, effective sample sizes, and a brief covariate balance overview.
## S3 method for class 'forest_balance' print(x, ...)## S3 method for class 'forest_balance' print(x, ...)
x |
A |
... |
Ignored. |
The input x, invisibly. Called for its side effect of
printing a summary to the console.
Generates data from a design with nonlinear confounding, where covariates jointly influence both treatment assignment and the outcome through non-trivial functions. The true average treatment effect is known, allowing evaluation of estimator performance.
simulate_data(n = 500, p = 10, ate = 0, rho = -0.25, sigma = 1, dgp = 1)simulate_data(n = 500, p = 10, ate = 0, rho = -0.25, sigma = 1, dgp = 1)
n |
Sample size. Default is 500. |
p |
Number of covariates. Must be at least 5 for |
ate |
True average treatment effect. Default is 0. |
rho |
Correlation parameter for the AR(1) covariance structure among
covariates: |
sigma |
Noise standard deviation for the outcome. Default is 1. |
dgp |
Integer selecting the data generating process. Default is 1. See Details. |
Both DGPs generate covariates from where
.
DGP 1 (default): Confounding through via a Beta density.
Propensity: where
is the Beta(2,4) density.
Outcome: .
DGP 2: Rich outcome surface with moderate confounding. Designed to
illustrate the benefit of the augmented estimator. Confounding operates
through and , while the outcome depends on
with interactions and nonlinearities.
Propensity: .
Outcome: .
A list with the following elements:
The covariate matrix.
Binary (0/1) treatment assignment vector.
Observed outcome vector.
True propensity scores .
The true ATE used in the simulation.
Sample size.
Number of covariates.
The DGP that was used.
dat1 <- simulate_data(n = 500, p = 10, dgp = 1) dat2 <- simulate_data(n = 500, p = 20, dgp = 2)dat1 <- simulate_data(n = 500, p = 10, dgp = 1) dat2 <- simulate_data(n = 500, p = 20, dgp = 2)
Produces a detailed summary of the forest balance fit, including the ATE estimate, covariate balance diagnostics (SMD, ESS, energy distance), and kernel sparsity information.
## S3 method for class 'forest_balance' summary(object, X.trans = NULL, threshold = 0.1, energy.dist = TRUE, ...) ## S3 method for class 'summary.forest_balance' print(x, ...)## S3 method for class 'forest_balance' summary(object, X.trans = NULL, threshold = 0.1, energy.dist = TRUE, ...) ## S3 method for class 'summary.forest_balance' print(x, ...)
object |
A |
X.trans |
An optional matrix of nonlinear or transformed covariates
( |
threshold |
SMD threshold for flagging imbalanced covariates. Default is 0.1. |
energy.dist |
Logical; if |
... |
Ignored. |
x |
A |
Invisibly returns a list of class "summary.forest_balance"
containing:
The estimated ATE.
A forest_balance_diag object for the
weighted sample.
A forest_balance_diag object for the
unweighted sample.
Fraction of nonzero entries in the kernel matrix.
Total sample size.
Number of treated.
Number of control.
Number of trees in the forest.
The SMD threshold used for flagging imbalanced covariates.
The input x, invisibly. Called for its side effect of
printing a detailed balance summary to the console.
n <- 500; p <- 10 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(0.5 * X[, 1])) Y <- X[, 1] + rnorm(n) fit <- forest_balance(X, A, Y) summary(fit) # With nonlinear balance assessment X.nl <- cbind(X[,1]^2, X[,1]*X[,2]) colnames(X.nl) <- c("X1^2", "X1*X2") summary(fit, X.trans = X.nl)n <- 500; p <- 10 X <- matrix(rnorm(n * p), n, p) A <- rbinom(n, 1, plogis(0.5 * X[, 1])) Y <- X[, 1] + rnorm(n) fit <- forest_balance(X, A, Y) summary(fit) # With nonlinear balance assessment X.nl <- cbind(X[,1]^2, X[,1]*X[,2]) colnames(X.nl) <- c("X1^2", "X1*X2") summary(fit, X.trans = X.nl)