大家好,在100天搞定机器学习|Day63 彻底掌握 LightGBM一文中,我介绍了LightGBM 的模型原理和一个极简实例。最近我发现Huggingface与Streamlit好像更配,所以就开发了一个简易的 LightGBM 可视化调参的小工具,旨在让大家可以更深入地理解 LightGBM

网址:

https://huggingface.co/spaces/beihai/LightGBM-parameter-tuning

我只随便放了几个参数,调整这些参数可以实时看到模型评估指标的变化。代码我也放到文章中了,大家有好的优化思路可以留言。下面就详细介绍一下实现过程:

LightGBM 的参数

在完成模型构建之后,必须对模型的效果进行评估,根据评估结果来继续调整模型的参数、特征或者算法,以达到满意的结果。

LightGBM,有核心参数,学习控制参数,IO参数,目标参数,度量参数,网络参数,GPU参数,模型参数,这里我常修改的便是核心参数,学习控制参数,度量参数等。

Control Parameters 含义 用法
max_depth 树的最大深度 当模型过拟合时,可以考虑首先降低 max_depth
min_data_in_leaf 叶子可能具有的最小记录数 默认20,过拟合时用
feature_fraction 例如 为0.8时,意味着在每次迭代中随机选择80%的参数来建树 boosting 为 random forest 时用
bagging_fraction 每次迭代时用的数据比例 用于加快训练速度和减小过拟合
early_stopping_round 如果一次验证数据的一个度量在最近的early_stopping_round 回合中没有提高,模型将停止训练 加速分析,减少过多迭代
lambda 指定正则化 0~1
min_gain_to_split 描述分裂的最小 gain 控制树的有用的分裂
max_cat_group 在 group 边界上找到分割点 当类别数量很多时,找分割点很容易过拟合时

CoreParameters 含义 用法
Task 数据的用途 选择 train 或者 predict
application 模型的用途 选择 regression: 回归时,binary: 二分类时,multiclass: 多分类时
boosting 要用的算法 gbdt, rf: random forest, dart: Dropouts meet Multiple Additive Regression Trees, goss: Gradient-based One-Side Sampling
num_boost_round 迭代次数 通常 100+
learning_rate 如果一次验证数据的一个度量在最近的 early_stopping_round 回合中没有提高,模型将停止训练 常用 0.1, 0.001, 0.003…
num_leaves 默认 31
device cpu 或者 gpu
metric mae: mean absolute error , mse: mean squared error , binary_logloss: loss for binary classification , multi_logloss: loss for multi classification

Faster Speed better accuracy over-fitting
将 max_bin 设置小一些 用较大的 max_bin max_bin 小一些
num_leaves 大一些 num_leaves 小一些
用 feature_fraction 来做 sub-sampling 用 feature_fraction
用 bagging_fraction 和 bagging_freq 设定 bagging_fraction 和 bagging_freq
training data 多一些 training data 多一些
用 save_binary 来加速数据加载 直接用 categorical feature 用 gmin_data_in_leaf 和 min_sum_hessian_in_leaf
用 parallel learning 用 dart 用 lambda_l1, lambda_l2 ,min_gain_to_split 做正则化
num_iterations 大一些,learning_rate 小一些 用 max_depth 控制树的深度

模型评估指标

以分类模型为例,常见的模型评估指标有一下几种:

混淆矩阵

混淆矩阵是能够比较全面的反映模型的性能,从混淆矩阵能够衍生出很多的指标来。

ROC曲线

ROC曲线,全称The Receiver Operating Characteristic Curve,译为受试者操作特性曲线。这是一条以不同阈值 下的假正率FPR为横坐标,不同阈值下的召回率Recall为纵坐标的曲线。让我们衡量模型在尽量捕捉少数类的时候,误伤多数类的情况如何变化的。

AUC

AUC(Area Under the ROC Curve)指标是在二分类问题中,模型评估阶段常被用作最重要的评估指标来衡量模型的稳定性。ROC曲线下的面积称为AUC面积,AUC面积越大说明ROC曲线越靠近左上角,模型越优;

Streamlit 实现

Streamlit我就不再多做介绍了,老读者应该都特别熟悉了。就再列一下之前开发的几个小东西:

核心代码如下,完整代码我放到Github,欢迎大家给个Star

https://github.com/tjxj/visual-parameter-tuning-with-streamlit

  1. from definitions import *
  2. st.set_option('deprecation.showPyplotGlobalUse', False)
  3. st.sidebar.subheader("请选择模型参数:sunglasses:")
  4. # 加载数据
  5. breast_cancer = load_breast_cancer()
  6. data = breast_cancer.data
  7. target = breast_cancer.target
  8. # 划分训练数据和测试数据
  9. X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)
  10. # 转换为Dataset数据格式
  11. lgb_train = lgb.Dataset(X_train, y_train)
  12. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  13. # 模型训练
  14. params = {'num_leaves': num_leaves, 'max_depth': max_depth,
  15. 'min_data_in_leaf': min_data_in_leaf,
  16. 'feature_fraction': feature_fraction,
  17. 'min_data_per_group': min_data_per_group,
  18. 'max_cat_threshold': max_cat_threshold,
  19. 'learning_rate':learning_rate,'num_leaves':num_leaves,
  20. 'max_bin':max_bin,'num_iterations':num_iterations
  21. }
  22. gbm = lgb.train(params, lgb_train, num_boost_round=2000, valid_sets=lgb_eval, early_stopping_rounds=500)
  23. lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
  24. probs = gbm.predict(X_test, num_iteration=gbm.best_iteration) # 输出的是概率结果
  25. fpr, tpr, thresholds = roc_curve(y_test, probs)
  26. st.write('------------------------------------')
  27. st.write('Confusion Matrix:')
  28. st.write(confusion_matrix(y_test, np.where(probs > 0.5, 1, 0)))
  29. st.write('------------------------------------')
  30. st.write('Classification Report:')
  31. report = classification_report(y_test, np.where(probs > 0.5, 1, 0), output_dict=True)
  32. report_matrix = pd.DataFrame(report).transpose()
  33. st.dataframe(report_matrix)
  34. st.write('------------------------------------')
  35. st.write('ROC:')
  36. plot_roc(fpr, tpr)

上传Huggingface

Huggingface 前一篇文章(腾讯的这个算法,我搬到了网上,随便玩!)我已经介绍过了,这里就顺便再讲一下步骤吧。

step1:注册Huggingface账号

step2:创建Space,SDK记得选择Streamlit

step3:克隆新建的space代码,然后将改好的代码push上去

  1. git lfs install
  2. git add .
  3. git commit -m "commit from $beihai"
  4. git push

push的时候会让输入用户名(就是你的注册邮箱)和密码,解决git总输入用户名和密码的问题:git config --global credential.helper store

push完成就大功告成了,回到你的space页对应项目,就可以看到效果了。

机器学习系列:LightGBM 可视化调参的更多相关文章

  1. 工程能力UP | LightGBM的调参干货教程与并行优化

    这是个人在竞赛中对LGB模型进行调参的详细过程记录,主要包含下面六个步骤: 大学习率,确定估计器参数n_estimators/num_iterations/num_round/num_boost_ro ...

  2. 贪玩ML系列之CIFAR-10调参

    调参方法:网格调参 tf.layers.conv2d()中的padding参数 取值“same”,表示当filter移出边界时,给空位补0继续计算.该方法能够更多的保留图像边缘信息.当图片较小(如CI ...

  3. LightGBM调参笔记

    本文链接:https://blog.csdn.net/u012735708/article/details/837497031. 概述在竞赛题中,我们知道XGBoost算法非常热门,是很多的比赛的大杀 ...

  4. 调参、最优化、ml算法(未完成)

    最优化方法 调参方法 ml算法 梯度下降gd grid search lr 梯度上升 随机梯度下降 pca 随机梯度下降sgd  贝叶斯调参 lda 牛顿算法   knn 拟牛顿算法   kmeans ...

  5. 【转载】 自动化机器学习(AutoML)之自动贝叶斯调参

    原文地址: https://blog.csdn.net/linxid/article/details/81189154 ---------------------------------------- ...

  6. LightGBM 调参方法(具体操作)

     sklearn实战-乳腺癌细胞数据挖掘(博主亲自录制视频) https://study.163.com/course/introduction.htm?courseId=1005269003& ...

  7. 【Python机器学习实战】决策树与集成学习(七)——集成学习(5)XGBoost实例及调参

    上一节对XGBoost算法的原理和过程进行了描述,XGBoost在算法优化方面主要在原损失函数中加入了正则项,同时将损失函数的二阶泰勒展开近似展开代替残差(事实上在GBDT中叶子结点的最优值求解也是使 ...

  8. 自动调参库hyperopt+lightgbm 调参demo

    在此之前,调参要么网格调参,要么随机调参,要么肉眼调参.虽然调参到一定程度,进步有限,但仍然很耗精力. 自动调参库hyperopt可用tpe算法自动调参,实测强于随机调参. hyperopt 需要自己 ...

  9. python 机器学习中模型评估和调参

    在做数据处理时,需要用到不同的手法,如特征标准化,主成分分析,等等会重复用到某些参数,sklearn中提供了管道,可以一次性的解决该问题 先展示先通常的做法 import pandas as pd f ...

随机推荐

  1. UOJ191口胡

    UOJ191,你失败的原因只有一个:你没有强制在线. 首先这个序列末位加加减减很烦,于是换成操作树,这样就变成查询链的信息了. 注意到一个向量 \((x_1,y_1)\) 比 \((x_2,y_2)\ ...

  2. 解析ansible远程管理客户端【win终端为例】

    一.前提: 1.1.windows机器开启winrm服务,并设置成允许远程连接状态 具体操作命令如下 set-executionpolicy remotesigned winrm quickconfi ...

  3. 报错 ——Error evaluating expression 'id != null id > 0'.

    Exception in thread "main" org.apache.ibatis.exceptions.PersistenceException: ### Error qu ...

  4. 活用Windows Server 2008系统的几种安全功能

    与传统操作系统相比,Win2008系统的安全防范功能更加强大,安全防护能力自然也是高人一等,我们只要在平时善于使用该系统新增的各项安全防范功能,完全可以实现更高级别的安全保护目的.现在,本文就为大家贡 ...

  5. kubernetes证书过期处理

    rancher中文文档:http://docs.rancher.cn/ k8s中文文档:https://kubernetes.io/zh/docs 一.修改kubeadm 源码 增加证书到100年 $ ...

  6. java线程池之newFixedThreadPool定长线程池

    newFixedThreadPool 创建一个定长线程池,可控制线程最大并发数,超出的线程会在队列中等待. 线程池的作用: 线程池作用就是限制系统中执行线程的数量.     根 据系统的环境情况,可以 ...

  7. Azure DevOps (六) 通过FTP上传流水线制品到Linux服务器

    上一篇我们实现了把流水线的制品保存到azure的流水线制品仓库里去,本篇我们会开始研究azure的发布流水线. 本篇要研究的是把流水线仓库的制品发布到任意一台公网的linux服务器上去,所以我们先研究 ...

  8. 什么是Java序列化,如何实现Java序列化?或者请解释Serializable接口的作用?

    象序列化的目标是将对象保存到磁盘中,或允许在网络中直接传输对象,对象序列化机制允许把内存中的Java对象转换成平台无关的二进制流,从而允许把这种二进制流持久保存在磁盘上,通过网络将这种二进制流传输到另 ...

  9. 什么是 rabbitmq?

    采用 AMQP 高级消息队列协议的一种消息队列技术,最大的特点就是消费并不需要确保提供方存在,实现了服务之间的高度解耦

  10. JavaScript对不同数据结构的常见循环

    var obj1 = { title : 'tom and jetty', author : 'pecool' } function Book(){} Book.prototype.price = 2 ...