8  Sequential Transport

Objectives

This chapter uses sequential transport (developped in our paper) to make counterfactual inference.

In the article, we use three methods to create counterfactuals:

  1. Fairadapt (Chapter 6)
  2. Multivariate optimal transport (Chapter 7)
  3. Sequential transport (the methodology we develop in the paper, presented in this chapter).
Display the setting codes
# Required packages----
library(tidyverse)
library(fairadapt)

# Graphs----
font_main = font_title = 'Times New Roman'
extrafont::loadfonts(quiet = T)
face_text='plain'
face_title='plain'
size_title = 14
size_text = 11
legend_size = 11

global_theme <- function() {
  theme_minimal() %+replace%
    theme(
      text = element_text(family = font_main, size = size_text, face = face_text),
      legend.text = element_text(family = font_main, size = legend_size),
      axis.text = element_text(size = size_text, face = face_text), 
      plot.title = element_text(
        family = font_title, 
        size = size_title, 
        hjust = 0.5
      ),
      plot.subtitle = element_text(hjust = 0.5)
    )
}

# Seed
set.seed(2025)

source("../functions/utils.R")

8.1 Load Data and Classifier

We load the dataset where the sensitive attribute ((S)) is the race, obtained Chapter 4.3:

load("../data/df_race.rda")

We also load the dataset where the sensitive attribute is also the race, but where where the target variable ((Y), ZFYA) is binary (1 if the student obtained a standardized first year average over the median, 0 otherwise). This dataset was saved in Chapter 5.5:

load("../data/df_race_c.rda")

We also need the predictions made by the classifier (see Chapter 5):

# Predictions on train/test sets
load("../data/pred_aware.rda")
load("../data/pred_unaware.rda")
# Predictions on the factuals, on the whole dataset
load("../data/pred_aware_all.rda")
load("../data/pred_unaware_all.rda")

8.2 Counterfactuals with Sequential Transport

We now turn to sequential transport (the methodology developed in our paper). We define a function, transport_function() (see in functions/utils.R) to perform a fast sequential transport on causal graph.

#' Sequential transport
#'
#' @param data dataset with three columns:
#'  - S: sensitive attribute
#'  - X1: first predictor, assumed to be causally linked to S
#'  - X2: second predictor, assumed to be causally linked to S and X1
#' @param S_0 value for the sensitive attribute for the source distribution
#' @param number of cells in each dimension (default to 15)
#' @param h small value added to extend the area covered by the grid (default
#'  to .2)
#' @param d neighborhood weight when conditioning by x1 (default to .5)
transport_function <- function(data,
                               S_0,
                               n_grid = 15,
                               h = .2,
                               d = .5
                               ) {

  # Subset of the data: 0 for Black, 1 for White
  D_SXY_0 <- data[data$S ==S_0, ]
  D_SXY_1 <- data[data$S!= S_0, ]

  # Coordinates of the cells of the grid on subset of 0 (Black)
  vx1_0 <- seq(min(D_SXY_0$X1) - h, max(D_SXY_0$X1) + h, length = n_grid + 1)
  vx2_0 <- seq(min(D_SXY_0$X2) - h, max(D_SXY_0$X2) + h, length = n_grid + 1)
  # and middle point of the cells
  vx1_0_mid <- (vx1_0[2:(1+n_grid)]+vx1_0[1:(n_grid)]) / 2
  vx2_0_mid <- (vx2_0[2:(1+n_grid)]+vx2_0[1:(n_grid)]) / 2

  # Coordinates of the cells of the grid on subset of 1 (White)
  vx1_1 <- seq(min(D_SXY_1$X1) -h, max(D_SXY_1$X1) + h, length = n_grid + 1)
  vx1_1_mid <- (vx1_1[2:(1 + n_grid)] + vx1_1[1:(n_grid)]) / 2
  # and middle point of the cells
  vx2_1 <- seq(min(D_SXY_1$X2) - h, max(D_SXY_1$X2) + h, length = n_grid + 1)
  vx2_1_mid <- (vx2_1[2:(1 + n_grid)] + vx2_1[1:(n_grid)]) / 2

  # Creation of the grids for the CDF and Quantile function
  # init with NA values
  # One grid for X1 and X2, on both subsets of the data (Black/White)
  F1_0 <- F2_0 <- F1_1 <- F2_1 <- matrix(NA, n_grid, n_grid)
  Q1_0 <- Q2_0 <- Q1_1 <- Q2_1 <- matrix(NA, n_grid, n_grid)

  # Empirical CDF for X1 on subset of Black
  FdR1_0 <- Vectorize(function(x) mean(D_SXY_0$X1 <= x))
  f1_0 <- FdR1_0(vx1_0_mid)
  # Empirical CDF for X2 on subset of Black
  FdR2_0 <- Vectorize(function(x) mean(D_SXY_0$X2 <= x))
  f2_0 <- FdR2_0(vx2_0_mid)
  # Empirical CDF for X1 on subset of White
  FdR1_1 <- Vectorize(function(x) mean(D_SXY_1$X1 <= x))
  f1_1 <- FdR1_1(vx1_1_mid)
  # Empirical CDF for X2 on subset of White
  FdR2_1 <- Vectorize(function(x) mean(D_SXY_1$X2 <= x))
  f2_1 <- FdR2_1(vx2_1_mid)

  u <- (1:n_grid) / (n_grid + 1)
  # Empirical quantiles for X1 on subset of Black
  Qtl1_0 <- Vectorize(function(x) quantile(D_SXY_0$X1, x))
  q1_0 <- Qtl1_0(u)
  # Empirical quantiles for X2 on subset of Black
  Qtl2_0 <- Vectorize(function(x) quantile(D_SXY_0$X2, x))
  q2_0 <- Qtl2_0(u)
  # Empirical quantiles for X1 on subset of White
  Qtl1_1 <- Vectorize(function(x) quantile(D_SXY_1$X1, x))
  q1_1 <- Qtl1_1(u)
  # Empirical quantiles for X2 on subset of White
  Qtl2_1 <- Vectorize(function(x) quantile(D_SXY_1$X2, x))
  q2_1 <- Qtl2_1(u)

  # Compute c.d.f and quantile at each cell of the grid in both groups
  for(i in 1:n_grid) {
    # Subset of Black
    idx1_0 <- which(abs(D_SXY_0$X1 - vx1_0_mid[i]) < d)
    FdR2_0 <- Vectorize(function(x) mean(D_SXY_0$X2[idx1_0] <= x))
    F2_0[, i] <- FdR2_0(vx2_0_mid)
    Qtl2_0 <- Vectorize(function(x) quantile(D_SXY_0$X2[idx1_0], x))
    Q2_0[, i] <- Qtl2_0(u)

    idx2_0 <- which(abs(D_SXY_0$X2 - vx2_0_mid[i]) < d)
    FdR1_0 <- Vectorize(function(x) mean(D_SXY_0$X1[idx2_0] <= x))
    F1_0[, i] <- FdR1_0(vx1_0_mid)
    Qtl1_0 <- Vectorize(function(x) quantile(D_SXY_0$X1[idx2_0], x))
    Q1_0[, i] <- Qtl1_0(u)

    # Subset of White
    idx1_1 <- which(abs(D_SXY_1$X1 - vx1_1_mid[i]) < d)
    FdR2_1 <- Vectorize(function(x) mean(D_SXY_1$X2[idx1_1] <= x))
    F2_1[, i] <- FdR2_1(vx2_1_mid)
    Qtl2_1 <- Vectorize(function(x) quantile(D_SXY_1$X2[idx1_1], x))
    Q2_1[, i] <- Qtl2_1(u)

    idx2_1 <- which(abs(D_SXY_1$X2-vx2_1_mid[i])<d)
    FdR1_1 <- Vectorize(function(x) mean(D_SXY_1$X1[idx2_1] <= x))
    F1_1[, i] <- FdR1_1(vx1_1_mid)
    Qtl1_1 <- Vectorize(function(x) quantile(D_SXY_1$X1[idx2_1], x))
    Q1_1[, i] <- Qtl1_1(u)
  }

  # Transport for X2
  T2 <- function(x2) {
    i <- which.min(abs(vx2_0_mid - x2))
    p <- f2_0[i]
    i <- which.min(abs(u - p))
    x2star <- q2_1[i]
    x2star
  }

  # Transport for X1
  T1 <- function(x1) {
    i <- which.min(abs(vx1_0_mid - x1))
    p <- f1_0[i]
    i <- which.min(abs(u - p))
    x1star <- q1_1[i]
    x1star
  }

  # Transport for X2 conditional on X1
  T2_cond_x1 <- function(x2, x1) {
    k0 <- which.min(abs(vx1_0_mid - x1))
    k1 <- which.min(abs(vx1_1_mid - T1(x1)))
    i <- which.min(abs(vx2_0_mid - x2))
    p <- F2_0[i, k0]
    i <- which.min(abs(u - p))
    x2star <- Q2_1[i, k1]
    x2star
  }

  # Transport for X1 conditional on X2
  T1_cond_x2 <- function(x1, x2) {
    k0 <- which.min(abs(vx2_0_mid - x2))
    k1 <- which.min(abs(vx2_1_mid - T2(x2)))
    i <- which.min(abs(vx1_0_mid - x1))
    p <- F1_0[i, k0]
    i <- which.min(abs(u - p))
    x1star <- Q1_1[i, k1]
    x1star
  }

  list(
    Transport_x1 = T1,
    Transport_x2 = T2,
    Transport_x1_cond_x2 = T1_cond_x2,
    Transport_x2_cond_x1 = T2_cond_x1
  )
}
Note

The transport_function() function returns not only the functions Transport_x1(), Transport_x2(), Transport_x1_cond_x2(), Transport_x2_cond_x1(), but also the useful values of the grid (e.g., vx1_0_mid defined in the environment of the function and used in the functions). Note that defining a global object named vx1_0_mid will not alter the object of the same name defined in the environment of transport_function(): R will call the vx1_0_mid from that environment and not the one that may be defined in the global environment.

Let us apply this function. Note that we use a grid of length 500 to fasten the computation of sequential transport (the estimation takes about 45 seconds on a standard computer). But first, we create a dataset, df_race_c_light, with the sensitive attribute and the two characteristics only:

df_race_c_light <- df_race_c |> select(S, X1, X2)
ind_white <- which(df_race_c_light$S == "White")
ind_black <- which(df_race_c_light$S == "Black")
seq_functions <- transport_function(
  data = df_race_c_light, S_0 = "Black", n_grid = 500
)

Let us extract the transport functions to transport \(X_1\) and \(X_2\):

T_X1 <- seq_functions$Transport_x1
T_X2 <- seq_functions$Transport_x2

We also do the same with the transport of \(X_2\) conditional on \(X_1\):

T_X2_c_X1 <- seq_functions$Transport_x2_cond_x1

Now, we can apply these functions to the subset of Black individuals to sequentially transport \(X_1\) (UGPA) and then \(X_2\) (LSAT) conditional on the transported value of \(X_1\):

The values of \(X_1\) and \(X_2\) for Black individuals:

a10 <- df_race_c_light$X1[ind_black]
a20 <- df_race_c_light$X2[ind_black]

The transported values:

x1_star <- map_dbl(a10, T_X1) # Transport X1 to group S=White
x2_star <- map2_dbl(a20, a10, T_X2_c_X1) # Transport X2|X1 to group S=White

We build a dataset with the sensitive attribute of Black individuals changed to white, and their characteristics changed to their transported characteristics:

df_counterfactuals_seq_black <- 
  df_race_c_light |> mutate(id = row_number()) |> 
  filter(S == "Black") |> 
  mutate(
    S = "White",
    X1 = x1_star,
    X2 = x2_star
  )

We make predictions based on those counterfactuals obrained with sequential transport, on both models (the unaware model, and the aware model):

model_unaware <- pred_unaware$model
pred_seq_unaware <- predict(
  model_unaware, newdata = df_counterfactuals_seq_black,type = "response"
)

model_aware <- pred_aware$model
pred_seq_aware <- predict(
  model_aware, newdata = df_counterfactuals_seq_black,type = "response"
)
counterfactuals_unaware_seq_black <- 
  df_counterfactuals_seq_black |> 
  mutate(pred = pred_seq_unaware, type = "counterfactual")
counterfactuals_aware_seq_black <- 
  df_counterfactuals_seq_black |> 
  mutate(pred = pred_seq_aware, type = "counterfactual")

We create a tibble with the factuals and the predictions by the aware model, an another with the predictions by the unaware model:

factuals_aware <- tibble(
  S = df_race$S,
  X1 = df_race$X1,
  X2 = df_race$X2,
  pred = pred_aware_all,
  type = "factual"
)

factuals_unaware <- tibble(
  S = df_race$S,
  X1 = df_race$X1,
  X2 = df_race$X2,
  pred = pred_unaware_all,
  type = "factual"
)

Let us put in a single table the predictions made by the classifier (either aware or unaware) on black individuals based on their factual characteristics, and those made based on the counterfactuals:

aware_seq_black <- bind_rows(
  factuals_aware |> mutate(id = row_number()) |> filter(S == "Black"), 
  counterfactuals_aware_seq_black
)
unaware_seq_black <- bind_rows(
  factuals_unaware |> mutate(id = row_number()) |> filter(S == "Black"), 
  counterfactuals_aware_seq_black)
ggplot(
  data = unaware_seq_black, 
  mapping = aes(x = pred, fill = type)
) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5) +
  labs(
    title = "Unware model, S: Race - Black --> White",
       x = "Predictions for Y",
       y = "Density") +
  global_theme()
Figure 8.1: Unaware model, Sensitive: Race, Reference: White individuals
ggplot(
  data = aware_seq_black, 
  mapping = aes(x = pred, fill = type)
) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5) +
  labs(
    title = "Aware model, S: Race - Black --> White",
       x = "Predictions for Y",
       y = "Density") +
  global_theme()
Figure 8.2: Aware model, Sensitive: Race, Reference: White individuals

8.3 Comparison for Two Individuals

Let us focus on the first three Black individuals of the dataset.

factuals_unaware |> mutate(id = row_number()) |> 
  filter(S == "Black") |> 
  dplyr::slice(1:3)
# A tibble: 3 × 6
  S        X1    X2  pred type       id
  <fct> <dbl> <dbl> <dbl> <chr>   <int>
1 Black   2.8    29 0.300 factual    24
2 Black   3.2    19 0.206 factual    40
3 Black   2.6    23 0.198 factual    51

Their characteristics after sequential transport (and the predicted value with the unaware model):

indiv_counterfactuals_unaware_seq <- counterfactuals_unaware_seq_black[c(1:3), ]
indiv_counterfactuals_unaware_seq
# A tibble: 3 × 6
  S        X1    X2    id  pred type          
  <chr> <dbl> <dbl> <int> <dbl> <chr>         
1 White   3.2  37      24 0.491 counterfactual
2 White   3.5  28.5    40 0.381 counterfactual
3 White   3.1  31.5    51 0.379 counterfactual
indiv_unaware_seq <- bind_rows(
  factuals_unaware |> 
    mutate(id = row_number()) |> 
    filter(S == "Black") |> 
    dplyr::slice(1:3),
  indiv_counterfactuals_unaware_seq
)
indiv_unaware_seq
# A tibble: 6 × 6
  S        X1    X2  pred type              id
  <chr> <dbl> <dbl> <dbl> <chr>          <int>
1 Black   2.8  29   0.300 factual           24
2 Black   3.2  19   0.206 factual           40
3 Black   2.6  23   0.198 factual           51
4 White   3.2  37   0.491 counterfactual    24
5 White   3.5  28.5 0.381 counterfactual    40
6 White   3.1  31.5 0.379 counterfactual    51

And with the aware model:

indiv_counterfactuals_aware_seq <- counterfactuals_aware_seq_black[c(1:3), ]
indiv_counterfactuals_aware_seq
# A tibble: 3 × 6
  S        X1    X2    id  pred type          
  <chr> <dbl> <dbl> <int> <dbl> <chr>         
1 White   3.2  37      24 0.507 counterfactual
2 White   3.5  28.5    40 0.418 counterfactual
3 White   3.1  31.5    51 0.413 counterfactual
indiv_aware_seq <- bind_rows(
  factuals_aware |> 
    mutate(id = row_number()) |> 
    filter(S == "Black") |> 
    dplyr::slice(1:3),
  indiv_counterfactuals_aware_seq
)
indiv_aware_seq
# A tibble: 6 × 6
  S        X1    X2   pred type              id
  <chr> <dbl> <dbl>  <dbl> <chr>          <int>
1 Black   2.8  29   0.133  factual           24
2 Black   3.2  19   0.0933 factual           40
3 Black   2.6  23   0.0882 factual           51
4 White   3.2  37   0.507  counterfactual    24
5 White   3.5  28.5 0.418  counterfactual    40
6 White   3.1  31.5 0.413  counterfactual    51

8.4 Counterfactual Demographic Parity

For the unaware model:

mean(
  counterfactuals_unaware_seq_black$pred -
    factuals_unaware |> filter(S == "Black") |> pull("pred")
)
[1] 0.1816625

For the aware model:

mean(
  counterfactuals_aware_seq_black$pred - 
    factuals_aware |> filter(S == "Black") |> pull("pred")
)
[1] 0.3722886

8.5 Saving Objects

save(
  counterfactuals_unaware_seq_black, 
  file = "../data/counterfactuals_unaware_seq_black.rda"
)
save(
  counterfactuals_aware_seq_black, 
  file = "../data/counterfactuals_aware_seq_black.rda"
)