5  Wrong Causal Assumptions

Objectives

In this chapter, we generate a dataset consisting of an outcome variable \(Y \in \{0,1\}\), two covariates \(\boldsymbol{X} = (X_1, X_2)\), and a sensitive attribute \(S \in \{0,1\}\). The two covariates are independently distributed. We then explore different causal assumptions: the correct assumption, a first incorrect assumption where \(X_1\) causes \(X_2\), and a second incorrect assumption where \(X_2\) causes \(X_1\). We then build various types of counterfactuals: (i) naive counterfactuals, where the sensitive attribute for individuals in the protected group (\(S=0\)) is changed to \(S=1\) while all other characteristics remain unchanged (ceteris paribus), (ii) multivariate optimal transport, and (iii) sequential transport under the three causal assumptions. The consequences of these assumptions are then examined on an individual basis, and then over the entire protected group.

Required packages and definition of colours.
# Required packages----
library(tidyverse)
library(ks)
library(expm)

# Graphs----
font_main = font_title = 'Times New Roman'
extrafont::loadfonts(quiet = T)
face_text='plain'
face_title='plain'
size_title = 14
size_text = 11
legend_size = 11

global_theme <- function() {
  theme_minimal() %+replace%
    theme(
      text = element_text(family = font_main, size = size_text, face = face_text),
      legend.text = element_text(family = font_main, size = legend_size),
      axis.text = element_text(size = size_text, face = face_text), 
      plot.title = element_text(
        family = font_title, 
        size = size_title, 
        hjust = 0.5
      ),
      plot.subtitle = element_text(hjust = 0.5)
    )
}

# Seed
set.seed(2025)

colours <- c(
  `0` = "#5BBCD6",
  `1` = "#FF0000",
  A = "#00A08A",
  B = "#F2AD00",
  with = "#046C9A",
  without = "#C93312",
  `2` = "#0B775E"
)

We load our small package to have access to the sequential tranposort function, as well as the function used to generate tikz pictures for causal graphs.

library(devtools)
Loading required package: usethis
load_all("../seqtransfairness/")
ℹ Loading seqtransfairness

Let \(S\in\{0,1\}\) be the sensitive attribute.

We generate two variables:

These variables are independent from one another.

set.seed(123) # set the seed for reproductible results
n <- 500

# group s = 0
X1_0 <- runif(n, 0, 1)
X2_0 <- runif(n, 0, 1)
# group s = 1
X1_1 <- runif(n, 1, 2)
X2_1 <- runif(n, 1, 2)

We simulate a binary response variable with a logistic model depending on covariates \(S(\boldsymbol{X})\).

# Drawing random binary response variable Y with logistic model for each group
eta_1 <- (X1_1 * .8 + X2_1 / 2 * 1.2) / 2
eta_0 <- (X1_0 * 1.2 + X2_0 / 2 * .8) / 2
p_1 <- exp(eta_1) / (1 + exp(eta_1))
p_0 <- exp(eta_0) / (1 + exp(eta_0))

Y_0 <- rbinom(n, size = 1, prob = p_0)
Y_1 <- rbinom(n, size = 1, prob = p_1)

D_SXY_0 <- tibble(
  S = 0,
  X1 = X1_0,
  X2 = X2_0,
  Y = Y_0
)
D_SXY_1 <- tibble(
  S = 1,
  X1 = X1_1,
  X2 = X2_1,
  Y = Y_1
)

D_SXY <- bind_rows(D_SXY_0, D_SXY_1)

Let us have a look at the distribution of the drawn probabilities (and observed thanks to the simulation setup).

# Group 0
par(mar = c(5.1, 4.1, 2.1, 2.1))
hist(
  p_0, 
  col = colours[["A"]],
  breaks = 20,
  xlim = c(0, 1),
  ylim = c(0,100),
  xlab = "Drawn Probabilities",
  main = NA
)
# Group 1
hist(
  p_1, 
  col = colours[["B"]],
  breaks = 10, 
  add = TRUE
)
legend(
  "topleft", 
  legend = c("p_0", "p_1"), 
  fill = colours[c("A", "B")]
)
Figure 5.1: Histogram with probability values for both groups

We can visualize the estimated density of the two covariates in each group (\(S=0\) and \(S=1\)).

Codes used to create the Figure.
# Computation of smoothing parameters (bandwidth) for kernel density estimation
H0 <- Hpi(D_SXY_0[, c("X1","X2")])
H1 <- Hpi(D_SXY_1[, c("X1","X2")])

# Calculating multivariate densities in each group
f0_2d <- kde(D_SXY_0[, c("X1","X2")], H = H0, xmin = c(-1, -1), xmax = c(5, 3))
f1_2d <- kde(D_SXY_1[, c("X1","X2")], H = H1, xmin = c(-1, -1), xmax = c(5, 3))

par(mar = c(2,2,0,0))
plot(
  c(-1,3), c(-1,3),
  xlab = "", ylab = "", axes = FALSE, col = NA,
  xlim = c(-1, 3), ylim = c(-1, 3)
)
axis(1)
axis(2)
# Observed values
points(D_SXY_0$X1, D_SXY_0$X2, col = alpha(colours["A"], .4), pch = 19, cex = .3)
# Estimated density
contour(
  f0_2d$eval.point[[1]],
  f0_2d$eval.point[[2]],
  f0_2d$estimate,col=scales::alpha(colours["A"], 1),
  add = TRUE
)
points(D_SXY_1$X1, D_SXY_1$X2, col = alpha(colours["B"], .4), pch = 19, cex = .3)
contour(
  f1_2d$eval.point[[1]],
  f1_2d$eval.point[[2]],
  f1_2d$estimate,col=scales::alpha(colours["B"], 1),
  add = TRUE
)
Figure 5.2: Bivariate Gaussian densities within each group (\(S=0\)) on the left and \(S=1\) on the right), estimated with a Gaussian kernel.

5.1 Assumed Causal Models

In this simulation setup where the covariates \(X_1\) and \(X_2\) are independent, we will assume two causal models:

  1. Correct assumption: we will assume that there is no dependence between \(X_1\) and \(X_2\) (as shown in Figure 5.3)
  2. Wrong Dependence 1: we will asumme that \(X_2\) depends on \(X_1\) (as shown in Figure 5.4)
  3. Wrong Dependence 2: we will asumme that \(X_1\) depends on \(X_2\) (as shown in Figure 5.5).
variables <- c("S", "X1", "X2", "Y")

5.1.1 Correct Assumption

adj_indep_correct <- matrix(
  # S  X1 X2 Y
  c(0, 1, 1, 1, # S
    0, 0, 0, 1, # X1
    0, 0, 0, 1, # X2
    0, 0, 0, 0  # Y
  ),
  ncol = length(variables),
  dimnames = rep(list(variables), 2),
  byrow = TRUE
)
Codes to create Tikz from an adjacency matrix.
#' Add a tikz graph in a quarto HTML document
#'
#' @param tikz_code Tikz code.
add_tikz_graph <- function(tikz_code,
                           label,
                           caption = "Causal Graph",
                           echo = "true",
                           code_fold = "true",
                           fig_ext = "png",
                           code_summary = "Tikz code") {
  
  res <- knitr::knit_child(
    text = glue::glue(r"(
             ```{tikz}
             #| echo: {echo}
             #| label: {label}
             #| fig-cap: {caption}
             #| fig-ext: {fig_ext}
             #| code-fold: {code_fold}
             #| code-summary: {code_summary}
             \usetikzlibrary{{arrows}}
             {tikz_code}
             ```)"
    ),
    quiet = TRUE
  )
  knitr::asis_output(res)
}

colour_nodes <- c(
  "S" = "red!30",
  "X1" = "yellow!60", 
  "X2" = "yellow!60", 
  "Y" = "blue!30"
)
# Then, in the document:
# `r add_tikz_graph(tikz_code = causal_graph_tikz(adj_indep_correct,colour_nodes), label = "fig-causal-graph-indep-correct", caption = "\"Assumed Causal Graph: Correct Model\"", echo = "false")`
Figure 5.3: Assumed Causal Graph: Correct Model

5.1.2 Wrong Assumption: \(X_1\) causes \(X_2\)

adj_indep_inc_x1_then_x2 <- matrix(
  # S  X1 X2 Y
  c(0, 1, 1, 1, # S
    0, 0, 1, 1, # X1
    0, 0, 0, 1, # X2
    0, 0, 0, 0  # Y
  ),
  ncol = length(variables),
  dimnames = rep(list(variables), 2),
  byrow = TRUE
)
Figure 5.4: Assumed Causal Graph: Incorrecly Assumes that \(X_1\) causes \(X_2\)

5.1.3 Wrong Assumption: \(X_2\) causes \(X_1\)

adj_indep_inc_x2_then_x1 <- matrix(
  # S  X1 X2 Y
  c(0, 1, 1, 1, # S
    0, 0, 0, 1, # X1
    0, 1, 0, 1, # X2
    0, 0, 0, 0  # Y
  ),
  ncol = length(variables),
  dimnames = rep(list(variables), 2),
  byrow = TRUE
)
Figure 5.5: Assumed Causal Graph: Incorrecly Assumes that \(X_2\) causes \(X_1\)

5.2 Counterfactuals

We now build counterfactuals for individuals from group \(S=0\). First, we get the counterfactuals using optimal transport.

5.2.1 Multivariate OT

To compute the counterfactual with multivariate optimal transport, we turn to python, to use the {POT: Python Optimal Transport} library.

We export the data into a CSV file:

write_csv(D_SXY, file = "../data/D_SXY_indep.csv")

And we use the {reticulate} R package to run python in this chapter.

library(reticulate)
# reticulate::install_miniconda(force = TRUE)
use_virtualenv("~/quarto-python-env", required = TRUE)
# py_install("POT")

Some libraries need to be loaded, including the {POT} library, called ot.

import ot
import pandas as pd
import numpy as np
import matplotlib.pyplot as pl
import ot.plot

The data with the factuals need to be loaded:

tb = pd.read_csv('../data/D_SXY_indep.csv')
x_S = tb.drop(columns=['Y'])
x_S.head()
   S        X1        X2
0  0  0.287578  0.353606
1  0  0.788305  0.366441
2  0  0.408977  0.287100
3  0  0.883017  0.079973
4  0  0.940467  0.365454

We will also make some predictions for an additional point: \((s=0, x_1= 0, x_2=0.5)\).

new_point = pd.DataFrame({'S': [0], 'X1': [0], 'X2': [0.5]})
x_S = pd.concat([x_S, new_point], ignore_index=True)

x_0 = x_S[x_S['S'] == 0]
x_0 = x_0.drop(columns=['S'])
x_1 = x_S[x_S['S'] == 1]
x_1 = x_1.drop(columns=['S'])

n_0 = len(x_0)
n_1 = len(x_1)
# Uniform weights
w_0 = (1/n_0)*np.ones(n_0)
w_1 = (1/n_1)*np.ones(n_1)

Cost matrix between both distributions:

x_0 = x_0.to_numpy()
x_1 = x_1.to_numpy()
C = ot.dist(x_0, x_1)

Transport plan: from 0 to 1

pi_0_1 = ot.emd(w_0, w_1, C, numItermax=1e8)
pi_1_0 = pi_0_1.T
pi_0_1.shape
(501, 500)
sum_of_rows = np.sum(pi_0_1, axis=1)
sum_of_rows*n_0
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1.])
pi_1_0.shape
(500, 501)
sum_of_rows = np.sum(pi_1_0, axis=1)
sum_of_rows*n_1
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1.])

The transported values:

transformed_x_0 = n_0*pi_0_1@x_1
transformed_x_1 = n_1*pi_1_0@x_0

Let us build an array with the values of the counterfactuals (for both groups here, even if we will only use the counterfactuals for group \(S=0\) only).

counterfactual_x = x_S.drop(columns=['S'])
counterfactual_x[x_S['S'] == 0] = transformed_x_0
counterfactual_x[x_S['S'] == 1] = transformed_x_1

Lastly, we export this array in a CSV file:

csv_file_path = '../data/counterfactuals_ot_test_indep.csv'
counterfactual_x.to_csv(csv_file_path, index=False)

We can then go back to R.

counterfactuals_ot <- 
  read_csv("../data/../data/counterfactuals_ot_test_indep.csv")
Rows: 1001 Columns: 2
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (2): X1, X2

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
counterfactuals_ot
# A tibble: 1,001 × 2
      X1    X2
   <dbl> <dbl>
 1  1.29  1.39
 2  1.80  1.32
 3  1.45  1.26
 4  1.88  1.05
 5  1.88  1.36
 6  1.01  1.17
 7  1.57  1.51
 8  1.87  1.50
 9  1.58  1.94
10  1.48  1.35
# ℹ 991 more rows

Recall that we added an additional point. The last observation from this table corresponds to the counterfactual of \((s=0, x_1=0, x_2=0.5)\).

5.2.2 Sequential Transport

We rely on our function, seq_trans() available in our small R package (type ?seq_trans in the console after loading the package.

We compute the counterfactuals for individuals from group \(S=0\) for each assumed causal graph.

trans_indep_correct <- seq_trans(
  data = D_SXY, adj = adj_indep_correct, s = "S", S_0 = 0, y = "Y"
)
Transporting  X1 
Transporting  X2 
trans_indep_inc_x1_then_x2 <- seq_trans(
  data = D_SXY, adj = adj_indep_inc_x1_then_x2, s = "S", S_0 = 0, y = "Y"
)
Transporting  X1 
Transporting  X2 
trans_indep_inc_x2_then_x1 <- seq_trans(
  data = D_SXY, adj = adj_indep_inc_x2_then_x1, s = "S", S_0 = 0, y = "Y"
)
Transporting  X2 
Transporting  X1 

Let us also get a counterfactual for the additional point \((s=0, x_1=0, x_2=0.5)\).

new_obs <- tibble(S = 0, X1 = .5, X2 = .5)
new_obs_indep_correct <- seq_trans_new(
  x = trans_indep_correct, newdata = new_obs, data = D_SXY
)
new_obs_indep_inc_x1_then_x2 <- seq_trans_new(
  x = trans_indep_inc_x1_then_x2, newdata = new_obs, data = D_SXY
)
new_obs_indep_inc_x2_then_x1 <- seq_trans_new(
  x = trans_indep_inc_x2_then_x1, newdata = new_obs, data = D_SXY
)

5.3 Hypothetical Model

Assume the scores obtained from a logistic regression write:

\[ m(x_1,x_2,s)=\big(1+\exp\big[-\big((x_1+x_2)/2 + \boldsymbol{1}(s=1)\big)\big]\big)^{-1}. \]

#' Logistic regression
#' 
#' @param x1 first numerical predictor
#' @param x2 second numerical predictor
#' @param s sensitive attribute (0/1)
logistique_reg <- function(x1, x2, s) {
  eta <- (x1 + x2) / 2 + s
  exp(eta) / (1 + exp(eta))
}

We can use this hypothetical model to make predictions given the values of \(x_1\) and \(x_2\). Let us make some on a grid.

vx0 <- seq(-5, 5, length = 251)
data_grid <- expand.grid(x = vx0, y = vx0)
L0 <- logistique_reg(x1 = data_grid$x, x2 = data_grid$y, s = 0)
L1 <- logistique_reg(x1 = data_grid$x, x2 = data_grid$y, s = 1)
# as a grid:
dlogistique0 <- matrix(L0, length(vx0), length(vx0))
dlogistique1 <- matrix(L1, length(vx0), length(vx0))

5.4 Visualization

Let us create a table with the coordinates of different points: the new individual and its new coordinates depending on the transport method.

coords_indep <- tibble(
  # 1. (x1, x2)
  start = c(new_obs$X1, new_obs$X2, 0),
  #
  # 2. (T_1(x1 | x2), T_2(x2)), correct assumption
  correct = c(
    new_obs_indep_correct$X1, new_obs_indep_correct$X2, 1
  ),
  # 3. (T_1(x1), T_2(x2 | x1)), assuming X1 -> X2
  inc_x1_then_x2 = c(
    new_obs_indep_inc_x1_then_x2$X1, new_obs_indep_inc_x1_then_x2$X2, 1
  ),
  # 4. (T_1(x1 | x2), T_2(x2)), assuming X2 -> X1
  inc_x2_then_x1 = c(
    new_obs_indep_inc_x2_then_x1$X1, new_obs_indep_inc_x2_then_x1$X2, 1
  ),
  #
  # 5. (T_1(x1), x2)
  correct_interm_x2 = c(new_obs_indep_correct$X1, new_obs$X2, 0.5),
  # 6. (T_1(x1), x2), assuming X1 -> X2
  inc_x1_then_x2_interm_x2 = c(new_obs_indep_inc_x1_then_x2$X1, new_obs$X2, 0.5),
  # 7. (T_1(x1), x2), assuming X2 -> X1
  inc_x2_then_x1_interm_x2 = c(new_obs_indep_inc_x2_then_x1$X1, new_obs$X2, 0.5),
  #
  # 8. (x1, T(x2))
  correct_interm_x1 = c(new_obs$X1, new_obs_indep_correct$X2, 0.5),
  # 9. (T_1(x1), x2), assuming X1 -> X2
  inc_x1_then_x2_interm_x1 = c(new_obs$X1, new_obs_indep_inc_x1_then_x2$X2, 0.5),
  # 10. (T_1(x1), x2), assuming X2 -> X1
  inc_x2_then_x1_interm_x1 = c(new_obs$X1, new_obs_indep_inc_x2_then_x1$X2, 0.5),
  # 11. T*(x1,x2)
  ot = c(last(counterfactuals_ot$X1), last(counterfactuals_ot$X2), 1)
)

We make “Predictions” by the hypothetical model for the new point, given the characteristics \(x_1\) and \(x_2\) contained in coords_indep, setting \(s=0\).

predicted_val <- apply(
  coords_indep, 
  2, 
  function(column) logistique_reg(x1 = column[1], x2 = column[2], s = column[3])
)

Then, we add a prediction for the point using the factuals values:

predicted_val <- c(
  naive = logistique_reg(
    x1 = coords_indep$start[1], x2 = coords_indep$start[2], s = 1
  ),
  predicted_val
)

Then, we can plot all this information on a graph. The arrows show the path from the new individual to its counterfactual assuming the two different causal relationships. The red square shows the value obtained with multivariate optimal transport.

Code
# Colour scale from colour of class 0 to class 1
colfunc <- colorRampPalette(c(colours["0"], colours["1"]))
scl <- scales::alpha(colfunc(9),.9)

par(mar = c(2, 2, 0, 0))
# Group 0
## Estimated density: level curves for (x1, x2) -> m(0, x1, x2)
CeX <- 1
contour(
  f0_2d$eval.point[[1]],
  f0_2d$eval.point[[2]],
  f0_2d$estimate,
  col = scales::alpha(colours["A"], .3),
  axes = FALSE, xlab = "", ylab = "",
  xlim = c(-1, 3), ylim = c(-1, 3)
)
# Group 1
## Estimated density: level curves for (x1, x2) -> m(1, x1, x2)
contour(
  f1_2d$eval.point[[1]],
  f1_2d$eval.point[[2]],
  f1_2d$estimate,
  col = scales::alpha(colours["B"], .3), add = TRUE
)
contour(
  vx0, vx0, dlogistique0,
  levels = seq(0.05, 0.95, by = 0.05),
  col = scl,
  add = TRUE,lwd = 2
)
axis(1)
axis(2)

###
# Individual (S=0, x1=.5, x2=.5)
###
points(coords_indep$start[1], coords_indep$start[2], pch = 19, cex = CeX, col = "darkblue")
## Predicted value for the individual, based on factuals
text(
  coords_indep$start[1], coords_indep$start[2], 
  paste(round(predicted_val[2] * 100, 1), "%", sep = ""), 
  pos = 1, cex = CeX, col = "darkblue"
)
Figure 5.6: In the background, level curves for \((x_1,x_2)\mapsto m(0,x_1,x_2)\). The blue point corresponds to the individual \((s,x_1,x_2)=(s=0,.5,.5)\) (predicted 62.2% by model \(m\)).
Code
colour_start <- "darkblue"
colour_correct <- "#CC79A7"
colour_inc_x1_then_x2 <- "darkgray"
colour_inc_x2_then_x1 <- "#56B4E9"
colour_ot <- "#C93312"
colour_naive <- "#000000"

CeX <- 1
par(mar = c(2, 2, 0, 0))
# Group 0
## Estimated density: level curves for (x1, x2) -> m(0, x1, x2)
contour(
  f0_2d$eval.point[[1]],
  f0_2d$eval.point[[2]],
  f0_2d$estimate,
  col = scales::alpha(colours["A"], .3),
  axes = FALSE, xlab = "", ylab = "",
  xlim = c(-1, 3), ylim = c(-1, 3)
)
# Group 1
## Estimated density: level curves for (x1, x2) -> m(1, x1, x2)
contour(
  f1_2d$eval.point[[1]],
  f1_2d$eval.point[[2]],
  f1_2d$estimate,
  col = scales::alpha(colours["B"], .3), add = TRUE
)

# Contour of estimates by the model for s=1
contour(
  vx0, vx0, dlogistique1,
  levels = seq(0.05, 0.95, by = 0.05),
  col = scl, lwd=2,
  add = TRUE
)
axis(1)
axis(2)
###
# Individual (s=0, x1=-2, x2=-1)
###
points(
  coords_indep$start[1], coords_indep$start[2], 
  pch = 19, cex = CeX, col = colour_start
)
## Predicted value for the individual, based on factuals
text(
  coords_indep$start[1], coords_indep$start[2], 
  paste(round(predicted_val["start"] * 100, 1), "%", sep = ""),
  pos = 1, cex = CeX, col = colour_start
)

###
# Transported individual using correct DAG
###
points(
  coords_indep$correct[1], coords_indep$correct[2],
  pch = 19, cex = CeX, col = colour_correct
)
text(
  coords_indep$correct[1]-0.1, coords_indep$correct[2]+0.15,
  paste(round(predicted_val[["correct"]]*100,1),"%",sep=""),
  pos = 4, cex = CeX, col = colour_correct
)
segments(
  x0 = coords_indep$start[1], y0 = coords_indep$start[2],
  x1 = coords_indep$correct_interm_x2[1], y1 = coords_indep$correct_interm_x2[2],
  lwd = .8, col = colour_correct
)
segments(
  x0 = coords_indep$correct_interm_x2[1], y0 = coords_indep$correct_interm_x2[2],
  x1 = coords_indep$correct[1], y1 = coords_indep$correct[2],
  lwd = .8, col = colour_correct
)
## Intermediate point
points(
  coords_indep$correct_interm_x2[1], coords_indep$correct_interm_x2[2],
  pch = 19, col = "white", cex = CeX
)
points(
  coords_indep$correct_interm_x2[1], coords_indep$correct_interm_x2[2],
  pch = 1, cex = CeX, col = colour_correct
)


###
# Transported individual assuming X2 depends on X1
###
points(
  coords_indep$inc_x1_then_x2[1], coords_indep$inc_x1_then_x2[2],
  pch=19,cex=CeX, col = colour_inc_x1_then_x2
)
segments(
  x0 = coords_indep$start[1], y0 = coords_indep$start[2],
  x1 = coords_indep$inc_x1_then_x2[1], y1 = coords_indep$inc_x1_then_x2_interm_x2[2],
  lwd = .8, col = colour_inc_x1_then_x2
)
segments(
  x0 = coords_indep$inc_x1_then_x2[1], y0 = coords_indep$inc_x1_then_x2_interm_x2[2],
  x1 = coords_indep$inc_x1_then_x2[1], y1 = coords_indep$inc_x1_then_x2[2],
  lwd = .8, col = colour_inc_x1_then_x2
)
## Intermediate point
points(
  coords_indep$inc_x1_then_x2[1], coords_indep$inc_x1_then_x2_interm_x2[2],
  pch = 19, col = "white", cex = CeX
)
points(
  coords_indep$inc_x1_then_x2[1], coords_indep$inc_x1_then_x2_interm_x2[2],
  pch = 1, cex = CeX, col = colour_inc_x1_then_x2
)
## New predicted value
text(
  coords_indep$inc_x1_then_x2[1]+0.6, coords_indep$inc_x1_then_x2[2],
  paste(round(predicted_val[["inc_x1_then_x2"]]*100,1),"%",sep=""),
  pos = 2, cex = CeX, col = colour_inc_x1_then_x2
)


###
# Transported individual assuming X1 depends on X2
###
points(
  coords_indep$inc_x2_then_x1[1], coords_indep$inc_x2_then_x1[2],
  pch=19,cex=CeX, col = colour_inc_x2_then_x1
)
segments(
  x0 = coords_indep$start[1], y0 = coords_indep$start[2],
  x1 = coords_indep$inc_x2_then_x1_interm_x1[1], y1 = coords_indep$inc_x2_then_x1[2],
  lwd = .8, col = colour_inc_x2_then_x1
)
segments(
  x0 = coords_indep$inc_x2_then_x1_interm_x1[1], y0 = coords_indep$inc_x2_then_x1[2],
  x1 = coords_indep$inc_x2_then_x1[1], y1 = coords_indep$inc_x2_then_x1[2],
  lwd = .8, col = colour_inc_x2_then_x1
)
## Intermediate point
points(
  coords_indep$inc_x2_then_x1_interm_x1[1], coords_indep$inc_x2_then_x1[2],
  pch = 19, col = "white", cex = CeX
)
points(
  coords_indep$inc_x2_then_x1_interm_x1[1], coords_indep$inc_x2_then_x1[2],
  pch = 1, cex = CeX, col = colour_inc_x2_then_x1
)
## New predicted value
text(
  coords_indep$inc_x2_then_x1[1]+0.3, coords_indep$inc_x2_then_x1[2]-0.1,
  paste(round(predicted_val[["inc_x2_then_x1"]]*100,1),"%",sep=""),
  pos = 3, cex = CeX, col = colour_inc_x2_then_x1
)


###
# Transported individual with multivariate optimal transport
###
points(
  coords_indep$ot[1], coords_indep$ot[2],
  pch=15,cex=CeX, col = colour_ot
)
segments(
  x0 = coords_indep$start[1], y0 = coords_indep$start[2],
  x1 = coords_indep$ot[1], y1 = coords_indep$ot[2],
  lwd = .8, col = colour_ot, lty = 2
)
## New predicted value
text(
  coords_indep$ot[1]-0.1, coords_indep$ot[2]+0.1,
  paste(round(predicted_val[["ot"]] * 100, 1), "%", sep = ""),
  pos = 4, cex = CeX, col = colour_ot
)

###
# New predicted value for (do(s=1), x1, x2), no transport
###
ry <- .09
plotrix::draw.circle(
  x = coords_indep$start[1] - .9*ry * sqrt(2), 
  y = coords_indep$start[2] - .9*ry * sqrt(2),
  radius = ry * sqrt(2), border = colour_naive
)
text(
  coords_indep$start[1], coords_indep$start[2],
  paste(round(predicted_val["naive"] * 100, 1), "%", sep = ""),
  pos = 3, cex = CeX, col = colour_naive
)
legend(
  "topleft",
  legend = c(
    "Start point",
    "Naive",
    "Multivariate optimal transport",
    "Seq. T.: Correct DAG",
    latex2exp::TeX("Seq. T.: Assuming $X_1 \\rightarrow X_2$"),
    latex2exp::TeX("Seq. T.: Assuming $X_2 \\rightarrow X_1$")
  ),
  col = c(
    colour_start,
    colour_naive,
    colour_ot,
    colour_correct,
    colour_inc_x1_then_x2,
    colour_inc_x2_then_x1
  ),
  pch = c(19, 19, 15, 19, 19, 19),
  lty = c(NA, NA, 2, 1, 1, 1), bty = "n"
)
Figure 5.7: In the background, level curves for\(m(1,x_1,x_2)\). The dark blue dot represents an individual \((s,x_1,x_2)=(s=0, 0.5, 0.5)\) (predicted 62.2% by model \(m\), and 81.8% if \(s\) is set to 1 leaving \(x_1\) and \(x_2\) unchanged). The other dots represent counterfactuals \((s=1,x_1^\star,x_2^\star)\) according to the assumed causal graph where \(x_1\) and \(x_2\) are independent (correct DAG, predicted 92.5%), \(x_2\) depends on \(x_1\) (bottom right path, predicted 91.4%), \(x_1\) depends on \(x_2\) (top left path, predicted 92.9%). The red square shows the counterfactual obtained with optimal transport (with a predicted value by model \(m\) at 90.4%).

Let us now consider all the points from \(S=0\) and not a single one.

We estimate the densities of \((T(X_1), T(X_2))\) in each of the four configurations. We can then plot these estimated densities on top of those previously estimated on \((X_1, X_2)\) using the factuals.

We define a table with the counterfactuals obtained with multivariate optimal transport on the subgroup of \(S=0\):

tb_transpoort_ot <- counterfactuals_ot |> 
  slice(1:nrow(D_SXY)) |> # remove last observation: this is the new point
  mutate(S = D_SXY$S) |> 
  filter(S == 0) |> select(-S)
Codes to create the Figure.
H1_OT <- Hpi(tb_transpoort_ot)
H1_indep_correct <- Hpi(as_tibble(trans_indep_correct$transported))
H1_indep_inc_x1_then_x2 <- Hpi(as_tibble(trans_indep_inc_x1_then_x2$transported))
H1_indep_inc_x2_then_x1 <- Hpi(as_tibble(trans_indep_inc_x2_then_x1$transported))
f1_2d_OT <- kde(tb_transpoort_ot, H = H1_OT, xmin = c(-5, -5), xmax = c(5, 5))
f1_2d_indep_correct <- kde(
  as_tibble(trans_indep_correct$transported), 
  H = H1_indep_correct, xmin = c(-5, -5), xmax = c(5, 5)
)
f1_2d_indep_inc_x1_then_x2 <- kde(
  as_tibble(trans_indep_inc_x1_then_x2$transported), 
  H = H1_indep_inc_x1_then_x2, xmin = c(-5, -5), xmax = c(5, 5)
)
f1_2d_indep_inc_x2_then_x1 <- kde(
  as_tibble(trans_indep_inc_x2_then_x1$transported), 
  H = H1_indep_inc_x2_then_x1, xmin = c(-5, -5), xmax = c(5, 5)
)

x_lim <- c(-.5, 2.5)
y_lim <- c(-.25, 3)

# Plotting densities
par(mar = c(2,2,0,0), mfrow = c(2,2))
# Group S=0
contour(
  f0_2d$eval.point[[1]], f0_2d$eval.point[[2]], f0_2d$estimate, 
  col = colours["A"], axes = FALSE, xlab = "", ylab = "", 
  xlim = x_lim, ylim = y_lim
)
axis(1)
axis(2)
# Group S=1
contour(
  f1_2d$eval.point[[1]], f1_2d$eval.point[[2]], f1_2d$estimate, 
  col = colours["B"], add = TRUE
)

# Group S=1, Optimal transport
contour(
  f1_2d_OT$eval.point[[1]], f1_2d_OT$eval.point[[2]], f1_2d_OT$estimate, 
  col = colour_ot, add = TRUE
)
legend(
  "topleft", legend = c("Obs S=0", "Obs S=1", "OT"),
  lty=1,
  col = c(colours["A"], colours["B"], colour_ot), bty="n"
)



# Group S=1, Sequential transport, correct DAG
# par(mar = c(2,2,0,0))
# Group S=0
contour(
  f0_2d$eval.point[[1]], f0_2d$eval.point[[2]], f0_2d$estimate, 
  col = colours["A"], axes = FALSE, xlab = "", ylab = "",
  xlim = x_lim, ylim = y_lim
)
axis(1)
axis(2)
# Group S=1
contour(
  f1_2d$eval.point[[1]], f1_2d$eval.point[[2]], f1_2d$estimate, 
  col = colours["B"], add = TRUE
)
contour(
  f1_2d_indep_correct$eval.point[[1]], f1_2d_indep_correct$eval.point[[2]], 
  f1_2d_indep_correct$estimate, 
  col = colour_correct, add = TRUE
)
legend(
  "topleft", legend = c("Obs S=0", "Obs S=1", "Seq. T: Correct DAG"),
  lty=1,
  col = c(colours["A"], colours["B"], colour_correct), bty="n"
)


# Group S=1, Sequential transport, Wrong DAG: assuming X2 depends on X1
# par(mar = c(2,2,0,0))
# Group S=0
contour(
  f0_2d$eval.point[[1]], f0_2d$eval.point[[2]], f0_2d$estimate, 
  col = colours["A"], axes = FALSE, xlab = "", ylab = "",
  xlim = x_lim, ylim = y_lim
)
axis(1)
axis(2)
# Group S=1
contour(
  f1_2d$eval.point[[1]], f1_2d$eval.point[[2]], f1_2d$estimate, 
  col = colours["B"], add = TRUE
)
contour(
  f1_2d_indep_inc_x1_then_x2$eval.point[[1]], 
  f1_2d_indep_inc_x1_then_x2$eval.point[[2]], 
  f1_2d_indep_inc_x1_then_x2$estimate, 
  col = colour_inc_x1_then_x2, add = TRUE
)
legend(
  "topleft", legend = c(
    "Obs S=0", "Obs S=1", 
    latex2exp::TeX("Seq. T.: Assuming $X_1 \\rightarrow X_2$")
  ),
  lty=1,
  col = c(colours["A"], colours["B"], colour_inc_x1_then_x2), bty="n"
)

# Group S=1, Sequential transport, Wrong DAG: assuming X1 depends on X2
# par(mar = c(2,2,0,0))
# Group S=0
contour(
  f0_2d$eval.point[[1]], f0_2d$eval.point[[2]], f0_2d$estimate, 
  col = colours["A"], axes = FALSE, xlab = "", ylab = "",
  xlim = x_lim, ylim = y_lim
)
axis(1)
axis(2)
# Group S=1
contour(
  f1_2d$eval.point[[1]], f1_2d$eval.point[[2]], f1_2d$estimate, 
  col = colours["B"], add = TRUE
)
contour(
  f1_2d_indep_inc_x2_then_x1$eval.point[[1]], 
  f1_2d_indep_inc_x2_then_x1$eval.point[[2]], 
  f1_2d_indep_inc_x2_then_x1$estimate, 
  col = colour_inc_x2_then_x1, add = TRUE
)
legend(
  "topleft", legend = c(
    "Obs S=0", "Obs S=1",
    latex2exp::TeX("Seq. T.: Assuming $X_2 \\rightarrow X_1$")
  ),
  lty=1,
  col = c(colours["A"], colours["B"], colour_inc_x2_then_x1), bty="n"
)
Figure 5.8: Estimated densities of \((X_1, X_2)\) using the factual values or the counterfactual values.