6  Grid Search

Display the required packages.
library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.0     ✔ stringr   1.5.1
✔ ggplot2   3.5.1     ✔ tibble    3.2.1
✔ lubridate 1.9.3     ✔ tidyr     1.3.1
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
Display the required packages.
library(randomForest)
randomForest 4.7-1.1
Type rfNews() to see new features/changes/bug fixes.

Attaching package: 'randomForest'

The following object is masked from 'package:dplyr':

    combine

The following object is masked from 'package:ggplot2':

    margin
Display the required packages.
library(future)
library(binom)
library(locfit)
locfit 1.5-9.9   2024-03-01

Attaching package: 'locfit'

The following object is masked from 'package:purrr':

    none

6.1 Data

We use data obtained from UCI Yeh (2016), presenting research customers’ default payments in Taiwan. This dataset contains \(n = 30,000\) instances and 23 numeric features. The outcome variable, corresponding to the observed default payment in next month, is positive in 22.12% of cases.

Let us load the data.

data_credit <- read.csv2(
  "data/default_credit.csv", 
  header = TRUE, skip = 1
)
colnames(data_credit)[25] <- "d"
data_credit <- data_credit |> select(-ID)
data_credit <- data_credit |> as_tibble( )
data_credit <- data_credit |> 
  mutate(
    across(
      c("SEX", "EDUCATION", "MARRIAGE", "PAY_0", "PAY_2", "PAY_3", 
        "PAY_4", "PAY_5", "PAY_6"), 
      ~as.factor(.x)
    )
  )

Following the methodology outlined in Subasi A. (2019), we employ the Synthetic Minority Over-sampling Technique (SMOTE) at a rate of 200% to rebalance the data. To do so, we use the smote() function from {performanceEstimation}.

library(performanceEstimation)
# need of a factor variable to apply SMOTE function
set.seed(123)
data_credit$d <- as.factor(data_credit$d)
new_df <- performanceEstimation::smote(d ~ ., data_credit, perc.over = 2)

data_credit <- new_df
data_credit$d <- as.numeric(data_credit$d) # change in the labels
data_credit <- data_credit |> mutate(d = ifelse(d == 1, 1, 2))
data_credit <- data_credit |> mutate(d = d-1)

Let us save this dataset for later use in the simulations.

write.csv(
  data_credit, 
  file = "data/data_credit_smote.csv",
  row.names = FALSE
)

We split the data into two sets:

  1. a training set on which we will train random forests (with 50% of observations)
  2. the remaining data which will be further split into a calibration set and a test set.
set.seed(1234)
ind_train <- sample(1:nrow(data_credit), size = .5*nrow(data_credit))
tb_train <- data_credit |> slice(ind_train)
tb_rest <- data_credit |> slice(-ind_train)

Let us save those files for later use in the smimulations.

write.csv(tb_train, "data/data_credit_smote_train.csv", row.names=FALSE)
write.csv(tb_rest, "data/data_credit_smote_rest.csv", row.names=FALSE)

6.2 Grid With Hyperparameters

We conduct a grid search to find the set of hyperparameters that optimize a criterion:

  1. the out-of-bag mean squared error (MSE) for the regressor
  2. the error rate for the classifier.
grid_params <- 
  expand_grid(
    num_trees = c(100,300, 500),
    mtry = seq(1,(ncol(tb_train)/2)),
    nodesize = c(5, 10, 15, 20)
  )
grid_params
# A tibble: 144 × 3
   num_trees  mtry nodesize
       <dbl> <int>    <dbl>
 1       100     1        5
 2       100     1       10
 3       100     1       15
 4       100     1       20
 5       100     2        5
 6       100     2       10
 7       100     2       15
 8       100     2       20
 9       100     3        5
10       100     3       10
# ℹ 134 more rows

6.3 Helper Functions

Let us define some helper functions to compute the MSE, the accuracy (with a probability threshold of .5) and the AUC (using {pROC}).

#' Mean Squared Error
mse_function <- function(pred, obs) mean((pred - obs)^2)

#' Accuracy with threshold 0.5
accuracy_function <- function(pred, obs, threshold = 0.5) {
  mean((as.numeric(as.character(pred)) > 0.5) == obs)
}

#' AUC (no threshold)
auc_function <- function(pred, obs){
  auc(obs, pred)
}

6.4 Estimations

Let us now go through the hyperparameter grid.

6.4.1 Regression

nb_cores <- future::availableCores()-1
plan(multisession, workers = nb_cores)
progressr::with_progress({
  p <- progressr::progressor(steps = nrow(grid_params))
  mse_oob_rf_reg <- furrr::future_map(
    .x = 1:nrow(grid_params),
    .f = ~{
      # Estim random forest and get the evaluation metric
      rf <- randomForest(
        d ~ ., 
        data = tb_train, 
        mtry = grid_params$mtry[.x], 
        nodesize = grid_params$nodesize[.x], 
        ntree = grid_params$num_trees[.x],
        keep.inbag = TRUE
      )
      
      num_trees <- grid_params$num_trees[.x]
      
      # Identify out of bag observations in each tree
      out_of_bag <- map(.x = 1:nrow(tb_train), .f = ~which(rf[["inbag"]][.x,] == 0))
      rf_pred_all <- predict(rf, tb_train,
                             predict.all = TRUE,
                             type = "response")$individual
      rf_pred <- unlist(map(.x = 1:nrow(tb_train), .f = ~mean(rf_pred_all[.x,out_of_bag[[.x]]])))
      
      oob_err <- mse_function(pred = rf_pred, obs = tb_train |> pull(d))
      mse_oob <- oob_err
      
      # Progress bar
      p()
      
      # Return object:
      tibble(
        mtry = grid_params$mtry[.x], 
        nodesize = grid_params$nodesize[.x], 
        num_trees = grid_params$num_trees[.x],
        mse_oob = mse_oob
      )
    },
    .options = furrr::furrr_options(seed = FALSE)
  )
})

Let us order the computed out-of-bag MSE by ascending values:

best_params_rf_reg <- 
  mse_oob_rf_reg |> list_rbind() |> 
  arrange(mse_oob)

We save the results for later use.

write.csv(
  best_params_rf_reg, 
  file = "output/best_params_rf_reg.csv",
  row.names = FALSE
)
best_params_rf_reg
    mtry nodesize num_trees   mse_oob
1     12        5       500 0.1161355
2     11        5       500 0.1167286
3     10        5       500 0.1167594
4     11        5       300 0.1168635
5     10        5       300 0.1170744
6      9        5       300 0.1172085
7      9        5       500 0.1172791
8     12        5       300 0.1173561
9      8        5       500 0.1175019
10     8        5       300 0.1181500
11     7        5       500 0.1182748
12     6        5       500 0.1184314
13     7        5       300 0.1187873
14     6        5       300 0.1194305
15     5        5       500 0.1195314
16    11        5       100 0.1195434
17    12        5       100 0.1198150
18     9        5       100 0.1198567
19    10        5       100 0.1201343
20     4        5       500 0.1206271
21     7        5       100 0.1206874
22     5        5       300 0.1206969
23    12       10       500 0.1208994
24    11       10       500 0.1210198
25    11       10       300 0.1211552
26     8        5       100 0.1211788
27    12       10       300 0.1212391
28    10       10       500 0.1213147
29     4        5       300 0.1214632
30     9       10       500 0.1217760
31    10       10       300 0.1220816
32     9       10       300 0.1222290
33     6        5       100 0.1222443
34     8       10       300 0.1222482
35     8       10       500 0.1225779
36     7       10       500 0.1226872
37     3        5       500 0.1227032
38     5        5       100 0.1230636
39     3        5       300 0.1230863
40    11       10       100 0.1232123
41    12       10       100 0.1234719
42     4        5       100 0.1234766
43     7       10       300 0.1236268
44     6       10       500 0.1236749
45    10       10       100 0.1238468
46     6       10       300 0.1241986
47    12       15       500 0.1243726
48     8       10       100 0.1245897
49     5       10       500 0.1248403
50    11       15       500 0.1248550
51     9       10       100 0.1248996
52    11       15       300 0.1250312
53    12       15       300 0.1251028
54    10       15       500 0.1252682
55     5       10       300 0.1253665
56     7       10       100 0.1255997
57     3        5       100 0.1257361
58     9       15       500 0.1258641
59     6       10       100 0.1259584
60    10       15       300 0.1259758
61     4       10       500 0.1263621
62     9       15       300 0.1264203
63    12       15       100 0.1265025
64     8       15       500 0.1265721
65     2        5       500 0.1266220
66     8       15       300 0.1267140
67     4       10       300 0.1267876
68    10       15       100 0.1268396
69    11       15       100 0.1269556
70     2        5       300 0.1271676
71     7       15       500 0.1273421
72     7       15       300 0.1276667
73     5       10       100 0.1279446
74    12       20       500 0.1280429
75     9       15       100 0.1280750
76    12       20       300 0.1281329
77     6       15       500 0.1281506
78    11       20       300 0.1281588
79     3       10       500 0.1281836
80    10       20       500 0.1282266
81    11       20       500 0.1282600
82     6       15       300 0.1282876
83     3       10       300 0.1285083
84     8       15       100 0.1289850
85     5       15       500 0.1290990
86     9       20       500 0.1291231
87     5       15       300 0.1291516
88    10       20       300 0.1291917
89     4       10       100 0.1292167
90     2        5       100 0.1296366
91     8       20       500 0.1298449
92     9       20       300 0.1299142
93     7       15       100 0.1302080
94     6       15       100 0.1302180
95     8       20       300 0.1302658
96    11       20       100 0.1303368
97    12       20       100 0.1303956
98     7       20       500 0.1305237
99     4       15       500 0.1306045
100    7       20       300 0.1307034
101    4       15       300 0.1314075
102    3       10       100 0.1314194
103   10       20       100 0.1314446
104    9       20       100 0.1314797
105    6       20       500 0.1316348
106    6       20       300 0.1318650
107    5       15       100 0.1322313
108    2       10       500 0.1323110
109    8       20       100 0.1323435
110    7       20       100 0.1327727
111    2       10       300 0.1328086
112    5       20       500 0.1329164
113    4       15       100 0.1329842
114    3       15       500 0.1330115
115    5       20       300 0.1330483
116    3       15       300 0.1333896
117    6       20       100 0.1334507
118    4       20       500 0.1343606
119    5       20       100 0.1346723
120    4       20       300 0.1348001
121    3       15       100 0.1352052
122    2       10       100 0.1355645
123    3       20       500 0.1366326
124    4       20       100 0.1369212
125    3       20       300 0.1369492
126    2       15       500 0.1370411
127    2       15       300 0.1371923
128    3       20       100 0.1389632
129    2       15       100 0.1396311
130    2       20       500 0.1408191
131    2       20       300 0.1411691
132    2       20       100 0.1429509
133    1        5       300 0.1517564
134    1        5       500 0.1517581
135    1        5       100 0.1533601
136    1       10       500 0.1552974
137    1       10       300 0.1556152
138    1       10       100 0.1567145
139    1       15       300 0.1579995
140    1       15       500 0.1585141
141    1       20       500 0.1600574
142    1       20       300 0.1609307
143    1       15       100 0.1613020
144    1       20       100 0.1623087

6.4.2 Classification

nb_cores <- future::availableCores()-1
plan(multisession, workers = nb_cores)
progressr::with_progress({
  p <- progressr::progressor(steps = nrow(grid_params))
  mse_oob_rf_classif <- furrr::future_map(
    .x = 1:nrow(grid_params),
    .f = ~{
      # Estim random forest and get the evaluation metric
      rf <- randomForest(
        as.factor(d) ~ ., 
        data = tb_train, 
        mtry = grid_params$mtry[.x], 
        nodesize = grid_params$nodesize[.x], 
        ntree = grid_params$num_trees[.x],
        keep.inbag = TRUE
      )
      
      num_trees <- grid_params$num_trees[.x]
      
      # Identify out of bag observations in each tree
      err_oob <- rf$err.rate[num_trees,1]
      
      p()
      
      # Return object:
      tibble(
        mtry = grid_params$mtry[.x], 
        nodesize = grid_params$nodesize[.x], 
        num_trees = grid_params$num_trees[.x],
        err_oob = err_oob
      )
    },
    .options = furrr::furrr_options(seed = FALSE)
  )
})

Let us order the computed out-of-bag error rate by ascending values:

best_params_rf_classif <- 
  mse_oob_rf_classif |> list_rbind() |> 
  arrange(err_oob)

We save the results for later use.

write.csv(
  best_params_rf_classif,
  file = "output/best_params_rf_classif.csv", 
  row.names = FALSE
  )
best_params_rf_classif
    mtry nodesize num_trees   err_oob
1     12        5       300 0.1567209
2     11        5       500 0.1580126
3     10        5       500 0.1580556
4      9        5       300 0.1585723
5     10        5       300 0.1586584
6     11        5       300 0.1591751
7      9        5       500 0.1595195
8     12        5       500 0.1596056
9      7        5       500 0.1596917
10     8        5       500 0.1600792
11     6        5       300 0.1614570
12     6        5       500 0.1621459
13     8        5       300 0.1627056
14    12        5       100 0.1631361
15     7        5       300 0.1634375
16     5        5       500 0.1643417
17    10        5       100 0.1652458
18     9        5       100 0.1653320
19    11       10       300 0.1661069
20     5        5       300 0.1664944
21     6        5       100 0.1665375
22     4        5       500 0.1666236
23    12       10       500 0.1672264
24    11        5       100 0.1673125
25     4        5       300 0.1673555
26    10       10       500 0.1677000
27     7        5       100 0.1677430
28    11       10       500 0.1677861
29    12       10       300 0.1682167
30     8        5       100 0.1685180
31     9       10       300 0.1692930
32     8       10       500 0.1692930
33     9       10       500 0.1693791
34     3        5       500 0.1698958
35     5        5       100 0.1706708
36     8       10       300 0.1712736
37     3        5       300 0.1715319
38    10       10       300 0.1717902
39     7       10       500 0.1721777
40     7       10       300 0.1730388
41    11       10       100 0.1731680
42     6       10       300 0.1731680
43     6       10       500 0.1732972
44    12       10       100 0.1736416
45     4        5       100 0.1742444
46    10       10       100 0.1749333
47    11       15       500 0.1754930
48    11       15       300 0.1758805
49    12       15       300 0.1760958
50     9       15       500 0.1762249
51    10       15       500 0.1762680
52     9       10       100 0.1763110
53    12       15       500 0.1763541
54     8       10       100 0.1766124
55     5       10       500 0.1766555
56     5       10       300 0.1770860
57    10       15       300 0.1772152
58     2        5       500 0.1777749
59     3        5       100 0.1778180
60     4       10       500 0.1778180
61     7       10       100 0.1779041
62     6       10       100 0.1781624
63     9       15       300 0.1785499
64     2        5       300 0.1794110
65     8       15       300 0.1799707
66     8       15       500 0.1800138
67     7       15       500 0.1801429
68     4       10       300 0.1804443
69     7       15       300 0.1809610
70    12       15       100 0.1820804
71    12       20       500 0.1821235
72     5       10       100 0.1822957
73    11       15       100 0.1823818
74     4       10       100 0.1826401
75    10       15       100 0.1827693
76     6       15       300 0.1829846
77     3       10       500 0.1831568
78     2        5       100 0.1833290
79    12       20       300 0.1835443
80    11       20       300 0.1835874
81     6       15       500 0.1838887
82    11       20       500 0.1838887
83     5       15       500 0.1840610
84     8       15       100 0.1841471
85     9       20       300 0.1843193
86     7       15       100 0.1844054
87     9       20       500 0.1844915
88     3       10       300 0.1850512
89    10       20       300 0.1851373
90     9       15       100 0.1852665
91    10       20       500 0.1855679
92     4       15       500 0.1862998
93    12       20       100 0.1864721
94     8       20       500 0.1865582
95    11       20       100 0.1867304
96     8       20       300 0.1873332
97     5       15       300 0.1876345
98     6       15       100 0.1877637
99     7       20       500 0.1881512
100    7       20       300 0.1883665
101    9       20       100 0.1884095
102    3       10       100 0.1885818
103    2       10       500 0.1897443
104    4       15       300 0.1900887
105   10       20       100 0.1901317
106    8       20       100 0.1905192
107    6       20       500 0.1905192
108    5       15       100 0.1906484
109    4       15       100 0.1915095
110    7       20       100 0.1919401
111    6       20       300 0.1921553
112    3       15       500 0.1928012
113    2       10       300 0.1929303
114    3       15       300 0.1930164
115    5       20       500 0.1933178
116    5       20       300 0.1940067
117    6       20       100 0.1955136
118    4       20       500 0.1959442
119    3       15       100 0.1970206
120    4       20       300 0.1973650
121    2       10       100 0.1974942
122    5       20       100 0.1983983
123    4       20       100 0.1990011
124    2       15       300 0.2003789
125    3       20       500 0.2010678
126    2       15       500 0.2013692
127    3       20       300 0.2017567
128    3       20       100 0.2037802
129    2       15       100 0.2047275
130    2       20       300 0.2083441
131    2       20       500 0.2083441
132    2       20       100 0.2106260
133    1        5       500 0.2295703
134    1        5       100 0.2300870
135    1        5       300 0.2349092
136    1       10       300 0.2349522
137    1       10       500 0.2386980
138    1       10       100 0.2398174
139    1       15       300 0.2410230
140    1       15       500 0.2412813
141    1       15       100 0.2456730
142    1       20       500 0.2462757
143    1       20       100 0.2473952
144    1       20       300 0.2490743