sklearn的class_weight设置为'balanced'的计算方法
分类的时候,当不同类别的样本量差异很大时,很容易影响分类结果,因此要么每个类别的数据量大致相同,要么就要进行校正。
sklearn的做法可以是加权,加权就要涉及到class_weight和sample_weight,当不设置class_weight参数时,默认值是所有类别的权值为1。
在python中:
# class_weight的传参
class_weight : {dict, 'balanced'}, optional
Set the parameter C of class i to class_weight[i]*C for
SVC. If not given, all classes are supposed to have
weight one. The "balanced" mode uses the values of y to automatically
adjust weights inversely proportional to class frequencies as
``n_samples / (n_classes * np.bincount(y))``
# 当使用字典时,其形式为:Weights associated with classes in the form ``{class_label: weight}``,比如:{0: 1, 1: 1}表示类0的权值为1,类1的权值为1. # sample_weight的传参
sample_weight : array-like, shape (n_samples,)
Per-sample weights. Rescale C per sample. Higher weights
force the classifier to put more emphasis on these points.
1. 在:from sklearn.utils.class_weight import compute_class_weight 里面可以看到计算的源代码。
2. 除了通过字典形式传入权重参数,还可以设置的是:class_weight = 'balanced',例如使用SVM分类:
clf = SVC(kernel = 'linear', class_weight='balanced', decision_function_shape='ovr')
clf.fit(X_train, y_train)
3. 那么'balanced'的计算方法是什么呢?看例子:
import numpy as np y = [0,0,0,0,0,0,0,0,1,1,1,1,1,1,2,2] #标签值,一共16个样本 a = np.bincount(y) # array([8, 6, 2], dtype=int64) 计算每个类别的样本数量
aa = 1/a #倒数 array([0.125 , 0.16666667, 0.5 ])
print(aa) from sklearn.utils.class_weight import compute_class_weight
class_weight = 'balanced'
classes = np.array([0, 1, 2]) #标签类别
weight = compute_class_weight(class_weight, classes, y)
print(weight) # [0.66666667 0.88888889 2.66666667] print(0.66666667*8) #5.33333336
print(0.88888889*6) #5.33333334
print(2.66666667*2) #5.33333334
# 这三个值非常接近
# 'balanced'计算出来的结果很均衡,使得惩罚项和样本量对应
可以看出计算出来的值,乘以样本量之后,三个类别的数字很接近,我想的是:个人觉得惩罚项就用样本量的倒数未尝不可,因为乘以样本量都是1,相当于'balanced'这里是多乘以了一个常数
4. 真正的魔法到了:还记得上面所给出的python中,当class_weight为'balanced'时的计算公式吗?
# weight_ = n_samples / (n_classes * np.bincount(y))``
# 这里
# n_samples为16
# n_classes为3
# np.bincount(y)实际上就是每个类别的样本数量
于是:
print(16/(3*8)) #输出 0.6666666666666666
print(16/(3*6)) #输出 0.8888888888888888
print(16/(3*2)) #输出 2.6666666666666665
是不是跟计算出来的权值一样?这就是class_weight设置为'balanced'时的计算方法了。
5. 当然,需要说明一下传入字典时的情形
import numpy as np y = [0,0,0,0,0,0,0,0,1,1,1,1,1,1,2,2] #标签值,一共16个样本 from sklearn.utils.class_weight import compute_class_weight
class_weight = {0:1,1:3,2:5} # {class_label_1:weight_1, class_label_2:weight_2, class_label_3:weight_3}
classes = np.array([0, 1, 2]) #标签类别
weight = compute_class_weight(class_weight, classes, y)
print(weight) # 输出:[1. 3. 5.],也就是字典中设置的值
参考:
https://blog.csdn.net/go_og/article/details/81281387
https://www.zhihu.com/question/265420166/answer/293896934
sklearn的class_weight设置为'balanced'的计算方法的更多相关文章
- sklearn逻辑回归(Logistic Regression)类库总结
class sklearn.linear_model.LogisticRegression(penalty=’l2’, dual=False, tol=0.0001, C=1.0, fit_inter ...
- sklearn逻辑回归(Logistic Regression,LR)调参指南
python信用评分卡建模(附代码,博主录制) https://study.163.com/course/introduction.htm?courseId=1005214003&utm_ca ...
- 逻辑回归原理_挑战者飞船事故和乳腺癌案例_Python和R_信用评分卡(AAA推荐)
sklearn实战-乳腺癌细胞数据挖掘(博客主亲自录制视频教程) https://study.163.com/course/introduction.htm?courseId=1005269003&a ...
- XGBoost、LightGBM、Catboost总结
sklearn集成方法 bagging 常见变体(按照样本采样方式的不同划分) Pasting:直接从样本集里随机抽取的到训练样本子集 Bagging:自助采样(有放回的抽样)得到训练子集 Rando ...
- XGBoost、LightGBM的详细对比介绍
sklearn集成方法 集成方法的目的是结合一些基于某些算法训练得到的基学习器来改进其泛化能力和鲁棒性(相对单个的基学习器而言)主流的两种做法分别是: bagging 基本思想 独立的训练一些基学习器 ...
- CART决策树和随机森林
CART 分裂规则 将现有节点的数据分裂成两个子集,计算每个子集的gini index 子集的Gini index: \(gini_{child}=\sum_{i=1}^K p_{ti} \sum_{ ...
- Python解决数据样本类别分布不均衡问题
所谓不平衡指的是:不同类别的样本数量差异非常大. 数据规模上可以分为大数据分布不均衡和小数据分布不均衡.大数据分布不均衡:例如拥有1000万条记录的数据集中,其中占比50万条的少数分类样本便于属于这种 ...
- 【机器学习基础】逻辑回归——LogisticRegression
LR算法作为一种比较经典的分类算法,在实际应用和面试中经常受到青睐,虽然在理论方面不是特别复杂,但LR所牵涉的知识点还是比较多的,同时与概率生成模型.神经网络都有着一定的联系,本节就针对这一算法及其所 ...
- (原创)(四)机器学习笔记之Scikit Learn的Logistic回归初探
目录 5.3 使用LogisticRegressionCV进行正则化的 Logistic Regression 参数调优 一.Scikit Learn中有关logistics回归函数的介绍 1. 交叉 ...
随机推荐
- 禁用 Ubuntu 18.04 Files 的 Type Ahead search 功能
. . . . . Ubuntu 的文件浏览器(Files)提供了一个搜索的功能,叫做“Type Ahead search”.即我们在文件浏览器中输入某个文件的名字时,Files 并不是将焦点定位在某 ...
- 恋恋山城 Jean de Florette (1986) 男人的野心 / 弗洛莱特的若望 / 让·德·弗罗莱特 / 水源 下一部 甘泉,玛侬
<让·德·弗洛莱特>电影剧本 文/[法]马赛尔·巴涅尔译/苏原 编者按:<让·德·弗洛莱特>和<甘泉,玛侬>是根据法国著名作家马赛尔·巴涅尔的同名小说改编的电影.马 ...
- centos7安装rsync及两台机器进行文件同步
安装及配置 yum -y install rsync #启动rsync服务 systemctl start rsyncd.service systemctl enable rsyncd.service ...
- oracle python操作 增删改查
oracle删除 删除表内容 truncate table new_userinfo; 删除表 drop table new_userinfo; 1.首先,python链接oracle数据库需要配置好 ...
- 1.2 lvm镜像卷
镜像能够分配物理分区的多个副本,从而提高数据的可用性.当某个磁盘发生故障并且其物理分区变为不可用时,您仍然可以访问可用磁盘上的镜像数据.LVM 在逻辑卷内执行镜像. 系统版本: # cat /etc ...
- 【笔试题】Overloading in Java
笔试题 Overloading in Java Question 1 以下程序的输出结果为( ). public class Test { public int getData() { return ...
- mysql 允许在唯一索引的字段中出现多个null值
线上问题:org.springframework.dao.DuplicateKeyException: PreparedStatementCallback; SQL [update fl_table ...
- 【转帖】HBase简介(梳理知识)
HBase简介(梳理知识) https://www.cnblogs.com/muhongxin/p/9471445.html 一. 简介 hbase是bigtable的开源山寨版本.是建立的hdf ...
- C基础 stack 设计
前言 - stack 设计思路 先说说设计 stack 结构的原由. 以前我们再释放查找树的时候多数用递归的后续遍历去释放. 其内部隐含了运行时的函数栈, 有些语言中存在爆栈风险. 所以想运用显示栈来 ...
- LeetCode 5216. 统计元音字母序列的数目(Java)DP
5216. 统计元音字母序列的数目 给你一个整数 n,请你帮忙统计一下我们可以按下述规则形成多少个长度为 n 的字符串: 字符串中的每个字符都应当是小写元音字母('a', 'e', 'i', 'o', ...