5  Gaussian Example

Objectives

In this chapter, we generate a binary treatment \(a\in\{0,1\}\) (assigned in a well-balances manner) which affects two mediators, \(x_1\) and \(x_2\), as well as an outcome \(y\). The outcome is generated according to a linear model. We then compute causal effects using different methods (causal mediation analysis, as well as transport-based methods).

library(tidyverse)
library(tikzDevice)
library(mnormt)
library(fairadapt)
Codes for graphical parameters
library(extrafont, quietly = TRUE)
col_group <- c("#00A08A","#F2AD00", "#1b95e0")
colour_methods <- c(
    "OT" = "#CC79A7", "OT-M" = "#009E73",
    "skh" = "darkgray",
    "seq_1" = "#0072B2", "seq_2" = "#D55E00",
    "fairadapt" = "#9966FF"
  )
colGpe1 <- col_group[2]
colGpe0 <- col_group[1]
colGpet <- col_group[3]
loadfonts(device = "pdf", quiet = TRUE)
font_size <- 20
font_family <- "CMU Serif"

path <- "./figs/"
if (!dir.exists(path)) dir.create(path)

source("../scripts/utils.R")

\[ \definecolor{wongBlack}{RGB}{0,0,0} \definecolor{wongGold}{RGB}{230, 159, 0} \definecolor{wongLightBlue}{RGB}{86, 180, 233} \definecolor{wongGreen}{RGB}{0, 158, 115} \definecolor{wongYellow}{RGB}{240, 228, 66} \definecolor{wongBlue}{RGB}{0, 114, 178} \definecolor{wongOrange}{RGB}{213, 94, 0} \definecolor{wongPurple}{RGB}{204, 121, 167} \definecolor{colGpe1}{RGB}{0, 160, 138} \definecolor{colGpe0}{RGB}{242, 173, 0} \]

5.1 Data Generating Process

We want to simulate potential outcomes in a binary treatment setting, with covariate shift between treatment groups.

Let \(n=500\) denote the number of individuals (or unit), and let \(\boldsymbol{X}=(X_1,X_2)\) be drawn from bivariate normal distrubtions whose mean vectors and covariance matrices depend on the treatment assignment \(A\in\{0,1\}\).

For untreated individuals (\(A=\color{colGpe0}0\)) the covariates \(\boldsymbol{X}^{(0)} = (X_1{(0)}, X_2{(0)})\) are sampled from a \(\mathcal{N}(\mu_0, \Sigma_0)\), where \(\mu_0 = -1\), \(\Sigma_0 = \begin{pmatrix} 1 & r_0 \\ r_0 & 1 \end{pmatrix}\) with \(r_0 = 0.7\).

For treated individuals (\(A=\color{colGpe1}1\)), covariates \(\boldsymbol{X}{(1)} = (X_1{(1)}, X_2{(1)})\) follow a \(\mathcal{N}(\mu_1, \Sigma_1)\), where \(\mu_1 = +1\), \(\Sigma_1 = \begin{pmatrix} 1 & r_1 \\ r_1 & 1 \end{pmatrix}\) with \(r_1 = -0.5\).

The treatment assignment \(A\) is randomized with probability \(p_1 = 0.5\).

The potential outcomes are linear functions of the covariates: \[ \begin{aligned} Y(0) &= a_1 X_1 + a_2 X_2 + \varepsilon,\\ Y(1) &= a_1 X_1 + a_2 X_2 + a_0 + \varepsilon . \end{aligned} \]

where \(\varepsilon \sim \mathcal{N}(0, 1)\) and \(a_0 = 3\), \(a_1 = 2\), \(a_2 = -1.5\).

The observed outcome is \[Y = A \cdot Y(1) + (1 - A) \cdot Y(0).\]

Natural Indirect Effect

Here, with the current DGP, we have: \[ \begin{cases} \delta_i(0) = a_1(X_{1,i}(1) - X_{1,i}(0)) + a_2(X_{2,i}(1)-X_{2,i}(0))\\ \delta_i(1) = a_1(X_{1,i}(1) - X_{1,i}(0)) + a_2(X_{2,i}(1)-X_{2,i}(0)) \end{cases} \]

Average Natural Indirect Effect

Here, we have:

\[ \begin{cases} \bar{\delta}(0) = (a_1+a2)(\mu_1-\mu_0)\\ \bar{\delta}(1) = (a_1+a2)(\mu_1-\mu_0) \end{cases} \]

Natural Direct Effect

Here, we have:

\[ \begin{cases} \zeta_i{(0)} = a_0\\ \zeta_i{(1)} = a_0\\ \end{cases} \]

Total Causal Effect

Here, we have:

\[ \tau_i = a_1(X_{1,i}(1) - X_{1,i}(0)) + a_2(X_{2,i}(1)-X_{2,i}(0)) + a_0 \] and

\[ \bar{\delta}=(a_1+a_2)(\mu_1-\mu_0) + a0 \]

set.seed(12345)
# Parameters
n <- 500
mu0 <- -1
mu1 <- +1
r0 <- +.7
r1 <- -.5
a <- 1
a0 <-  3
a1 <-  2
a2 <-  -1.5
p1 <- .5
Mu0 <- rep(mu0, 2)
Mu1 <- rep(mu1, 2)
Sig0 <- matrix(c(1, r0, r0, 1), 2, 2)
Sig1 <- matrix(c(1, r1, r1, 1), 2, 2)

# Draw covariates
X0 <- rmnorm(n, mean = a * Mu0, varcov = Sig0)
X1 <- rmnorm(n, mean = a * Mu1, varcov = Sig1)
# Random noise
E <- rnorm(n)
# Binary treatment
A <- sample(0:1, size = n, replace = TRUE, prob = c(1 - p1, p1))

X <- X0
X[A==1, ] = X1[A==1, ]

df <- tibble(
  X1 = X[, 1],
  X2 = X[, 2],
  A = A,
  Y0 = a1 * X1 + a2 * X2 + E,
  Y1 = a1 * X1 + a2 * X2 + a0 + E,
  Y = A * Y1 + (1-A) * Y0
)

We define a function to wrap this DGP.

The gen_data() function.
#' @param n Number of units.
#' @param mu0 Mean of the two covariates in group 0.
#' @param mu1 Mean of the two covariates in group 1.
#' @param r0 Covariance of the two covariates in group 0.
#' @param r1 Covariance of the two covariates in group 1.
#' @parma a Shift parameter for the mean in both groups
#'  (default to 1: no shift). Larger values decreases overlapping.
gen_data <- function(n = 500,
                     mu0 = -1,
                     mu1 = +1,
                     r0 = +.7,
                     r1 = -.5,
                     a = 1,
                     seed = NULL) {
  
  if (!is.null(seed)) set.seed(seed)
  
  a0 <-  3
  a1 <-  2
  a2 <-  -1.5
  p1 <- .5
  Mu0 <- rep(mu0, 2)
  Mu1 <- rep(mu1, 2)
  Sig0 <- matrix(c(1, r0, r0, 1), 2, 2)
  Sig1 <- matrix(c(1, r1, r1, 1), 2, 2)
  # Draw covariates
  X0 <- rmnorm(n, mean = a * Mu0, varcov = Sig0)
  X1 <- rmnorm(n, mean = a * Mu1, varcov = Sig1)
  # Random noise
  E <- rnorm(n)
  # Binary treatment
  A <- sample(0:1, size = n, replace = TRUE, prob = c(1 - p1, p1))
  X <- X0
  X[A==1, ] = X1[A==1, ]
  df <- tibble(
    X1 = X[, 1],
    X2 = X[, 2],
    A = A,
    Y0 = a1 * X1 + a2 * X2 + E,
    Y1 = a1 * X1 + a2 * X2 + a0 + E,
    Y = A * Y1 + (1-A) * Y0
  )
  
  df
}

5.2 Counterfactuals

Let us build counterfactuals for individuals from group 0, and for individuals from group 1. We will consider the following methods:

  • Multivariate Optimal Transport (since we know the parameters of the two Gaussians),
  • Regularized transport using Sinkhorn algorithm,
  • Sequential Optimal Transport.

5.2.1 Optimal Transport

library(expm)

Given two collection of points \(\{\boldsymbol{x}_{0,1},\cdots,\boldsymbol{x}_{0,n_0}\}\) and \(\{\boldsymbol{x}_{1,1},\cdots,\boldsymbol{x}_{1,n_1}\}\) in \(\mathcal{X}_0\) and \(\mathcal{X}_1\), and a cost \(c:\mathcal{X}_0\times\mathcal{X}_1\to\mathbb{R}_+\), define the cost matrix, \(n_0\times n_1\), \(\boldsymbol{C}:=[C_{i,j}]\) where \(C_{i,j}=c(\boldsymbol{x}_{0,i},\boldsymbol{x}_{1,j})\). The optimal matching problem is [ {({n_0},{n_1})} ,,,,, ={i=1}^{n_0}{j=1}^{n_1}P{ij},C_{ij}, ] where \(\mathcal{U}(\boldsymbol{1}_{n_0},\boldsymbol{1}_{n_1})\) is the polytope [ {,P+^{n_0n_1}: P,{n_1}=,  P^_{n_0}= }. ]

Here, in the Gaussian case, the optimal transport map \(T(x)\) from \(\mathcal{N}(\boldsymbol{\mu}_0, \boldsymbol{\Sigma}_0)\) to \(\mathcal{N}(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1)\) is: \[T(x) = \boldsymbol{\mu}_1 + \boldsymbol{A}(x - \boldsymbol{\mu}_0)\] where: \[ \boldsymbol{A} = \boldsymbol{\Sigma}_0^{1/2} \left( \boldsymbol{\Sigma}_0^{1/2} \boldsymbol{\Sigma}_1 \boldsymbol{\Sigma}_0^{1/2} \right)^{-1/2} \boldsymbol{\Sigma}_0^{1/2} \]

We define the function compute_ot_map() to compute the optimal mapping.

#' Optimal transport mapping between two Gaussian distributions 
#'  (from \eqn{\mathcal{N}(\mu_{\text{source}}, \Sigma_{\text{source}})} to 
#'   \eqn{\mathcal{N}(\mu_{\text{target}}, \Sigma_{\text{target}})})
#'  
#' @param mu_source Mean vector of the source Gaussian.
#' @param sigma_source Covariance matrix of the source Gaussian.
#' @param mu_target Mean vector of the target Gaussian.
#' @param sigma_target Covariance matrix of the target Gaussian.
compute_ot_map <- function(mu_source, sigma_source, mu_target, sigma_target) {
  sqrt_sigma_source <- sqrtm(sigma_source)
  sqrt_sigma_source_inv <- solve(sqrt_sigma_source)
  
  inner <- sqrt_sigma_source %*% sigma_target %*% sqrt_sigma_source
  sqrt_inner <- sqrtm(inner)
  
  A <- sqrt_sigma_source_inv %*% sqrt_inner %*% sqrt_sigma_source_inv
  
  list(A = A, shift = mu_target - A %*% mu_source)
}

We also define the apply_ot_transport() function which uses a transport plan to transport individuals.

#' Function to apply the transport map to simulated data
#' 
#' @param X Observations to transport.
#' @param mapping Optimal transport mapping (from `compute_ot_map()`)?
apply_ot_transport <- function(X, mapping) {
  A <- mapping$A
  shift <- mapping$shift
  t(apply(X, 1, function(x) as.vector(shift + A %*% x)))
}

Since we generated the data, we know the exact transport plan to transport individuals from group 0 to group 1. We also know the exact transport plan to transport individuals from group 1 to group 0.

Sigma0 <- matrix(c(1, r0, r0, 1), 2, 2)
Sigma1 <- matrix(c(1, r1, r1, 1), 2, 2)
Mu0 <- rep(a * mu0, 2)
Mu1 <- rep(a * mu1, 2)
# Mapping from group 0 to group 1
ot_map_0_to_1 <- compute_ot_map(
  mu_source = Mu0, sigma_source = Sigma0, 
  mu_target = Mu1, sigma_target = Sigma1
)
# Mapping from group 1 to group 0
ot_map_1_to_0 <- compute_ot_map(
  mu_source = Mu1, sigma_source = Sigma1, 
  mu_target = Mu0, sigma_target = Sigma0
)

We apply the transport map to the untreated units (A = 0).

X0 <- as.matrix(df[df$A == 0, c("X1", "X2")])
X0_t <- apply_ot_transport(X = X0, mapping = ot_map_0_to_1)
colnames(X0_t) <- c(c("X1", "X2"))

And to the transport map to the treated units (A = 1).

X1 <- as.matrix(df[df$A == 1, c("X1", "X2")])
X1_t <- apply_ot_transport(X = X1, mapping = ot_map_1_to_0)
colnames(X1_t) <- c(c("X1", "X2"))

Let us visualize the transported individuals. First, we define the function draw_ellipse() which will allow us to plot the 95% confidence ellipse in both groups.

The draw_ellipse() function.
draw_ellipse <- function(mu, 
                         sigma, 
                         col = "black", 
                         lty = 1, 
                         lwd = 1, 
                         level = 0.95, 
                         ...) {
  
  angles <- seq(0, 2 * pi, length.out = 100)
  vals <- sqrt(
    qchisq(level, df = 2)) * t(chol(sigma)) %*% rbind(cos(angles), sin(angles)
    )
  lines(mu[1] + vals[1, ], mu[2] + vals[2, ], col = col, lty = lty, lwd = lwd, ...)
  
}

We isolate the observations from group 0 and from group 1.

# Prepare data for the plot
X0 <- df[df$A == 0, c("X1", "X2")]
X1 <- df[df$A == 1, c("X1", "X2")]

The initial points and the transported values are shown in Figure 5.1

Codes to create the Figure.
par(mar = c(2.1, 2.1, 2.1, 0.1), mfrow = c(1, 2))
x_lim <- c(-4, 4)
y_lim <- c(-4, 4)

# From 0 to 1
plot(X0, 
     pch = 16, 
     col = adjustcolor(colGpe0, alpha = .3), 
     xlim = x_lim, ylim = y_lim, 
     xlab = "", ylab = "",
     main = "OT: from A=0 to A=1",
     family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X0_t, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X0$X1, y0 = X0$X2,
  x1 = X0_t[, 1], y1 = X0_t[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# True mean and covariance (scaled by 'a')
Mu0 <- rep(a * mu0, 2)
Mu1 <- rep(a * mu1, 2)
Sig0 <- matrix(c(1, r0, r0, 1), 2, 2)
Sig1 <- matrix(c(1, r1, r1, 1), 2, 2)

# Covariance of transported points (via OT map)
Sigma0_transport <- ot_map_0_to_1$A %*% Sig0 %*% t(ot_map_0_to_1$A)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)
draw_ellipse(Mu1, Sigma0_transport, col = colGpet, lty = 2)

# From 1 to 0
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "OT: from A=1 to A=0",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X1_t, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X1$X1, y0 = X1$X2,
  x1 = X1_t[, 1], y1 = X1_t[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# Covariance of transported points (via OT map)
Sigma0_transport <- ot_map_1_to_0$A %*% Sig1 %*% t(ot_map_1_to_0$A)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)
draw_ellipse(Mu0, Sigma0_transport, col = colGpet, lty = 2)
Figure 5.1: 500 points in each group drawn from bivariates Gaussian distributions and transported values from group 0 to group 1 (left), and from group 1 to group 0 (right), using optimal transport.

5.2.2 Transport-based Many-to-1 Matching

Let us use optimal transport to perform transport-based optimal many-to-1 matching. We use optimal transport with a uniform source and target distribution. For each unit from the source group (group~0), we select the matched unit from the target group (group~1) as the one wit the maximum transported mass. This way, each unit in the source is matched to exactly one in the target.

The transport_many_to_one() function.
#' @param X_source Source characteristics
#' @param X_target Target characteristics
#' @param method Algorithm to use for transport
transport_many_to_one <- function(X_source, 
                                  X_target, 
                                  method = "shortsimplex") {
  n_source <- nrow(X_source)
  n_target <- nrow(X_target)
  
  # Uniform weights
  w_source <- rep(1 / n_source, n_source)
  w_target <- rep(1 / n_target, n_target)
  
  # Cost matrix
  cost <- as.matrix(dist(rbind(X_source, X_target)))
  cost <- cost[1:n_source, (n_source + 1):(n_source + n_target)]
  
  # Solve OT plan
  ot_plan <- transport::transport(
    w_source, w_target, costm = cost, method = method
  )
  
  # For each source unit, select the target with the highest mass
  best_match <- ot_plan |>
    dplyr::group_by(from) |>
    dplyr::slice_max(mass, n = 1, with_ties = FALSE) |>
    dplyr::ungroup()
  
  # Matched matrix
  X_matched <- X_target[best_match$to, , drop = FALSE]
  
  X_matched
}

Let us apply the transport_many_to_one() function to get the counterfactuals in both group.

X0 <- as.matrix(df[df$A == 0, c("X1", "X2")])
X1 <- as.matrix(df[df$A == 1, c("X1", "X2")])
X0_tmatch <- transport_many_to_one(X_source = X0, X_target = X1)
X1_tmatch <- transport_many_to_one(X_source = X1, X_target = X0)
Codes to create the Figure.
par(mar = c(2.1, 2.1, 2.1, 0.1), mfrow = c(1, 2))
x_lim <- c(-4, 4)
y_lim <- c(-4, 4)

# From 0 to 1
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "Matching: from A=0 to A=1",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X0_tmatch, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X0[, "X1"], y0 = X0[, "X2"],
  x1 = X0_tmatch[, 1], y1 = X0_tmatch[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)


# From 1 to 0
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "Matching: from A=1 to A=0",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X1_tmatch, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X1[, "X1"], y0 = X1[, "X2"],
  x1 = X1_tmatch[, 1], y1 = X1_tmatch[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)
Figure 5.2: 500 points in each group drawn from bivariates Gaussian distributions and transported values from group 0 to group 1 (left), and from group 1 to group 0 (right), using transport-based many-to-one matching.

5.2.3 Soft Matching

Let us now turn to optimal transport–based matching where each unit is matched to a weighted combination of units in the other group. This allows us to produce synthetic treated units by assigning weights that sum to 1.

X0 <- as.matrix(df[df$A == 0, c("X1", "X2")])
X1 <- as.matrix(df[df$A == 1, c("X1", "X2")])
n_0 <- nrow(X0)
n_1 <- nrow(X1)

Let us use uniform weights:

w_0 <- rep(1 / n_0, n_0)
w_1 <- rep(1 / n_1, n_1)

We use the Euclidean distance between units as costs.

cost <- as.matrix(dist(rbind(X0, X1)))
cost <- cost[1:n_0, (n_0 + 1):(n_0 + n_1)]

Let us solve the optimal transport problem to transport units from group~0 to group~1.

ot_plan <- transport::transport(
  w_0, w_1, costm = cost, method = "shortsimplex"
)

We can then use the mapping to create synthetic individuals in group~1. We use the weights given in the mapping to create the synthetic units as fractions of the other points.

ot_plan <- ot_plan |> group_by(from) |> mutate(weight = mass / sum(mass))
assignment <- rep(NA, n_0)

X0_sm <- ot_plan |> 
  left_join(
    as_tibble(X1) |> mutate(to = row_number()),
    by = "to"
  ) |> 
  mutate(
    across(colnames(X1), ~.x * weight)
  ) |> 
  group_by(from) |> 
  summarise(across(colnames(X1), ~ sum(.x))) |> 
  arrange(from) |> 
  select(-from) |> 
  as.matrix()

Let us create a function, soft_match() to wrap-up the previous codes.

The soft_match() function.
#' @param X_source Source characteristics
#' @param X_target Target characteristics
#' @param method Algorithm to use for transport
soft_match <- function(X_source, 
                       X_target, 
                       method = "shortsimplex") {
  
  n_source <- nrow(X_source)
  n_target <- nrow(X_target)
  
  # Uniform weights
  w_source <- rep(1 / n_source, n_source)
  w_target <- rep(1 / n_target, n_target)
  
  # Compute cost matrix (Euclidean distances)
  cost <- as.matrix(dist(rbind(X_source, X_target)))
  cost <- cost[1:n_source, (n_source + 1):(n_source + n_target)]
  
  # Solve OT problem
  ot_plan <- transport::transport(
    w_source, w_target, costm = cost, method = method
  )
  
  # Normalize weights per source unit
  ot_plan <- dplyr::group_by(ot_plan, from)
  ot_plan <- dplyr::mutate(ot_plan, weight = mass / sum(mass))
  
  # Build transported covariates
  X_target_tbl <- tibble::as_tibble(X_target) |>
    dplyr::mutate(to = dplyr::row_number())
  
  X_sm <- ot_plan |>
    dplyr::left_join(X_target_tbl, by = "to") |>
    dplyr::mutate(
      dplyr::across(colnames(X_target), ~ .x * weight)
    ) |>
    dplyr::group_by(from) |>
    dplyr::summarise(
      dplyr::across(colnames(X_target), sum),
      .groups = "drop"
    ) |>
    dplyr::arrange(from) |>
    dplyr::select(-from) |>
    as.matrix()
  
  X_sm
}

Let us do soft-matching to get counterfactuals for units from both group 0 and group 1.

X0_sm <- soft_match(X_source = X0 , X_target = X1)
X1_sm <- soft_match(X_source = X1 , X_target = X0)
Codes to create the Figure.
par(mar = c(2.1, 2.1, 2.1, 0.1), mfrow = c(1, 2))
x_lim <- c(-4, 4)
y_lim <- c(-4, 4)

# From 0 to 1
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "SM: from A=0 to A=1",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X0_sm, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X0[, "X1"], y0 = X0[, "X2"],
  x1 = X0_sm[, 1], y1 = X0_sm[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)


# From 1 to 0
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "SM: from A=1 to A=0",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X1_sm, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X1[, "X1"], y0 = X1[, "X2"],
  x1 = X1_sm[, 1], y1 = X1_sm[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)
# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)
Figure 5.3: 500 points in each group drawn from bivariates Gaussian distributions and transported values from group 0 to group 1 (left), and from group 1 to group 0 (right), using soft matching.

5.2.4 Entropic Regularization via Sinkhorn Algorithm

The optimal transport problem can be computationally intensive (not here with the small amount of observations, though). In such cases, one can consider Entropic regularization via Sinkhorn algorithm, associated to the “Matrix Scaling Problem” in Sinkhorn (1962). Entropic regularization modifies the previous problem by adding a Kullback-Leibler divergence (\(\mathrm{d}_{\mathrm{KL}}\)) term to the optimization goal \[ \min_{\boldsymbol{P}\in\mathcal{U}(\boldsymbol{1}_{n_0},\boldsymbol{1}_{n_1})} \Big\lbrace \langle \boldsymbol{P},\boldsymbol{C}\rangle +\gamma\cdot \mathrm{d}_{\mathrm{KL}}(\boldsymbol{P}||{\boldsymbol{1}_{n_0}}\otimes{\boldsymbol{1}_{n_1}}) \Big\rbrace \tag{5.1}\] where the Kullback-Leibler divergence term corresponds to the opposite of the discrete entropy of the coupling matrix \(\boldsymbol{P}\), \[ H(\boldsymbol{P}) := -\sum_{i,j}P_{i,j}(\log(P_{i,j}) - 1). \] see Proposition 4.3 in Peyré and Cuturi (2019).

Let us make sure that the observations are stored in matrices for each group.

X0 <- as.matrix(df[df$A == 0, c("X1", "X2")])
X1 <- as.matrix(df[df$A == 1, c("X1", "X2")])

The number of observations in each group:

n_0 <- nrow(X0)
n_1 <- nrow(X1)

We will use uniform weights:

w_0 <- rep(1 / n_0, n_0)
w_1 <- rep(1 / n_1, n_1)

We compute the pairwise cost matrix, which is simply the squared Euclidean distance in this example.

cost_mat <- as.matrix(dist(rbind(X0, X1)))^2
C <- cost_mat[1:n_0, (n_0 + 1):(n_0 + n_1)]

Then, we can apply the Sinkhorn algorithm to solve the problem.

skh_res <- T4transport::sinkhornD(
  D = C, p = 2, wx = w_0, wy = w_1, lambda = 0.1
)

We extract the transport plan:

ot_plan_skh <- skh_res$plan

We normalize the plan:

row_sums <- rowSums(ot_plan_skh)
ot_plan_skh <- sweep(ot_plan_skh, 1, row_sums, FUN = "/")

The transported version of X_1 is the barycentric projection:

X0_skh <- ot_plan_skh %*% X1

We wrap these steps in a function, transport_regul().

The transport_regul() function.
#' @param X_source Matrix of observations to transport from the source group.
#' @param X_target Matrix of observations from the target group.
#' @param gamma A regularization parameter (default to 0.1).
transport_regul <- function(X_source, 
                            X_target, 
                            gamma) {

  X_source <- as.matrix(X_source)
  X_target <- as.matrix(X_target)
  n_source <- nrow(X_source)
  n_target <- nrow(X_target)
  # Uniform weights
  w_source <- rep(1 / n_source, n_source)
  w_target <- rep(1 / n_target, n_target)
  
  # Pairwise squared Euclidean distance
  cost_mat <- as.matrix(dist(rbind(X_source, X_target)))^2
  C <- cost_mat[1:n_source, (n_source + 1):(n_source + n_target)]
  
  # Run Sinkhorn with entropic regularization gamma
  skh_res <- T4transport::sinkhornD(
    D = C, p = 2, wx = w_source, wy = w_target, lambda = gamma
  )

  # Extract and normalize plan
  ot_plan_skh <- skh_res$plan
  ot_plan_skh <- sweep(ot_plan_skh, 1, rowSums(ot_plan_skh), FUN = "/")
  
  ot_plan_skh %*% X_target
}

We consider the following values for \(\gamma\):

gammas <- c(0.1, 1, 5, 10, 100, 1000, 10000)

We transport the observations from group~0 to group~1:

X0_skh_l <- map(
  gammas, 
  ~transport_regul(
    X_source = X0, 
    X_target = X1, 
    gamma = .x
  )
)
names(X0_skh_l) <- as.character(gammas)

And from group~1 to group~0:

X1_skh_l <- map(
  gammas, 
  ~transport_regul(
    X_source = X1, 
    X_target = X0, 
    gamma = .x
  )
)
names(X1_skh_l) <- as.character(gammas)

We visualize the results in Figure 5.4.

Codes to create the Figure.
layout(matrix(seq_len(length(gammas) + 1), ncol = 4, byrow = TRUE))
par(mar = c(2.1, 2.1, 2.1, 0.1))
# par(mar = c(2.1, 2.1, 2.1, 0.1), mfrow = c(1, 2))
x_lim <- c(-4, 4)
y_lim <- c(-4, 4)

## OT----
# From 0 to 1
plot(X0, 
     pch = 16, 
     col = adjustcolor(colGpe0, alpha = .3), 
     xlim = x_lim, ylim = y_lim, 
     xlab = "", ylab = "",
     main = "OT: from A=0 to A=1"
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)

# Add arrows from original to transported
arrows(
  x0 = X0[, "X1"], y0 = X0[, "X2"],
  x1 = X0_t[, 1], y1 = X0_t[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)
points(X0_t, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# True mean and covariance (scaled by 'a')
Mu0 <- rep(a * mu0, 2)
Mu1 <- rep(a * mu1, 2)
Sig0 <- matrix(c(1, r0, r0, 1), 2, 2)
Sig1 <- matrix(c(1, r1, r1, 1), 2, 2)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)

## OT with Entropic Regularization----

for (gamma in gammas) {
  X0_skh <- X0_skh_l[[as.character(gamma)]]
  
  # From 0 to 1
  plot(
    X0, 
    pch = 16, 
    col = adjustcolor(colGpe0, alpha = .3), 
    xlim = x_lim, ylim = y_lim, 
    xlab = "", ylab = "",
    main = latex2exp::TeX(paste0("$\\gamma=", gamma, "$"))
  )
  title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2)
  points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
  # Add arrows from original to transported
  arrows(
    x0 = X0[, "X1"], y0 = X0[, "X2"],
    x1 = X0_skh[, 1], y1 = X0_skh[, 2],
    length = 0.05, col = adjustcolor("gray", alpha = .3)
  )
  points(X0_skh[, "X1"], X0_skh[, "X2"], col = adjustcolor(colGpet, alpha = .3), pch = 17)
  
  # True mean and covariance (scaled by 'a')
  Mu0 <- rep(a * mu0, 2)
  Mu1 <- rep(a * mu1, 2)
  Sig0 <- matrix(c(1, r0, r0, 1), 2, 2)
  Sig1 <- matrix(c(1, r1, r1, 1), 2, 2)
  
  
  # Add ellipses
  draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
  draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)
}
Figure 5.4: 500 points in each group drawn from bivariates Gaussian distributions and transported values from group 0 to group 1, using optimal transport (first panel) and using entropic regularization and Sinkhorn Algorithm with different values of the regularization parameter \(\gamma\).

5.2.5 Sequential Transport

We will now transport individuals using sequential transport. The results are sensitive to the ordering within the sequence. We will consider both ordering to illustrate this.

  • a first marginal univariate optimal transport along the first dimension (\(X_1\)), then a conditional transport for the second dimension (\(X_2 \mid X_1)\): sequential_transport_12(),
  • a first marginal univariate optimal transport along the second dimension (\(X_2\)), then a conditional transport for the first dimension (\(X_1 \mid X_2)\): sequential_transport_21().
The sequential_transport_12() function.
#' Sequential transport from N(M_source, S_source) to N(M_target, S_target),
#' along X1, then X2 | X1
#'
#' @param X n x 2 matrix of source observations.
#' @param M_source Mean vector of the source distribution (length 2).
#' @param S_source Covariance matrix of the source distribution (2x2).
#' @param M_target Mean vector of the target distribution.
#' @param S_target Covariance matrix of the target distribution.
sequential_transport_12 <- function(X, 
                                    M_source, 
                                    S_source, 
                                    M_target, 
                                    S_target) {
  
  # marginal univariate transport along the first coordinate (X_1)
  T1x <- qnorm(
    pnorm(X[, 1], mean = M_source[1], sd = sqrt(S_source[1, 1])),
    mean = M_target[1], sd = sqrt(S_target[1, 1])
  )
  
  # conditional parameters for X_2 | X_1
  m_source <- M_source[2] + S_source[1, 2] / S_source[1, 1] * (X[, 1] - M_source[1])
  s_source <- S_source[2, 2] - S_source[1, 2]^2 / S_source[1, 1]
  
  m_target <- M_target[2] + S_target[1, 2] / S_target[1, 1] * (T1x - M_target[1])
  s_target <- S_target[2, 2] - S_target[1, 2]^2 / S_target[1, 1]
  
  # conditional transport for the second coordinate
  T2x <- qnorm(
    pnorm(X[, 2], mean = m_source, sd = sqrt(s_source)),
    mean = m_target, sd = sqrt(s_target)
  )
  
  cbind(T1x, T2x)
}
The sequential_transport_21() function.
#' Sequential transport from N(M_source, S_source) to N(M_target, S_target),
#' along X2, then X1 | X2
#'
#' @param X n x 2 matrix of source observations.
#' @param M_source Mean vector of the source distribution (length 2).
#' @param S_source Covariance matrix of the source distribution (2x2).
#' @param M_target Mean vector of the target distribution.
#' @param S_target Covariance matrix of the target distribution.
sequential_transport_21 <- function(X, M_source, S_source, M_target, S_target) {
  
  # marginal univariate transport along X_2
  T2x <- qnorm(
    pnorm(X[, 2], mean = M_source[2], sd = sqrt(S_source[2, 2])),
    mean = M_target[2], sd = sqrt(S_target[2, 2])
  )
  
  # conditional parameters for X_1 | X_2
  m_source <- M_source[1] + S_source[1, 2] / S_source[2, 2] * (X[, 2] - M_source[2])
  s_source <- S_source[1, 1] - S_source[1, 2]^2 / S_source[2, 2]
  
  m_target <- M_target[1] + S_target[1, 2] / S_target[2, 2] * (T2x - M_target[2])
  s_target <- S_target[1, 1] - S_target[1, 2]^2 / S_target[2, 2]
  
  # conditional transport for X1 | X_2
  T1x <- qnorm(
    pnorm(X[, 1], mean = m_source, sd = sqrt(s_source)),
    mean = m_target, sd = sqrt(s_target)
  )
  
  cbind(T1x, T2x)
}

We isolate the observations from group 0 and from group 1, and store them as matrices.

X0 <- as.matrix(df[df$A == 0, c("X1", "X2")])
X1 <- as.matrix(df[df$A == 1, c("X1", "X2")])

We then transport from group 0 to group group 1 with sequential transport, first transporting \(X_1\) then \(X_2 | X_1\).

X0_st_12 <- sequential_transport_12(
  X = X0, M_source = Mu0, S_source = Sig0, M_target = Mu1, S_target = Sig1
)

We do the same but for units in group 1 to group 0.

X1_st_12 <- sequential_transport_12(
  X = X1, M_source = Mu1, S_source = Sig1, M_target = Mu0, S_target = Sig0
)

We also transport from group 0 to group group 1 with sequential transport, first transporting \(X_1\) then \(X_2 | X_1\).

X0_st_21 <- sequential_transport_21(
  X = X0, M_source = Mu0, S_source = Sig0, M_target = Mu1, S_target = Sig1
)

We do the same but for units in group 1 to group 0.

X1_st_21 <- sequential_transport_21(
  X = X1, M_source = Mu1, S_source = Sig1, M_target = Mu0, S_target = Sig0
)

Again, we can visualize the results on a scatter plot (Figure 5.5).

Codes to create the Figure.
# Prepare data for the plot
X0 <- df[df$A == 0, c("X1", "X2")]
X1 <- df[df$A == 1, c("X1", "X2")]

par(mar = c(2.1, 2.1, 2.1, 0.1), mfrow = c(2,2))
x_lim <- c(-4, 4)
y_lim <- c(-4, 4)

# From 0 to 1, X1 then X2----
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
title(main = "A=0 to A=1, X1 then X2", line=.5, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X0_st_12, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X0$X1, y0 = X0$X2,
  x1 = X0_st_12[, 1], y1 = X0_st_12[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# True mean and covariance (scaled by 'a')
Mu0 <- rep(a * mu0, 2)
Mu1 <- rep(a * mu1, 2)
Sig0 <- matrix(c(1, r0, r0, 1), 2, 2)
Sig1 <- matrix(c(1, r1, r1, 1), 2, 2)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)


# From 1 to 0, X1 then X2----
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
title(main = "A=1 to A=0, X1 then X2", line=.5, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X1_st_12, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X1$X1, y0 = X1$X2,
  x1 = X1_st_12[, 1], y1 = X1_st_12[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)


# From A=0 to A=1, X2 then X1
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
title(main = "A=0 to A=1, X2 then X1", line=.5, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X0_st_21, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X0$X1, y0 = X0$X2,
  x1 = X0_st_21[, 1], y1 = X0_st_21[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)


# From A=1 to A=0, X2 then X1
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
title(main = "A=1 to A=0, X2 then X1", line=.5, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)
points(X1_st_21, col = adjustcolor(colGpet, alpha = .3), pch = 17)

# Add arrows from original to transported
arrows(
  x0 = X1$X1, y0 = X1$X2,
  x1 = X1_st_21[, 1], y1 = X1_st_21[, 2],
  length = 0.05, col = adjustcolor("gray", alpha = .3)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)
Figure 5.5: 500 points in each group drawn from bivariates Gaussian distributions and transported values from group 0 to group 1 (left), and from group 1 to group 0 (right), using sequential optimal transport, first transporting \(X_1\), then \(X_2 \mid X_1\) (top), and first transporting \(X_2\), then \(X_1 \mid X_2\) (bottom).
Codes for the Figure in the paper
scale <- 1.42
filename <- "gaussian-transport-0to1"
tikz(paste0("figs/", filename, ".tex"), width = scale*2.2, height = scale*1)

layout(matrix(1:3, ncol = 3), width = c(1, rep(.9, 2)))
par(mar = c(2.1, 2.1, 2.1, 0.1))
x_lim <- c(-4, 4)
y_lim <- c(-4, 4)
cex_pts <- .5
alpha_arrows <- .2
length_arrow <- 0

# From 0 to 1, with OT
plot(X0, 
     pch = 16, cex = cex_pts,
     col = adjustcolor(colGpe0, alpha = .3), 
     xlim = x_lim, ylim = y_lim, 
     xlab = "", ylab = "",
     main = "",
     family = font_family,
     axes = FALSE
)
axis(1, at = -3:3)
axis(2, at = -3:3)

title(xlab = "X1", ylab="X2", line=2, cex.lab=1, family = font_family)
title(main = "OT", line=.5, cex.main = 1, family = font_family, font.main=1)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16, cex = cex_pts)
points(X0_t, col = adjustcolor(colour_methods[["OT"]], alpha = .3), pch = 17, cex = cex_pts)

# Add arrows from original to transported
arrows(
  x0 = X0$X1, y0 = X0$X2,
  x1 = X0_t[, 1], y1 = X0_t[, 2],
  length = length_arrow, col = adjustcolor("gray", alpha = alpha_arrows)
)

# True mean and covariance (scaled by 'a')
Mu0 <- rep(a * mu0, 2)
Mu1 <- rep(a * mu1, 2)
Sig0 <- matrix(c(1, r0, r0, 1), 2, 2)
Sig1 <- matrix(c(1, r1, r1, 1), 2, 2)

# Covariance of transported points (via OT map)
Sigma0_transport <- ot_map_0_to_1$A %*% Sig0 %*% t(ot_map_0_to_1$A)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)

par(mar = c(2.1, 1.1, 2.1, 0.1))

# From 0 to 1, X1 then X2----
plot(
  X0, 
  pch = 16, cex = cex_pts,
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family,
  axes = FALSE
)
axis(1, at = -3:3)

title(xlab = "X1", ylab="X2", line=2, cex.lab=1, family = font_family)
title(main = "ST(1)", line=.5, cex.main = 1, family = font_family, font.main=1)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16, cex = cex_pts)
points(X0_st_12, col = adjustcolor(colour_methods[["seq_1"]], alpha = .3), pch = 17, cex = cex_pts)

# Add arrows from original to transported
arrows(
  x0 = X0$X1, y0 = X0$X2,
  x1 = X0_st_12[, 1], y1 = X0_st_12[, 2],
  length = length_arrow, col = adjustcolor("gray", alpha = .4)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)


# From A=0 to A=1, X2 then X1----
plot(
  X0, 
  pch = 16, cex = cex_pts,
  col = adjustcolor(colGpe0, alpha = alpha_arrows), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family,
  axes = FALSE
)
axis(1, at = -3:3)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1, family = font_family)
title(main = "ST(2)", line=.5, cex.main = 1, family = font_family, font.main=1)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16, cex = cex_pts)
points(X0_st_21, col = adjustcolor(colour_methods[["seq_2"]], alpha = .3), pch = 17, cex = cex_pts)

# Add arrows from original to transported
arrows(
  x0 = X0$X1, y0 = X0$X2,
  x1 = X0_st_21[, 1], y1 = X0_st_21[, 2],
  length = length_arrow, col = adjustcolor("gray", alpha = alpha_arrows)
)

# Add ellipses
draw_ellipse(Mu0, Sig0, col = colGpe0, lty = 2)
draw_ellipse(Mu1, Sig1, col = colGpe1, lty = 2)

# legend(
#   "bottomright",
#   pch = c(16, 16, 17),
#   col = c(colGpe0, colGpe1, colGpet),
#   legend = c("$A=0$", "$A=1$", "Transp."),
#   bty = "n"
# )

dev.off()
plot_to_pdf(filename = filename, path = "./figs/", keep_tex = FALSE, crop = FALSE)

5.2.6 Illustration for a Single Unit

export_tikz <- FALSE
# Focus on a unit
i <- 11

X0 <- df[df$A == 0, c("X1", "X2")]
X1 <- df[df$A == 1, c("X1", "X2")]

if (export_tikz == TRUE) 
  tikz('figs/gaussian-1-transport.tex', width = 2, height = 2.2)

par(mar = c(2.1, 2.1, 1.8, 0.1))
x_lim <- c(-4, 4)
y_lim <- c(-4, 4)

# X1 then X2
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family
)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
# title(main = "X1 then X2", line=.5, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16)

# Individual of interest
points(X0[i, ], col = adjustcolor(colGpe0, alpha = 1), pch = 15, cex = 1.5)
points(X0_t[i, 1], X0_t[i, 2], col = adjustcolor(colour_methods[["OT"]], alpha = 1), pch = 15, cex = 1.5)
points(X0_skh_l[["0.1"]][i, 1], X0_skh_l[["0.1"]][i, 2], col = adjustcolor(colour_methods[["skh"]], alpha = 1), pch = 15, cex = 1.5)
points(X0_st_12[i, 1], X0_st_12[i, 2], col = adjustcolor(colour_methods[["seq_1"]], alpha = 1), pch = 15, cex = 1.5)
points(X0_st_21[i, 1], X0_st_21[i, 2], col = adjustcolor(colour_methods[["seq_2"]], alpha = 1), pch = 15, cex = 1.5)

length_arrow <- 0.1
lwd_arrow <- 2
# OT
arrows(
  x0 = X0$X1[i], y0 = X0$X2[i],
  x1 = X0_t[i, 1], y1 = X0_t[i, 2],
  length = length_arrow, col = adjustcolor(colour_methods[["OT"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Regularization
arrows(
  x0 = X0$X1[i], y0 = X0$X2[i],
  x1 = X0_skh_l[["0.1"]][i, 1], y1 = X0_skh_l[["0.1"]][i, 2],
  length = length_arrow, col = adjustcolor(colour_methods[["skh"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Seq OT (1): X_1 first
points(X0_st_12[i, 1], X0$X2[i], col = adjustcolor(colour_methods[["seq_1"]], alpha = .5), pch = 16, cex = 1.5)
arrows(
  x0 = X0$X1[i], y0 = X0$X2[i],
  x1 = X0_st_12[i, 1], y1 = X0$X2[i],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_1"]], alpha = 1),
  lwd = lwd_arrow
)
arrows(
  x0 = X0_st_12[i, 1], y0 = X0$X2[i],
  x1 = X0_st_12[i, 1], y1 = X0_st_12[i, 2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_1"]], alpha = 1),
  lwd = lwd_arrow
)

# Seq OT (2): X_2 first
points(X0$X1[i], X0_st_21[i,2], col = adjustcolor(colour_methods[["seq_2"]], alpha = .5), pch = 16, cex = 1.5)
arrows(
  x0 = X0$X1[i], y0 = X0$X2[i],
  x1 = X0$X1[i], y1 = X0_st_21[i,2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_2"]], alpha = 1),
  lwd = lwd_arrow
)
arrows(
  x0 = X0$X1[i], y0 = X0_st_21[i,2],
  x1 = X0_st_21[i, 1], y1 = X0_st_21[i, 2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_2"]], alpha = 1),
  lwd = lwd_arrow
)

legend(
  "topleft", 
  legend = c("$x_i$ (Obs.)", "OT", "Sinkhorn", "Seq. OT (1)", "Seq. OT (2)"), 
  col = c(colGpe0, colour_methods[c("OT", "skh", "seq_1", "seq_2")]), 
  pch = 15, pt.cex = 1.5, cex = 1,
  bty = "n"
)

if (export_tikz == TRUE) {
  dev.off()
  
  plot_to_pdf(filename = "gaussian-1-transport", path = "./figs/", keep_tex = FALSE, crop = TRUE)
}

5.3 Causal Effect

We first generate (again) some data, using the DGP presented in Section 13.1.

df <- gen_data(
  n = 500, 
  mu0 = -1, mu1 = +1, 
  r0 = +.7, r1 = -.5, a = 1, 
  seed = 12345
)

5.3.1 With Causal Mediation Analysis

Let us create a dataset, tb, with only the binary response (Y), the binary treatment (A), and the two covariates.

tb <- df[, c("Y", "A", "X1", "X2")]
A_name <- "A"
A_untreated <- 0
Y_name <- "Y"
med_mod_12 <- mediation::multimed(
  outcome = "Y", 
  med.main = "X1", 
  med.alt = "X2", 
  treat = "A", 
  data = df
)
med_mod_21 <- mediation::multimed(
  outcome = "Y", 
  med.main = "X2", 
  med.alt = "X1", 
  treat = "A", 
  data = df
)
Warning

We do not load the {mediation} package since it creates multiple conflicts with useful functions from tidyverse (including select()).

Let us retrieve \(\bar{\delta}(0)\) (average causal mediation effect for \(a=0\)):

delta_0_med <- mean((med_mod_12$d0.lb + med_mod_12$d0.ub) / 2) +
  mean((med_mod_21$d0.lb + med_mod_21$d0.ub) / 2)
delta_0_med
[1] 0.9085616

The total average effect, \(\bar{\tau}\):

(tot_effect_med <- med_mod_12$tau)
       A 
3.943122 

We can then deduct the average direct effect for \(a=1\), i.e., \(\bar{\zeta}(1)\):

(zeta_1_med <- tot_effect_med-delta_0_med)
      A 
3.03456 

Let us also retrieve \(\bar{\delta}(1)\) (average causal mediation effect for \(a=1\)) and \(\bar{\zeta}(0)\) (average direct effect for \(a=0\)):

delta_1_med <- mean((med_mod_12$d1.lb + med_mod_12$d1.ub) / 2) + 
  mean((med_mod_21$d1.lb + med_mod_21$d1.ub) / 2)
zeta_0_med <- tot_effect_med - delta_1_med
c(delta_1_med, zeta_0_med)
                  A 
0.6905634 3.2525583 

5.3.2 With Optimal Transport

We define a function, causal_effects_cf() to compute the causal effect of \(A\) on the outcome \(Y\), for the treated individuals.

#' Estimation of total causal effect using counterfactuals.
#' 
#' @param data_untreated Dataset with the untreated units only.
#' @param data_treated Dataset with the treated units only.
#' @param data_cf_untreated Counterfactuals for untreated had they been treated.
#' @param data_cf_treated Counterfactuals for treated had they been untreated.
#' @param Y_name Name of the column with the outcome variable.
#' @param A_name Name of the column with the treatment variable.
#' @param A_untreated Value of the treatment for the untreated units.
#' 
#' @returns A list:
#' - `delta_0_i`: \eqn{\delta_(0)}, individual causal mediation effects for 
#'   \eqn{a=0} (computed on untreated),
#' - `delta_0`: \eqn{\bar{\delta}(0)}, average causal mediation effect for 
#'   \eqn{a=0} (computed on untreated),
#' - `delta_1_i`: \eqn{\delta_(1)}, individual causal mediation effects for 
#'   \eqn{a=1} (computed on treated),
#' - `delta_1`: \eqn{\bar{\delta}(1)}, average causal mediation effect for 
#'   \eqn{a=1} (computed on treated),
#' - `zeta_0_i`: \eqn{\zeta_(0)}, individual causal mediation effects for 
#'   \eqn{a=0} (computed on treaded),
#' - `zeta_0`: \eqn{\bar{\zeta}(0)}, average causal mediation effect for 
#'   \eqn{a=0} (computed on treated),
#' - `zeta_1_i`: \eqn{\zeta_(1)}, individual causal mediation effects for 
#'   \eqn{a=1} (computed on untreaded),
#' - `zeta_1`: \eqn{\bar{\zeta}(1)}, average causal mediation effect for 
#'   \eqn{a=1} (computed on untreated),
#' - `tot_effect`: \eqb{\tau}: average total effect (\eqn{\bar{\delta}(0) + 
#'   \bar{\zeta}(1)}).
#'
#' @importFrom randomForest randomForest
#' @importFrom dplyr pull select
#' @importFrom stats predict
#' @md
causal_effects_cf <- function(data_untreated,
                              data_treated,
                              data_cf_untreated,
                              data_cf_treated,
                              Y_name,
                              A_name,
                              A_untreated) {
  
  n_untreated <- nrow(data_untreated)
  n_treated <- nrow(data_treated)
  
  # Outcome model for untreated
  mu_untreated_model <- randomForest(
    x = data_untreated |> dplyr::select(-!!Y_name, -!!A_name),
    y = pull(data_untreated, !!Y_name)
  )
  
  # Outcome model for treated
  mu_treated_model <- randomForest(
    x = data_treated |> dplyr::select(-!!Y_name, -!!A_name),
    y = pull(data_treated, !!Y_name)
  )
  
  # Observed outcome
  y_untreated_obs <- data_untreated |> pull(!!Y_name)
  y_treated_obs <- data_treated |> pull(!!Y_name)
  
  # Natural Indirect Effect, using predictions
  delta_0_i <- predict(mu_untreated_model, newdata = data_cf_untreated) -
    predict(mu_untreated_model)
  delta_0 <- mean(delta_0_i)
  delta_1_i <- predict(mu_treated_model) - 
    predict(mu_treated_model, newdata = data_cf_treated)
  delta_1 <- mean(delta_1_i)

  # Natural Indirect Effect, using observed variables
  delta_0_i_obs <- predict(mu_untreated_model, newdata = data_cf_untreated) - 
    y_untreated_obs
  delta_0_obs <- mean(delta_0_i_obs)
  delta_1_i_obs <- y_treated_obs - 
    predict(mu_treated_model, newdata = data_cf_treated)
  delta_1_obs <- mean(delta_1_i_obs)
  
  # Natural Direct Effect (only predictions)
  zeta_0_i <- predict(mu_treated_model, newdata = data_cf_treated) -
    predict(mu_untreated_model, newdata = data_cf_treated)
  zeta_0 <- mean(zeta_0_i)
  
  zeta_1_i <- predict(mu_treated_model, newdata = data_cf_untreated) - 
    predict(mu_untreated_model, newdata = data_cf_untreated)
  zeta_1 <- mean(zeta_1_i)
  
  # Total Causal Effect for treated
  tot_effect <- delta_0 + zeta_1  
  tot_effect_obs <- delta_0_obs + zeta_1
  
  
  list(
    delta_0_i = delta_0_i,
    delta_1_i = delta_1_i,
    zeta_0_i = zeta_0_i,
    zeta_1_i = zeta_1_i,
    delta_0_i_obs = delta_0_i_obs,
    delta_1_i_obs = delta_1_i_obs,
    delta_0 = delta_0,
    delta_1 = delta_1,
    zeta_0 = zeta_0,
    zeta_1 = zeta_1,
    delta_0_obs = delta_0_obs,
    delta_1_obs = delta_1_obs,
    tot_effect = tot_effect,
    tot_effect_obs = tot_effect_obs
  )
}

We use a random forest to estimate the outcome model.

library(randomForest)

We apply this function to our simulated dataset.

tb_untreated <- tb |> filter(!!sym(A_name) == !!A_untreated)
tb_treated <- tb |> filter(!!sym(A_name) != !!A_untreated)

causal_effects_ot <- causal_effects_cf(
  data_untreated = tb_untreated, 
  data_treated = tb_treated,
  data_cf_untreated = as_tibble(X0_t),
  data_cf_treated = as_tibble(X1_t),
  Y_name = Y_name, 
  A_name = A_name, 
  A_untreated = A_untreated
)

cbind(
  delta_0 = causal_effects_ot$delta_0,
  zeta_1 = causal_effects_ot$zeta_1,
  delta_1 = causal_effects_ot$delta_1,
  zeta_0 = causal_effects_ot$zeta_0,
  tot_effect = causal_effects_ot$tot_effect,
  tot_effect_obs = causal_effects_ot$tot_effect_obs
)
       delta_0   zeta_1   delta_1   zeta_0 tot_effect tot_effect_obs
[1,] 0.9422158 3.153563 0.2031006 3.781928   4.095779       4.083447

5.3.3 With Transport-based Matching

We apply the same function as that used with the counterfactuals obtained with optimal transport (causal_effects_cf()). However, here, we feed it with the counterfactuals obtained with the transport-based many-to-one matching (Section 5.2.2).

causal_effects_tmatch <- causal_effects_cf(
  data_untreated = tb_untreated, 
  data_treated = tb_treated,
  data_cf_untreated = as_tibble(X0_tmatch) |> magrittr::set_colnames(c("X1", "X2")),
  data_cf_treated = as_tibble(X1_tmatch) |> magrittr::set_colnames(c("X1", "X2")),
  Y_name = Y_name, 
  A_name = A_name, 
  A_untreated = A_untreated
)

5.3.4 With Entropic Regularization

We apply the same function as that used with the counterfactuals obtained with optimal transport (causal_effects_cf()). However, here, we feed it with the counterfactuals obtained with the regularization method (Section 5.2.4).

causal_effect_skh <- map2(
  X0_skh_l, X1_skh_l, 
  ~causal_effects_cf(
    data_untreated = tb_untreated, 
    data_treated = tb_treated,
    data_cf_untreated = as_tibble(.x) |> magrittr::set_colnames(c("X1", "X2")),
    data_cf_treated = as_tibble(.y) |> magrittr::set_colnames(c("X1", "X2")),
    Y_name = Y_name, 
    A_name = A_name, 
    A_untreated = A_untreated
  )
)
map(
  causal_effect_skh,
  ~tibble(
    delta_0 = .x$delta_0,
    zeta_1 = .x$zeta_1,
    delta_1 = .x$delta_1,
    zeta_0 = .x$zeta_0,
    tot_effect = .x$tot_effect,
    tot_effect_obs = .x$tot_effect_obs
  )
) |> 
  list_rbind(names_to = "gamma")
# A tibble: 7 × 7
  gamma delta_0 zeta_1 delta_1 zeta_0 tot_effect tot_effect_obs
  <chr>   <dbl>  <dbl>   <dbl>  <dbl>      <dbl>          <dbl>
1 0.1     0.793   3.12  0.0992   3.61       3.92           3.91
2 1       0.854   3.10  0.0181   3.88       3.95           3.96
3 5       0.834   3.14  0.214    3.68       3.98           3.97
4 10      0.830   3.16 -0.0158   3.92       3.99           3.98
5 100     0.825   3.24  0.380    3.46       4.07           4.06
6 1000    1.18    2.78  0.383    3.97       3.96           3.96
7 10000   1.18    2.74  0.495    3.84       3.93           3.92

5.3.5 With Sequential Optimal Transport

Again, we causal_effects_cf()), feeding it with the counterfactuals obtained with sequential transport (sec-cf-sot). For those where we first transport \(X_1\) and then \(X_2 \mid X_1\):

causal_effect_sot_12 <- causal_effects_cf(
  data_untreated = tb_untreated, 
  data_treated = tb_treated,
  data_cf_untreated = as_tibble(X0_st_12) |> magrittr::set_colnames(c("X1", "X2")),
  data_cf_treated = as_tibble(X1_st_12) |> magrittr::set_colnames(c("X1", "X2")),
  Y_name = Y_name, 
  A_name = A_name, 
  A_untreated = A_untreated
)

And for the counterfactuales obtained by sequential transport where we first transport \(X_2\) and then \(X_1 \mid X_2\):

causal_effect_sot_21 <- causal_effects_cf(
  data_untreated = tb_untreated, 
  data_treated = tb_treated,
  data_cf_untreated = as_tibble(X0_st_21) |> magrittr::set_colnames(c("X1", "X2")),
  data_cf_treated = as_tibble(X1_st_21) |> magrittr::set_colnames(c("X1", "X2")),
  Y_name = Y_name, 
  A_name = A_name, 
  A_untreated = A_untreated
)

5.3.6 Summary

Code
tribble(
  ~Method, ~Name, ~Value,
  "Theoretical", "delta(0)", (a1+a2) * (mu1-mu0),
  "Theoretical", "delta(1)", (a1+a2) * (mu1-mu0),
  "Theoretical", "zeta(0)", a0,
  "Theoretical", "zeta(1)", a0,
  "Theoretical", "tau", (a1+a2) * (mu1-mu0) + a0,
  #
  "Mediation", "delta(0)", delta_0_med,
  "Mediation", "delta(1)", delta_1_med,
  "Mediation", "zeta(0)", zeta_0_med,
  "Mediation", "zeta(1)", zeta_1_med,
  "Mediation", "tau", tot_effect_med,
  #
  "OT", "delta(0)", causal_effects_ot$delta_0,
  "OT", "delta(1)", causal_effects_ot$delta_1,
  "OT", "zeta(0)", causal_effects_ot$zeta_0,  
  "OT", "zeta(1)", causal_effects_ot$zeta_1,
  "OT", "tau", causal_effects_ot$tot_effect,
  #
  "OT (Obs)", "delta(0)", causal_effects_ot$delta_0_obs,
  "OT (Obs)", "delta(1)", causal_effects_ot$delta_1_obs,
  "OT (Obs)", "tau", causal_effects_ot$tot_effect_obs,
  #
  "Matching", "delta(0)", causal_effects_tmatch$delta_0,
  "Matching", "delta(1)", causal_effects_tmatch$delta_1,
  "Matching", "zeta(0)", causal_effects_tmatch$zeta_0,  
  "Matching", "zeta(1)", causal_effects_tmatch$zeta_1,
  "Matching", "tau", causal_effects_tmatch$tot_effect,
  #
  "Matching (Obs)", "delta(0)", causal_effects_tmatch$delta_0_obs,
  "Matching (Obs)", "delta(1)", causal_effects_tmatch$delta_1_obs,
  "Matching (Obs)", "tau", causal_effects_tmatch$tot_effect_obs,
  #
  "Sinkhorn (gamma=.1)", "delta(0)", causal_effect_skh[["0.1"]]$delta_0,
  "Sinkhorn (gamma=.1)", "delta(1)", causal_effect_skh[["0.1"]]$delta_1,
  "Sinkhorn (gamma=.1)", "zeta(0)", causal_effect_skh[["0.1"]]$zeta_0,  
  "Sinkhorn (gamma=.1)", "zeta(1)", causal_effect_skh[["0.1"]]$zeta_1,
  "Sinkhorn (gamma=.1)", "tau", causal_effect_skh[["0.1"]]$tot_effect,
  #
  "Sinkhorn (gamma=.1) (Obs)", "delta(0)", causal_effect_skh[["0.1"]]$delta_0_obs,
  "Sinkhorn (gamma=.1) (Obs)", "delta(1)", causal_effect_skh[["0.1"]]$delta_1_obs,
  "Sinkhorn (gamma=.1) (Obs)", "tau", causal_effect_skh[["0.1"]]$tot_effect_obs,
  #
  "SOT (1)", "delta(0)", causal_effect_sot_12$delta_0,
  "SOT (1)", "delta(1)", causal_effect_sot_12$delta_1,
  "SOT (1)", "zeta(0)", causal_effect_sot_12$zeta_0,  
  "SOT (1)", "zeta(1)", causal_effect_sot_12$zeta_1,
  "SOT (1)", "tau", causal_effect_sot_12$tot_effect,
  #
  "SOT (1) (Obs)", "delta(0)", causal_effect_sot_12$delta_0_obs,
  "SOT (1) (Obs)", "delta(1)", causal_effect_sot_12$delta_1_obs,
  "SOT (1) (Obs)", "tau", causal_effect_sot_12$tot_effect_obs,
  #
  "SOT (2)", "delta(0)", causal_effect_sot_21$delta_0,
  "SOT (2)", "delta(1)", causal_effect_sot_21$delta_1,
  "SOT (2)", "zeta(0)", causal_effect_sot_21$zeta_0,
  "SOT (2)", "zeta(1)", causal_effect_sot_21$zeta_1,
  "SOT (2)", "tau", causal_effect_sot_21$tot_effect,
  #
  "SOT (2) (Obs)", "delta(0)", causal_effect_sot_21$delta_0_obs,
  "SOT (2) (Obs)", "delta(1)", causal_effect_sot_21$delta_1_obs,
  "SOT (2) (Obs)", "tau", causal_effect_sot_21$tot_effect_obs
) |> 
  pivot_wider(names_from = "Name", values_from = "Value")
# A tibble: 12 × 6
   Method                    `delta(0)` `delta(1)` `zeta(0)` `zeta(1)`   tau
   <chr>                          <dbl>      <dbl>     <dbl>     <dbl> <dbl>
 1 Theoretical                    1         1           3         3     4   
 2 Mediation                      0.909     0.691       3.25      3.03  3.94
 3 OT                             0.942     0.203       3.78      3.15  4.10
 4 OT (Obs)                       0.930     0.208      NA        NA     4.08
 5 Matching                       0.862     0.0860      3.85      3.08  3.94
 6 Matching (Obs)                 0.858     0.0880     NA        NA     3.93
 7 Sinkhorn (gamma=.1)            0.793     0.0992      3.61      3.12  3.92
 8 Sinkhorn (gamma=.1) (Obs)      0.791     0.117      NA        NA     3.91
 9 SOT (1)                        1.01      0.116       3.90      3.14  4.15
10 SOT (1) (Obs)                  1.00      0.118      NA        NA     4.14
11 SOT (2)                        0.904     0.114       3.95      3.04  3.94
12 SOT (2) (Obs)                  0.889     0.122      NA        NA     3.93

Let us have a look at the distribution of the individual causal effects (Figure 5.6).

Codes to create the Figure.
plot_hist_effects <- function(x,
                              var_name,
                              tikz = FALSE,
                              fill = "red",
                              method_main = "",
                              x_lim = NULL,
                              effect_ylab = "",
                              print_method_main = TRUE,
                              print_x_axis = TRUE,
                              print_y_axis = TRUE) {
  
  # effect label (LaTeX)
  name_effect <- dplyr::case_when(
    stringr::str_detect(var_name, "^delta_0") ~ "$\\delta_i(0)$",
    stringr::str_detect(var_name, "^zeta_1") ~ "$\\zeta_i(1)$",
    stringr::str_detect(var_name, "^tot_effect") ~ "$\\tau_i(1)$",
    TRUE ~ "other"
  )
  if (tikz == FALSE) name_effect <- latex2exp::TeX(name_effect)
  
  # data
  data_plot <- if (var_name == "tot_effect") x[["delta_0_i"]] + x[["zeta_1_i"]] else x[[var_name]]
  
  # hist
  if (is.null(x_lim)) {
    hist(data_plot, main = "", xlab = "", ylab = "", family = font_family,
         col = fill, axes = FALSE)
  } else {
    hist(data_plot, main = "", xlab = "", ylab = "", family = font_family,
         col = fill, xlim = x_lim, axes = FALSE)
  }
  
  if (print_x_axis) axis(1, family = font_family)
  if (print_y_axis) axis(2, family = font_family)
  
  # main title = method (only first row)
  if (print_method_main && method_main != "") {
    title(main = method_main, cex.main = 1, family = font_family, font.main = 1)
  }
  
  # y-axis label = effect name (only first column)
  if (effect_ylab != "") {
    title(ylab = effect_ylab, line = 2, cex.lab = 1, family = font_family, font.main = 1)
  }
  
  abline(v = mean(data_plot), col = "darkred", lty = 2, lwd = 2)
  
  invisible(name_effect)
}

export_tikz <- FALSE
scale <- 1.475
file_name <- "gaussian-indiv-effects"
width_tikz <- 2.7*scale
height_tikz <- 1.5*scale
if (export_tikz == TRUE)
  tikz(paste0("figs/", file_name, ".tex"), width = width_tikz, height = height_tikz)

layout(
  t(matrix(1:(3*5), byrow = TRUE, ncol = 3)),
  widths  = c(1, rep(.72, 4)),   # 5 columns
  heights = c(1, rep(.9, 2))     # 3 rows
)

x_lim_list <- list(
  "delta_0_i"  = c(-3,  6),
  "zeta_1_i"   = c(-3,  8),
  "tot_effect" = c(-4, 14)
)

for (i in 1:4) {
  
  x <- dplyr::case_when(
    i == 1 ~ causal_effects_ot,
    # i == 2 ~ causal_effects_tmatch,
    i == 2 ~ causal_effect_skh[["0.1"]],
    i == 3 ~ causal_effect_sot_12,
    i == 4 ~ causal_effect_sot_21
  )
  
  method <- dplyr::case_when(
    i == 1 ~ "OT",
    # i == 2 ~ "OT-M",
    i == 2 ~ "SKH",
    i == 3 ~ "ST(1)",
    i == 4 ~ "ST(2)"
  )
  
  colour <- dplyr::case_when(
    i == 1 ~ colour_methods[["OT"]],
    # i == 2 ~ colour_methods[["OT-M"]],
    i == 2 ~ colour_methods[["skh"]],
    i == 3 ~ colour_methods[["seq_1"]],
    i == 4 ~ colour_methods[["seq_2"]]
  )
  
  for (var_name in c("delta_0_i", "zeta_1_i", "tot_effect")) {
    
    # margins after transpose:
    mar_bottom <- ifelse(var_name == "tot_effect", 2.1, .6)  # bottom row
    mar_left   <- ifelse(i == 1, 3.1, 2.1)                   # first column
    mar_top    <- ifelse(var_name == "delta_0_i", 2.1, .1)   # first row
    mar_right  <- .4
    
    par(mar = c(mar_bottom, mar_left, mar_top, mar_right))
    
    # effect label only in first column
    if (i == 1) {
      effect_label <- dplyr::case_when(
        stringr::str_detect(var_name, "^delta_0") ~ "$\\delta_i$",
        stringr::str_detect(var_name, "^zeta_1") ~ "$\\zeta_i$",
        stringr::str_detect(var_name, "^tot_effect") ~ "$\\tau_i$",
        TRUE ~ "other"
      )
      if (export_tikz == FALSE) effect_label <- latex2exp::TeX(effect_label)
    } else {
      effect_label <- ""
    }
    
    
    plot_hist_effects(
      x = x,
      var_name = var_name,
      tikz = export_tikz,
      fill = colour,
      method_main = method,
      x_lim = x_lim_list[[var_name]],
      effect_ylab = effect_label,
      print_method_main = (var_name == "delta_0_i"),      # first row only
      print_x_axis = (var_name == "tot_effect"),          # bottom row only
      print_y_axis = TRUE
    )
  }
}

if (export_tikz == TRUE) {
  dev.off()
  plot_to_pdf(
    filename = file_name, 
    path = "./figs/", keep_tex = FALSE, crop = T
  )
}
Figure 5.6: Distribution of individual direct effect (\(\delta_i(0)\)), indirect effect (\(\zeta_i(0)\)), and total causal effect (\(\tau_i\)) estimated with transport-based counterfactuals with optimal transport (OT), penalized transport (SKH), and sequential transport (ST).

5.3.7 Some Individuals

Let us have a look at some individuals. We focus on the one with the predicted \(\tau_i\) estimated using Optimal Transport which is the closest to the theoretical value, and the one that is the farthest. Let us get the theoretical values:

(tau_theo <- (a1 + a2) * (mu1 - mu0) + a0)
[1] 4
(tau_med <- tot_effect_med[[1]])
[1] 3.943122
(tau_ot <- causal_effects_ot$tot_effect)
[1] 4.095779
(tau_tmatch <- causal_effects_tmatch$tot_effect)
[1] 3.937629
(tau_skh <- causal_effect_skh[["0.1"]]$tot_effect)
[1] 3.915611
(tau_sot_12 <- causal_effect_sot_12$tot_effect)
[1] 4.15099
(tau_sot_21 <- causal_effect_sot_21$tot_effect)
[1] 3.94203

We create a table that contains the coordinates of individuals from group 0), their transported coordinates (using OT, and sequential transport), and their estimated values for \(\delta_i(0)\), \(\zeta_i(1)\), and \(\tau_i\), obtained with the different counterfactuals.

tb_indiv_0 <- 
  tibble(
    X1 = X0$X1,
    X2 = X0$X2,
    X1_t = X0_t[, 1], # with OT
    X2_t = X0_t[, 2], # idem
    X1_tmatch = X0_tmatch[, 1], # with transport-based matching
    X2_tmatch = X0_tmatch[, 2], # idem
    X1_skh = X0_skh_l[["0.1"]][, 1], # with Regularization
    X2_skh = X0_skh_l[["0.1"]][, 2], # idem
    X1_sot_12 = X0_st_12[, 1], # with Seq T (1)
    X2_sot_12 = X0_st_12[, 2], # idem
    X1_sot_21 = X0_st_21[, 1], # with Seq T (2)
    X2_sot_21 = X0_st_21[, 2], # idem
    # OT
    delta_0_i_ot = causal_effects_ot$delta_0_i,
    zeta_1_i_ot = causal_effects_ot$zeta_1_i,
    # Matching
    delta_0_i_tmatch = causal_effects_tmatch$delta_0_i,
    zeta_1_i_tmatch = causal_effects_tmatch$zeta_1_i,
    # Skh
    delta_0_i_skh = causal_effect_skh[["0.1"]]$delta_0_i,
    zeta_1_i_skh = causal_effect_skh[["0.1"]]$zeta_1_i,
    # ST (1)
    delta_0_i_sot_12 = causal_effect_sot_12$delta_0_i,
    zeta_1_i_sot_12 = causal_effect_sot_12$zeta_1_i,
    # ST (2)
    delta_0_i_sot_21 = causal_effect_sot_21$delta_0_i,
    zeta_1_i_sot_21 = causal_effect_sot_21$zeta_1_i
  ) |> 
  # Total causal effect
  mutate(
    tau_i_ot = delta_0_i_ot + zeta_1_i_ot,
    tau_i_tmatch = delta_0_i_tmatch + zeta_1_i_tmatch,
    tau_i_skh = delta_0_i_skh + zeta_1_i_skh,
    tau_i_sot_12 = delta_0_i_sot_12 + zeta_1_i_sot_12,
    tau_i_sot_21 = delta_0_i_ot + zeta_1_i_sot_21
  ) |> 
  # Distance to the theoretical value
  mutate(
    tau_i_ot_dist = abs(tau_i_ot - tau_theo),
    tau_i_tmatch_dist = abs(tau_i_tmatch - tau_theo),
    tau_i_skh_dist = abs(tau_i_skh - tau_theo),
    tau_i_sot_12_dist = abs(tau_i_sot_12 - tau_theo),
    tau_i_sot_21_dist = abs(tau_i_sot_21 - tau_theo)
  )

In that table, we identify the two untreated units of interest.

ind_closest <- order(tb_indiv_0$tau_i_ot_dist)[1]
ind_farthest <- rev(order(tb_indiv_0$tau_i_ot_dist))[1]

Their coordinates are shown in Table 5.1, as well as the coordinates of their counterfactuals.

Codes to create the Table.
format_num <- function(x) scales::number(x, accuracy = 0.1)

tb_indiv_0_short <- 
  tb_indiv_0 |> 
  select(-ends_with("dist")) |> 
  mutate(
    across(where(is.numeric),
           ~format_num(.x)
    )
  ) |> 
  mutate(
    obs_type = case_when(
      row_number() == ind_closest ~ "Closest",
      row_number() == ind_farthest ~ "Farthest",
      TRUE ~ "Other"
    )
  ) |> 
  filter(obs_type != "Other") |> 
  mutate(
    coord = str_c("(", X1, ", ", X2, ")"),
    coord_OT = str_c("(", X1_t, ", ", X2_t, ")"),
    coord_TM = str_c("(", X1_tmatch, ", ", X2_tmatch, ")"),
    coord_SKH = str_c("(", X1_skh, ", ", X2_skh, ")"),
    coord_ST1 = str_c("(", X1_sot_12, ", ", X2_sot_12, ")"),
    coord_ST2 = str_c("(", X1_sot_21, ", ", X2_sot_21, ")")
  ) |> 
  select(
    -X1, -X2, -X1_t, -X2_t, -X1_tmatch, -X2_tmatch,
    -X1_skh, -X2_skh,
    -X1_sot_12, -X2_sot_12, -X1_sot_21, -X2_sot_21
  ) |> 
  pivot_longer(
    cols = c(
      -obs_type, -coord, -coord_OT, -coord_TM, -coord_SKH, -coord_ST1, -coord_ST2
    )
  ) |> 
  mutate(
    type = case_when(
      str_detect(name, "^delta") ~ "delta",
      str_detect(name, "^zeta") ~ "zeta",
      str_detect(name, "^tau") ~ "tau",
      TRUE ~ NA_character_
    ),
    type = factor(
      type, 
      levels = c("delta", "zeta", "tau")
    ),
    method = case_when(
      str_detect(name, "_med$") ~ "CM",
      str_detect(name, "_ot$") ~ "OT",
      str_detect(name, "_tmatch$") ~ "OT-M",
      str_detect(name, "_skh$") ~ "SKH",
      str_detect(name, "_sot_12$") ~ "ST(1)",
      str_detect(name, "sot_21$") ~ "ST(2)",
      TRUE ~ ""
    )
  ) |> 
  select(-name) |> 
  pivot_wider(
    names_from = type, values_from = value
  )

# tb_indiv_0_short |>
#   select(obs_type, method, delta, zeta, tau) |>
#   pivot_longer(cols = c(delta, zeta, tau)) |>
#   pivot_wider(names_from = method, values_from = value) |>
#   arrange(name, obs_type) |>
#   select(-obs_type) |>
#   kableExtra::kbl(booktabs = TRUE, format = "latex")

tb_indiv_0_short |> 
  group_by(obs_type) |> 
  slice_head(n=1) |> 
  select(coord, coord_OT, coord_TM, coord_SKH, coord_ST1, coord_ST2) |> 
  kableExtra::kbl(booktabs = TRUE)
Table 5.1: Coordinates of two untreated units before and after transport.
obs_type coord coord_OT coord_TM coord_SKH coord_ST1 coord_ST2
Closest (-1.5, -1.4) (0.6, 0.9) (0.2, 0.6) (0.3, 0.6) (0.5, 1.2) (0.9, 0.6)
Farthest (-2.4, -3.8) (1.4, -1.6) (3.5, -1.5) (-0.6, -0.9) (-0.4, -0.4) (3.0, -1.8)

The estimation of the direct, indirect and total effects are reported in Table 5.2, depending on the method used to create the counterfactual.

Codes to create the Table.
tb_indiv_0_short |> 
  select(method, delta, zeta, tau) |> 
  kableExtra::kbl(booktabs = TRUE)
Table 5.2: Estimated values of \(\delta_i(0)\), \(\zeta_i(1)\), and \(\tau_i\) for the two individuals, depending on the transport method.
method delta zeta tau
OT 4.2 4.5 8.8
OT-M 4.1 7.7 11.8
SKH 1.0 3.4 4.4
ST(1) 1.1 3.3 4.4
ST(2) 3.9 7.6 11.9
OT 1.3 2.7 4.0
OT-M 0.4 3.2 3.5
SKH 0.5 3.2 3.7
ST(1) -0.4 3.7 3.3
ST(2) 1.4 3.5 4.9
Codes to create the Figure
export_tikz <- FALSE
file_name <- "gaussian-tau-two-indiv"
scale <- 1.45
width_tikz <- 3.3*scale
height_tikz <- 1.55*scale
if (export_tikz == TRUE)
  tikz(paste0("figs/", file_name, ".tex"), width = width_tikz, height = height_tikz)

# par(mar = c(2.1, 2.1, .1, .1), mfrow = c(1, 2))
layout(matrix(c(1, 2), nrow = 1, byrow = TRUE), widths = c(10,8.1))
par(mar = c(2.1, 2.1, .1, .1))
## Closest----
cex_pts <- .3
lwd_arrow <- 1.5
length_arrow <- .05
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family,
  axes = FALSE,
  cex = cex_pts
)
axis(1, at = -3:3, labels = TRUE)
axis(2, at = -3:3, labels = TRUE)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16, cex = cex_pts)

# Individuals of interest
# Arrows to OT
arrows(
  x0 = X0$X1[c(ind_closest)],
  y0 = X0$X2[c(ind_closest)],
  x1 = X0_t[c(ind_closest), "X1"], 
  y1 = X0_t[c(ind_closest), "X2"],
  length = length_arrow, col = adjustcolor(colour_methods[["OT"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Arrows to OT-Matching
arrows(
  x0 = X0$X1[c(ind_closest)],
  y0 = X0$X2[c(ind_closest)],
  x1 = X0_tmatch[c(ind_closest), "X1"], 
  y1 = X0_tmatch[c(ind_closest), "X2"],
  length = length_arrow, col = adjustcolor(colour_methods[["OT-M"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Arrows to SKH
arrows(
  x0 = X0$X1[c(ind_closest)],
  y0 = X0$X2[c(ind_closest)],
  x1 = X0_skh_l[["0.1"]][c(ind_closest), "X1"], 
  y1 = X0_skh_l[["0.1"]][c(ind_closest), "X2"],
  length = length_arrow, col = adjustcolor(colour_methods[["skh"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Seq OT (1): X_1 first
# points(
#   X0_st_12[c(ind_closest), 1], 
#   X0$X2[c(ind_closest)], 
#   col = adjustcolor(colour_methods[["seq_1"]], alpha = .5), pch = 16, cex = 1
# )
arrows(
  x0 = X0$X1[c(ind_closest)], 
  y0 = X0$X2[c(ind_closest)],
  x1 = X0_st_12[c(ind_closest), 1], 
  y1 = X0$X2[c(ind_closest)],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_1"]], alpha = .5),
  lwd = lwd_arrow
)
arrows(
  x0 = X0_st_12[c(ind_closest), 1], 
  y0 = X0$X2[c(ind_closest)],
  x1 = X0_st_12[c(ind_closest), 1], 
  y1 = X0_st_12[c(ind_closest), 2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_1"]], alpha = .5),
  lwd = lwd_arrow
)
arrows(
  x0 = X0$X1[c(ind_closest)], 
  y0 = X0$X2[c(ind_closest)],
  x1 = X0$X1[c(ind_closest)], 
  y1 = X0_st_21[c(ind_closest),2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_2"]], alpha = .5),
  lwd = lwd_arrow
)
arrows(
  x0 = X0$X1[c(ind_closest)], 
  y0 = X0_st_21[c(ind_closest),2],
  x1 = X0_st_21[c(ind_closest), 1], 
  y1 = X0_st_21[c(ind_closest), 2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_2"]], alpha = .5),
  lwd = lwd_arrow
)

# Individuals
points(
  tb_indiv_0$X1[c(ind_closest)], 
  tb_indiv_0$X2[c(ind_closest)], 
  col = "black", pch = c(15), cex = 1
)
# Transported values for those individuals (OT)
points(
  X0_t[c(ind_closest), "X1"], 
  X0_t[c(ind_closest), "X2"], 
  col = colour_methods[["OT"]], pch = c(15), cex = 1
)
# With OT_based matching
points(
  X0_tmatch[c(ind_closest), "X1"], 
  X0_tmatch[c(ind_closest), "X2"], 
  col = colour_methods[["OT-M"]], pch = c(15), cex = 1
)
# With Sinkhorn
points(
  X0_skh_l[["0.1"]][c(ind_closest), "X1"], 
  X0_skh_l[["0.1"]][c(ind_closest), "X2"], 
  col = colour_methods[["skh"]], pch = c(15), cex = 1
)
# With Sequential transport (1)
points(
  X0_st_12[c(ind_closest), 1], 
  X0_st_12[c(ind_closest), 2], 
  col = adjustcolor(colour_methods[["seq_1"]], alpha = 1), pch = 15, cex = 1
)
# With Sequential transport (2)
points(
  X0_st_21[c(ind_closest), 1], 
  X0_st_21[c(ind_closest), 2], 
  col = adjustcolor(colour_methods[["seq_2"]], alpha = 1), pch = 15, cex = 1
)


if (export_tikz == FALSE) {
  lab_points_ot <- latex2exp::TeX("$\\tau_i^{OT}$")
  # round(tb_indiv_0$tau_i_ot[c(ind_closest)], 1)
  lab_points_tmatch <- latex2exp::TeX("$\\tau_i^{OT-M}$")
  # round(tb_indiv_0$tau_i_tmatch[c(ind_closest)], 1)
  lab_points_skh <- latex2exp::TeX("$\\tau_i^{SKH}$")
  # round(tb_indiv_0$tau_i_skh[c(ind_closest)], 1)
  lab_points_sot_12 <- latex2exp::TeX("$\\tau_i^{ST(1)}$")
  round(tb_indiv_0$tau_i_sot_12[c(ind_closest)], 1)
  lab_points_sot_21 <- latex2exp::TeX("$\\tau_i^{ST(2)}$")
  round(tb_indiv_0$tau_i_sot_21[c(ind_closest)], 1)
} else {
  lab_points_ot <- "OT"
  lab_points_tmatch <- "OT-M"
  lab_points_skh <- "SKH"
  lab_points_sot_12 <- "ST(1)"
  lab_points_sot_21 <- "ST(2)"
}
180 
4.9 
Codes to create the Figure
# \tau_i with OT
text(
  x = X0_t[c(ind_closest), "X1"] + 1, 
  y = X0_t[c(ind_closest), "X2"] + .5, 
  labels = lab_points_ot,
  col = colour_methods[["OT"]]
)
# \tau_i with OT-M
text(
  x = X0_tmatch[c(ind_closest), "X1"] -1.5, 
  y = X0_tmatch[c(ind_closest), "X2"] + .5, 
  labels = lab_points_tmatch,
  col = colour_methods[["OT-M"]]
)
# \tau_i with SKH
text(
  x = X0_skh_l[["0.1"]][c(ind_closest), 1] - .25, 
  y = X0_skh_l[["0.1"]][c(ind_closest), 2] - 1, 
  labels = lab_points_skh,
  col = colour_methods[["skh"]]
)
text(
  x = X0_st_12[c(ind_closest), 1], 
  y = X0_st_12[c(ind_closest), 2] + 1, 
  labels = lab_points_sot_12,
  col = colour_methods[["seq_1"]]
)
text(
  x = X0_st_21[c(ind_closest), 1] + 1, 
  y = X0_st_21[c(ind_closest), 2] - 1, 
  labels = lab_points_sot_21,
  col = colour_methods[["seq_2"]]
)

## Farthest----
par(mar = c(2.1, .1, .1, .1))
plot(
  X0, 
  pch = 16, 
  col = adjustcolor(colGpe0, alpha = .3), 
  xlim = x_lim, ylim = y_lim, 
  xlab = "", ylab = "",
  main = "",
  family = font_family,
  axes = FALSE,
  cex = cex_pts
)
axis(1, at = -3:3, labels = TRUE)
# axis(2, at = -3:3, labels = TRUE)
title(xlab = "X1", ylab="X2", line=2, cex.lab=1.2, family = font_family)
points(X1, col = adjustcolor(colGpe1, alpha = .3), pch = 16, cex = cex_pts)

# Individuals of interest
# Arrows to OT
arrows(
  x0 = X0$X1[c(ind_farthest)],
  y0 = X0$X2[c(ind_farthest)],
  x1 = X0_t[c(ind_farthest), "X1"], 
  y1 = X0_t[c(ind_farthest), "X2"],
  length = length_arrow, col = adjustcolor(colour_methods[["OT"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Arrows to OT-Matching
arrows(
  x0 = X0$X1[c(ind_farthest)],
  y0 = X0$X2[c(ind_farthest)],
  x1 = X0_tmatch[c(ind_farthest), "X1"], 
  y1 = X0_tmatch[c(ind_farthest), "X2"],
  length = length_arrow, col = adjustcolor(colour_methods[["OT-M"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Arrows to SKH
arrows(
  x0 = X0$X1[c(ind_farthest)],
  y0 = X0$X2[c(ind_farthest)],
  x1 = X0_skh_l[["0.1"]][c(ind_farthest), "X1"], 
  y1 = X0_skh_l[["0.1"]][c(ind_farthest), "X2"],
  length = length_arrow, col = adjustcolor(colour_methods[["skh"]], alpha = 1),
  lwd = lwd_arrow, lty = 2
)
# Seq OT (1): X_1 first
# points(
#   X0_st_12[c(ind_farthest), 1], 
#   X0$X2[c(ind_farthest)], 
#   col = adjustcolor(colour_methods[["seq_1"]], alpha = .5), pch = 16, cex = 1
# )
arrows(
  x0 = X0$X1[c(ind_farthest)], 
  y0 = X0$X2[c(ind_farthest)],
  x1 = X0_st_12[c(ind_farthest), 1], 
  y1 = X0$X2[c(ind_farthest)],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_1"]], alpha = .5),
  lwd = lwd_arrow
)
arrows(
  x0 = X0_st_12[c(ind_farthest), 1], 
  y0 = X0$X2[c(ind_farthest)],
  x1 = X0_st_12[c(ind_farthest), 1], 
  y1 = X0_st_12[c(ind_farthest), 2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_1"]], alpha = .5),
  lwd = lwd_arrow
)
# Seq OT (2): X_2 first
# points(
#   X0$X1[c(ind_farthest)], 
#   X0_st_21[c(ind_farthest),2], 
#   col = adjustcolor(colour_methods[["seq_2"]], alpha = .5), pch = 16, cex = 1
# )
arrows(
  x0 = X0$X1[c(ind_farthest)], 
  y0 = X0$X2[c(ind_farthest)],
  x1 = X0$X1[c(ind_farthest)], 
  y1 = X0_st_21[c(ind_farthest),2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_2"]], alpha = .5),
  lwd = lwd_arrow
)
arrows(
  x0 = X0$X1[c(ind_farthest)], 
  y0 = X0_st_21[c(ind_farthest),2],
  x1 = X0_st_21[c(ind_farthest), 1], 
  y1 = X0_st_21[c(ind_farthest), 2],
  length = length_arrow, col = adjustcolor(colour_methods[["seq_2"]], alpha = .5),
  lwd = lwd_arrow
)

# Individuals
points(
  tb_indiv_0$X1[c(ind_farthest)], 
  tb_indiv_0$X2[c(ind_farthest)], 
  col = "black", pch = c(15), cex = 1
)
# Transported values for those individuals (OT)
points(
  X0_t[c(ind_farthest), "X1"], 
  X0_t[c(ind_farthest), "X2"], 
  col = colour_methods[["OT"]], pch = c(15), cex = 1
)
# With OT-Matching
points(
  X0_tmatch[c(ind_farthest), "X1"], 
  X0_tmatch[c(ind_farthest), "X2"], 
  col = colour_methods[["OT-M"]], pch = c(15), cex = 1
)
# With Sinkhorn
points(
  X0_skh_l[["0.1"]][c(ind_farthest), "X1"], 
  X0_skh_l[["0.1"]][c(ind_farthest), "X2"], 
  col = colour_methods[["skh"]], pch = c(15), cex = 1
)
# With Sequential transport (1)
points(
  X0_st_12[c(ind_farthest), 1], 
  X0_st_12[c(ind_farthest), 2], 
  col = adjustcolor(colour_methods[["seq_1"]], alpha = 1), pch = 15, cex = 1
)
# With Sequential transport (2)
points(
  X0_st_21[c(ind_farthest), 1], 
  X0_st_21[c(ind_farthest), 2], 
  col = adjustcolor(colour_methods[["seq_2"]], alpha = 1), pch = 15, cex = 1
)


if (export_tikz == FALSE) {
  lab_points_ot <- latex2exp::TeX("$\\tau_i^{OT}")
  lab_points_tmatch <- latex2exp::TeX("$\\tau_i^{OT-M}$")
  lab_points_skh <- latex2exp::TeX("$\\tau_i^{SKH}$")
  lab_points_sot_12 <- latex2exp::TeX("$\\tau_i^{ST(1)}$")
  lab_points_sot_21 <- latex2exp::TeX("$\\tau_i^{ST(2)}$")
} else {
  lab_points_ot <- "OT"
  lab_points_tmatch <- "OT-M"
  lab_points_skh <- "SKH"
  lab_points_sot_12 <- "ST(1)"
  lab_points_sot_21 <- "ST(2)"
}
# \tau with OT
text(
  x = X0_t[c(ind_farthest), "X1"], 
  y = X0_t[c(ind_farthest), "X2"] + .75, 
  labels = lab_points_ot,
  col = colour_methods[["OT"]]
)
# \tau with OT-Matching
text(
  x = X0_tmatch[c(ind_farthest), "X1"] - .25, 
  y = X0_tmatch[c(ind_farthest), "X2"] + 1, 
  labels = lab_points_tmatch,
  col = colour_methods[["OT-M"]]
)
text(
  x = X0_skh_l[["0.1"]][c(ind_farthest), "X1"] - 1.5, 
  y = X0_skh_l[["0.1"]][c(ind_farthest), "X2"], 
  labels = lab_points_skh,
  col = colour_methods[["skh"]]
)
text(
  x = X0_st_12[c(ind_farthest), 1], 
  y = X0_st_12[c(ind_farthest), 2] + 1, 
  labels = lab_points_sot_12,
  col = colour_methods[["seq_1"]]
)
text(
  x = X0_st_21[c(ind_farthest), 1], 
  y = X0_st_21[c(ind_farthest), 2] - 1, 
  labels = lab_points_sot_21,
  col = colour_methods[["seq_2"]]
)

Codes to create the Figure
if (export_tikz == TRUE) {
  dev.off()
  plot_to_pdf(filename = file_name, path = "./figs/", keep_tex = FALSE, crop = FALSE)
}

5.4 Monte-Carlo Simulations

Let us perform Monte-Carlo simulations to observe the stability of the previous estimations. We define a function, sim_f(), to perform three steps:

  1. Generate a sample from the DGP shown in Section 13.1,
  2. Build the counterfactuals using OT, entropy regularized transport, and Sequential transport as in Section 5.2,
  3. Compute the causal effects as in Section 5.3.
The sim_f() function.
sim_f <- function(n = 500,
                  mu0, 
                  mu1, 
                  r0, 
                  r1, 
                  a, 
                  seed = NULL) {
  
  if (!is.null(seed)) set.seed(seed)
  
  # 1. Generate data
  df <- gen_data(
    n = 500, 
    mu0 = mu0, mu1 = mu1, 
    r0 = r0, r1 = r1, a = a, 
    seed = seed
  )
  
  # 2. Building Counterfactuals
  
  ## With Optimal Transport
  # Transporting map for source: group 1, target: group 0 (careful here)
  Sigma0 <- matrix(c(1, r0, r0, 1), 2, 2)
  Sigma1 <- matrix(c(1, r1, r1, 1), 2, 2)
  Mu0 <- rep(a * mu0, 2)
  Mu1 <- rep(a * mu1, 2)
  
  # Mapping from group 0 to group 1
  ot_map_0_to_1 <- compute_ot_map(
    mu_source = Mu0, sigma_source = Sigma0, 
    mu_target = Mu1, sigma_target = Sigma1
  )
  # Mapping from group 1 to group 0
  ot_map_1_to_0 <- compute_ot_map(
    mu_source = Mu1, sigma_source = Sigma0, 
    mu_target = Mu0, sigma_target = Sigma0
  )  
  
  # Apply transport map to treated units (A = 1)
  X0 <- as.matrix(df[df$A == 0, c("X1", "X2")])
  X1 <- as.matrix(df[df$A == 1, c("X1", "X2")])
  X0_t <- apply_ot_transport(X = X0, mapping = ot_map_0_to_1)
  colnames(X0_t) <- c(c("X1", "X2"))
  X1_t <- apply_ot_transport(X = X1, mapping = ot_map_1_to_0)
  colnames(X1_t) <- c(c("X1", "X2"))
  
  # With OT-Matching
  X0_tmatch <- transport_many_to_one(X_source = X0, X_target = X1)
  X1_tmatch <- transport_many_to_one(X_source = X1, X_target = X0)
  
  ## With Entropy regularized transport
  # Transport from group 0 to group 1:
  X0_skh <- transport_regul(
    X_source = X0, 
    X_target = X1, 
    gamma = 0.1
  )
  # Transport from group 1 to group 0:
  X1_skh <- transport_regul(
    X_source = X1, 
    X_target = X0, 
    gamma = 0.1
  )
  
  ## With Sequential Transport
  # Transport from group 0 to group 1: X1 then X2 | X1
  X0_st_12 <- sequential_transport_12(
    X = X0, M_source = Mu0, S_source = Sigma0, M_target = Mu1, S_target = Sigma1
  )
  # Transport from group 1 to group 0: X1 then X2 | X1
  X1_st_12 <- sequential_transport_12(
    X = X1, M_source = Mu1, S_source = Sigma1, M_target = Mu0, S_target = Sigma0
  )
  # Transport from group 0 to group 1: X2 then X1 | X2
  X0_st_21 <- sequential_transport_21(
    X = X0, M_source = Mu0, S_source = Sigma0, M_target = Mu1, S_target = Sigma1
  )
  # Transport from group 1 to group 0: X2 then X1 | X2
  X1_st_21 <- sequential_transport_21(
    X = X1, M_source = Mu1, S_source = Sigma1, M_target = Mu0, S_target = Sigma0
  )
  
  # 3. Measuring Total Causal Effect
  tb <- df[, c("Y", "A", "X1", "X2")]
  A_name <- "A"
  A_untreated <- 0
  Y_name <- "Y"
  
  # Causal Mediation Analysis
  med_mod_12 <- mediation::multimed(
    outcome = "Y", 
    med.main = "X1", 
    med.alt = "X2", 
    treat = "A", 
    data = df
  )
  med_mod_21 <- mediation::multimed(
    outcome = "Y", 
    med.main = "X2", 
    med.alt = "X1", 
    treat = "A", 
    data = df
  )
  
  delta_0_med <- mean((med_mod_12$d0.lb + med_mod_12$d0.ub) / 2) +
    mean((med_mod_21$d0.lb + med_mod_21$d0.ub) / 2)
  delta_1_med <- mean((med_mod_12$d1.lb + med_mod_12$d1.ub) / 2) + 
    mean((med_mod_21$d1.lb + med_mod_21$d1.ub) / 2)
  tot_effect_med <- med_mod_12$tau
  zeta_0_med <- tot_effect_med - delta_1_med
  zeta_1_med <- tot_effect_med-delta_0_med
  
  # With OT counterfactuals
  tb_untreated <- tb |> filter(!!sym(A_name) == !!A_untreated)
  tb_treated <- tb |> filter(!!sym(A_name) != !!A_untreated)
  
  causal_effects_ot <- causal_effects_cf(
    data_untreated = tb_untreated, 
    data_treated = tb_treated,
    data_cf_untreated = as_tibble(X0_t),
    data_cf_treated = as_tibble(X1_t),
    Y_name = Y_name, 
    A_name = A_name, 
    A_untreated = A_untreated
  )
  
  # With OT-Matching counterfactuals
  causal_effects_tmatch <- causal_effects_cf(
    data_untreated = tb_untreated, 
    data_treated = tb_treated,
    data_cf_untreated = as_tibble(X0_tmatch),
    data_cf_treated = as_tibble(X1_tmatch),
    Y_name = Y_name, 
    A_name = A_name, 
    A_untreated = A_untreated
  )
  
  # With entropy regularized transport
  causal_effects_skh <- causal_effects_cf(
    data_untreated = tb_untreated, 
    data_treated = tb_treated,
    data_cf_untreated = as_tibble(X0_skh),
    data_cf_treated = as_tibble(X1_skh),
    Y_name = Y_name, 
    A_name = A_name, 
    A_untreated = A_untreated
  )
  
  # With Sequential Transport counterfactuals
  causal_effect_sot_12 <- causal_effects_cf(
    data_untreated = tb_untreated, 
    data_treated = tb_treated,
    data_cf_untreated = as_tibble(X0_st_12) |> magrittr::set_colnames(c("X1", "X2")),
    data_cf_treated = as_tibble(X1_st_12) |> magrittr::set_colnames(c("X1", "X2")),
    Y_name = Y_name, 
    A_name = A_name, 
    A_untreated = A_untreated
  )
  causal_effect_sot_21 <- causal_effects_cf(
    data_untreated = tb_untreated, 
    data_treated = tb_treated,
    data_cf_untreated = as_tibble(X0_st_21) |> magrittr::set_colnames(c("X1", "X2")),
    data_cf_treated = as_tibble(X1_st_21) |> magrittr::set_colnames(c("X1", "X2")),
    Y_name = Y_name, 
    A_name = A_name, 
    A_untreated = A_untreated
  )
  
  tibble(
    # Mediation
    delta_0_med = delta_0_med,
    delta_1_med = delta_1_med,
    zeta_0_med = zeta_0_med,
    zeta_1_med = zeta_1_med,
    tot_effect_med = tot_effect_med,
    # OT
    delta_0_ot = causal_effects_ot$delta_0,
    delta_1_ot = causal_effects_ot$delta_1,
    delta_0_ot_obs = causal_effects_ot$delta_0_obs,
    delta_1_ot_obs = causal_effects_ot$delta_1_obs,
    zeta_0_ot = causal_effects_ot$zeta_0,
    zeta_1_ot = causal_effects_ot$zeta_1,
    tot_effect_ot = causal_effects_ot$tot_effect,
    tot_effect_ot_obs = causal_effects_ot$tot_effect_obs,
    # OT-M
    delta_0_tmatch = causal_effects_tmatch$delta_0,
    delta_1_tmatch = causal_effects_tmatch$delta_1,
    delta_0_tmatch_obs = causal_effects_tmatch$delta_0_obs,
    delta_1_tmatch_obs = causal_effects_tmatch$delta_1_obs,
    zeta_0_tmatch = causal_effects_tmatch$zeta_0,
    zeta_1_tmatch = causal_effects_tmatch$zeta_1,
    tot_effect_tmatch = causal_effects_tmatch$tot_effect,
    tot_effect_tmatch_obs = causal_effects_tmatch$tot_effect_obs,
    # SKH
    delta_0_skh = causal_effects_skh$delta_0,
    delta_1_skh = causal_effects_skh$delta_1,
    delta_0_skh_obs = causal_effects_skh$delta_0_obs,
    delta_1_skh_obs = causal_effects_skh$delta_1_obs,
    zeta_0_skh = causal_effects_skh$zeta_0,
    zeta_1_skh = causal_effects_skh$zeta_1,
    tot_effect_skh = causal_effects_skh$tot_effect,
    tot_effect_skh_obs = causal_effects_skh$tot_effect_obs,
    # SOT 12
    delta_0_sot_12 = causal_effect_sot_12$delta_0,
    delta_1_sot_12 = causal_effect_sot_12$delta_1,
    delta_0_sot_12_obs = causal_effect_sot_12$delta_0_obs,
    delta_1_sot_12_obs = causal_effect_sot_12$delta_1_obs,
    zeta_0_sot_12 = causal_effect_sot_12$zeta_0,
    zeta_1_sot_12 = causal_effect_sot_12$zeta_1,
    tot_effect_sot_12 = causal_effect_sot_12$tot_effect,
    tot_effect_sot_12_obs = causal_effect_sot_12$tot_effect_obs,
    # SOT 21
    delta_0_sot_21 = causal_effect_sot_21$delta_0,
    delta_1_sot_21 = causal_effect_sot_21$delta_1,
    delta_0_sot_21_obs = causal_effect_sot_21$delta_0_obs,
    delta_1_sot_21_obs = causal_effect_sot_21$delta_1_obs,
    zeta_0_sot_21 = causal_effect_sot_21$zeta_0,
    zeta_1_sot_21 = causal_effect_sot_21$zeta_1,
    tot_effect_sot_21 = causal_effect_sot_21$tot_effect,
    tot_effect_sot_21_obs = causal_effect_sot_21$tot_effect_obs,
    n = n,
    seed = seed,
    mu0 = mu0,
    mu1 = mu1,
    r0 = r0,
    r1 = r1,
    a = a
  )
}

The simulations can be run in parallel, as follows.

# This chunk takes 3 minutes and 40 seconds to run
# (MB Pro 2023, Apple M2 Pro ship, 32 GB RAM).
# We do not evaluate when compiling the document.
# Instead, we load previously obtained results.
library(pbapply)
library(parallel)
ncl <- detectCores()-1
(cl <- makeCluster(ncl))

clusterEvalQ(cl, {
  library(tidyverse)
  library(mnormt)
  library(expm)
  library(randomForest)
}) |>
  invisible()

clusterExport(
  cl = cl, c(
    "gen_data", "compute_ot_map", "apply_ot_transport",
    "transport_regul", "transport_many_to_one",
    "sequential_transport_12", "sequential_transport_21",
    "causal_effects_cf", "sim_f"
  )
)

res_sim <- pbapply::pblapply(1:200, function(seed) {
  sim_f(n = 500, mu0 = -1, mu1 = +1, r0 = +.7, r1 = -.5, a = 1, seed = seed)
}, cl = cl)

stopCluster(cl)
res_sim <- list_rbind(res_sim)

save(res_sim, file = "../output/res_sim-gaussian-mc-a1.rda")

We load previously obtained results:

load("../output/res_sim-gaussian-mc-a1.rda")
Codes to create the Figure.
#' @parma metric_name Name of the metric (e.g., "tot_effect_ot")
#' @param method_label Label of the method (for title, e.g., "OT")
plot_hist_sim <- function(metric_name, 
                          method_label,
                          x_lim = NULL,
                          y_axis = TRUE,
                          export_tikz = FALSE) {
  # type <- match.arg(type)
  type <- case_when(
    str_detect(metric_name, "^delta") ~ "delta",
    str_detect(metric_name, "^zeta") ~ "zeta",
    str_detect(metric_name, "^tot_effect") & !str_detect(metric_name, "obs$") ~ "tau",
    str_detect(metric_name, "^tot_effect") & str_detect(metric_name, "obs$") ~ "theta",
    TRUE ~ NA_character_
  )
  group <- case_when(
    str_detect(metric_name, "_0_") ~ "0",
    str_detect(metric_name, "_1_") ~ "1",
    TRUE ~ ""
  )
  
  if (type == "delta") {
    title_lab <- paste0("$\\bar{\\delta}_", group, "$, ", method_label)
  } else if (type == "zeta") {
    title_lab <- paste0("$\\bar{\\zeta}_", group, "$, ", method_label)
  } else if (type == "tau") {
    title_lab <- paste0("$\\bar{\\tau}$, ", method_label)
  } else {
    title_lab <- paste0("$\\bar{\\theta}$, ", method_label)
  }
  
  if (export_tikz == FALSE) {
    title_lab <- latex2exp::TeX(title_lab)
  }
  
  if (group == "0") {
    fill_col <- colGpe0
  } else if (group == "1") {
    fill_col <- colGpe1
  } else {
    fill_col <- "gray"
  }
  
  if (is.null(x_lim)) x_lim <- range(res_sim |> pull(!!metric_name))
  
  hist(
    res_sim |> pull(!!metric_name), 
    main = title_lab,
    xlab = "",
    col = adjustcolor(fill_col, alpha = .5),
    xlim = x_lim,
    axes = FALSE
  )
  axis(1)
  if (y_axis == TRUE) axis(2)
  if (type == "delta") {
    abline(v = (a1+a2)*(mu1-mu0), col = "darkred", lty = 2, lwd = 2)
  } else if (type == "zeta") {
    abline(v = a0, col = "darkred", lty = 2, lwd = 2)
  } else if (type %in% c("tau", "theta")) {
    abline(v = (a1+a2)*(mu1-mu0) + a0, col = "darkred", lty = 2, lwd = 2)
  }
}


export_tikz <- FALSE
file_name <- "gaussian-hist-mc"

if (export_tikz == TRUE)
  tikz(paste0("figs/", file_name, ".tex"), width = 2.8, height = 2.8)

par(mar = c(2.1, 2.1, 2.1, .5), mfrow = c(6, 3))
# CM
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_0_med", method_label = "CM", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_1_med", method_label = "CM", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_med", method_label = "CM", x_lim = x_lim, export_tikz = export_tikz)
# OT
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_0_ot", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_1_ot", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_ot", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
# OT-M
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_0_tmatch", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_1_tmatch", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_tmatch", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)
# SKH
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_0_skh", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_1_skh", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_skh", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
# ST(1)
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_0_sot_12", method_label = "ST (1)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_1_sot_12", method_label = "ST (1)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_sot_12", method_label = "ST (1)", x_lim= x_lim, export_tikz = export_tikz)
# ST(2)
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_0_sot_21", method_label = "ST (2)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_1_sot_21", method_label = "ST (2)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_sot_21", method_label = "ST (2)", x_lim= x_lim, export_tikz = export_tikz)

if (export_tikz == TRUE) {
  dev.off()
  plot_to_pdf(filename = file_name, path = "./figs/", keep_tex = FALSE, crop = TRUE)
}
Figure 5.7: Estimated values of \(\bar{\delta}(0)\), \(\bar{\zeta}(1)\), and \(\bar{\tau}\) across 200 Monte Carlo simulations using Causal Mediation (CM), Optimal Transport (OT), Entropy Regularized Transports using Sinkhorn (SKH), and Sequential Transport ST (1) (moving \(X_1\) first) and ST (2) (moving \(X_2\) first). The red vertical bar denotes the theoretical value.
Codes to create the Figure.
par(mar = c(2.1, 2.1, 2.1, .5), mfrow = c(6, 3))
# CM
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_1_med", method_label = "CM", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_0_med", method_label = "CM", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_med", method_label = "CM", x_lim = x_lim, export_tikz = export_tikz)
# OT
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_1_ot", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_0_ot", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_ot", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
# OT-M
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_1_tmatch", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_0_tmatch", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_tmatch", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)
# SKH
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_1_skh", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_0_skh", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_skh", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
# ST(1)
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_1_sot_12", method_label = "ST (1)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_0_sot_12", method_label = "ST (1)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_sot_12", method_label = "ST (1)", x_lim= x_lim, export_tikz = export_tikz)
# ST(2)
x_lim <- c(0,5); plot_hist_sim(metric_name = "delta_1_sot_21", method_label = "ST (2)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(-1,5); plot_hist_sim(metric_name = "zeta_0_sot_21", method_label = "ST (2)", x_lim = x_lim, export_tikz = export_tikz)
x_lim <- c(3,5); plot_hist_sim(metric_name = "tot_effect_sot_21", method_label = "ST (2)", x_lim= x_lim, export_tikz = export_tikz)
Figure 5.8: Estimated values of \(\bar{\delta}(1)\), \(\bar{\zeta}(0)\), and \(\bar{\tau}\) across 200 Monte Carlo simulations using Causal Mediation (CM), Optimal Transport (OT), Entropy Regularized Transports using Sinkhorn (SKH), and Sequential Transport ST (1) (moving \(X_1\) first) and ST (2) (moving \(X_2\) first). The red vertical bar denotes the theoretical value.

The comparison of the estimates made using observed values for \(y\) (whenever observed) is shown in Figure 5.9.

Codes to create the Figure.
par(mar = c(2.1, 2.1, 2.1, .5), mfrow = c(6,2))
x_lim <- c(3,5)
plot_hist_sim(metric_name = "tot_effect_ot", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
plot_hist_sim(metric_name = "tot_effect_tmatch", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)

plot_hist_sim(metric_name = "tot_effect_ot_obs", method_label = "OT", x_lim = x_lim, export_tikz = export_tikz)
plot_hist_sim(metric_name = "tot_effect_tmatch_obs", method_label = "OT-M", x_lim = x_lim, export_tikz = export_tikz)


plot_hist_sim(metric_name = "tot_effect_skh", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
plot_hist_sim(metric_name = "tot_effect_sot_12", method_label = "ST (1)", x_lim = x_lim, export_tikz = export_tikz)


plot_hist_sim(metric_name = "tot_effect_skh_obs", method_label = "SKH", x_lim = x_lim, export_tikz = export_tikz)
plot_hist_sim(metric_name = "tot_effect_sot_12_obs", method_label = "ST (1)", x_lim = x_lim, export_tikz = export_tikz)



plot_hist_sim(metric_name = "tot_effect_sot_21", method_label = "ST (2)", x_lim = x_lim, export_tikz = export_tikz)
plot.new()

plot_hist_sim(metric_name = "tot_effect_sot_21_obs", method_label = "ST (2)", x_lim = x_lim, export_tikz = export_tikz)
plot.new()
Figure 5.9: Estimated values of \(\bar{\tau}\) and \(\bar{\theta}\) across 200 Monte Carlo simulations using ptimal Transport (OT), and Sequential Transport ST (1) (moving \(X_1\) first) and ST (2) (moving \(X_2\) first). The red vertical bar denotes the theoretical value.

We also use violin plots, since exporting these histograms in a two-column format paper is not a good idea.

Codes to create the Figure.
export_pdf <- FALSE

p <- ggplot(
  data = res_sim |> 
    select(
      delta_0_med, delta_0_ot, delta_0_tmatch, 
      delta_0_skh, delta_0_sot_12, delta_0_sot_21,
      #
      zeta_1_med, zeta_1_ot, zeta_1_tmatch, 
      zeta_1_skh, zeta_1_sot_12, zeta_1_sot_21,
      #
      tot_effect_med, tot_effect_ot, tot_effect_tmatch, 
      tot_effect_skh, tot_effect_sot_12, tot_effect_sot_21
    ) |> 
    mutate(row = row_number()) |> 
    pivot_longer(cols = -row) |> 
    mutate(
      type = case_when(
        str_detect(name, "^delta") ~ "delta",
        str_detect(name, "^zeta") ~ "zeta",
        str_detect(name, "^tot_effect") & !str_detect(name, "obs$") ~ "tau",
        TRUE ~ NA_character_
      ),
      type = factor(
        type, 
        levels = c("delta", "zeta", "tau"),
        labels = c("$\\bar{\\delta}(0)$", "$\\bar{\\zeta}(1)$", 
                   "$\\bar{\\tau}$")
      ),
      method = case_when(
        str_detect(name, "_med$") ~ "CM",
        str_detect(name, "_ot$") ~ "OT",
        str_detect(name, "_tmatch$") ~ "OT-M",
        str_detect(name, "_skh$") ~ "SKH",
        str_detect(name, "_sot_12$") ~ "ST(1)",
        str_detect(name, "sot_21$") ~ "ST(2)",
        TRUE ~ ""
      ),
      method = factor(
        method, levels = rev(c("CM", "OT", "OT-M", "SKH", "ST(1)", "ST(2)"))
      )
    )
) +
  geom_violin(
    mapping = aes(x = value, y = method, fill = method),
    draw_quantiles = c(.25, .5, .75)) +
  labs(x = NULL, y = NULL)
Warning: The `draw_quantiles` argument of `geom_violin()` is deprecated as of ggplot2
4.0.0.
ℹ Please use the `quantiles.linetype` argument instead.
Codes to create the Figure.
if (export_pdf == TRUE) {
  p <- p + 
    facet_wrap(
      ~ type, scales = "free_x"
    )
} else {
  p <- p +
    facet_wrap(
      ~ type, scales = "free_x",
      labeller = as_labeller(latex2exp::TeX, default = label_parsed)
    )
}

p <- p +
  geom_vline(
    data = tibble(
      type = c("$\\bar{\\delta}(0)$", "$\\bar{\\zeta}(1)$", 
               "$\\bar{\\tau}$"), 
      val_theo = c(
        (a1+a2)*(mu1-mu0),
        a0,
        (a1+a2)*(mu1-mu0) + a0
      )
    ) |> 
      mutate(
        type = factor(
          type, 
          levels = c("$\\bar{\\delta}(0)$", "$\\bar{\\zeta}(1)$", 
                     "$\\bar{\\tau}$")
        )
      ),
    mapping = aes(xintercept = val_theo),
    colour = "darkred", linetype = "dashed", linewidth = 1
  ) +
  scale_fill_manual(
    NULL, 
    values = c(
      "CM" = "#56B4E9",
      "OT" = colour_methods[["OT"]], 
      "OT-M" = colour_methods[["OT-M"]], 
      "SKH" = colour_methods[["skh"]], 
      "ST(1)" =  colour_methods[["seq_1"]],
      "ST(2)" =  colour_methods[["seq_2"]]
    ),
    guide = "none"
  ) +
  theme_paper()

p

if (export_pdf == TRUE) {
  ggplot2_to_pdf(
    plot = p + theme(panel.spacing = unit(0.4, "lines")) +
      scale_x_continuous(
        labels = function(x) paste0("$", x, "$")
      ),
    filename = "gaussian-violin-mc", path = "figs/", 
    width = 3.4, height = 2,
    crop = TRUE
  )
  
  system(paste0("pdfcrop figs/gaussian-violin-mc.pdf figs/gaussian-violin-mc.pdf"))
}
Figure 5.10: Estimated values of \(\bar{\delta}(0)\), \(\bar{\zeta}(1)\), and \(\bar{\tau}\) across 200 Monte Carlo simulations using Causal Mediation (CM), Optimal Transport (OT), Entropy Regularized Transports using Sinkhorn (SKH), and Sequential Transport ST (1) (moving \(X_1\) first) and ST (2) (moving \(X_2\) first). The red vertical bar denotes the theoretical value.