rm(list = ls())
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))
library(kknn)
tidymodels_prefer()
<- read_rds("./datasets/all_plays.rds")
all_plays
set.seed(20220520)
<- initial_split(all_plays, 0.75, strata = play_type)
split_pbp
<- training(split_pbp)
train_data <- testing(split_pbp) test_data
47 workflow实现多模型比较
注意
这部分内容主要是几个综合性的机器学习和预测建模R包的介绍,更多的信息,可参考机器学习合集
前面给大家介绍了使用tidymodels
搞定二分类资料的模型评价和比较。
简介的语法、统一的格式、优雅的操作,让人欲罢不能!
但是太费事儿了,同样的流程来了4遍,那要是选择10个模型,就得来10遍!无聊,非常的无聊。
所以个大家介绍简便方法,不用重复写代码,一次搞定多个模型!使用tidymodels
中的workflow
即可。
这个工作流
的概念我们也介绍过:tidymodels工作流:workflow
47.1 加载数据和R包
首先还是加载数据和R包,和前面的一模一样的操作,数据也没变。
47.2 数据预处理
<- recipe(play_type ~ ., data = train_data) %>%
pbp_rec step_rm(half_seconds_remaining,yards_gained, game_id) %>%
step_string2factor(posteam, defteam) %>%
step_corr(all_numeric(), threshold = 0.7) %>%
step_center(all_numeric()) %>%
step_zv(all_predictors())
47.3 选择模型
直接选择4个模型,你想选几个都是可以的。
<- logistic_reg(mode = "classification",engine = "glm")
lm_mod <- nearest_neighbor(mode = "classification", engine = "kknn")
knn_mod <- rand_forest(mode = "classification", engine = "ranger")
rf_mod <- decision_tree(mode = "classification",engine = "rpart") tree_mod
47.4 选择重抽样方法
set.seed(20220520)
<- vfold_cv(train_data, v = 10)
folds
folds## # 10-fold cross-validation
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [62082/6899]> Fold01
## 2 <split [62083/6898]> Fold02
## 3 <split [62083/6898]> Fold03
## 4 <split [62083/6898]> Fold04
## 5 <split [62083/6898]> Fold05
## 6 <split [62083/6898]> Fold06
## 7 <split [62083/6898]> Fold07
## 8 <split [62083/6898]> Fold08
## 9 <split [62083/6898]> Fold09
## 10 <split [62083/6898]> Fold10
47.5 构建workflow
这一步就是不用重复写代码的关键,把所有模型和数据预处理步骤自动连接起来。
library(workflowsets)
<- workflow_set(list(rec = pbp_rec),
four_mods list(lm = lm_mod,
knn = knn_mod,
rf = rf_mod,
tree = tree_mod
),cross = T
)
four_mods## # A workflow set/tibble: 4 × 4
## wflow_id info option result
## <chr> <list> <list> <list>
## 1 rec_lm <tibble [1 × 4]> <opts[0]> <list [0]>
## 2 rec_knn <tibble [1 × 4]> <opts[0]> <list [0]>
## 3 rec_rf <tibble [1 × 4]> <opts[0]> <list [0]>
## 4 rec_tree <tibble [1 × 4]> <opts[0]> <list [0]>
47.6 运行模型
首先是一些运行过程中的参数设置:
<- control_resamples(save_pred = T, verbose = T) keep_pred
然后就是运行4个模型(目前一直是在训练集中),我们给它加速一下:
library(doParallel)
<- makePSOCKcluster(12) # 加速,用12个线程
cl registerDoParallel(cl)
<- four_mods %>%
four_fits workflow_map("fit_resamples",
seed = 0520,
verbose = T,
resamples = folds,
control = keep_pred
)
1 of 4 resampling: rec_lm
i 1 of 4 resampling: rec_lm (26.6s)
✔ 2 of 4 resampling: rec_knn
i 2 of 4 resampling: rec_knn (3m 44.1s)
✔ 3 of 4 resampling: rec_rf
i 3 of 4 resampling: rec_rf (1m 10.9s)
✔ 4 of 4 resampling: rec_tree
i 4 of 4 resampling: rec_tree (4.5s)
✔
#saveRDS(four_fits,file="datasets/four_fits.rds")
stopCluster(cl)
four_fits
需要很长时间!大家笔记本如果内存不够可能会失败哦~
47.7 查看结果
查看模型在训练集中的表现:
collect_metrics(four_fits)
## # A tibble: 8 × 9
## wflow_id .config preproc model .metric .estimator mean n std_err
## <chr> <chr> <chr> <chr> <chr> <chr> <dbl> <int> <dbl>
## 1 rec_lm Preprocessor1_M… recipe logi… accura… binary 0.724 10 1.91e-3
## 2 rec_lm Preprocessor1_M… recipe logi… roc_auc binary 0.781 10 1.88e-3
## 3 rec_knn Preprocessor1_M… recipe near… accura… binary 0.671 10 7.31e-4
## 4 rec_knn Preprocessor1_M… recipe near… roc_auc binary 0.716 10 1.28e-3
## 5 rec_rf Preprocessor1_M… recipe rand… accura… binary 0.732 10 1.48e-3
## 6 rec_rf Preprocessor1_M… recipe rand… roc_auc binary 0.799 10 1.90e-3
## 7 rec_tree Preprocessor1_M… recipe deci… accura… binary 0.720 10 1.97e-3
## 8 rec_tree Preprocessor1_M… recipe deci… roc_auc binary 0.704 10 2.01e-3
查看每一个预测结果,这个就不运行了,毕竟好几万行,太多了。。。
collect_predictions(four_fits)
47.8 可视化结果
直接可视化4个模型的结果,感觉比ROC曲线更好看,还给出了可信区间。
这个图可以自己用ggplot2
语法修改。
%>% autoplot(metric = "roc_auc")+theme_bw() four_fits
47.9 选择最好的模型用于测试集
选择表现最好的应用于测试集:
<- last_fit(rf_mod,pbp_rec,split_pbp)
rand_res saveRDS(rand_res,file = "./datasets/rand_res.rds")
查看在测试集的模型表现:
<- readRDS(file = "./datasets/rand_res.rds") rand_res
collect_metrics(rand_res) # test 中的模型表现
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.731 Preprocessor1_Model1
## 2 roc_auc binary 0.799 Preprocessor1_Model1
使用其他指标查看模型表现:
<- metric_set(accuracy, mcc, f_meas, j_index)
metricsets
collect_predictions(rand_res) %>%
metricsets(truth = play_type, estimate = .pred_class)
## # A tibble: 4 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.731
## 2 mcc binary 0.440
## 3 f_meas binary 0.774
## 4 j_index binary 0.438
可视化结果,喜闻乐见的混淆矩阵:
collect_predictions(rand_res) %>%
conf_mat(play_type,.pred_class) %>%
autoplot()
喜闻乐见的ROC曲线:
collect_predictions(rand_res) %>%
roc_curve(play_type,.pred_pass) %>%
autoplot()
还有非常多曲线和评价指标可选,大家可以看我之前的介绍推文~