library(xgboost)
rm(list = ls())
# load data
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
<- agaricus.train
train <- agaricus.test
test
str(train)
## List of 2
## $ data :Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
## .. ..@ i : int [1:143286] 2 6 8 11 18 20 21 24 28 32 ...
## .. ..@ p : int [1:127] 0 369 372 3306 5845 6489 6513 8380 8384 10991 ...
## .. ..@ Dim : int [1:2] 6513 126
## .. ..@ Dimnames:List of 2
## .. .. ..$ : NULL
## .. .. ..$ : chr [1:126] "cap-shape=bell" "cap-shape=conical" "cap-shape=convex" "cap-shape=flat" ...
## .. ..@ x : num [1:143286] 1 1 1 1 1 1 1 1 1 1 ...
## .. ..@ factors : list()
## $ label: num [1:6513] 1 0 0 1 0 0 0 1 0 0 ...
str(test)
## List of 2
## $ data :Formal class 'dgCMatrix' [package "Matrix"] with 6 slots
## .. ..@ i : int [1:35442] 0 2 7 11 13 16 20 22 27 31 ...
## .. ..@ p : int [1:127] 0 83 84 806 1419 1603 1611 2064 2064 2701 ...
## .. ..@ Dim : int [1:2] 1611 126
## .. ..@ Dimnames:List of 2
## .. .. ..$ : NULL
## .. .. ..$ : chr [1:126] "cap-shape=bell" "cap-shape=conical" "cap-shape=convex" "cap-shape=flat" ...
## .. ..@ x : num [1:35442] 1 1 1 1 1 1 1 1 1 1 ...
## .. ..@ factors : list()
## $ label: num [1:1611] 0 1 0 0 0 0 1 0 1 0 ...
class(train$data)
## [1] "dgCMatrix"
## attr(,"package")
## [1] "Matrix"
22 XGBoost
xgboost
(Extreme Gradient Boosting),极限梯度提升,是基于梯度提升树(gradient boosting decision tree,GBDT)实现的集成(ensemble)算法,本质上还是一种提升(boosting)算法,但是把速度和效率提升到最强,所以加了Extreme
。
xgboost
的一些特性包括:
- 速度快效率高:默认会借助
OpenMP
进行并行计算 - 核心代码使用C++实现,速度快,易分享;
- 正则化:可以使用正则化技术避免过度拟合;
- 交叉验证:内部会进行交叉验证;
- 缺失值处理:可以处理缺失值,不需要提前插补;
- 适用于多种任务类型:支持回归分类排序等,还支持用户自定义的目标函数;
关于更多它的背景知识,大家可以参考集成算法类型一章。下面我们看下它在R语言中的使用。
22.1 准备数据
先用自带数据演示一下简单的使用方法。
首先是加载数据,这是一个二分类数据,其中label
是结果变量,数值型,使用0和1表示,其余是预测变量。
xgboost
对数据格式是有要求的,可以看到train
和test
都是列表,其中包含了预测变量data
和结果变量label
,其中label
是使用0和1表示的。data
部分是稀疏矩阵的形式(dgCMatrix
)。
# 查看数据维度,6513行,126列
dim(train$data)
## [1] 6513 126
table(train$label)
##
## 0 1
## 3373 3140
22.2 拟合模型
接下来就是使用训练数据拟合模型:
<- xgboost(data = train$data, label = train$label,
model max.depth = 2, # 树的最大深度
eta = 1, # 学习率
nrounds = 2,
nthread = 2, # 使用的CPU线程数
objective = "binary:logistic")
## [1] train-logloss:0.233376
## [2] train-logloss:0.136658
其中的参数nround
在这里表示最终模型中树的数量,objective
是目标函数。
除了使用列表传入数据,也支持R语言中的密集矩阵(matrix
),比如:
<- xgboost(data = as.matrix(train$data), label = train$label,
model max.depth = 2,
eta = 1,
nrounds = 2,
nthread = 2,
objective = "binary:logistic")
## [1] train-logloss:0.233376
## [2] train-logloss:0.136658
开发者最推荐的格式还是特别为xgboost
设计的xgb.DMatrix
格式。
# 建立xgb.DMatrix
<- xgb.DMatrix(data = train$data, label = train$label)
dtrain
<- xgboost(data = dtrain, # 这样就不用单独传入label了
xx max.depth = 2,
eta = 1,
nrounds = 2,
nthread = 2,
objective = "binary:logistic"
)## [1] train-logloss:0.233376
## [2] train-logloss:0.136658
这个格式也是xgboost
特别设计的,有助于更好更快的进行计算,还可以支持更多的信息传入。可以通过getinfo
获取其中的元素:
head(getinfo(dtrain, "label"))
## [1] 1 0 0 1 0 0
22.3 新数据预测
拟合模型后,就可以对新数据进行预测了:
# predict
<- predict(model, test$data)
pred head(pred)
## [1] 0.28583017 0.92392391 0.28583017 0.28583017 0.05169873 0.92392391
range(pred)
## [1] 0.01072847 0.92392391
我们的任务是一个二分类的,但是xgboost
的预测结果是概率,并不是直接的类别。我们需要自己转换一下,比如规定概率大于0.5就是类别1,小于等于0.5就是类别0(这个阈值其实也可以当做超参数调整的)。
<- ifelse(pred > 0.5,1,0)
pred_label table(pred_label)
## pred_label
## 0 1
## 826 785
# 混淆矩阵
table(test$label, pred_label)
## pred_label
## 0 1
## 0 813 22
## 1 13 763
全对,准确率100%。
除此之外还提供一个xgb.cv()
用于实现交叉验证的建模,使用方法与xgboost()
一致:
<- xgb.cv(data = train$data, label = train$label,
cv.res nrounds = 2,
objective = "binary:logistic",
nfold = 10 # 交叉验证的折数
)## [1] train-logloss:0.439673+0.000200 test-logloss:0.439961+0.000873
## [2] train-logloss:0.299550+0.000206 test-logloss:0.299903+0.001440
min(cv.res$evaluation_log)
## [1] 0.0001998477
22.4 控制输出日志
xgboost()
和xgb.train()
有参数verbose
可以控制输出日志的多少,默认是verbose = 1
,输出性能指标结果。xgboost()
是xgb.train()
的简单封装,xgb.train()
是训练xgboost模型的高级接口。
如果是0,则是没有任何输出:
<- xgboost(data = train$data, label = train$label, objective = "binary:logistic"
xx nrounds = 2
,verbose = 0
, )
如果是2,会输出性能指标结果和其他信息(这里没显示):
<- xgboost(data = train$data, label = train$label, objective = "binary:logistic"
xx nrounds = 2
,verbose = 2
,
)## [1] train-logloss:0.439409
## [2] train-logloss:0.299260
xgb.cv()
的verbose
参数只有TRUE
和FALSE
。
22.5 变量重要性
xgboost
中变量的重要性是这样计算的:
我们如何在xgboost中定义特性的重要性?在xgboost中,每次分割都试图找到最佳特征和分割点(splitting point)来优化目标。我们可以计算每个节点上的增益,它是所选特征的贡献。最后,我们对所有的树进行研究,总结每个特征的贡献,并将其视为重要性。如果特征的数量很大,我们也可以在绘制图之前对特征进行聚类。
查看变量重要性:
<- xgb.importance(model = model)
importance_matrix
importance_matrix## Feature Gain Cover Frequency
## <char> <num> <num> <num>
## 1: odor=none 0.67615470 0.4978746 0.4
## 2: stalk-root=club 0.17135376 0.1920543 0.2
## 3: stalk-root=rooted 0.12317236 0.1638750 0.2
## 4: spore-print-color=green 0.02931918 0.1461960 0.2
可视化变量重要性:
xgb.plot.importance(importance_matrix)
或者ggplot2
版本:
xgb.ggplot.importance(importance_matrix)
22.6 查看树的信息
可以把学习好的树打印出来,查看具体情况:
xgb.dump(model, with_stats = T)
## [1] "booster[0]"
## [2] "0:[f28<0.5] yes=1,no=2,missing=1,gain=4000.53101,cover=1628.25"
## [3] "1:[f55<0.5] yes=3,no=4,missing=3,gain=1158.21204,cover=924.5"
## [4] "3:leaf=1.71217716,cover=812"
## [5] "4:leaf=-1.70044053,cover=112.5"
## [6] "2:[f108<0.5] yes=5,no=6,missing=5,gain=198.173828,cover=703.75"
## [7] "5:leaf=-1.94070864,cover=690.5"
## [8] "6:leaf=1.85964918,cover=13.25"
## [9] "booster[1]"
## [10] "0:[f59<0.5] yes=1,no=2,missing=1,gain=832.544983,cover=788.852051"
## [11] "1:[f28<0.5] yes=3,no=4,missing=3,gain=569.725098,cover=768.389709"
## [12] "3:leaf=0.78471756,cover=458.936859"
## [13] "4:leaf=-0.968530357,cover=309.45282"
## [14] "2:leaf=-6.23624468,cover=20.462389"
还可以可视化树:
xgb.plot.tree(model = model)
这个图展示了2棵树的分支过程,因为我们设置了nround=2
,所以结果就是只有2棵树。
如果树的数量非常多的时候,这样每棵树看过来并不是很直观,通常xgboost
虽然不如随机森林需要的树多,但是几十棵总是要的,所以xgboost
提供了一种能把所有的树结合在一起展示的方法。
# 多棵树展示在一起
xgb.plot.multi.trees(model = model,fill=TRUE)
## Column 2 ['No'] of item 2 is missing in item 1. Use fill=TRUE to fill with NA (NULL for list columns), or use.names=FALSE to ignore column names. use.names='check' (default from v1.12.2) emits this message and proceeds as if use.names=FALSE for backwards compatibility. See news item 5 in v1.12.2 for options to control this message.
这幅图就是把上面那张图的信息整合到了一起,大家仔细对比下图中的数字就会发现信息是一样的哦。
除了以上方法可以检查树的信息外,还可以通过查看树的深度来检查树的结构。
<- xgboost(data = train$data, label = train$label, max.depth = 15,
bst eta = 1, nthread = 2, nround = 30, objective = "binary:logistic",
min_child_weight = 50,verbose = 0)
xgb.plot.deepness(model = bst)
这两幅图的横坐标都是树的深度,上面的图纵坐标是叶子的数量,展示了每层深度中的叶子数量。下面的图纵坐标是每片叶子的归一化之后的加权覆盖。
从图中可以看出树的深度在5之后,叶子的数量就很少了,这提示我们为了防止过拟合,可以把树的深度控制在5以内。
22.7 保存加载模型
保存加载训练好的模型:
# 保存
xgb.save(model, "xgboost.model")
# 加载
xgb.load("xgboost.model")
<- predict(model, test$data) aa
在R语言中除了直接使用xgboost
这个R包实现之外,还有许多综合性的R包都可以实现xgboost算法,并支持超参数调优等更多任务,比如caret
、tidymodels
、mlr3
。
22.8 参数解释
xgboost
本身是基于梯度提升树(GBDT)实现的集成算法,所以它的参数整体来说可以分为三个部分:集成算法本身,用于集成的弱评估器(决策树),以及应用中的其他过程。
下面是一些参数的详细介绍:
- nrounds:最大迭代次数(最终模型中树的数量)。
- early_stopping_rounds:一个正整数,表示在验证集中经过K次训练如果模型表现还是没有提高就停止训练。
- print_every_n:如果verbose>0,这个参数表示每多少次迭代(多少棵树)打印一次日志信息。
params
是xgb.train()
中最重要的参数了,params
接受一个列表,列表内包含超多参数,这些参数主要分为3大类,也是我们调参需要重点关注的参数:
- 通用参数
booster
:提升器类型,gbtree
(默认)或者gblinear
。多数情况下都是gbtree
的效果更好,但是如果你的预测变量和结果变量呈现明显的线性关系,可能gblinear
更好,但也不是绝对的,开发者建议都试一下。
- booster相关的参数 2.1 tree booster相关的参数
eta
:学习率η,每棵树在最终解中的贡献,决定迭代速度,默认为0.3,范围是[0,1]。一般会选择比较小的值,比如0.01,0.001等。gamma
:在进行分支时所需要的最小的目标函数减少量,如果大于这个值,则继续分支,非常重要的参数,控制树的规模,默认是0,范围是[0,inf]max_depth
:单个树的最大深度。非常重要,通常max_depth和gamma只调一个即可min_child_weight
:对树进行提升时使用的最小权重,默认为1。叶子节点的二阶导数之和subsample
:子样本数据占整个数据的比例,也就是每次重抽样的比例,默认值为1(100%)。colsample_bytree
:建立树时随机抽取的特征数量,用一个比率表示,默认值为1(使用100%的特征)。比较重要lambda
:L2正则化的比例,默认是1,也就是lasso。alpha
:L1正则化的比例,默认是0。- … 2.2 linear booster相关的参数
- …
- 任务相关的参数
objective
:指定任务类型和目标函数,支持自定义函数,默认的有以下类型,主要是回归、分类、生存、排序等:reg:squarederror
:均方根误差(默认值)。reg:squaredlogerror
:均方根对数误差。reg:logistic
:logistic函数。reg:pseudohubererror
:Pseudo Huber损失函数。binary:logistic
:二分类逻辑回归,输出概率值。binary:logitraw
:二分类逻辑回归,输出logistic转换之前的值。binary:hinge
:二分类hinge loss,输出0或者1。支持向量机的损失函数count:poisson
:计数数据的泊松回归survival:cox
:右删失生存数据的cox回归,返回风险比HR。survival:aft
:加速失效模型。- …
base_score
:叶子权重eval_metric
:验证集的性能指标,回归任务默认是rmse,二分类默认是错分率error。
以下是一些其他需要注意的点:
eta
/subsample
/nrounds
通常并不是提高模型表现的,主要是控制调参时间的。gamma
是通过降低训练集的表现来防止过拟合的(让训练集和测试集的模型表现更接近),如果太大,也会降低测试集的表现- 先通过网格搜索调整
nrounds
和eta
,然后使用gamma
或者
max_depth`看是否过拟合,再剪枝 xgboost
可以自动处理缺失值,不需要预处理,参数missing
scale_pos_weight
用于处理类不平衡的权重
22.9 超参数调优
下面我们使用印第安人糖尿病数据集,演示下如何对xgboost
进行超参数调优,由于目前xgboost
包里面并没有专门的调优参数,所以还是需要借助其他R包实现(或者自己写循环),我这里借助了caret
。
rm(list = ls())
library(MASS)
library(xgboost)
load(file = "datasets/pimadiabetes.rdata")
# 结果变量改成1和0表示
$diabetes <- ifelse(pimadiabetes$diabetes == "pos",1,0)
pimadiabetesstr(pimadiabetes)
## 'data.frame': 768 obs. of 9 variables:
## $ pregnant: num 6 1 8 1 0 5 3 10 2 8 ...
## $ glucose : num 148 85 183 89 137 116 78 115 197 125 ...
## $ pressure: num 72 66 64 66 40 ...
## $ triceps : num 35 29 22.9 23 35 ...
## $ insulin : num 202.2 64.6 217.1 94 168 ...
## $ mass : num 33.6 26.6 23.3 28.1 43.1 ...
## $ pedigree: num 0.627 0.351 0.672 0.167 2.288 ...
## $ age : num 50 31 32 21 33 30 26 29 53 54 ...
## $ diabetes: num 0 1 0 1 0 1 0 1 0 0 ...
# 按照7:3的比例划分训练集、测试集
set.seed(502)
<- sample(1:nrow(pimadiabetes), size = 0.7*nrow(pimadiabetes))
ind <- pimadiabetes[ind,]
pima.train <- pimadiabetes[-ind,]
pima.test
dim(pima.train)
## [1] 537 9
dim(pima.test)
## [1] 231 9
训练集有537行,9列,测试集有231行,9列,其中diabetes
列是结果变量,1表示有糖尿病,0表示没有糖尿病。
下面我们使用默认参数拟合模型,看看模型效果。顺便学习下如果准备这些参数。
注意,所有的预测变量都需要是数值型(这和我们前面介绍过的xgboost
输入数据的格式有关,矩阵需要都是数值型的),所以分类变量需要进行一些转换,比如哑变量、独热编码等。
# 选择参数的值
<- list(objective = "binary:logistic", # 二分类
param booster = "gbtree",
eval_metric = "error",
eta = 0.3,
max_depth = 3,
subsample = 1,
colsample_bytree = 1,
gamma = 0.5)
# 准备预测变量和结果变量
<- as.matrix(pima.train[, 1:8])
x <- pima.train$diabetes
y
# 放进专用的格式中
<- xgb.DMatrix(data = x, label = y)
train.mat
train.mat## xgb.DMatrix dim: 537 x 8 info: label colnames: yes
这样参数和数据就都准备好了,下面开始训练即可。xgboost()
是xgb.train()
的简单封装,xgb.train()
是训练xgboost模型的高级接口。xgboost模型的参数非常多,详情参考上面的介绍。
set.seed(1)
<- xgb.train(params = param,
xgb.fit data = train.mat,
nrounds = 100)
有了这个结果后你可以查看变量重要性,查看每棵树的信息,得出预测类别的概率,画出ROC曲线等,详情请参考前面的部分,这里就不再重复演示了。
下面就是对这些参数进行调整,我们就使用caret
进行演示。
caret
作为R语言中经典的机器学习综合性R包,使用起来非常简单,我们也写过非常详细的系列教程了,公众号后台回复caret即可获取caret
系列推文合集。
library(caret)
# 选择参数范围
<- expand.grid(nrounds = c(75, 100),
grid colsample_bytree = 1,
min_child_weight = 1,
eta = c(0.01, 0.1, 0.3),
gamma = c(0.5, 0.25),
subsample = 0.5,
max_depth = c(2, 3))
# 一些控制参数,重抽样方法选择5折交叉验证
<- trainControl(method = "cv",
cntrl number = 5,
verboseIter = F,
returnData = F,
returnResamp = "final")
# 开始调优
set.seed(1)
<- train(x = pima.train[, 1:8],
train.xgb y = pima.train$diabetes,
trControl = cntrl,
tuneGrid = grid,
method = "xgbTree")
## [19:59:21] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:21] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:21] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:22] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:23] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:24] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:25] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:26] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:27] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:28] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [19:59:29] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
#train.xgb # 太长不展示了
结果中给出了最优的超参数:nrounds = 75, max_depth = 2, eta = 0.1, gamma = 0.5, colsample_bytree = 1, min_child_weight = 1, subsample = 0.5。
这个结果可以探索可视化的地方非常多,比如:查看不同超参数对模型性能的影响:
plot(train.xgb)
也是支持ggplot2
的。
ggplot(train.xgb)
更多方法大家可以探索我们的caret
合集。
22.10 重新拟合模型
接下来就是使用最优的超参数重新拟合模型。
# 选择最优的参数值
<- list(objective = "binary:logistic",
param booster = "gbtree",
eval_metric = "error",
eta = 0.1,
max_depth = 2,
subsample = 0.5,
colsample_bytree = 1,
gamma = 0.5)
# 拟合模型
set.seed(1)
<- xgb.train(params = param,
xgb.fit data = train.mat,
nrounds = 75)
22.11 模型评价
画个ROC曲线,先计算一下训练集的预测概率,再画ROC曲线即可,没有任何难度:
<- predict(xgb.fit, newdata = train.mat)
pred_train head(pred_train)
## [1] 0.8459527 0.5674704 0.2664627 0.6891936 0.8845208 0.9655565
library(ROCR)
<- prediction(pred_train, pima.train$diabetes)
pred <- performance(pred, "tpr", "fpr")
perf <- round(performance(pred, "auc")@y.values[[1]],digits = 4)
auc
plot(perf,
main = paste("ROC curve (", "AUC = ",auc,")"),
col = 2,
lwd = 2)
abline(0,1, lty = 2, lwd = 2)
AUC值达到了0.9以上。
公众号后台回复ROC即可获取ROC曲线合集,回复最佳截点即可获取ROC曲线的最佳截点合集。
计算混淆矩阵等请参考前面的部分,无非就是把概率转换为硬类别而已。
22.12 测试集
首先需要把测试集的格式转换一下。
# 放在专用的格式中
<- xgb.DMatrix(data = as.matrix(pima.test[, 1:8]),
test.mat label = pima.test$diabetes)
# 预测测试集的概率
<- predict(xgb.fit, newdata = test.mat)
pred_test head(pred_test)
## [1] 0.1269337 0.9513396 0.9778093 0.4253268 0.4961042 0.6491649
# 绘制ROC曲线
library(ROCR)
<- prediction(pred_test, pima.test$diabetes)
pred <- performance(pred, "tpr", "fpr")
perf <- round(performance(pred, "auc")@y.values[[1]],digits = 4)
auc
plot(perf,
main = paste("ROC curve (", "AUC = ",auc,")"),
col = 2,
lwd = 2)
abline(0,1, lty = 2, lwd = 2)
easy!但是有点过拟合了,可以尝试下不同的超参数再试下。
有些指标是基于预测概率的,有些指标是基于预测列别的,xgboost只能给出预测概率,我们自己转换一下即可计算各种基于类别的指标了。
22.13 参考资料
- 帮助文档
- https://blog.csdn.net/weixin_43217641/article/details/126599474
- 精通机器学习基于R
- 官方文档:https://xgboost.readthedocs.io/en/latest/R-package/xgboostPresentation.html
- https://www.r-bloggers.com/2016/03/an-introduction-to-xgboost-r-package/