10  Multivariate Optimal Transport

Objectives

This chapter uses multivariate optimal transport (De Lara et al. (2024)) to make counterfactual inference. We obtain counterfactual values \(\boldsymbol{x}^\star\) for individuals from the protected group \(S=0\). Then, we use the aware and unaware classifiers \(m(\cdot)\) (see Chapter 7) to make new predictions \(m(s=1,\boldsymbol{x}^\star)\) for individuals in the protected class, i.e., observations in \(\mathcal{D}_0\).

In the article, we use three methods to create counterfactuals:

  1. Naive approach (Chapter 8)
  2. Fairadapt (Chapter 9)
  3. Multivariate optimal transport (this chapter)
  4. Sequential transport (the methodology we develop in the paper, see Chapter 11).
Required packages and definition of colours.
# Required packages----
library(tidyverse)
library(fairadapt)

# Graphs----
font_main = font_title = 'Times New Roman'
extrafont::loadfonts(quiet = T)
face_text='plain'
face_title='plain'
size_title = 14
size_text = 11
legend_size = 11

global_theme <- function() {
  theme_minimal() %+replace%
    theme(
      text = element_text(family = font_main, size = size_text, face = face_text),
      legend.text = element_text(family = font_main, size = legend_size),
      axis.text = element_text(size = size_text, face = face_text), 
      plot.title = element_text(
        family = font_title, 
        size = size_title, 
        hjust = 0.5
      ),
      plot.subtitle = element_text(hjust = 0.5),
      legend.position = "bottom"
    )
}

# Seed
set.seed(2025)

library(devtools)
load_all("../seqtransfairness/")

colours_all <- c(
  "source" = "#00A08A",
  "reference" = "#F2AD00",
  "naive" = "gray",
  "fairadapt" = '#D55E00',
  "ot" = "#56B4E9"
)

10.1 Load Data and Classifier

We load the dataset where the sensitive attribute (\(S\)) is the race, obtained Chapter 6.3:

load("../data/df_race.rda")

We also load the dataset where the sensitive attribute is also the race, but where where the target variable (\(Y\), ZFYA) is binary (1 if the student obtained a standardized first year average over the median, 0 otherwise). This dataset was saved in Chapter 7.5:

load("../data/df_race_c.rda")

We also need the predictions made by the classifier (see Chapter 7):

# Predictions on train/test sets
load("../data/pred_aware.rda")
load("../data/pred_unaware.rda")
# Predictions on the factuals, on the whole dataset
load("../data/pred_aware_all.rda")
load("../data/pred_unaware_all.rda")

10.2 Counterfactuals with Multivariate Optimal Transport

We apply multivariate optimal transport (OT), following the methodology developed in De Lara et al. (2024). Note that with OT, it is not possible to handle new cases. Counterfactuals will only be calculated on the train set.

The codes are run in python. We use the {reticulate} R package to call python in this notebook.

library(reticulate)
use_virtualenv("~/quarto-python-env", required = TRUE)
# reticulate::install_miniconda(force = TRUE)
# py_install("POT")

Some libraries need to be loaded (including POT called ot)

import ot
import pandas as pd
import numpy as np
import matplotlib.pyplot as pl
import ot.plot

The data with the factuals need to be loaded:

df_aware = pd.read_csv('../data/factuals_aware.csv')
df_unaware = pd.read_csv('../data/factuals_unaware.csv')
x_S = df_aware.drop(columns=['pred', 'type', 'id_indiv'])
x_S.head()
       S   X1    X2
0  White  3.1  39.0
1  White  3.0  36.0
2  White  3.1  30.0
3  White  3.4  37.0
4  White  3.6  30.5
x_white = x_S[x_S['S'] == 'White']
x_white = x_white.drop(columns=['S'])
x_black = x_S[x_S['S'] == 'Black']
x_black = x_black.drop(columns=['S'])

n_white = len(x_white)
n_black = len(x_black)
# Uniform weights
w_white = (1/n_white)*np.ones(n_white)
w_black = (1/n_black)*np.ones(n_black)

Cost matrix between both distributions:

x_white = x_white.to_numpy()
x_black = x_black.to_numpy()
C = ot.dist(x_white, x_black)
pl.figure(1)
pl.plot(x_white[:, 0], x_white[:, 1], '+b', label='Source samples')
pl.plot(x_black[:, 0], x_black[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')
Figure 10.1: Source and target distributions
pl.figure(2)
pl.imshow(C, interpolation='nearest')
pl.title('Cost matrix C')
Figure 10.2: Cost matric C

The transport plan: white –> black

pi_white_black = ot.emd(w_white, w_black, C, numItermax=1e8)
pi_black_white = pi_white_black.T
pi_white_black.shape
(18285, 1282)
sum_of_rows = np.sum(pi_white_black, axis=1)
sum_of_rows*n_white
array([1., 1., 1., ..., 1., 1., 1.], shape=(18285,))
pi_black_white.shape
(1282, 18285)
sum_of_rows = np.sum(pi_black_white, axis=1)
sum_of_rows*n_black
array([1., 1., 1., ..., 1., 1., 1.], shape=(1282,))
pl.figure(3)
pl.imshow(pi_white_black, interpolation='nearest')
pl.title('OT matrix pi_white_black')

pl.figure(4)
ot.plot.plot2D_samples_mat(x_white, x_black, pi_white_black, c=[.5, .5, 1])
pl.plot(x_white[:, 0], x_white[:, 1], '+b', label='Source samples')
pl.plot(x_black[:, 0], x_black[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix with samples')
Figure 10.3: OT matrix pi_white_black
transformed_x_white = n_white*pi_white_black@x_black

transformed_x_white.shape
(18285, 2)
transformed_x_white
array([[ 2.7, 31. ],
       [ 2.7, 28. ],
       [ 2.6, 21. ],
       ...,
       [ 3.9, 28. ],
       [ 2.5, 22. ],
       [ 3. , 19. ]], shape=(18285, 2))
transformed_x_black = n_black*pi_black_white@x_white
transformed_x_black.shape
(1282, 2)
transformed_x_black
array([[ 3.2       , 37.58851518],
       [ 3.28565491, 28.02103363],
       [ 2.95793273, 32.14022423],
       ...,
       [ 3.28597758, 33.        ],
       [ 2.65092152, 41.43910309],
       [ 2.75152858, 36.        ]], shape=(1282, 2))
counterfactual_x = x_S.drop(columns=['S'])
counterfactual_x[x_S['S'] == 'White'] = transformed_x_white
counterfactual_x[x_S['S'] == 'Black'] = transformed_x_black
counterfactual_x.head()
    X1    X2
0  2.7  31.0
1  2.7  28.0
2  2.6  21.0
3  3.1  28.0
4  3.2  21.0
counterfactual_x.shape
(19567, 2)
transformed_x_white
array([[ 2.7, 31. ],
       [ 2.7, 28. ],
       [ 2.6, 21. ],
       ...,
       [ 3.9, 28. ],
       [ 2.5, 22. ],
       [ 3. , 19. ]], shape=(18285, 2))

Lastly, we export the results in a CSV file:

csv_file_path = '../data/counterfactuals_ot.csv'
counterfactual_x.to_csv(csv_file_path, index=False)

Let us get back to R, and load the results.

counterfactuals_ot <- read_csv('../data/counterfactuals_ot.csv') |> 
  mutate(id_indiv = row_number())

We add the sensitive attribute to the dataset (Black individuals become White, and conversely):

S_star <- df_race_c |> 
  mutate(
    S_star = case_when(
      S == "Black" ~ "White",
      S == "White" ~ "Black",
      TRUE ~ "Error"
    )
  ) |> 
  pull("S_star")

counterfactuals_ot <- counterfactuals_ot |> 
  mutate(
    S_origin = df_race_c$S,
    S = S_star
  )
counterfactuals_ot_black <- 
  counterfactuals_ot |> filter(S_origin == "Black") |> 
  bind_rows(
    df_race_c |> select(-Y) |> 
      mutate(
        id_indiv = row_number(),
        S_origin = S,
        ) |> 
      filter(S == "White")
  ) |> 
  arrange(id_indiv)
counterfactuals_ot_white <- 
  counterfactuals_ot |> filter(S_origin == "White") |> 
  bind_rows(
    df_race_c |> select(-Y) |> 
      mutate(
        id_indiv = row_number(),
        S_origin = S,
        ) |> 
      filter(S == "Black")
  ) |> 
  arrange(id_indiv)

We consider Black individuals (minority group) to be the source group. Let us make prediction with the unaware model, then with the aware model on the counterfactuals obtained with OT.

model_unaware <- pred_unaware$model
pred_unaware_ot_black <- predict(
  model_unaware, newdata = counterfactuals_ot_black, type = "response"
)
counterfactuals_unaware_ot_black <- counterfactuals_ot_black |> 
  mutate(pred = pred_unaware_ot_black, type = "counterfactual")

If, instead, the source group is White:

pred_unaware_ot_white <- predict(
  model_unaware, newdata = counterfactuals_ot_white, type = "response"
)
counterfactuals_unaware_ot_white <- counterfactuals_ot_white |> 
  mutate(pred = pred_unaware_ot_white, type = "counterfactual")

With the aware model, if Black is the source group:

model_aware <- pred_aware$model
pred_aware_ot_black <- predict(
  model_aware, newdata = counterfactuals_ot_black, type = "response"
)
counterfactuals_aware_ot_black <- counterfactuals_ot_black |> 
  mutate(pred = pred_aware_ot_black, type = "counterfactual")

If, instead, the source group is White:

pred_aware_ot_white <- predict(
  model_aware, newdata = counterfactuals_ot_white, type = "response"
)
counterfactuals_aware_ot_white <- counterfactuals_ot_white |> 
  mutate(pred = pred_aware_ot_white, type = "counterfactual")

Next, we can compare predicted scores for both type of model and on the source group.

10.2.1 Unaware Model

The predicted values using the initial characteristics (the factuals), for the unaware model are stored in the object pred_unaware_all. We put in a table the initial characteristics (factuals) and the prediction made by the unaware model:

factuals_unaware <- tibble(
  S = df_race_c$S,
  S_origin = df_race_c$S,
  X1 = df_race_c$X1,
  X2 = df_race_c$X2,
  Y = df_race_c$Y,
  pred = pred_unaware_all,
  type = "factual"
) |> 
  mutate(id_indiv = row_number())

We bind together the predictions made with the observed values and those made with the counterfactual values.

unaware_ot_black <- 
  factuals_unaware |> mutate(S_origin = S) |> 
  bind_rows(counterfactuals_unaware_ot_black)

unaware_ot_white <- 
  factuals_unaware |> mutate(S_origin = S) |> 
  bind_rows(counterfactuals_unaware_ot_white)

Now, we can visualize the distribution of the values predicted by the unaware model within each group defined by the sensitive attribute.

Codes used to create the Figure.
ggplot(
  data = unaware_ot_black |> 
    mutate(
      group = case_when(
        S_origin == "Black" & S == "Black" ~ "Black (Original)",
        S_origin == "Black" & S == "White" ~ "Black -> White (Counterfactual)",
        S_origin == "White" & S == "White" ~ "White (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "Black (Original)", "Black -> White (Counterfactual)", "White (Original)"
        )
      )
    ),
  aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(
      y = after_stat(density)), alpha = 0.5, colour = NA,
    position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.3, linewidth = 1) +
  facet_wrap(~S) +
  scale_fill_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  labs(
    title = "Unaware model, Sensitive: Race, Reference: White individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme() +
  theme(legend.position = "bottom")
Figure 10.4: Unaware model, Sensitive: Race, Reference: White individuals
Codes used to create the Figure.
ggplot(
  data = unaware_ot_white |> 
    mutate(
      group = case_when(
        S_origin == "White" & S == "White" ~ "White (Original)",
        S_origin == "White" & S == "Black" ~ "White -> Black (Counterfactual)",
        S_origin == "Black" & S == "Black" ~ "Black (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "White (Original)", "White -> Black (Counterfactual)", "Black (Original)"
        )
      )
    ),
  aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(
      y = after_stat(density)), alpha = 0.5, colour = NA,
    position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.3, linewidth = 1) +
  facet_wrap(~S) +
  scale_fill_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]],
      "Black (Original)" = colours_all[["reference"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]],
      "Black (Original)" = colours_all[["reference"]]
    )
  ) +
  labs(
    title = "Unaware model, Sensitive: Race, Reference: Black individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme() +
  theme(legend.position = "bottom")
Figure 10.5: Unaware model, Sensitive: Race, Reference: Black individuals

Then, we focus on the distribution of predicted scores forcounterfactual of Black students and factuals of white students.

Codes used to create the Figure.
ggplot(
  data = unaware_ot_black |> 
    mutate(
      group = case_when(
        S_origin == "Black" & S == "Black" ~ "Black (Original)",
        S_origin == "Black" & S == "White" ~ "Black -> White (Counterfactual)",
        S_origin == "White" & S == "White" ~ "White (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "Black (Original)", "Black -> White (Counterfactual)", "White (Original)"
        )
      )
    ) |> 
    filter(S_origin == "Black"),
  mapping = aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5, linewidth = 1) +
  scale_fill_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  labs(
    title = "Predicted Scores for Minority Class\n Unware model, Sensitive: Race, Reference: White individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme()
Figure 10.6: Distribution of Predicted Scores for Minority Class (Black), Unaware model, Sensitive: Race, Reference: White individuals
Codes used to create the Figure.
ggplot(
  data = unaware_ot_white |> 
    mutate(
      group = case_when(
        S_origin == "White" & S == "White" ~ "White (Original)",
        S_origin == "White" & S == "Black" ~ "White -> Black (Counterfactual)",
        S_origin == "Black" & S == "Black" ~ "Black (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "White (Original)", "White -> Black (Counterfactual)", "Black (Original)"
        )
      )
    ) |> 
    filter(S_origin == "White"),
  mapping = aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5, linewidth = 1) +
  scale_fill_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]]
    )
  ) +
  labs(
    title = "Predicted Scores for Minority Class\n Unware model, Sensitive: Race, Reference: Black individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme()
Figure 10.7: Distribution of Predicted Scores for Minority Class (White), Unaware model, Sensitive: Race, Reference: Black individuals

10.2.2 Aware Model

factuals_aware <- tibble(
  S = df_race_c$S,
  S_origin = df_race_c$S,
  X1 = df_race_c$X1,
  X2 = df_race_c$X2,
  Y = df_race_c$Y,
  pred = pred_aware_all,
  type = "factual"
) |> 
  mutate(id_indiv = row_number())

We merge the two datasets, factuals_aware and counterfactuals_aware_ot in a single one.

# dataset with counterfactuals, for aware model
aware_ot_black <- 
  factuals_aware |> mutate(S_origin = S) |> 
  bind_rows(counterfactuals_aware_ot_black)

aware_ot_white <- 
  factuals_aware |> mutate(S_origin = S) |> 
  bind_rows(counterfactuals_aware_ot_white)

Now, we can visualize the distribution of the values predicted by the unaware model within each group defined by the sensitive attribute.

Codes used to create the Figure.
ggplot(
  data = aware_ot_black |> 
    mutate(
      group = case_when(
        S_origin == "Black" & S == "Black" ~ "Black (Original)",
        S_origin == "Black" & S == "White" ~ "Black -> White (Counterfactual)",
        S_origin == "White" & S == "White" ~ "White (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "Black (Original)", "Black -> White (Counterfactual)", "White (Original)"
        )
      )
    ),
  aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(
      y = after_stat(density)), alpha = 0.5, colour = NA,
    position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.3, linewidth = 1) +
  facet_wrap(~S) +
  scale_fill_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  labs(
    title = "Aware model, Sensitive: Race, Reference: White individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme() +
  theme(legend.position = "bottom")
Figure 10.8: Aware model, Sensitive: Race, Reference: White individuals
Codes used to create the Figure.
ggplot(
  data = aware_ot_white |> 
    mutate(
      group = case_when(
        S_origin == "White" & S == "White" ~ "White (Original)",
        S_origin == "White" & S == "Black" ~ "White -> Black (Counterfactual)",
        S_origin == "Black" & S == "Black" ~ "Black (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "White (Original)", "White -> Black (Counterfactual)", "Black (Original)"
        )
      )
    ),
  aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(
      y = after_stat(density)), alpha = 0.5, colour = NA,
    position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.3, linewidth = 1) +
  facet_wrap(~S) +
  scale_fill_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]],
      "Black (Original)" = colours_all[["reference"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]],
      "Black (Original)" = colours_all[["reference"]]
    )
  ) +
  labs(
    title = "Aware model, Sensitive: Race, Reference: Black individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme() +
  theme(legend.position = "bottom")
Figure 10.9: Aware model, Sensitive: Race, Reference: Black individuals

Then, we focus on the distribution of predicted scores forcounterfactual of Black students and factuals of white students.

Codes used to create the Figure.
ggplot(
  data = aware_ot_black |> 
    mutate(
      group = case_when(
        S_origin == "Black" & S == "Black" ~ "Black (Original)",
        S_origin == "Black" & S == "White" ~ "Black -> White (Counterfactual)",
        S_origin == "White" & S == "White" ~ "White (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "Black (Original)", "Black -> White (Counterfactual)", "White (Original)"
        )
      )
    ) |> 
    filter(S_origin == "Black"),
  mapping = aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5, linewidth = 1) +
  scale_fill_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "Black (Original)" = colours_all[["source"]],
      "Black -> White (Counterfactual)" = colours_all[["ot"]],
      "White (Original)" = colours_all[["reference"]]
    )
  ) +
  labs(
    title = "Predicted Scores for Minority Class\n Unware model, Sensitive: Race, Reference: White individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme()
Figure 10.10: Distribution of Predicted Scores for Minority Class (Black), Aware model, Sensitive: Race, Reference: White individuals
Codes used to create the Figure.
ggplot(
  data = aware_ot_white |> 
    mutate(
      group = case_when(
        S_origin == "White" & S == "White" ~ "White (Original)",
        S_origin == "White" & S == "Black" ~ "White -> Black (Counterfactual)",
        S_origin == "Black" & S == "Black" ~ "Black (Original)"
      ),
      group = factor(
        group, 
        levels = c(
          "White (Original)", "White -> Black (Counterfactual)", "Black (Original)"
        )
      )
    ) |> 
    filter(S_origin == "White"),
  mapping = aes(x = pred, fill = group, colour = group)
) +
  geom_histogram(
    mapping = aes(y = after_stat(density)), 
    alpha = 0.5, position = "identity", binwidth = 0.05
  ) +
  geom_density(alpha = 0.5, linewidth = 1) +
  scale_fill_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]]
    )
  ) +
  scale_colour_manual(
    NULL, values = c(
      "White (Original)" = colours_all[["source"]],
      "White -> Black (Counterfactual)" = colours_all[["ot"]]
    )
  ) +
  labs(
    title = "Predicted Scores for Minority Class\n Unware model, Sensitive: Race, Reference: Black individuals",
    x = "Predictions for Y",
    y = "Density"
  ) +
  global_theme()
Figure 10.11: Distribution of Predicted Scores for Minority Class (White), Aware model, Sensitive: Race, Reference: Black individuals

10.3 Saving Objects

save(
  counterfactuals_unaware_ot_black,
  file = "../data/counterfactuals_unaware_ot_black.rda"
)
save(
  counterfactuals_aware_ot_black, 
  file = "../data/counterfactuals_aware_ot_black.rda"
)