## wants loaded
## [1,] "CVXR" TRUE
## [2,] "tidyverse" TRUE
## [3,] "data.table" TRUE
## [4,] "ebal" TRUE
## [5,] "tictoc" TRUE
## [6,] "knitr" TRUE
knitr::opts_chunk$set(fig.width=12, fig.height=8,
echo=TRUE, warning=FALSE, message=FALSE, cache = TRUE)
CVXR
is a disciplined convex programming (DCP) domain-specific-language (DSL) that allows rapid prototyping of estimators that solve convex optimization problems. This writeup illustrates the implementation of Entropy Balancing (Hainmueller 2012) in CVXR.
Choose balancing weights \(w_i\) that solves
\[\begin{align*} \max_{\mathbf{w}} H(w) &= - \sum_{i : D_i = 0} w_i \log w_i \\ \text{Balance constraints:} & \sum_{i : D_i = 0} w_i c_{ri}(\mathbf{X}_i) = m_r \text{ with } r \in 1, \dots, R \\ \text{'Proper' weights:} & \sum_{i : D_i = 0} w_i = 1 \; \; \text{and} w_i \geq 0 \; \forall \; \{i: D = 0 \} \end{align*}\]
Hainmueller (2012) solves the dual of this problem in ebal
. Here, we implement the math directly to illustrate the rapid translation of math into code.
Simulated data where treated observations have mean 0.5, and control obs have mean 0.
# %% # create a toy dataset (50 controls, 30 treated)
d <- c(rep(0,50),rep(1,30))
# 3 covariates normals, treated have higher means
X <- rbind(replicate(3,rnorm(50, 0)),
replicate(3,rnorm(30, .5)))
colnames(X) <- paste("x",1:3,sep="")
df = data.table(d, X)
ebal
## Converged within tolerance
# %% # means in reweighted control group data match treated
rbind(
# means in T and C groups
df[, lapply(.SD, mean), by = d],
# means in reweighted control group data
data.table(d = -1, t(apply(df[d==0, 2:4], 2, weighted.mean, w=eb.out$w)))
)
## d x1 x2 x3
## 1: 0 0.08333 -0.09338 -0.0007558
## 2: 1 0.42659 0.17099 0.3293803
## 3: -1 0.39624 0.17093 0.3090740
bottom row is reweighted mean of control Xs using ebal weights
CVXR
solving for weights in CVXR
involves writing the optimisation problem almost verbatim as code, and running it. the entr
‘atom’ is an internal CVXR shortcut to evaluate \(-w_i \log w_i\). Several other similar atoms exist (for computing norms, and so on).
# balance with treatment means
b = df[d == 1, lapply(.SD, mean)][1, 2:4] |> as.numeric()
# control observations
X = df[d == 0, 2:4] |> as.matrix()
# %% manual entropy balancing
ω = Variable(df[d == 0, .N]) # weights are of length n_0
objective <- Maximize(sum(entr(ω))) # entropy objective fn
constraints <- list(
ω >= 0, sum(ω) == 1, # proper weights
t(X) %*% ω == b # balance
)
prob <- Problem(objective, constraints)
result <- solve(prob, solver = "MOSEK")
ω_hat <- result$getValue(ω)
# %%
rbind(
# means in T and C groups
df[, lapply(.SD, mean), by = d],
# means in reweighted control group data
data.table(d = -1, t(apply(df[d==0, 2:4], 2, weighted.mean, w=ω_hat)))
)
## d x1 x2 x3
## 1: 0 0.08333 -0.09338 -0.0007558
## 2: 1 0.42659 0.17099 0.3293803
## 3: -1 0.42659 0.17099 0.3293803
## wants loaded
## [1,] "causalsens" TRUE
## [2,] "glue" TRUE
## [3,] "lfe" TRUE
data(lalonde.exp); data(lalonde.psid);
dtlalonde = data.table(lalonde.exp)
dtpsid = data.table(lalonde.psid)
y = 're78'
w = 'treat'
x = setdiff(colnames(dtpsid), c(y, w))
ebal
weights in regressionCVXR
eb_reg2 = function(df, y, w, x){
setorderv(df, w)
# treated and control obs
ctrl = df[eval(parse(text = glue("{w} == 0")))]
treat = df[eval(parse(text = glue("{w} == 1")))]
# balance constraints - means of all Xs in treated group
b = treat[, lapply(.SD, mean), .SDcols = x] |> as.numeric()
X = ctrl[, ..x] |> as.matrix() # matrix of control obs
# solve optimisation problem in CVXR
ω = Variable(nrow(ctrl)) # weights are of length n_0
objective <- Maximize(sum(entr(ω))) # entropy objective fn
constraints <- list(
ω >= 0, sum(ω) == 1, # proper weights
t(X) %*% ω == b # balance
)
prob <- Problem(objective, constraints)
result <- solve(prob, solver = "MOSEK")
ω_hat <- result$getValue(ω)
# line up weights (solved weights for control units and 1/n1 for treat units)
ω_all = c(ω_hat, rep(1/nrow(df), nrow(treat) ) )
treatment = df[[w]]; y = df[[y]]
eff = felm(y ~ treatment, weights = ω_all) |> robustify()
summary(eff)$coefficients[2,1:2]
}