3  Adult Dataset

Objectives

We apply the two methods presented in Chapter 1 to build counterfactuals for a categorical feature, as in [Chapter -Chapter 2.

The illustration is made on a second real world dataset, the Adult Income dataset from the UCI Machine Learning Repository.

library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ ggplot2   3.5.1     ✔ tibble    3.2.1
✔ lubridate 1.9.4     ✔ tidyr     1.3.1
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(devtools)
Loading required package: usethis
library(ggtern)
Registered S3 methods overwritten by 'ggtern':
  method           from   
  grid.draw.ggplot ggplot2
  plot.ggplot      ggplot2
  print.ggplot     ggplot2
--
Remember to cite, run citation(package = 'ggtern') for further info.
--

Attaching package: 'ggtern'

The following objects are masked from 'package:ggplot2':

    aes, annotate, ggplot, ggplot_build, ggplot_gtable, ggplotGrob,
    ggsave, layer_data, theme_bw, theme_classic, theme_dark,
    theme_gray, theme_light, theme_linedraw, theme_minimal, theme_void
# load the functions from our package to perform optimal transport on 
# compositional data
load_all("../../")
ℹ Loading transportsimplex
Definition of a ggplot2 theme.
#' Theme for ggplot2
#'
#' @param ... arguments passed to the theme function
#' @export
#' @importFrom ggplot2 element_rect element_text element_blank element_line unit
#'   rel
theme_paper <- function(...) {
  theme(
    text = element_text(family = "Times New Roman"),
    plot.background = element_rect(fill = "transparent", color = NA),
    legend.text = element_text(size = rel(1.1)),
    legend.title = element_text(size = rel(1.1)),
    legend.background = element_rect(
     fill = "transparent", linetype="solid", colour ="black"),
    legend.position = "bottom",
    legend.direction = "horizontal",
    legend.box = "vertical",
    legend.key = element_blank(),
    panel.spacing = unit(1, "lines"),
    plot.title = element_text(hjust = 0, size = rel(1.3), face = "bold"),
    plot.title.position = "plot",
    strip.background = element_rect(fill = NA, colour = NA),
    strip.text = element_text(size = rel(1.1))
  )
}

3.1 The Dataset

library(fairml)
data(adult)

We want to build counterfactual values for the variable The categorical variable for which we want to build counterfactuals for the marital_status variable for women, had they been men. These are the raw categories:

adult |> pull(marital_status) |> table()

   Married-civ-spouse              Divorced         Never-married 
                14065                  4214                  9726 
            Separated               Widowed Married-spouse-absent 
                  939                   827                   370 
    Married-AF-spouse 
                   21 

Let us regroup these categories in three: (i) married, (ii) never married, (iii) separated.

adult <- 
  adult |> 
  mutate(
    marital_status = case_when(
      marital_status %in% c(
        "Married-AF-spouse", "Married-civ-spouse"
      ) ~ "Married",
      marital_status %in% c(
        "Divorced", "Separated", "Widowed", "Married-spouse-absent"
      ) ~ "Separated",
      marital_status %in% c("Never-married") ~ "Never-married",
      TRUE ~ "error"
    ),
    marital_status = factor(marital_status)
  )

The global proportions in the dataset for each category of marital status:

prop.table(adult |> pull(marital_status) |> table())

      Married Never-married     Separated 
    0.4670115     0.3224587     0.2105298 

And if we compare by gender:

Codes to create the Table.
prop_marital <- 
  adult |> 
  count(sex, marital_status) |> 
  group_by(sex) |> 
  mutate(prop = n / sum(n)) |> 
  select(-n) |> 
  pivot_wider(names_from = marital_status, values_from = prop)
prop_marital |> 
  kableExtra::kbl(
    booktabs = TRUE, digits = 4,
  ) |> 
  kableExtra::kable_paper() |> 
  kableExtra::add_header_above(c(" " = 1, "Proportions" = 3))
Table 3.1: Proportions of each marital status for women and for men.
Proportions
sex Married Never-married Separated
Female 0.1525 0.4408 0.4067
Male 0.6180 0.2657 0.1164

3.2 Estimation of Scores

We fit four models, \(\widehat{m}^{(k)}(\mathbf{x}_j|\mathbf{x}_{-j})\), using a multinomial loss, yielding predicted scores \(\widehat{\mathbf{x}}_{j}^{(k)}\in\mathcal{S}_d\) for \(k \in \{1,2,3,4\}\). The sensitive attribute \(s\). The income variable is also removed. The four models are:

  • GAM-MLR (1): A multinomial model with splines for two continuous variables and a third variable.
  • GAM-MLR (2): A multinomial model with adidtional variables.
  • Random Forest: A classifier using all available variables.
  • Gradient Boosting Model: A GBM trained on all available variables.

The estimation of GAM-MLR (1):

library(splines)
require(nnet)
Loading required package: nnet
set.seed(123)
model_glm_1 <- multinom(
  marital_status ~ bs(age) + bs(hours_per_week) + occupation, 
  data = adult |> select(-sex, -income)
)
# weights:  63 (40 variable)
initial  value 33136.343851 
iter  10 value 25885.836959
iter  20 value 25131.762781
iter  30 value 24607.941350
iter  40 value 24439.499919
iter  50 value 24438.214321
final  value 24438.207100 
converged

The estimation of GAM-MLR (2):

model_glm_2 <- multinom(
  marital_status ~ bs(age) + bs(hours_per_week) + occupation + relationship +
    workclass + bs(education_num) + education + bs(capital_gain),
  data = adult |> select(-sex, -income)
)
# weights:  144 (94 variable)
initial  value 33136.343851 
iter  10 value 9899.654429
iter  20 value 9211.159339
iter  30 value 8206.013535
iter  40 value 7892.102459
iter  50 value 7614.929161
iter  60 value 7450.099601
iter  70 value 7386.503591
iter  80 value 7371.425003
iter  90 value 7367.283292
iter 100 value 7367.014720
final  value 7367.014720 
stopped after 100 iterations

The estimation of the random forest:

library(randomForest)
model_rf <- randomForest(
  marital_status ~ ., data = adult |> select(-sex, -income)
)
save(model_rf, file = "../output/model_rf_adult.rda")
load( "../output/model_rf_adult.rda")

The estimation of the gradient boosting model:

library(gbm)
library(caret)
model_gbm <- gbm(
  marital_status ~.,
  data = adult |> select(-sex, -income),
  distribution = "multinomial",
  cv.folds = 10,
  shrinkage = .01,
  n.minobsinnode = 10,
  n.trees = 2000
)
save(model_gbm, file = "../output/model_gbm_adult.rda")
load( "../output/model_gbm_adult.rda")

Now that we have estimated the four models, we can extract the estimates scores \(\widehat{\mathbf{x}}_{j}^{(k)}\in\mathcal{S}_d\):

scores_glm_1 <- predict(model_glm_1, type = "probs")
scores_glm_2 <- predict(model_glm_2, type = "probs")
scores_rf <- predict(model_rf, type = "prob")
scores_gbm <- predict.gbm(
  object = model_gbm,
  newdata = adult |> select(-sex, -income),
  n.trees = 2000,
  type = "response")
scores_gbm <- scores_gbm[ , , 1]

Let us have a look at the predicted scores of four individuals (2 women, 2 men), for each model.

Codes to create the Table.
print_highlighted_obs <- function(scores) {
  ind_highlight <- c(1, 2, 5, 6)
  tb <- adult |>
    select(marital_status, sex) |> 
    bind_cols(scores) |> 
    slice(ind_highlight)
  row.names(tb) <- NULL
  tb
}

tb_four_indiv <- 
  print_highlighted_obs(scores_glm_1) |> mutate(model = "glm_1") |> 
  bind_rows(
    print_highlighted_obs(scores_glm_2) |> mutate(model = "glm_2")
  ) |> 
  bind_rows(
    print_highlighted_obs(scores_rf) |> mutate(model = "rf")
  ) |> 
  bind_rows(
    print_highlighted_obs(scores_gbm) |> mutate(model = "gbm")
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("glm_1", "glm_2", "rf", "gbm"),
      labels = c("GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
                 "Gradient Boosting Model")
    )
  ) |> 
  relocate(model, .before = marital_status)

tb_four_indiv |> select(-model) |> 
  kableExtra::kbl(
    booktabs = TRUE, digits = 4,
  ) |> 
  kableExtra::kable_paper() |> 
  kableExtra::add_header_above(c(" " = 2, "Predicted Scores" = 3)) |> 
  kableExtra::pack_rows(index = table(tb_four_indiv$model))
Table 3.2: Mappings from the marital_status categorical variable \(x\) to the compositional one \(\tilde{\mathbf{x}}\), for four individuals of the dataset.
Predicted Scores
marital_status sex Married Never-married Separated
GAM-MLR(1)
Never-married Male 0.3523 0.2059 0.4418
Married Male 0.5428 0.1054 0.3519
Married Female 0.3026 0.5988 0.0986
Married Female 0.5957 0.1710 0.2333
GAM-MLR(2)
Never-married Male 0.0052 0.6323 0.3625
Married Male 1.0000 0.0000 0.0000
Married Female 1.0000 0.0000 0.0000
Married Female 1.0000 0.0000 0.0000
Random Forest
Never-married Male 0.0052 0.7565 0.2383
Married Male 0.9944 0.0000 0.0056
Married Female 0.9652 0.0348 0.0000
Married Female 1.0000 0.0000 0.0000
Gradient Boosting Model
Never-married Male 0.0034 0.6210 0.3757
Married Male 0.9991 0.0003 0.0007
Married Female 0.9980 0.0018 0.0001
Married Female 0.9992 0.0005 0.0003

3.3 Optimal Transport in the Euclidean Representation

We can now apply Algorithm 1.1 with a Gaussian mapping in an Euclidean representation space to transport from observed \(\widehat{\mathbf{x}}_j|s=0\) (scores for women) to counterfactual \(\widehat{\mathbf{x}}_j|s=1\) (scores for men), in \(\mathcal{S}_d\).

Subset

We will work on a subset of the data here for the transport. Also, if the number is too large, the returned weights tend to be all equal to 1…

set.seed(1234)
idx <- sample(1:nrow(adult),size = 400)
adult_subset <- adult[idx, ]

Let us isolate women and men:

ind_0 <- which(adult_subset$sex == "Female")
ind_1 <- which(adult_subset$sex == "Male")

Then, we create matrices for each model with the predicted scores for each category, i.e., the representation of the categorical variable marital_status in the unit simplex.

# GAM-MLR(1)
X0_glm_1 <- scores_glm_1[idx, ][ind_0, ]
X1_glm_1 <- scores_glm_1[idx, ][ind_1, ]
# GAM-MLR(2)
X0_glm_2 <- scores_glm_2[idx, ][ind_0,]
X1_glm_2 <- scores_glm_2[idx, ][ind_1,]
# RF
X0_rf <- scores_rf[idx, ][ind_0, ]
X1_rf <- scores_rf[idx, ][ind_1, ]
# GBM
X0_gbm <- scores_gbm[idx, ][ind_0, ]
X1_gbm <- scores_gbm[idx, ][ind_1, ]

For the random forest and the gradient boosting model, we add a tiny bit to scores exactly equal to zero and substract the same tiny bit to scores exactly equal to one.

# RF
for(i in 1:3) X0_rf[which(X0_rf[, i] == 0), i] = .0000001
for(i in 1:3) X0_rf[which(X0_rf[, i] == 1), i] = 1-.0000001
for(i in 1:3) X1_rf[which(X1_rf[, i] == 0), i] = .0000001
for(i in 1:3) X1_rf[which(X1_rf[, i] == 1), i] = 1-.0000001
# GBM
for(i in 1:3) X0_gbm[which(X0_gbm[, i] == 0), i] = .0000001
for(i in 1:3) X0_gbm[which(X0_gbm[, i] == 1), i] = 1-.0000001
for(i in 1:3) X1_gbm[which(X1_gbm[, i] == 0), i] = .0000001
for(i in 1:3) X1_gbm[which(X1_gbm[, i] == 1), i] = 1-.0000001

Then, we can apply Algorithm 1.1 to each set of predicted scores. We use the clr transform and Gaussian OT:

transp_glm_1 <- transport_simplex(X0 = X0_glm_1, X1 = X1_glm_1, n_interp = 31)
transp_glm_2 <- transport_simplex(X0 = X0_glm_2, X1 = X1_glm_2, n_interp = 31)
transp_rf <- transport_simplex(X0 = X0_rf, X1 = X1_rf, n_interp = 31)
transp_gbm <- transport_simplex(X0 = X0_rf, X1 = X1_gbm, n_interp = 31)

We can then have a look at the percentage of each purpose category in the initial dataset, compare it with the average predicted score of each category for each model, and with the average of the transported predicted score for women.

Codes to create the Table.
# Proportions of each purpose level by gender in the dataset
prop_marital <- 
  adult |> 
  count(marital_status, sex) |> 
  group_by(sex) |> 
  mutate(prop = n / sum(n)) |> 
  select(-n) |> 
  pivot_wider(names_from = marital_status, values_from = prop) |> 
  mutate(type = "Categorical")

get_table_pred_transp <- function(scores, transp_scores) {
  # Average predicted scores for each purpose level by gender
  mean_scores_by_gender <- 
    adult_subset |> 
    select(marital_status, sex) |> 
    bind_cols(scores) |> 
    group_by(sex) |> 
    summarise(across(colnames(!!scores), ~mean(.x))) |> 
    mutate(type = "Composition")
  
  # Average predicted transported score of women for each purpose level
  mean_transp_scores_women <- colMeans(transp_scores) |> 
    as_tibble_row() |>
    mutate(type = "Transported", sex = "Female -> Male")
  
  mean_scores_by_gender |> 
    bind_rows(mean_transp_scores_women)
}

tb_pred_transp_mean <-
  prop_marital |> mutate(model = "obs") |> 
  bind_rows(
    get_table_pred_transp(scores_glm_1[idx, ], transp_glm_1) |> 
      mutate(model = "glm_1")
  ) |> 
  bind_rows(
    get_table_pred_transp(scores_glm_2[idx, ], transp_glm_2) |> mutate(model = "glm_2")
  ) |> 
  bind_rows(
    get_table_pred_transp(scores_rf[idx, ], transp_rf) |> mutate(model = "rf")
  ) |> 
  bind_rows(
    get_table_pred_transp(scores_gbm[idx, ], transp_gbm) |> mutate(model = "gbm")
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("obs", "glm_1", "glm_2", "rf", "gbm"),
      labels = c(
        "Observed Values",
        "GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
        "Gradient Boosting Model"
      )
    )
  ) |> 
  relocate(model, .before = sex) |> 
  relocate(type, .after = model)

tb_pred_transp_mean |> select(-model) |> 
  kableExtra::kbl(
  booktabs = TRUE, digits = 4,
) |> 
  kableExtra::kable_paper() |> 
  kableExtra::add_header_above(c(" " = 2, "Marital Status" = 3)) |> 
  kableExtra::pack_rows(index = table(tb_pred_transp_mean$model))
Table 3.3
Marital Status
type sex Married Never-married Separated
Observed Values
Categorical Female 0.1525 0.4408 0.4067
Categorical Male 0.6180 0.2657 0.1164
GAM-MLR(1)
Composition Female 0.3524 0.3965 0.2510
Composition Male 0.5040 0.3104 0.1856
Transported Female -> Male 0.4908 0.3213 0.1879
GAM-MLR(2)
Composition Female 0.1333 0.4605 0.4062
Composition Male 0.5959 0.2732 0.1309
Transported Female -> Male 0.7760 0.0780 0.1461
Random Forest
Composition Female 0.1316 0.4871 0.3813
Composition Male 0.5929 0.2769 0.1302
Transported Female -> Male 0.4867 0.3337 0.1796
Gradient Boosting Model
Composition Female 0.1353 0.4604 0.4043
Composition Male 0.5950 0.2754 0.1296
Transported Female -> Male 0.5165 0.3137 0.1698

Optimal transport using the clr transformation, and Gaussian optimal transports, on the purpose scores in the German Credit database, with two logistic GAM models to predict scores, a random forest, and a boosting model. For observed values, the observed proportions of purpose categories are reported by gender. Then, for each model, the average of predicted scores by gender for each categories are shown (Composition). Lastly, the average of transported predicted scores for women are reported (Transported).

3.3.1 Visualization of Transported Categories

We can then show the counterfactuals on a ternary plot, and graph the displacement interpolation when generationg from the factual (women) to the counterfactuals (men).

First, we format the paths:

transp_glm_1_path <- interpolated(transp_glm_1) |> 
  list_rbind(names_to = "id_obs")
transp_glm_2_path <- interpolated(transp_glm_2) |> 
  list_rbind(names_to = "id_obs")
transp_rf_path <- interpolated(transp_rf) |> 
  list_rbind(names_to = "id_obs")
transp_gbm_path <- interpolated(transp_gbm) |> 
  list_rbind(names_to = "id_obs")

Then, we can show, for each model, the representation in the simplex of the categorical variable marital_status, by gender (women in red and men in blue), as well as the displacement interpolation.

Codes to create the Figure.
scores_all <- as_tibble(scores_glm_1[idx, ]) |> 
  mutate(sex = adult_subset$sex, model = "glm_1") |> 
  bind_rows(
    as_tibble(scores_glm_2[idx, ]) |> 
      mutate(sex = adult_subset$sex, model = "glm_2")
  ) |> 
  bind_rows(
    as_tibble(as.data.frame(scores_rf[idx, ])) |> 
      mutate(sex = adult_subset$sex, model = "rf")
  ) |> 
  bind_rows(
    as_tibble(scores_gbm[idx, ]) |> 
      mutate(sex = adult_subset$sex, model = "gbm")
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("obs", "glm_1", "glm_2", "rf", "gbm"),
      labels = c(
        "Observed Values",
        "GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
        "Gradient Boosting Model"
      )
    )
  )

transp_all_path <- 
  transp_glm_1_path |> 
  mutate(
    sex = factor("Female", levels = levels(adult_subset$sex)),
    model = "glm_1"
  ) |> 
  bind_rows(
    transp_glm_2_path |> 
      mutate(
        sex = factor("Female", levels = levels(adult_subset$sex)),
        model = "glm_2"
      ) 
  ) |> 
  bind_rows(
    transp_rf_path |> 
      mutate(
        sex = factor("Female", levels = levels(adult_subset$sex)),
        model = "rf"
      ) 
  ) |> 
  bind_rows(
    transp_gbm_path |> 
      mutate(
        sex = factor("Female", levels = levels(adult_subset$sex)),
        model = "gbm"
      ) 
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("obs", "glm_1", "glm_2", "rf", "gbm"),
      labels = c(
        "Observed Values",
        "GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
        "Gradient Boosting Model"
      )
    )
  )

ggtern(
  data = scores_all,
  mapping = aes(x = Married, y = Separated, z = `Never-married`, colour = sex)
) +
  geom_point(size = .1) +
  geom_line(
    data = transp_all_path, 
    linewidth = .2, alpha = .8,
    mapping = aes(group = id_obs)
  ) +
  scale_colour_manual(values = c("Female" = "red", "Male" = "blue"), guide = "none") +
  facet_wrap(~model) +
  labs(x = "Married", y = "Never-married", z = "Separated") +
  theme_paper() +
  theme(
    tern.axis.title = element_text(size = rel(.8))
  )
Figure 3.1: Optimal Transport using clr transform. Points in red are compositions for women, whereas points in blue are compositions for men. The lines indicate the displacement interpolation when generating counterfactuals.

3.4 Optimal Transport in \(\mathcal{S}_3\)

Let us now use Algorithm 1.2 to create counterfactuals using matching in \(\mathcal{S}_3\).

To that end, we use the wasserstein_simplex() function from our package.

Note

The codes are a bit long to run. Here, we load results saved from a previously evaluated code.

mapping_glm_1 <- wasserstein_simplex(X0_glm_1, X1_glm_1)
mapping_glm_2 <- wasserstein_simplex(X0_glm_2, X1_glm_2)
mapping_rf <- wasserstein_simplex(X0_rf, X1_rf)
mapping_gbm <- wasserstein_simplex(X0_gbm, X1_gbm)

if (!dir.exists("../output/")) dir.create("../output/", recursive = TRUE)
save(
  mapping_glm_1, mapping_glm_2, mapping_rf, mapping_gbm,
  file = "../output/matching_adult.rda"
)
load("../output/matching_adult.rda")

We extract the estimated weights for all the individuals:

M0_glm_1 <- mapping_glm_1$plan * nrow(X0_glm_1)
M0_glm_2 <- mapping_glm_2$plan * nrow(X0_glm_2)
M0_rf <- mapping_rf$plan * nrow(X0_rf)
M0_gbm <- mapping_gbm$plan * nrow(X0_gbm)

Let us focus on a single individual \(x_{0,i}=\):

i <- 3

For each model, we extract the representation of the marital_status characteristic in the simplex (i.e., the predicted scores by the \(k\)-th model, \(\widehat{m}^{(k)}(\mathbf{x}_j|\mathbf{x}_{-j})\)). Let us denote this composition as \(\mathbf{x}_{0,i}^{(k)}\)

indiv_i_glm_1 <- X0_glm_1[i, ]
indiv_i_glm_2 <- X0_glm_2[i, ]
indiv_i_rf <- X0_rf[i, ]
indiv_i_gbm <- X0_gbm[i, ]

We then extract the weights \(\mathbf{P}^{\star(k)}_i\):

weights_i_glm_1 <- M0_glm_1[i, ]
weights_i_glm_2 <- M0_glm_2[i, ]
weights_i_rf <- M0_rf[i, ]
weights_i_gbm <- M0_gbm[i, ]

Lastly, we compute, for our individual, its counterfactual \(T^\star(\mathbf{x}_{0,i})\), by simply computing the weighted average of the characteristics of the individuals from the other group.

cfact_i_glm_1 <- weights_i_glm_1 %*% X1_glm_1
cfact_i_glm_2 <- weights_i_glm_2 %*% X1_glm_2
cfact_i_rf <- weights_i_rf %*% X1_rf
cfact_i_gbm <- weights_i_gbm %*% X1_gbm

We can then plot (Figure 3.2) the representation of the woman of interest obtained for each model (red dot), and its counterfactual obtained by matching (blue dot). We also plot, on the ternary plot, all the women and men. The size of the dots for men is proportional to the weights corresponding to the women of interest.

Codes to create the Figure.
# Women
tb_plot_females <- 
  as_tibble(X0_glm_1) |> 
  mutate(sex = "Female", model = "glm_1") |> 
  bind_rows(
    as_tibble(X0_glm_2) |> 
      mutate(sex = "Female", model = "glm_2")
  ) |> 
  bind_rows(
    as_tibble(X0_rf) |> 
      mutate(sex = "Female", model = "rf")
  ) |> 
  bind_rows(
    as_tibble(X0_gbm) |> 
      mutate(sex = "Female", model = "gbm",)
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("obs", "glm_1", "glm_2", "rf", "gbm"),
      labels = c(
        "Observed Values",
        "GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
        "Gradient Boosting Model"
      )
    )
  )

# Males individuals, with a column weights_i giving their weight used to 
# construct the counterfactual for indiv i
tb_plot_males <- 
  as_tibble(X1_glm_1) |> 
  mutate(
    sex = "Male", model = "glm_1", weights_i = weights_i_glm_1
  ) |> 
  bind_rows(
    as_tibble(X1_glm_2) |> 
      mutate(
        sex = "Male", model = "glm_2", weights_i = weights_i_glm_2
      )
  ) |> 
  bind_rows(
    as_tibble(X1_rf) |> 
      mutate(
        sex = "Male", model = "rf", weights_i = weights_i_rf
      )
  ) |> 
  bind_rows(
    as_tibble(X1_gbm) |> 
      mutate(
        sex = "Male", model = "gbm", weights_i = weights_i_rf
      )
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("obs", "glm_1", "glm_2", "rf", "gbm"),
      labels = c(
        "Observed Values",
        "GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
        "Gradient Boosting Model"
      )
    )
  )

indiv_i <- 
  as_tibble_row(indiv_i_glm_1) |> mutate(sex = "Female", model = "glm_1") |> 
  bind_rows(
    as_tibble_row(indiv_i_glm_2) |> mutate(sex = "Female", model = "glm_2")
  ) |> 
  bind_rows(
    as_tibble_row(indiv_i_rf) |> mutate(sex = "Female", model = "rf")
  ) |> 
  bind_rows(
    as_tibble_row(indiv_i_gbm) |> mutate(sex = "Female", model = "gbm")
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("obs", "glm_1", "glm_2", "rf", "gbm"),
      labels = c(
        "Observed Values",
        "GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
        "Gradient Boosting Model"
      )
    )
  )

cfact_indiv_i <- 
  as_tibble(cfact_i_glm_1) |> mutate(sex = "Male", model = "glm_1") |> 
  bind_rows(
    as_tibble(cfact_i_glm_2) |> mutate(sex = "Male", model = "glm_2")
  ) |> 
  bind_rows(
    as_tibble(cfact_i_rf) |> mutate(sex = "Male", model = "rf")
  ) |> 
  bind_rows(
    as_tibble(cfact_i_gbm) |> mutate(sex = "Male", model = "gbm")
  ) |> 
  mutate(
    model = factor(
      model, 
      levels = c("obs", "glm_1", "glm_2", "rf", "gbm"),
      labels = c(
        "Observed Values",
        "GAM-MLR(1)", "GAM-MLR(2)", "Random Forest", 
        "Gradient Boosting Model"
      )
    )
  )

ggtern(
  mapping = aes(x = Married, y = Separated, z = `Never-married`, colour = sex)
) +
  geom_point(
    data = tb_plot_females,
    size = .1,
    alpha = .6
  ) +
  geom_point(
    data = tb_plot_males,
    mapping = aes(size = weights_i),
    alpha = .5
  ) +
  geom_point(data = indiv_i, size = 3, colour = "white") +
  geom_point(data = indiv_i, size = 2) +
  geom_point(data = cfact_indiv_i, size = 3, colour = "white", shape = 15) +
  geom_point(data = cfact_indiv_i, size = 3, shape = 15) +
  facet_wrap(~ model) +
  labs(x = "Married", y = "Separated", z = "Never Married") +
  scale_colour_manual(
    values = c("Female" = "red", "Male" = "blue"), guide = "none"
  ) +
  scale_size_continuous(range = c(0, 1), guide = "none") +
  theme_paper() +
  theme(
    tern.axis.title = element_text(size = rel(.8))
  )
Figure 3.2: Empirical matching of a woman \(\mathbf{x}_{0,i}^{(k)}\) (big red dot) with men (blue dots). The Size of blue dots are proportional to the weights \(\mathbf{P}^\star_i\). The counterfactual obtained with matching \(T^\star(\mathbf{x}_{0,i})\) is shown as a blue square.

To finish, let us look more closely to the i-th woman for which we have shown the counterfactual on the ternary plot. The value of the marital_status variable for her is Married:

adult_subset[ind_0[i], "marital_status"]
[1] Married
Levels: Married Never-married Separated

Using the GAM-MLR(1) model, we obtained the following composition:

X0_glm_1[i, ]
      Married Never-married     Separated 
   0.14317730    0.81513857    0.04168414 

The closest points in the group of men, obtained using Algorithm 1.2 are:

ind_close_points <- tail(order(weights_i_glm_1), 10)
X1_glm_1[ind_close_points, ] |> 
  as_tibble() |> 
  mutate(
    Purpose = adult_subset[ind_1[ind_close_points], ] |> select(marital_status),
    index = ind_close_points,
    weights_i = weights_i_glm_1[ind_close_points]) |> 
  arrange(desc(weights_i)) |> 
  kableExtra::kbl(booktabs = TRUE) |> 
  kableExtra::kable_paper()
Married Never-married Separated Purpose index weights_i
0.2041504 0.7544468 0.0414028 Never-married 130 0.4813495
0.2381338 0.7089583 0.0529079 Never-married 89 0.3645661
0.1720608 0.7923851 0.0355542 Married 119 0.1544133
0.0187499 0.9619327 0.0193173 Never-married 26 0.0000035
0.0050026 0.9826104 0.0123870 Never-married 200 0.0000034
0.7732104 0.0600182 0.1667714 Separated 110 0.0000034
0.7790189 0.0590491 0.1619320 Married 72 0.0000034
0.6202558 0.1271203 0.2526239 Married 70 0.0000033
0.6417728 0.1617917 0.1964355 Separated 257 0.0000033
0.5991270 0.1735119 0.2273610 Married 44 0.0000033

The counterfactual version for the composition is:

cfact_i_glm_1
       Married Never-married  Separated
[1,] 0.2114987     0.7438497 0.04465168

The counterfactual categorical value would thus be:

colnames(cfact_i_glm_1)[which.max(cfact_i_glm_1)]
[1] "Never-married"

For comparison, using Gaussian OT, the counterfactual would be:

transp_glm_1 |> slice(i)
# A tibble: 1 × 3
  Married `Never-married` Separated
    <dbl>           <dbl>     <dbl>
1   0.250           0.709    0.0412