7  Multivariate Optimal Transport

Objectives

This chapter uses multivariate optimal transport (De Lara et al. (2024)) to make counterfactual inference.

In the article, we use three methods to create counterfactuals:

  1. Fairadapt (Chapter 6)
  2. Multivariate optimal transport (this chapter)
  3. Sequential transport (the methodology we develop in the paper, see Chapter 8).
Display the setting codes
# Required packages----
library(tidyverse)
library(fairadapt)

# Graphs----
font_main = font_title = 'Times New Roman'
extrafont::loadfonts(quiet = T)
face_text='plain'
face_title='plain'
size_title = 14
size_text = 11
legend_size = 11

global_theme <- function() {
  theme_minimal() %+replace%
    theme(
      text = element_text(family = font_main, size = size_text, face = face_text),
      legend.text = element_text(family = font_main, size = legend_size),
      axis.text = element_text(size = size_text, face = face_text), 
      plot.title = element_text(
        family = font_title, 
        size = size_title, 
        hjust = 0.5
      ),
      plot.subtitle = element_text(hjust = 0.5)
    )
}

# Seed
set.seed(2025)

source("../functions/utils.R")

7.1 Load Data and Classifier

We load the dataset where the sensitive attribute ((S)) is the race, obtained Chapter 4.3:

load("../data/df_race.rda")

We also load the dataset where the sensitive attribute is also the race, but where where the target variable ((Y), ZFYA) is binary (1 if the student obtained a standardized first year average over the median, 0 otherwise). This dataset was saved in Chapter 5.5:

load("../data/df_race_c.rda")

We also need the predictions made by the classifier (see Chapter 5):

# Predictions on train/test sets
load("../data/pred_aware.rda")
load("../data/pred_unaware.rda")
# Predictions on the factuals, on the whole dataset
load("../data/pred_aware_all.rda")
load("../data/pred_unaware_all.rda")

7.2 Counterfactuals with Multivariate Optimal Transport

We apply multivariate optimal transport (OT), following the methodology developed in De Lara et al. (2024). Note that with OT, it is not possible to handle new cases. Counterfactuals will only be calculated on the train set.

The codes are run in python. We use the {reticulate} R package to call python in this notebook.

library(reticulate)
# py_install("POT")

Some libraries need to be loaded (including POT called ot)

import ot
import pandas as pd
import numpy as np
import matplotlib.pyplot as pl
import ot.plot

The data with the factuals need to be loaded:

df_aware = pd.read_csv('../data/factuals_aware.csv')
df_unaware = pd.read_csv('../data/factuals_unaware.csv')
x_S = df_aware.drop(columns=['pred', 'type'])
x_S.head()
       S   X1    X2
0  White  3.1  39.0
1  White  3.0  36.0
2  White  3.1  30.0
3  White  3.4  37.0
4  White  3.6  30.5
x_white = x_S[x_S['S'] == 'White']
x_white = x_white.drop(columns=['S'])
x_black = x_S[x_S['S'] == 'Black']
x_black = x_black.drop(columns=['S'])

n_white = len(x_white)
n_black = len(x_black)
# Uniform weights
w_white = (1/n_white)*np.ones(n_white)
w_black = (1/n_black)*np.ones(n_black)

Cost matrix between both distributions:

x_white = x_white.to_numpy()
x_black = x_black.to_numpy()
C = ot.dist(x_white, x_black)
pl.figure(1)
pl.plot(x_white[:, 0], x_white[:, 1], '+b', label='Source samples')
pl.plot(x_black[:, 0], x_black[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')
Figure 7.1: Source and target distributions
pl.figure(2)
pl.imshow(C, interpolation='nearest')
pl.title('Cost matrix C')
Figure 7.2: Cost matric C

The transport plan: white –> black

pi_white_black = ot.emd(w_white, w_black, C, numItermax=1e8)
pi_black_white = pi_white_black.T
pi_white_black.shape
(18285, 1282)
sum_of_rows = np.sum(pi_white_black, axis=1)
sum_of_rows*n_white
array([1., 1., 1., ..., 1., 1., 1.])
pi_black_white.shape
(1282, 18285)
sum_of_rows = np.sum(pi_black_white, axis=1)
sum_of_rows*n_black
array([1., 1., 1., ..., 1., 1., 1.])
pl.figure(3)
pl.imshow(pi_white_black, interpolation='nearest')
pl.title('OT matrix pi_white_black')

pl.figure(4)
ot.plot.plot2D_samples_mat(x_white, x_black, pi_white_black, c=[.5, .5, 1])
pl.plot(x_white[:, 0], x_white[:, 1], '+b', label='Source samples')
pl.plot(x_black[:, 0], x_black[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix with samples')
Figure 7.3: OT matrix pi_white_black
transformed_x_white = n_white*pi_white_black@x_black

transformed_x_white.shape
(18285, 2)
transformed_x_white
array([[ 2.7, 31. ],
       [ 2.7, 28. ],
       [ 2.6, 21. ],
       ...,
       [ 3.9, 28. ],
       [ 2.5, 22. ],
       [ 3. , 19. ]])
transformed_x_black = n_black*pi_black_white@x_white
transformed_x_black.shape
(1282, 2)
transformed_x_black
array([[ 3.2       , 37.58851518],
       [ 3.28565491, 28.02103363],
       [ 2.95793273, 32.14022423],
       ...,
       [ 3.28597758, 33.        ],
       [ 2.65092152, 41.43910309],
       [ 2.75152858, 36.        ]])
counterfactual_x = x_S.drop(columns=['S'])
counterfactual_x[x_S['S'] == 'White'] = transformed_x_white
counterfactual_x[x_S['S'] == 'Black'] = transformed_x_black
counterfactual_x.head()
    X1    X2
0  2.7  31.0
1  2.7  28.0
2  2.6  21.0
3  3.1  28.0
4  3.2  21.0
counterfactual_x.shape
(19567, 2)

Lastly, we export the results in a CSV file:

csv_file_path = '../data/counterfactuals_ot.csv'
counterfactual_x.to_csv(csv_file_path, index=False)

Let us get back to R, and load the results.

counterfactuals_ot <- read_csv('../data/counterfactuals_ot.csv')

We add the sensitive attribute to the dataset (Black individuals become White, and conversely):

S_star <- df_race_c |> 
  mutate(
    S_star = case_when(
      S == "Black" ~ "White",
      S == "White" ~ "Black",
      TRUE ~ "Error"
    )
  ) |> 
  pull("S_star")

counterfactuals_ot <- counterfactuals_ot |> 
  mutate(S = S_star)

7.2.1 Unaware Model

Let us make prediction with the unaware model on the counterfactuals obtained with OT:

model_unaware <- pred_unaware$model
pred_unaware_ot <- predict(
  model_unaware, newdata = counterfactuals_ot, type = "response"
)
counterfactuals_unaware_ot <- counterfactuals_ot |> 
  mutate(pred = pred_unaware_ot, type = "counterfactual")

We put in a table the initial characteristics (factuals) and the prediction made by the unaware model:

factuals_unaware <- tibble(
  S = df_race$S,
  X1 = df_race$X1,
  X2 = df_race$X2,
  pred = pred_unaware_all,
  type = "factual"
)

We bind the factuals and counterfactuals with their respective predicted values in a single dataset:

unaware_ot <- bind_rows(
  # predicted values on factuals
  factuals_unaware |> mutate(type = "factual"), 
  # predicted values on counterfactuals obtained with OT
  counterfactuals_unaware_ot
)

Then, we can visualize the distribution of the values predicted by the unaware model within each group defined by the sensitive attribute.

unaware_ot_white <- unaware_ot |>  filter(S == "White") 
unaware_ot_black <- unaware_ot |>  filter(S == "Black")
ggplot(
  data = unaware_ot_black, 
  mapping = aes(x = pred, fill = type)
) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5) +
  labs(
    title = "Unaware model, Sensitive: Race, Reference: Black individual",
       x = "Predictions for Y",
       y = "Density"
  ) +
  global_theme()
Figure 7.4: Unaware model, Sensitive: Race, Reference: Black individuals
ggplot(
  data = unaware_ot_white, 
  mapping = aes(x = pred, fill = type)) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5) +
  labs(
    title = "Unaware model, Sensitive: Race, Reference: White",
       x = "Predictions for Y",
       y = "Density") +
  global_theme()
Figure 7.5: Unaware model, Sensitive: Race, Reference: White individuals

7.2.2 Aware Model

Let us make prediction with the aware model on the counterfactuals obtained with OT:

model_aware <- pred_aware$model
pred_aware_ot <- predict(
  model_aware, newdata = counterfactuals_ot, type = "response"
)
counterfactuals_aware_ot <- counterfactuals_ot |>  
  mutate(pred = pred_aware_ot, type = "counterfactual")
counterfactuals_aware_ot
# A tibble: 19,567 × 5
      X1    X2 S       pred type          
   <dbl> <dbl> <chr>  <dbl> <chr>         
 1   2.7  31   Black 0.141  counterfactual
 2   2.7  28   Black 0.120  counterfactual
 3   2.6  21   Black 0.0791 counterfactual
 4   3.1  28   Black 0.143  counterfactual
 5   3.2  21   Black 0.104  counterfactual
 6   3.3  27.5 Black 0.152  counterfactual
 7   2.4  29   Black 0.111  counterfactual
 8   2.3  29   Black 0.106  counterfactual
 9   3.3  22   Black 0.115  counterfactual
10   2.8  34   Black 0.171  counterfactual
# ℹ 19,557 more rows

We put in a table the initial characteristics (factuals) and the prediction made by the aware model:

factuals_aware <- tibble(
  S = df_race$S,
  X1 = df_race$X1,
  X2 = df_race$X2,
  pred = pred_aware_all,
  type = "factual"
)

We bind the factuals and counterfactuals with their respective predicted values in a single dataset:

aware_ot <- bind_rows(
  factuals_aware, 
  counterfactuals_aware_ot
)

Then, we can visualize the distribution of the values predicted by the unaware model within each group defined by the sensitive attribute.

aware_ot_white <- aware_ot |> filter(S == "White") 
aware_ot_black <- aware_ot |>  filter(S == "Black")
ggplot(
  data = aware_ot_black, 
  mapping = aes(x = pred, fill = type)
) +
  geom_histogram(
    mapping = aes(y = ..density..), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5) +
  labs(
    title = "Aware model, Sensitive: Race, Reference: Black individual",
       x = "Predictions for Y",
       y = "Density"
  ) +
  global_theme()
Warning: The dot-dot notation (`..density..`) was deprecated in ggplot2 3.4.0.
ℹ Please use `after_stat(density)` instead.
Figure 7.6: Aware model, Sensitive: Race, Reference: Black individuals
ggplot(
  data = aware_ot_white, 
  mapping = aes(x = pred, fill = type)) +
  geom_histogram(
    mapping = aes(y = ..density..), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5) +
  labs(
    title = "Aware model, Sensitive: Race, Reference: White",
       x = "Predictions for Y",
       y = "Density") +
  global_theme()
Figure 7.7: Aware model, Sensitive: Race, Reference: White individuals

7.3 Comparison for Two Individuals

Let us, again, focus on two individuals: 24 (Black) and 25 (White):

(indiv_factuals_unaware <- factuals_unaware[24:25, ])
# A tibble: 2 × 5
  S        X1    X2  pred type   
  <fct> <dbl> <dbl> <dbl> <chr>  
1 Black   2.8    29 0.300 factual
2 White   2.8    34 0.382 factual

The counterfactuals for those individuals, using the unaware model:

(indiv_counterfactuals_unaware_ot <- counterfactuals_unaware_ot[24:25, ])
# A tibble: 2 × 5
     X1    X2 S      pred type          
  <dbl> <dbl> <chr> <dbl> <chr>         
1  3.20  37.6 White 0.502 counterfactual
2  2.4   25   Black 0.203 counterfactual

Let us put the factuals and counterfactuals in a single table:

indiv_unaware_ot <- bind_rows(
  indiv_factuals_unaware |> mutate(id = c(24, 25)),
  indiv_counterfactuals_unaware_ot |> mutate(id = c(24, 25))
)
indiv_unaware_ot
# A tibble: 4 × 6
  S        X1    X2  pred type              id
  <chr> <dbl> <dbl> <dbl> <chr>          <dbl>
1 Black  2.8   29   0.300 factual           24
2 White  2.8   34   0.382 factual           25
3 White  3.20  37.6 0.502 counterfactual    24
4 Black  2.4   25   0.203 counterfactual    25

We compute the difference between the predicted value by the unaware model using the counterfactuals and the predicted value by the unaware model using the factuals:

indiv_unaware_ot |> select(id , type, pred) |> 
  pivot_wider(names_from = type, values_from = pred) |> 
  mutate(diff = counterfactual - factual)
# A tibble: 2 × 4
     id factual counterfactual   diff
  <dbl>   <dbl>          <dbl>  <dbl>
1    24   0.300          0.502  0.202
2    25   0.382          0.203 -0.179

We do the same for the aware model:

indiv_aware_ot <- bind_rows(
  factuals_aware[c(24, 25),] |> mutate(id = c(24, 25)),
  counterfactuals_aware_ot[c(24, 25),] |> mutate(id = c(24, 25))
)
indiv_aware_ot
# A tibble: 4 × 6
  S        X1    X2   pred type              id
  <chr> <dbl> <dbl>  <dbl> <chr>          <dbl>
1 Black  2.8   29   0.133  factual           24
2 White  2.8   34   0.413  factual           25
3 White  3.20  37.6 0.515  counterfactual    24
4 Black  2.4   25   0.0898 counterfactual    25

The difference between the counterfactual and the factual for these two individuals, when using the aware model:

indiv_aware_ot |> select(id , type, pred) |> 
  pivot_wider(names_from = type, values_from = pred) |> 
  mutate(diff = counterfactual - factual)
# A tibble: 2 × 4
     id factual counterfactual   diff
  <dbl>   <dbl>          <dbl>  <dbl>
1    24   0.133         0.515   0.383
2    25   0.413         0.0898 -0.323

7.4 Counterfactual Demographic Parity

As for the counterfactuals obtained with fairadapt, we assume here that the reference group is “White individuals” (i.e., the group with the most individuals in the dataset). We focus on the minority, i.e., Black individuals. We consider here that the model is fair towards the minority class if: \[ P(\hat{Y}_{S \leftarrow \text{White}} = 1 | S = \text{Black}, X_1, X_2) = P(\hat{Y} = 1 | S = \text{White}, X_1, X_2) \] If the model is fair with respect to this criterion, the proportion of Black individuals predicted to have grades above the median should be the same as if they had been white.

For predictions made with the unaware model:

dp_unaware_pt <- mean(
  counterfactuals_unaware_ot |> filter(S == "White") |> pull("pred") - 
    factuals_unaware |> filter(S == "Black") |> pull("pred")
)
dp_unaware_pt
[1] 0.1821212

We do the same with the aware model:

dp_aware_ot <- mean(
  counterfactuals_aware_ot |> filter(S == "White") |> pull("pred") - 
    factuals_aware |> filter(S == "Black") |> pull("pred")
)
dp_aware_ot
[1] 0.3726591

7.5 Saving Objects

save(
  counterfactuals_unaware_ot,
  file = "../data/counterfactuals_unaware_ot.rda"
)
save(
  counterfactuals_aware_ot, 
  file = "../data/counterfactuals_aware_ot.rda"
)