阅读 60

【Lesson1】R 机器学习流程及案例实现

R 机器学习流程及案例实现

一直在学习机器学习的项目;学的断断续续。近期需要完成一些数据建模与分析,将机器学习重新整理了一遍。这篇文章主要是介绍R数据科学中,构建机器学习模型的流程。为了更适合无基础的人快速了解整个流程框架,本文省去机器学习模型的原理及公式部分,如果需要了解,请戳 Here

在看完本文以后,让你们能够对机器学习模型有一个基本认识,然后根据现有数据去构建一个机器学习模型及其需要的步骤与预期结果,最后可以对自己的进行操作练习与实现。

机器学习-流程

根据Max Kuhn 的Caret文章,进行总结,一般的机器学习流程主要分为以下过程。

image.png

将Data分成Train与Test两部分。主要花费的精力是在Train数据集上,因为需要找到一个合适的模型来拟合Train数据,对模型参数进行不断调整,达到该数据的最优。同时还需要考虑resampling,至于为什么要resample,其实就是:针对本数据模型的 可以达到0.99,但是只适用于本数据,不能外推,所以the goal is not to “predict” the data you have in hand, but to develop a model that will predict new datasets.
有时候,变量较多,或者变量会存在相关系,那么就会涉及到变量的处理,Pre-processing(这也是一个相当麻烦的过程)。

1.数据拆分Train与Test数据集
2.Train数据集模型选择与调参
3.模型预测Test数据集

在上述模型调整好以后,嗯,那我们可以对Test数据进行预测了。看下模型预测效果。这里预测的效果优越是需要根据预测变量类型来选择不同的评估指标,主要分为分类与回归两种。然后绘制相应的RMSE曲线或者ROC曲线,来展示模型的预测性能。

当然了,在医学上机器学习应用远不止于此,还需探究变量间的关联性,称之为explanation ML,在后面篇幅会介绍。。

案例操作

下面以caret举例,Caret包的优点:主要使用train函数,集中多个模型。其中函数中定义了模型与调节参数,所以只要替换模型与参数,即可调用不同模型。因此省去了因运行不同模型而学习不同的packages。另外对于预测变量不管是分类变量还是连续性变量,Caret都可以构建。
本次操作利用pdp包里面的pima数据集进行演示。该数据收集了
392例女性糖尿病患者的临床指标,包括年龄,血糖,胰岛素及血压等指标。主要是通过临床指标预测患者是否患糖尿病。

1. 数据拆分

将pima数据进行预处理,丢弃NA,glucose转成分类变量(glucose > 149=="High")。然后利用createDataPartition()将数据分成train(80%)与test (20%)两个部分。

library(tidyverse)
library(caret)
library(pdp)
### get data
data(pima)
df=pima %>% na.omit() %>% as.tbl() %>% 
  mutate(glucose=as.factor(ifelse(glucose>143,"High","Low")))
### splitdata
set.seed(13)
samp = createDataPartition(df$diabetes, p = 0.8, list = FALSE)
train = df[samp,]
test = df[-samp,]

2. 模型构建

这里使用train()函数,因变量为diabetes,自变量默认选择全部,需要提前使用trainControl()设置resampling方法,里面涉及"boot", "cv", "LOOCV", "LGOCV"等一系列方法,这里我们设置为5-fold cross validation--method = "cv", number = 5
因为diabetes是二分类变量,我们采用gbm算法,然后用AUC来评估训练模型的优越性。

myControl = trainControl(method = "cv", 
                         classProbs=T,
                         number = 5,
                         summaryFunction=prSummary,
                         verboseIter = FALSE)
set.seed(12)
model_gbm = train(diabetes ~ ., 
                  data = train,
                  method = "gbm",
                  trControl = myControl,
                  verbose = F,
                  #tuneGrid = gbm.grid,
                  metric = "ROC")

需要提示的是,这里为了减少运行时间,并没有进行tuning 参数调节。gbm模型主要涉及三个参数,可以把参数放入gird,然后一个一个测试,得出每个参数对应调节下的AUC值,根据最大的AUC,选择对应的模型参数。当然如果不设置grid,train会自动选择最适参数。

gbm.grid <- expand.grid(interaction.depth = c(1,2,8),
                         n.trees = c(50, 100, 150, 200, 250, 300),
                         shrinkage = 0.1,
                         n.minobsinnode = 20)
 head(gbm.grid)
 
 model_gbm = train(diabetes ~ ., 
                  data = train,
                  method = "gbm",
                  trControl = myControl,
                  verbose = F,
                  tuneGrid = gbm.grid,
                  metric = "ROC")

接下来,我们看下model_gbm,这里面储存了我们所要的信息。gbm最合适参数


image.png

3. 模型预测

### Predict
pred = predict(model_gbm,newdata=test)
confusionMatrix(pred,test$diabetes)
Confusion Matrix and Statistics

          Reference
Prediction neg pos
       neg  47   9
       pos   5  17
                                          
               Accuracy : 0.8205          
                 95% CI : (0.7172, 0.8983)
    No Information Rate : 0.6667          
    P-Value [Acc > NIR] : 0.001942        
                                          
                  Kappa : 0.58            
                                          
 Mcnemar's Test P-Value : 0.422678        
                                          
            Sensitivity : 0.9038          
            Specificity : 0.6538          
         Pos Pred Value : 0.8393          
         Neg Pred Value : 0.7727          
             Prevalence : 0.6667          
         Detection Rate : 0.6026          
   Detection Prevalence : 0.7179          
      Balanced Accuracy : 0.7788          
                                          
       'Positive' Class : neg      

4. 变量重要性与解释

这里显示, "insulin" "glucose" 与 "mass" 对模型结果影响较大。具体怎么样的影响需要借助于边际效应的关系。pdp-案例:Explaining Black-Box Machine Learning Models - Code Part 1: tabular data + caret + iml

 varImp(model_gbm)
 plot(varImp(model_gbm))
image.png

4. 多个模型比较

有时候需要多个模型放在一起比较。

set.seed(12)
model_gbm = train(diabetes ~ ., 
                  data = train,
                  method = "gbm",
                  trControl = myControl,
                  verbose = F,
                  #tuneGrid = gbm.grid,
                  metric = "ROC")
model_svm = train(diabetes ~ ., 
                 data=train,
                 method = "svmRadial",
                 trControl = myControl,
                 tuneLength = 8,
                 metric = "ROC")
                                   
model_rda = train(diabetes ~ ., 
                 data=train,
                 method = "rda", 
                 trControl = myControl,
                 tuneLength = 4,
                 metric = "ROC")
                                    
# compare all
all=resamples(list(GBM = model_gbm,SVM=model_svm,RDA = model_rda))
summary(all)
Call:
summary.resamples(object = all)

Models: GBM, SVM, RDA 
Number of resamples: 5 

AUC 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.8499955 0.8508692 0.8611407 0.8696634 0.8868533 0.8994585    0
SVM 0.8300370 0.8355535 0.8563194 0.8584288 0.8608459 0.9093879    0
RDA 0.8252053 0.8387715 0.8963407 0.8772405 0.9124427 0.9134421    0

F 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.7804878 0.8235294 0.8297872 0.8193452 0.8314607 0.8314607    0
SVM 0.8043478 0.8089888 0.8181818 0.8208631 0.8222222 0.8505747    0
RDA 0.7380952 0.8048780 0.8181818 0.8073135 0.8275862 0.8478261    0

Precision 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.7500000 0.7872340 0.7872340 0.7876843 0.8000000 0.8139535    0
SVM 0.7400000 0.7659574 0.7708333 0.7763243 0.7826087 0.8222222    0
RDA 0.7380952 0.7800000 0.7826087 0.7851408 0.8000000 0.8250000    0

Recall 
         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
GBM 0.7619048 0.8333333 0.8809524 0.8571429 0.8809524 0.9285714    0
SVM 0.8571429 0.8571429 0.8809524 0.8714286 0.8809524 0.8809524    0
RDA 0.7380952 0.7857143 0.8571429 0.8333333 0.8571429 0.9285714    0

模型AUC

可以看出AUC最大的为gbm模型0.8739。

# ROC
# Build custom AUC function to extract AUC
# from the caret model object
library(pROC) 
test_roc = function(model, data) {
  roc(data$diabetes,
      predict(model, data, type = "prob")[, "pos"])
  
}

# Examine results for test set
model_list = list(GBM = model_gbm,SVM=model_svm,RDA = model_rda)

model_list_roc = model_list %>%
  map(test_roc, data = test)

model_list_roc %>%
  map(auc)

# plot
df_roc=c()
for (i in 1:length(model_list)) {
  a=test_roc(model_list[[i]],test)
  b=tibble(tpr=a$sensitivities,
           fpr=1-a$specificities,
           model=names(model_list)[i])
  
  df_roc=rbind(df_roc,b)
}

ggplot(data=df_roc,aes(x = fpr, y = tpr, group = model)) +
  geom_line(aes(color = model), size = 1) +
  geom_abline(intercept = 0, slope = 1, 
              color = "gray", size = 1)+
  labs(title = ("ROC Curves for all models"),
       x="False Positive Rate (1 - Specificity)",
       y="True Positive Rate (Sensivity or Recall)")

image.png

结语

这是Caret的使用,后续会介绍如何使用Tidymodel,将更简化操作,输入输出步骤。
未完待续。

Caret 参考

  1. Caret resampling介绍
  2. Caret基础介绍-Rebecca
  3. A Brief Introduction to caret 变量为连续性
  4. Caret Tune 参数 循环设置
  5. Kaggle Caret 实战
  6. Data Science and Predictive Analytics
  7. Evaluating Model Performance by Building Cross-Validation from Scratch【为什么要resampling 】

next

Using XGBoost with Tidymodels 结合Caret
Caret 案例Machine Learning for Insurance Claims
Caret 预测Amesing huose-多个caret模型
Predict the Residential Sale Price of Properties in Ames
Multivariate Adaptive Regression Splines
Ames housing prediction
Tidymodels: tidy machine learning in R

pdp

pdp-案例:Explaining Black-Box Machine Learning Models - Code Part 1: tabular data + caret + iml
Chapter 5: Model-Agnostic Methods
Shining a light on the “Black Box” of machine learning
Gradient Boosting Machines
Partial dependence plots for tidymodels-based xgboost
【VIP】--Variable importance plots: an introduction to vip
【pdp】: An R Package for
VIP: Classification of Student Success with Caret

Handling Class Imbalance data

主要两种,1.resample方法增加精度。2.采用PROC评估。

  1. 【Weighting and sampling】-Handling Class Imbalance with R and Caret - An Introduction
  2. 【PROC】-Handling Class Imbalance with R and Caret - Caveats when using the AUC

Tidymodel with R

https://www.tidymodels.org/learn/
https://www.tmwr.org/
https://algotech.netlify.app/blog/tidymodels/

作者:jamesjin63

原文链接:https://www.jianshu.com/p/99de269db4a4

文章分类
后端
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐