22  XGBoost

xgboost(Extreme Gradient Boosting),极限梯度提升,是基于梯度提升树(gradient boosting decision tree,GBDT)实现的集成(ensemble)算法,本质上还是一种提升(boosting)算法,但是把速度和效率提升到最强,所以加了Extreme

xgboost的一些特性包括:

关于更多它的背景知识,大家可以参考集成算法类型一章。下面我们看下它在R语言中的使用。

22.1 准备数据

先用自带数据演示一下简单的使用方法。

首先是加载数据,这是一个二分类数据,其中label是结果变量,数值型,使用0和1表示,其余是预测变量。

library(xgboost)

rm(list = ls())
# load data
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train <- agaricus.train
test <- agaricus.test

str(train)
## List of 2
##  $ data :Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
##   .. ..@ i       : int [1:143286] 2 6 8 11 18 20 21 24 28 32 ...
##   .. ..@ p       : int [1:127] 0 369 372 3306 5845 6489 6513 8380 8384 10991 ...
##   .. ..@ Dim     : int [1:2] 6513 126
##   .. ..@ Dimnames:List of 2
##   .. .. ..$ : NULL
##   .. .. ..$ : chr [1:126] "cap-shape=bell" "cap-shape=conical" "cap-shape=convex" "cap-shape=flat" ...
##   .. ..@ x       : num [1:143286] 1 1 1 1 1 1 1 1 1 1 ...
##   .. ..@ factors : list()
##  $ label: num [1:6513] 1 0 0 1 0 0 0 1 0 0 ...
str(test)
## List of 2
##  $ data :Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
##   .. ..@ i       : int [1:35442] 0 2 7 11 13 16 20 22 27 31 ...
##   .. ..@ p       : int [1:127] 0 83 84 806 1419 1603 1611 2064 2064 2701 ...
##   .. ..@ Dim     : int [1:2] 1611 126
##   .. ..@ Dimnames:List of 2
##   .. .. ..$ : NULL
##   .. .. ..$ : chr [1:126] "cap-shape=bell" "cap-shape=conical" "cap-shape=convex" "cap-shape=flat" ...
##   .. ..@ x       : num [1:35442] 1 1 1 1 1 1 1 1 1 1 ...
##   .. ..@ factors : list()
##  $ label: num [1:1611] 0 1 0 0 0 0 1 0 1 0 ...

class(train$data)
## [1] "dgCMatrix"
## attr(,"package")
## [1] "Matrix"

xgboost对数据格式是有要求的,可以看到traintest都是列表,其中包含了预测变量data和结果变量label,其中label是使用0和1表示的。data部分是稀疏矩阵的形式(dgCMatrix)。

# 查看数据维度,6513行,126列
dim(train$data)
## [1] 6513  126

table(train$label)
## 
##    0    1 
## 3373 3140

22.2 拟合模型

接下来就是使用训练数据拟合模型:

model <- xgboost(data = train$data, label = train$label, 
                 max.depth = 2, # 树的最大深度
                 eta = 1, # 学习率
                 nrounds = 2,
                 nthread = 2, # 使用的CPU线程数
                 objective = "binary:logistic")
## [1]  train-logloss:0.233376 
## [2]  train-logloss:0.136658

其中的参数nround在这里表示最终模型中树的数量,objective是目标函数。

除了使用列表传入数据,也支持R语言中的密集矩阵(matrix),比如:

model <- xgboost(data = as.matrix(train$data), label = train$label, 
                 max.depth = 2, 
                 eta = 1, 
                 nrounds = 2,
                 nthread = 2, 
                 objective = "binary:logistic")
## [1]  train-logloss:0.233376 
## [2]  train-logloss:0.136658

开发者最推荐的格式还是特别为xgboost设计的xgb.DMatrix格式。

# 建立xgb.DMatrix
dtrain <- xgb.DMatrix(data = train$data, label = train$label)

xx <- xgboost(data = dtrain, # 这样就不用单独传入label了
              max.depth = 2, 
              eta = 1, 
              nrounds = 2,
              nthread = 2, 
              objective = "binary:logistic"
              )
## [1]  train-logloss:0.233376 
## [2]  train-logloss:0.136658

这个格式也是xgboost特别设计的,有助于更好更快的进行计算,还可以支持更多的信息传入。可以通过getinfo获取其中的元素:

head(getinfo(dtrain, "label"))
## [1] 1 0 0 1 0 0

22.3 新数据预测

拟合模型后,就可以对新数据进行预测了:

# predict
pred <- predict(model, test$data)
head(pred)
## [1] 0.28583017 0.92392391 0.28583017 0.28583017 0.05169873 0.92392391

range(pred)
## [1] 0.01072847 0.92392391

我们的任务是一个二分类的,但是xgboost的预测结果是概率,并不是直接的类别。我们需要自己转换一下,比如规定概率大于0.5就是类别1,小于等于0.5就是类别0(这个阈值其实也可以当做超参数调整的)。

pred_label <- ifelse(pred > 0.5,1,0)
table(pred_label)
## pred_label
##   0   1 
## 826 785

# 混淆矩阵
table(test$label, pred_label)
##    pred_label
##       0   1
##   0 813  22
##   1  13 763

全对,准确率100%。

除此之外还提供一个xgb.cv()用于实现交叉验证的建模,使用方法与xgboost()一致:

cv.res <- xgb.cv(data = train$data, label = train$label, 
                 nrounds = 2,
                 objective = "binary:logistic",
                 nfold = 10 # 交叉验证的折数
                 )
## [1]  train-logloss:0.439673+0.000200 test-logloss:0.439961+0.000873 
## [2]  train-logloss:0.299550+0.000206 test-logloss:0.299903+0.001440
min(cv.res$evaluation_log)
## [1] 0.0001998477

22.4 控制输出日志

xgboost()xgb.train()有参数verbose可以控制输出日志的多少,默认是verbose = 1,输出性能指标结果。xgboost()xgb.train()的简单封装,xgb.train()是训练xgboost模型的高级接口。

如果是0,则是没有任何输出:

xx <- xgboost(data = train$data, label = train$label, objective = "binary:logistic"
        ,nrounds = 2
        ,verbose = 0
        )

如果是2,会输出性能指标结果和其他信息(这里没显示):

xx <- xgboost(data = train$data, label = train$label, objective = "binary:logistic"
        ,nrounds = 2
        ,verbose = 2
        )
## [1]  train-logloss:0.439409 
## [2]  train-logloss:0.299260

xgb.cv()verbose参数只有TRUEFALSE

22.5 变量重要性

xgboost中变量的重要性是这样计算的:

我们如何在xgboost中定义特性的重要性?在xgboost中,每次分割都试图找到最佳特征和分割点(splitting point)来优化目标。我们可以计算每个节点上的增益,它是所选特征的贡献。最后,我们对所有的树进行研究,总结每个特征的贡献,并将其视为重要性。如果特征的数量很大,我们也可以在绘制图之前对特征进行聚类。

查看变量重要性:

importance_matrix <- xgb.importance(model = model)
importance_matrix
##                    Feature       Gain     Cover Frequency
##                     <char>      <num>     <num>     <num>
## 1:               odor=none 0.67615470 0.4978746       0.4
## 2:         stalk-root=club 0.17135376 0.1920543       0.2
## 3:       stalk-root=rooted 0.12317236 0.1638750       0.2
## 4: spore-print-color=green 0.02931918 0.1461960       0.2

可视化变量重要性:

xgb.plot.importance(importance_matrix)

或者ggplot2版本:

xgb.ggplot.importance(importance_matrix)

22.6 查看树的信息

可以把学习好的树打印出来,查看具体情况:

xgb.dump(model, with_stats = T)
##  [1] "booster[0]"                                                       
##  [2] "0:[f28<0.5] yes=1,no=2,missing=1,gain=4000.53101,cover=1628.25"   
##  [3] "1:[f55<0.5] yes=3,no=4,missing=3,gain=1158.21204,cover=924.5"     
##  [4] "3:leaf=1.71217716,cover=812"                                      
##  [5] "4:leaf=-1.70044053,cover=112.5"                                   
##  [6] "2:[f108<0.5] yes=5,no=6,missing=5,gain=198.173828,cover=703.75"   
##  [7] "5:leaf=-1.94070864,cover=690.5"                                   
##  [8] "6:leaf=1.85964918,cover=13.25"                                    
##  [9] "booster[1]"                                                       
## [10] "0:[f59<0.5] yes=1,no=2,missing=1,gain=832.544983,cover=788.852051"
## [11] "1:[f28<0.5] yes=3,no=4,missing=3,gain=569.725098,cover=768.389709"
## [12] "3:leaf=0.78471756,cover=458.936859"                               
## [13] "4:leaf=-0.968530357,cover=309.45282"                              
## [14] "2:leaf=-6.23624468,cover=20.462389"

还可以可视化树:

xgb.plot.tree(model = model)

这个图展示了2棵树的分支过程,因为我们设置了nround=2,所以结果就是只有2棵树。

如果树的数量非常多的时候,这样每棵树看过来并不是很直观,通常xgboost虽然不如随机森林需要的树多,但是几十棵总是要的,所以xgboost提供了一种能把所有的树结合在一起展示的方法。

# 多棵树展示在一起
xgb.plot.multi.trees(model = model,fill=TRUE)
## Column 2 ['No'] of item 2 is missing in item 1. Use fill=TRUE to fill with NA (NULL for list columns), or use.names=FALSE to ignore column names. use.names='check' (default from v1.12.2) emits this message and proceeds as if use.names=FALSE for  backwards compatibility. See news item 5 in v1.12.2 for options to control this message.

这幅图就是把上面那张图的信息整合到了一起,大家仔细对比下图中的数字就会发现信息是一样的哦。

除了以上方法可以检查树的信息外,还可以通过查看树的深度来检查树的结构。

bst <- xgboost(data = train$data, label = train$label, max.depth = 15,
                 eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
                 min_child_weight = 50,verbose = 0)

xgb.plot.deepness(model = bst)

这两幅图的横坐标都是树的深度,上面的图纵坐标是叶子的数量,展示了每层深度中的叶子数量。下面的图纵坐标是每片叶子的归一化之后的加权覆盖。

从图中可以看出树的深度在5之后,叶子的数量就很少了,这提示我们为了防止过拟合,可以把树的深度控制在5以内。

22.7 保存加载模型

保存加载训练好的模型:

# 保存
xgb.save(model, "xgboost.model")

# 加载
xgb.load("xgboost.model")
aa <- predict(model, test$data)

在R语言中除了直接使用xgboost这个R包实现之外,还有许多综合性的R包都可以实现xgboost算法,并支持超参数调优等更多任务,比如carettidymodelsmlr3

22.8 参数解释

xgboost本身是基于梯度提升树(GBDT)实现的集成算法,所以它的参数整体来说可以分为三个部分:集成算法本身,用于集成的弱评估器(决策树),以及应用中的其他过程。

下面是一些参数的详细介绍:

  • nrounds:最大迭代次数(最终模型中树的数量)。
  • early_stopping_rounds:一个正整数,表示在验证集中经过K次训练如果模型表现还是没有提高就停止训练。
  • print_every_n:如果verbose>0,这个参数表示每多少次迭代(多少棵树)打印一次日志信息。

paramsxgb.train()中最重要的参数了,params接受一个列表,列表内包含超多参数,这些参数主要分为3大类,也是我们调参需要重点关注的参数:

  1. 通用参数
  • booster:提升器类型,gbtree(默认)或者gblinear。多数情况下都是gbtree的效果更好,但是如果你的预测变量和结果变量呈现明显的线性关系,可能gblinear更好,但也不是绝对的,开发者建议都试一下。
  1. booster相关的参数 2.1 tree booster相关的参数
    • eta:学习率η,每棵树在最终解中的贡献,决定迭代速度,默认为0.3,范围是[0,1]。一般会选择比较小的值,比如0.01,0.001等。
    • gamma:在进行分支时所需要的最小的目标函数减少量,如果大于这个值,则继续分支,非常重要的参数,控制树的规模,默认是0,范围是[0,inf]
    • max_depth:单个树的最大深度。非常重要,通常max_depth和gamma只调一个即可
    • min_child_weight:对树进行提升时使用的最小权重,默认为1。叶子节点的二阶导数之和
    • subsample:子样本数据占整个数据的比例,也就是每次重抽样的比例,默认值为1(100%)。
    • colsample_bytree:建立树时随机抽取的特征数量,用一个比率表示,默认值为1(使用100%的特征)。比较重要
    • lambda:L2正则化的比例,默认是1,也就是lasso。
    • alpha:L1正则化的比例,默认是0。
    • … 2.2 linear booster相关的参数
  2. 任务相关的参数
  • objective:指定任务类型和目标函数,支持自定义函数,默认的有以下类型,主要是回归、分类、生存、排序等:
    • reg:squarederror:均方根误差(默认值)。
    • reg:squaredlogerror:均方根对数误差。
    • reg:logistic:logistic函数。
    • reg:pseudohubererror:Pseudo Huber损失函数。
    • binary:logistic:二分类逻辑回归,输出概率值。
    • binary:logitraw:二分类逻辑回归,输出logistic转换之前的值。
    • binary:hinge:二分类hinge loss,输出0或者1。支持向量机的损失函数
    • count:poisson:计数数据的泊松回归
    • survival:cox:右删失生存数据的cox回归,返回风险比HR。
    • survival:aft:加速失效模型。
  • base_score:叶子权重
  • eval_metric:验证集的性能指标,回归任务默认是rmse,二分类默认是错分率error。

以下是一些其他需要注意的点:

  • eta/subsample/nrounds通常并不是提高模型表现的,主要是控制调参时间的。
  • gamma是通过降低训练集的表现来防止过拟合的(让训练集和测试集的模型表现更接近),如果太大,也会降低测试集的表现
  • 先通过网格搜索调整nroundseta,然后使用gammamax_depth`看是否过拟合,再剪枝
  • xgboost可以自动处理缺失值,不需要预处理,参数missing
  • scale_pos_weight用于处理类不平衡的权重

22.9 超参数调优

下面我们使用印第安人糖尿病数据集,演示下如何对xgboost进行超参数调优,由于目前xgboost包里面并没有专门的调优参数,所以还是需要借助其他R包实现(或者自己写循环),我这里借助了caret

rm(list = ls())
library(MASS)
library(xgboost)

load(file = "datasets/pimadiabetes.rdata")

# 结果变量改成1和0表示
pimadiabetes$diabetes <- ifelse(pimadiabetes$diabetes == "pos",1,0)
str(pimadiabetes)
## 'data.frame':    768 obs. of  9 variables:
##  $ pregnant: num  6 1 8 1 0 5 3 10 2 8 ...
##  $ glucose : num  148 85 183 89 137 116 78 115 197 125 ...
##  $ pressure: num  72 66 64 66 40 ...
##  $ triceps : num  35 29 22.9 23 35 ...
##  $ insulin : num  202.2 64.6 217.1 94 168 ...
##  $ mass    : num  33.6 26.6 23.3 28.1 43.1 ...
##  $ pedigree: num  0.627 0.351 0.672 0.167 2.288 ...
##  $ age     : num  50 31 32 21 33 30 26 29 53 54 ...
##  $ diabetes: num  0 1 0 1 0 1 0 1 0 0 ...

# 按照7:3的比例划分训练集、测试集
set.seed(502)
ind <- sample(1:nrow(pimadiabetes), size = 0.7*nrow(pimadiabetes))
pima.train <- pimadiabetes[ind,]
pima.test <- pimadiabetes[-ind,]

dim(pima.train)
## [1] 537   9
dim(pima.test)
## [1] 231   9

训练集有537行,9列,测试集有231行,9列,其中diabetes列是结果变量,1表示有糖尿病,0表示没有糖尿病。

下面我们使用默认参数拟合模型,看看模型效果。顺便学习下如果准备这些参数。

注意,所有的预测变量都需要是数值型(这和我们前面介绍过的xgboost输入数据的格式有关,矩阵需要都是数值型的),所以分类变量需要进行一些转换,比如哑变量、独热编码等。

# 选择参数的值
param <- list(objective = "binary:logistic", # 二分类
              booster = "gbtree",
              eval_metric = "error",
              eta = 0.3,
              max_depth = 3,
              subsample = 1,
              colsample_bytree = 1,
              gamma = 0.5)

# 准备预测变量和结果变量
x <- as.matrix(pima.train[, 1:8])
y <- pima.train$diabetes

# 放进专用的格式中
train.mat <- xgb.DMatrix(data = x, label = y)
train.mat
## xgb.DMatrix  dim: 537 x 8  info: label  colnames: yes

这样参数和数据就都准备好了,下面开始训练即可。xgboost()xgb.train()的简单封装,xgb.train()是训练xgboost模型的高级接口。xgboost模型的参数非常多,详情参考上面的介绍。

set.seed(1)
xgb.fit <- xgb.train(params = param, 
                     data = train.mat, 
                     nrounds = 100)

有了这个结果后你可以查看变量重要性,查看每棵树的信息,得出预测类别的概率,画出ROC曲线等,详情请参考前面的部分,这里就不再重复演示了。

下面就是对这些参数进行调整,我们就使用caret进行演示。

caret作为R语言中经典的机器学习综合性R包,使用起来非常简单,我们也写过非常详细的系列教程了,公众号后台回复caret即可获取caret系列推文合集。

library(caret)

# 选择参数范围
grid <- expand.grid(nrounds = c(75, 100),
                    colsample_bytree = 1,
                    min_child_weight = 1,
                    eta = c(0.01, 0.1, 0.3),
                    gamma = c(0.5, 0.25),
                    subsample = 0.5,
                    max_depth = c(2, 3))

# 一些控制参数,重抽样方法选择5折交叉验证
cntrl <- trainControl(method = "cv",
                      number = 5,
                      verboseIter = F,
                      returnData = F,
                      returnResamp = "final")

# 开始调优
set.seed(1)
train.xgb <- train(x = pima.train[, 1:8],
                   y = pima.train$diabetes,
                   trControl = cntrl,
                   tuneGrid = grid,
                   method = "xgbTree")
## [19:59:21] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:21] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:21] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:29] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
#train.xgb # 太长不展示了

结果中给出了最优的超参数:nrounds = 75, max_depth = 2, eta = 0.1, gamma = 0.5, colsample_bytree = 1, min_child_weight = 1, subsample = 0.5。

这个结果可以探索可视化的地方非常多,比如:查看不同超参数对模型性能的影响:

plot(train.xgb)

也是支持ggplot2的。

ggplot(train.xgb)

更多方法大家可以探索我们的caret合集。

22.10 重新拟合模型

接下来就是使用最优的超参数重新拟合模型。

# 选择最优的参数值
param <- list(objective = "binary:logistic",
              booster = "gbtree",
              eval_metric = "error",
              eta = 0.1,
              max_depth = 2,
              subsample = 0.5,
              colsample_bytree = 1,
              gamma = 0.5)


# 拟合模型
set.seed(1)
xgb.fit <- xgb.train(params = param, 
                     data = train.mat, 
                     nrounds = 75)

22.11 模型评价

画个ROC曲线,先计算一下训练集的预测概率,再画ROC曲线即可,没有任何难度:

pred_train <- predict(xgb.fit, newdata = train.mat)
head(pred_train)
## [1] 0.8459527 0.5674704 0.2664627 0.6891936 0.8845208 0.9655565

library(ROCR)
pred <- prediction(pred_train, pima.train$diabetes)
perf <- performance(pred, "tpr", "fpr")
auc <- round(performance(pred, "auc")@y.values[[1]],digits = 4)

plot(perf, 
     main = paste("ROC curve (", "AUC = ",auc,")"), 
     col = 2, 
     lwd = 2)
abline(0,1, lty = 2, lwd = 2)

AUC值达到了0.9以上。

公众号后台回复ROC即可获取ROC曲线合集,回复最佳截点即可获取ROC曲线的最佳截点合集。

计算混淆矩阵等请参考前面的部分,无非就是把概率转换为硬类别而已。

22.12 测试集

首先需要把测试集的格式转换一下。

# 放在专用的格式中
test.mat <- xgb.DMatrix(data = as.matrix(pima.test[, 1:8]), 
                        label = pima.test$diabetes)
# 预测测试集的概率
pred_test <- predict(xgb.fit, newdata = test.mat)
head(pred_test)
## [1] 0.1269337 0.9513396 0.9778093 0.4253268 0.4961042 0.6491649

# 绘制ROC曲线
library(ROCR)
pred <- prediction(pred_test, pima.test$diabetes)
perf <- performance(pred, "tpr", "fpr")
auc <- round(performance(pred, "auc")@y.values[[1]],digits = 4)

plot(perf, 
     main = paste("ROC curve (", "AUC = ",auc,")"), 
     col = 2, 
     lwd = 2)
abline(0,1, lty = 2, lwd = 2)

easy!但是有点过拟合了,可以尝试下不同的超参数再试下。

有些指标是基于预测概率的,有些指标是基于预测列别的,xgboost只能给出预测概率,我们自己转换一下即可计算各种基于类别的指标了。

22.13 参考资料

  1. 帮助文档
  2. https://blog.csdn.net/weixin_43217641/article/details/126599474
  3. 精通机器学习基于R
  4. 官方文档:https://xgboost.readthedocs.io/en/latest/R-package/xgboostPresentation.html
  5. https://www.r-bloggers.com/2016/03/an-introduction-to-xgboost-r-package/