This chapter uses multivariate optimal transport (De Lara et al. (2024 ) ) to make counterfactual inference.
In the article, we use three methods to create counterfactuals:
Fairadapt (Chapter 6 )
Multivariate optimal transport (this chapter)
Sequential transport (the methodology we develop in the paper, see Chapter 8 ).
Display the setting codes
# Required packages----
library (tidyverse)
library (fairadapt)
# Graphs----
font_main = font_title = 'Times New Roman'
extrafont:: loadfonts (quiet = T)
face_text= 'plain'
face_title= 'plain'
size_title = 14
size_text = 11
legend_size = 11
global_theme <- function () {
theme_minimal () %+replace%
theme (
text = element_text (family = font_main, size = size_text, face = face_text),
legend.text = element_text (family = font_main, size = legend_size),
axis.text = element_text (size = size_text, face = face_text),
plot.title = element_text (
family = font_title,
size = size_title,
hjust = 0.5
),
plot.subtitle = element_text (hjust = 0.5 )
)
}
# Seed
set.seed (2025 )
source ("../functions/utils.R" )
Load Data and Classifier
We load the dataset where the sensitive attribute ((S)) is the race, obtained Chapter 4.3 :
load ("../data/df_race.rda" )
We also load the dataset where the sensitive attribute is also the race, but where where the target variable ((Y), ZFYA) is binary (1 if the student obtained a standardized first year average over the median, 0 otherwise). This dataset was saved in Chapter 5.5 :
load ("../data/df_race_c.rda" )
We also need the predictions made by the classifier (see Chapter 5 ):
# Predictions on train/test sets
load ("../data/pred_aware.rda" )
load ("../data/pred_unaware.rda" )
# Predictions on the factuals, on the whole dataset
load ("../data/pred_aware_all.rda" )
load ("../data/pred_unaware_all.rda" )
Counterfactuals with Multivariate Optimal Transport
We apply multivariate optimal transport (OT), following the methodology developed in De Lara et al. (2024 ) . Note that with OT, it is not possible to handle new cases. Counterfactuals will only be calculated on the train set.
The codes are run in python. We use the {reticulate} R package to call python in this notebook.
library (reticulate)
# py_install("POT")
Some libraries need to be loaded (including POT called ot)
import ot
import pandas as pd
import numpy as np
import matplotlib.pyplot as pl
import ot.plot
The data with the factuals need to be loaded:
df_aware = pd.read_csv('../data/factuals_aware.csv' )
df_unaware = pd.read_csv('../data/factuals_unaware.csv' )
x_S = df_aware.drop(columns= ['pred' , 'type' ])
x_S.head()
S X1 X2
0 White 3.1 39.0
1 White 3.0 36.0
2 White 3.1 30.0
3 White 3.4 37.0
4 White 3.6 30.5
x_white = x_S[x_S['S' ] == 'White' ]
x_white = x_white.drop(columns= ['S' ])
x_black = x_S[x_S['S' ] == 'Black' ]
x_black = x_black.drop(columns= ['S' ])
n_white = len (x_white)
n_black = len (x_black)
# Uniform weights
w_white = (1 / n_white)* np.ones(n_white)
w_black = (1 / n_black)* np.ones(n_black)
Cost matrix between both distributions:
x_white = x_white.to_numpy()
x_black = x_black.to_numpy()
C = ot.dist(x_white, x_black)
pl.figure(1 )
pl.plot(x_white[:, 0 ], x_white[:, 1 ], '+b' , label= 'Source samples' )
pl.plot(x_black[:, 0 ], x_black[:, 1 ], 'xr' , label= 'Target samples' )
pl.legend(loc= 0 )
pl.title('Source and target distributions' )
pl.figure(2 )
pl.imshow(C, interpolation= 'nearest' )
pl.title('Cost matrix C' )
The transport plan: white –> black
pi_white_black = ot.emd(w_white, w_black, C, numItermax= 1e8 )
pi_black_white = pi_white_black.T
pi_white_black.shape
sum_of_rows = np.sum (pi_white_black, axis= 1 )
sum_of_rows* n_white
array([1., 1., 1., ..., 1., 1., 1.])
sum_of_rows = np.sum (pi_black_white, axis= 1 )
sum_of_rows* n_black
array([1., 1., 1., ..., 1., 1., 1.])
pl.figure(3 )
pl.imshow(pi_white_black, interpolation= 'nearest' )
pl.title('OT matrix pi_white_black' )
pl.figure(4 )
ot.plot.plot2D_samples_mat(x_white, x_black, pi_white_black, c= [.5 , .5 , 1 ])
pl.plot(x_white[:, 0 ], x_white[:, 1 ], '+b' , label= 'Source samples' )
pl.plot(x_black[:, 0 ], x_black[:, 1 ], 'xr' , label= 'Target samples' )
pl.legend(loc= 0 )
pl.title('OT matrix with samples' )
transformed_x_white = n_white* pi_white_black@ x_black
transformed_x_white.shape
array([[ 2.7, 31. ],
[ 2.7, 28. ],
[ 2.6, 21. ],
...,
[ 3.9, 28. ],
[ 2.5, 22. ],
[ 3. , 19. ]])
transformed_x_black = n_black* pi_black_white@ x_white
transformed_x_black.shape
array([[ 3.2 , 37.58851518],
[ 3.28565491, 28.02103363],
[ 2.95793273, 32.14022423],
...,
[ 3.28597758, 33. ],
[ 2.65092152, 41.43910309],
[ 2.75152858, 36. ]])
counterfactual_x = x_S.drop(columns= ['S' ])
counterfactual_x[x_S['S' ] == 'White' ] = transformed_x_white
counterfactual_x[x_S['S' ] == 'Black' ] = transformed_x_black
X1 X2
0 2.7 31.0
1 2.7 28.0
2 2.6 21.0
3 3.1 28.0
4 3.2 21.0
Lastly, we export the results in a CSV file:
csv_file_path = '../data/counterfactuals_ot.csv'
counterfactual_x.to_csv(csv_file_path, index= False )
Let us get back to R, and load the results.
counterfactuals_ot <- read_csv ('../data/counterfactuals_ot.csv' )
We add the sensitive attribute to the dataset (Black individuals become White, and conversely):
S_star <- df_race_c |>
mutate (
S_star = case_when (
S == "Black" ~ "White" ,
S == "White" ~ "Black" ,
TRUE ~ "Error"
)
) |>
pull ("S_star" )
counterfactuals_ot <- counterfactuals_ot |>
mutate (S = S_star)
Unaware Model
Let us make prediction with the unaware model on the counterfactuals obtained with OT:
model_unaware <- pred_unaware$ model
pred_unaware_ot <- predict (
model_unaware, newdata = counterfactuals_ot, type = "response"
)
counterfactuals_unaware_ot <- counterfactuals_ot |>
mutate (pred = pred_unaware_ot, type = "counterfactual" )
We put in a table the initial characteristics (factuals) and the prediction made by the unaware model:
factuals_unaware <- tibble (
S = df_race$ S,
X1 = df_race$ X1,
X2 = df_race$ X2,
pred = pred_unaware_all,
type = "factual"
)
We bind the factuals and counterfactuals with their respective predicted values in a single dataset:
unaware_ot <- bind_rows (
# predicted values on factuals
factuals_unaware |> mutate (type = "factual" ),
# predicted values on counterfactuals obtained with OT
counterfactuals_unaware_ot
)
Then, we can visualize the distribution of the values predicted by the unaware model within each group defined by the sensitive attribute.
unaware_ot_white <- unaware_ot |> filter (S == "White" )
unaware_ot_black <- unaware_ot |> filter (S == "Black" )
ggplot (
data = unaware_ot_black,
mapping = aes (x = pred, fill = type)
) +
geom_histogram (
mapping = aes (y = after_stat (density)),
alpha = 0.5 , position = "identity" , binwidth = 0.05
) +
geom_density (alpha = 0.5 ) +
labs (
title = "Unaware model, Sensitive: Race, Reference: Black individual" ,
x = "Predictions for Y" ,
y = "Density"
) +
global_theme ()
ggplot (
data = unaware_ot_white,
mapping = aes (x = pred, fill = type)) +
geom_histogram (
mapping = aes (y = after_stat (density)),
alpha = 0.5 , position = "identity" , binwidth = 0.05
) +
geom_density (alpha = 0.5 ) +
labs (
title = "Unaware model, Sensitive: Race, Reference: White" ,
x = "Predictions for Y" ,
y = "Density" ) +
global_theme ()
Aware Model
Let us make prediction with the aware model on the counterfactuals obtained with OT:
model_aware <- pred_aware$ model
pred_aware_ot <- predict (
model_aware, newdata = counterfactuals_ot, type = "response"
)
counterfactuals_aware_ot <- counterfactuals_ot |>
mutate (pred = pred_aware_ot, type = "counterfactual" )
counterfactuals_aware_ot
# A tibble: 19,567 × 5
X1 X2 S pred type
<dbl> <dbl> <chr> <dbl> <chr>
1 2.7 31 Black 0.141 counterfactual
2 2.7 28 Black 0.120 counterfactual
3 2.6 21 Black 0.0791 counterfactual
4 3.1 28 Black 0.143 counterfactual
5 3.2 21 Black 0.104 counterfactual
6 3.3 27.5 Black 0.152 counterfactual
7 2.4 29 Black 0.111 counterfactual
8 2.3 29 Black 0.106 counterfactual
9 3.3 22 Black 0.115 counterfactual
10 2.8 34 Black 0.171 counterfactual
# ℹ 19,557 more rows
We put in a table the initial characteristics (factuals) and the prediction made by the aware model:
factuals_aware <- tibble (
S = df_race$ S,
X1 = df_race$ X1,
X2 = df_race$ X2,
pred = pred_aware_all,
type = "factual"
)
We bind the factuals and counterfactuals with their respective predicted values in a single dataset:
aware_ot <- bind_rows (
factuals_aware,
counterfactuals_aware_ot
)
Then, we can visualize the distribution of the values predicted by the unaware model within each group defined by the sensitive attribute.
aware_ot_white <- aware_ot |> filter (S == "White" )
aware_ot_black <- aware_ot |> filter (S == "Black" )
ggplot (
data = aware_ot_black,
mapping = aes (x = pred, fill = type)
) +
geom_histogram (
mapping = aes (y = ..density..),
alpha = 0.5 , position = "identity" , binwidth = 0.05
) +
geom_density (alpha = 0.5 ) +
labs (
title = "Aware model, Sensitive: Race, Reference: Black individual" ,
x = "Predictions for Y" ,
y = "Density"
) +
global_theme ()
Warning: The dot-dot notation (`..density..`) was deprecated in ggplot2 3.4.0.
ℹ Please use `after_stat(density)` instead.
ggplot (
data = aware_ot_white,
mapping = aes (x = pred, fill = type)) +
geom_histogram (
mapping = aes (y = ..density..),
alpha = 0.5 , position = "identity" , binwidth = 0.05
) +
geom_density (alpha = 0.5 ) +
labs (
title = "Aware model, Sensitive: Race, Reference: White" ,
x = "Predictions for Y" ,
y = "Density" ) +
global_theme ()
Comparison for Two Individuals
Let us, again, focus on two individuals: 24 (Black) and 25 (White):
(indiv_factuals_unaware <- factuals_unaware[24 : 25 , ])
# A tibble: 2 × 5
S X1 X2 pred type
<fct> <dbl> <dbl> <dbl> <chr>
1 Black 2.8 29 0.300 factual
2 White 2.8 34 0.382 factual
The counterfactuals for those individuals, using the unaware model:
(indiv_counterfactuals_unaware_ot <- counterfactuals_unaware_ot[24 : 25 , ])
# A tibble: 2 × 5
X1 X2 S pred type
<dbl> <dbl> <chr> <dbl> <chr>
1 3.20 37.6 White 0.502 counterfactual
2 2.4 25 Black 0.203 counterfactual
Let us put the factuals and counterfactuals in a single table:
indiv_unaware_ot <- bind_rows (
indiv_factuals_unaware |> mutate (id = c (24 , 25 )),
indiv_counterfactuals_unaware_ot |> mutate (id = c (24 , 25 ))
)
indiv_unaware_ot
# A tibble: 4 × 6
S X1 X2 pred type id
<chr> <dbl> <dbl> <dbl> <chr> <dbl>
1 Black 2.8 29 0.300 factual 24
2 White 2.8 34 0.382 factual 25
3 White 3.20 37.6 0.502 counterfactual 24
4 Black 2.4 25 0.203 counterfactual 25
We compute the difference between the predicted value by the unaware model using the counterfactuals and the predicted value by the unaware model using the factuals:
indiv_unaware_ot |> select (id , type, pred) |>
pivot_wider (names_from = type, values_from = pred) |>
mutate (diff = counterfactual - factual)
# A tibble: 2 × 4
id factual counterfactual diff
<dbl> <dbl> <dbl> <dbl>
1 24 0.300 0.502 0.202
2 25 0.382 0.203 -0.179
We do the same for the aware model:
indiv_aware_ot <- bind_rows (
factuals_aware[c (24 , 25 ),] |> mutate (id = c (24 , 25 )),
counterfactuals_aware_ot[c (24 , 25 ),] |> mutate (id = c (24 , 25 ))
)
indiv_aware_ot
# A tibble: 4 × 6
S X1 X2 pred type id
<chr> <dbl> <dbl> <dbl> <chr> <dbl>
1 Black 2.8 29 0.133 factual 24
2 White 2.8 34 0.413 factual 25
3 White 3.20 37.6 0.515 counterfactual 24
4 Black 2.4 25 0.0898 counterfactual 25
The difference between the counterfactual and the factual for these two individuals, when using the aware model:
indiv_aware_ot |> select (id , type, pred) |>
pivot_wider (names_from = type, values_from = pred) |>
mutate (diff = counterfactual - factual)
# A tibble: 2 × 4
id factual counterfactual diff
<dbl> <dbl> <dbl> <dbl>
1 24 0.133 0.515 0.383
2 25 0.413 0.0898 -0.323
Counterfactual Demographic Parity
As for the counterfactuals obtained with fairadapt, we assume here that the reference group is “White individuals” (i.e., the group with the most individuals in the dataset). We focus on the minority, i.e., Black individuals. We consider here that the model is fair towards the minority class if: \[
P(\hat{Y}_{S \leftarrow \text{White}} = 1 | S = \text{Black}, X_1, X_2) = P(\hat{Y} = 1 | S = \text{White}, X_1, X_2)
\] If the model is fair with respect to this criterion, the proportion of Black individuals predicted to have grades above the median should be the same as if they had been white.
For predictions made with the unaware model:
dp_unaware_pt <- mean (
counterfactuals_unaware_ot |> filter (S == "White" ) |> pull ("pred" ) -
factuals_unaware |> filter (S == "Black" ) |> pull ("pred" )
)
dp_unaware_pt
We do the same with the aware model:
dp_aware_ot <- mean (
counterfactuals_aware_ot |> filter (S == "White" ) |> pull ("pred" ) -
factuals_aware |> filter (S == "Black" ) |> pull ("pred" )
)
dp_aware_ot
Saving Objects
save (
counterfactuals_unaware_ot,
file = "../data/counterfactuals_unaware_ot.rda"
)
save (
counterfactuals_aware_ot,
file = "../data/counterfactuals_aware_ot.rda"
)
De Lara, Lucas, Alberto González-Sanz, Nicholas Asher, Laurent Risser, and Jean-Michel Loubes. 2024. “Transport-Based Counterfactual Models.” Journal of Machine Learning Research 25 (136): 1–59.