fit_best()
takes results from tuning many models and fits the workflow
configuration associated with the best performance to the training set.
Usage
# S3 method for workflow_set
fit_best(x, metric = NULL, eval_time = NULL, ...)
Arguments
- x
A
workflow_set
object that has been evaluated withworkflow_map()
. Note that the workflow set must have been fitted with the control optionsave_workflow = TRUE
.- metric
A character string giving the metric to rank results by.
- eval_time
A single numeric time point where dynamic event time metrics should be chosen (e.g., the time-dependent ROC curve, etc). The values should be consistent with the values used to create
x
. TheNULL
default will automatically use the first evaluation time used byx
.- ...
Additional options to pass to tune::fit_best.
Details
This function is a shortcut for the steps needed to fit the
numerically optimal configuration in a fitted workflow set.
The function ranks results, extracts the tuning result pertaining
to the best result, and then again calls fit_best()
(itself a
wrapper) on the tuning result containing the best result.
In pseudocode:
Note
The package supplies two pre-generated workflow sets, two_class_set
and chi_features_set
, and associated sets of model fits
two_class_res
and chi_features_res
.
The two_class_*
objects are based on a binary classification problem
using the two_class_dat
data from the modeldata package. The six
models utilize either a bare formula or a basic recipe utilizing
recipes::step_YeoJohnson()
as a preprocessor, and a decision tree,
logistic regression, or MARS model specification. See ?two_class_set
for source code.
The chi_features_*
objects are based on a regression problem using the
Chicago
data from the modeldata package. Each of the three models
utilize a linear regression model specification, with three different
recipes of varying complexity. The objects are meant to approximate the
sequence of models built in Section 1.3 of Kuhn and Johnson (2019). See
?chi_features_set
for source code.
Examples
library(tune)
library(modeldata)
library(rsample)
data(Chicago)
Chicago <- Chicago[1:1195,]
time_val_split <-
sliding_period(
Chicago,
date,
"month",
lookback = 38,
assess_stop = 1
)
chi_features_set
#> # A workflow set/tibble: 3 × 4
#> wflow_id info option result
#> <chr> <list> <list> <list>
#> 1 date_lm <tibble [1 × 4]> <opts[0]> <list [0]>
#> 2 plus_holidays_lm <tibble [1 × 4]> <opts[0]> <list [0]>
#> 3 plus_pca_lm <tibble [1 × 4]> <opts[0]> <list [0]>
chi_features_res_new <-
chi_features_set %>%
# note: must set `save_workflow = TRUE` to use `fit_best()`
option_add(control = control_grid(save_workflow = TRUE)) %>%
# evaluate with resamples
workflow_map(resamples = time_val_split, grid = 21, seed = 1, verbose = TRUE)
#> i No tuning parameters. `fit_resamples()` will be attempted
#> i 1 of 3 resampling: date_lm
#> → A | warning: prediction from rank-deficient fit; consider predict(., rankdeficient="NA")
#> There were issues with some computations A: x1
#> There were issues with some computations A: x1
#>
#> ✔ 1 of 3 resampling: date_lm (241ms)
#> i No tuning parameters. `fit_resamples()` will be attempted
#> i 2 of 3 resampling: plus_holidays_lm
#> → A | warning: prediction from rank-deficient fit; consider predict(., rankdeficient="NA")
#> ✔ 2 of 3 resampling: plus_holidays_lm (205ms)
#> i 3 of 3 tuning: plus_pca_lm
#> → A | warning: prediction from rank-deficient fit; consider predict(., rankdeficient="NA")
#> ✔ 3 of 3 tuning: plus_pca_lm (886ms)
chi_features_res_new
#> # A workflow set/tibble: 3 × 4
#> wflow_id info option result
#> <chr> <list> <list> <list>
#> 1 date_lm <tibble [1 × 4]> <opts[3]> <rsmp[+]>
#> 2 plus_holidays_lm <tibble [1 × 4]> <opts[3]> <rsmp[+]>
#> 3 plus_pca_lm <tibble [1 × 4]> <opts[3]> <tune[+]>
# sort models by performance metrics
rank_results(chi_features_res_new)
#> # A tibble: 12 × 9
#> wflow_id .config .metric mean std_err n preprocessor model rank
#> <chr> <chr> <chr> <dbl> <dbl> <int> <chr> <chr> <int>
#> 1 plus_pca_… Prepro… rmse 0.586 NA 1 recipe line… 1
#> 2 plus_pca_… Prepro… rsq 0.989 NA 1 recipe line… 1
#> 3 plus_pca_… Prepro… rmse 0.590 NA 1 recipe line… 2
#> 4 plus_pca_… Prepro… rsq 0.988 NA 1 recipe line… 2
#> 5 plus_pca_… Prepro… rmse 0.591 NA 1 recipe line… 3
#> 6 plus_pca_… Prepro… rsq 0.988 NA 1 recipe line… 3
#> 7 plus_pca_… Prepro… rmse 0.594 NA 1 recipe line… 4
#> 8 plus_pca_… Prepro… rsq 0.989 NA 1 recipe line… 4
#> 9 plus_holi… Prepro… rmse 0.646 NA 1 recipe line… 5
#> 10 plus_holi… Prepro… rsq 0.986 NA 1 recipe line… 5
#> 11 date_lm Prepro… rmse 0.733 NA 1 recipe line… 6
#> 12 date_lm Prepro… rsq 0.982 NA 1 recipe line… 6
# fit the numerically optimal configuration to the training set
chi_features_wf <- fit_best(chi_features_res_new)
chi_features_wf
#> ══ Workflow [trained] ════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#>
#> ── Preprocessor ──────────────────────────────────────────────────────────
#> 5 Recipe Steps
#>
#> • step_date()
#> • step_holiday()
#> • step_dummy()
#> • step_zv()
#> • step_pca()
#>
#> ── Model ─────────────────────────────────────────────────────────────────
#>
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#>
#> Coefficients:
#> (Intercept) temp_min temp
#> 5.067e+02 -4.811e-04 6.885e-02
#> temp_max temp_change dew
#> 9.511e-04 NA -5.110e-02
#> humidity pressure pressure_change
#> 2.516e-02 6.921e-01 2.230e-02
#> wind wind_max gust
#> -1.642e-02 1.409e-04 3.146e-03
#> gust_max percip percip_max
#> 7.870e-03 -7.111e+00 2.199e-01
#> weather_rain weather_snow weather_cloud
#> -6.168e-01 -2.689e-01 -9.951e-02
#> weather_storm Blackhawks_Away Blackhawks_Home
#> 2.603e-01 -1.245e-01 -1.114e-01
#> Bulls_Away Bulls_Home Bears_Away
#> 9.407e-02 1.833e-01 3.306e-01
#> Bears_Home WhiteSox_Away WhiteSox_Home
#> 3.531e-01 -5.198e-01 NA
#> Cubs_Away Cubs_Home date_year
#> NA NA -2.638e-01
#> date_LaborDay date_NewYearsDay date_ChristmasDay
#> 5.166e-01 -1.275e+01 -1.308e+01
#> date_dow_Mon date_dow_Tue date_dow_Wed
#> 1.232e+01 1.345e+01 1.348e+01
#> date_dow_Thu date_dow_Fri date_dow_Sat
#> 1.325e+01 1.281e+01 9.855e-01
#> date_month_Feb date_month_Mar date_month_Apr
#> 4.218e-02 3.897e-01 5.472e-01
#> date_month_May date_month_Jun date_month_Jul
#> 2.842e-01 9.032e-01 3.897e-01
#> date_month_Aug date_month_Sep date_month_Oct
#> 4.855e-01 1.588e-01 6.197e-01
#> date_month_Nov date_month_Dec PC1
#> -4.350e-01 -8.359e-01 2.979e-02
#> PC2 PC3
#> 1.225e-01 -1.722e-01
#>
# to select optimal value based on a specific metric:
fit_best(chi_features_res_new, metric = "rmse")
#> ══ Workflow [trained] ════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#>
#> ── Preprocessor ──────────────────────────────────────────────────────────
#> 5 Recipe Steps
#>
#> • step_date()
#> • step_holiday()
#> • step_dummy()
#> • step_zv()
#> • step_pca()
#>
#> ── Model ─────────────────────────────────────────────────────────────────
#>
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#>
#> Coefficients:
#> (Intercept) temp_min temp
#> 5.067e+02 -4.811e-04 6.885e-02
#> temp_max temp_change dew
#> 9.511e-04 NA -5.110e-02
#> humidity pressure pressure_change
#> 2.516e-02 6.921e-01 2.230e-02
#> wind wind_max gust
#> -1.642e-02 1.409e-04 3.146e-03
#> gust_max percip percip_max
#> 7.870e-03 -7.111e+00 2.199e-01
#> weather_rain weather_snow weather_cloud
#> -6.168e-01 -2.689e-01 -9.951e-02
#> weather_storm Blackhawks_Away Blackhawks_Home
#> 2.603e-01 -1.245e-01 -1.114e-01
#> Bulls_Away Bulls_Home Bears_Away
#> 9.407e-02 1.833e-01 3.306e-01
#> Bears_Home WhiteSox_Away WhiteSox_Home
#> 3.531e-01 -5.198e-01 NA
#> Cubs_Away Cubs_Home date_year
#> NA NA -2.638e-01
#> date_LaborDay date_NewYearsDay date_ChristmasDay
#> 5.166e-01 -1.275e+01 -1.308e+01
#> date_dow_Mon date_dow_Tue date_dow_Wed
#> 1.232e+01 1.345e+01 1.348e+01
#> date_dow_Thu date_dow_Fri date_dow_Sat
#> 1.325e+01 1.281e+01 9.855e-01
#> date_month_Feb date_month_Mar date_month_Apr
#> 4.218e-02 3.897e-01 5.472e-01
#> date_month_May date_month_Jun date_month_Jul
#> 2.842e-01 9.032e-01 3.897e-01
#> date_month_Aug date_month_Sep date_month_Oct
#> 4.855e-01 1.588e-01 6.197e-01
#> date_month_Nov date_month_Dec PC1
#> -4.350e-01 -8.359e-01 2.979e-02
#> PC2 PC3
#> 1.225e-01 -1.722e-01
#>