学习笔记26— roc曲线(python)
一、概念:
准确率(Accuracy), 精确率(Precision), 召回率(Recall)和F1-Measure
机器学习(ML), 自然语言处理(NLP), 信息检索(IR)等领域, 评估(Evaluation)是一个必要的工作, 而其评价指标往往有如下几点: 准确率(Accuracy), 精确率(Precision), 召回率(Recall) 和 F1-Measure.(注:相对来说,IR 的 ground truth 很多时候是一个 Ordered List, 而不是一个 Bool 类型的 Unordered Collection,在都找到的情况下,排在第三名还是第四名损失并不是很大,而排在第一名和第一百名,虽然都是“找到了”,但是意义是不一样的,因此更多可能适用于 MAP 之类评估指标.)
本文将简单介绍其中几个概念. 中文中这几个评价指标翻译各有不同, 所以一般情况下推荐使用英文.
现在我先假定一个具体场景作为例子.
假如某个班级有男生 80 人, 女生20人, 共计 100 人. 目标是找出所有女生. 现在某人挑选出 50 个人, 其中 20 人是女生, 另外还错误的把 30 个男生也当作女生挑选出来了. 作为评估者的你需要来评估(evaluation)下他的工作
首先我们可以计算准确率(accuracy), 其定义是: 对于给定的测试数据集,分类器正确分类的样本数与总样本数之比. 也就是损失函数是0-1损失时测试数据集上的准确率[1].
这样说听起来有点抽象,简单说就是,前面的场景中,实际情况是那个班级有男的和女的两类,某人(也就是定义中所说的分类器)他又把班级中的人分为男女两类. accuracy 需要得到的是此君分正确的人占总人数的比例. 很容易,我们可以得到:他把其中70(20女+50男)人判定正确了, 而总人数是100人,所以它的 accuracy 就是70 %(70 / 100).
由准确率,我们的确可以在一些场合,从某种意义上得到一个分类器是否有效,但它并不总是能有效的评价一个分类器的工作. 举个例子, google 抓取了 argcv 100个页面,而它索引中共有10,000,000个页面, 随机抽一个页面,分类下, 这是不是 argcv 的页面呢?如果以 accuracy 来判断我的工作,那我会把所有的页面都判断为"不是 argcv 的页面", 因为我这样效率非常高(return false, 一句话), 而 accuracy 已经到了99.999%(9,999,900/10,000,000), 完爆其它很多分类器辛辛苦苦算的值, 而我这个算法显然不是需求期待的, 那怎么解决呢?这就是 precision, recall 和 f1-measure 出场的时间了.
再说 precision, recall 和 f1-measure 之前, 我们需要先需要定义 TP, FN, FP, TN 四种分类情况.
按照前面例子, 我们需要从一个班级中的人中寻找所有女生, 如果把这个任务当成一个分类器的话, 那么女生就是我们需要的, 而男生不是, 所以我们称女生为"正类", 而男生为"负类".
相关(Relevant), 正类 | 无关(NonRelevant), 负类 | |
---|---|---|
被检索到(Retrieved) | true positives (TP 正类判定为正类, 例子中就是正确的判定"这位是女生") | false positives (FP 负类判定为正类,"存伪", 例子中就是分明是男生却判断为女生, 当下伪娘横行, 这个错常有人犯) |
未被检索到(Not Retrieved) | false negatives (FN 正类判定为负类,"去真", 例子中就是, 分明是女生, 这哥们却判断为男生--梁山伯同学犯的错就是这个) | true negatives (TN 负类判定为负类, 也就是一个男生被判断为男生, 像我这样的纯爷们一准儿就会在此处) |
或者
可以很容易看出, 所谓 TRUE/FALSE 表示从结果是否分对了, Positive/Negative 表示我们认为的是"是"还是"不是".
通过这张表, 我们可以很容易得到这几个值:
- TP=20
- FP=30
- FN=0
- TN=50
精确率(precision)的公式是P=TPTP+FPP = \frac{TP}{TP+FP}P=TP+FPTP, 它计算的是所有"正确被检索的结果(TP)"占所有"实际被检索到的(TP+FP)"的比例.
在例子中就是希望知道此君得到的所有人中, 正确的人(也就是女生)占有的比例. 所以其 precision 也就是40%(20女生/(20女生+30误判为女生的男生)).
召回率(recall)的公式是R=TPTP+FNR = \frac{TP}{TP+FN}R=TP+FNTP, 它计算的是所有"正确被检索的结果(TP)"占所有"应该检索到的结果(TP+FN)"的比例.
在例子中就是希望知道此君得到的女生占本班中所有女生的比例, 所以其 recall 也就是100%(20女生/(20女生+ 0 误判为男生的女生))
F1值就是精确值和召回率的调和均值, 也就是
2F1=1P+1R\frac{2}{F_1} = \frac{1}{P} + \frac{1}{R} F12=P1+R1
调整下也就是
F1=2PRP+R=2TP2TP+FP+FNF_1 = \frac{2PR}{P+R} = \frac{2TP}{2TP + FP + FN} F1=P+R2PR=2TP+FP+FN2TP
例子中 F1-measure 也就是约为 57.143%(2∗0.4∗10.4+1\frac{2 * 0.4 * 1}{0.4 + 1}0.4+12∗0.4∗1).
需要说明的是, 有人[3]列了这样个公式, 对非负实数β\betaβ
Fβ=(β2+1)∗PRβ2∗P+RF_\beta = (\beta^2 + 1) * \frac{PR}{\beta^2*P+R}Fβ=(β2+1)∗β2∗P+RPR
将 F-measure 一般化.
F1-measure 认为精确率和召回率的权重是一样的, 但有些场景下, 我们可能认为精确率会更加重要, 调整参数 β\betaβ , 使用 Fβ_\betaβ-measure 可以帮助我们更好的 evaluate 结果.
注意:召回率就是tpr(TP/(TP+FN)) (命中率),fpr(FP/(FP+TN)):特异性(specificity),TNR(TN/(FP+TN)):敏感性
参考链接:http://mlwiki.org/index.php/ROC_Analysis#ROC_Analysis
https://blog.argcv.com/articles/1036.c
http://www.cnblogs.com/pinard/p/5993450.html
1、简单(源码):
print(__doc__)
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
# Import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Binarize the output
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]
# Add noisy features to make the problem harder
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
random_state=0)
# Learn to predict each class against the other
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
random_state=random_state))
y_score = classifier.fit(X_train, y_train).decision_function(X_test)
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
#Plot of a ROC curve for a specific class
plt.figure()
lw = 2
plt.plot(fpr[2], tpr[2], color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[2])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
2、复杂(源码):
print(__doc__) import numpy as np
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import StratifiedKFold # #############################################################################
# Data IO and generation # Import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape # Add noisy features
random_state = np.random.RandomState(0)
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] # #############################################################################
# Classification and ROC analysis # Run classifier with cross-validation and plot ROC curves
cv = StratifiedKFold(n_splits=6)
classifier = svm.SVC(kernel='linear', probability=True,
random_state=random_state) tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100) i = 0
for train, test in cv.split(X, y):
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
# Compute ROC curve and area the curve
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
tprs.append(interp(mean_fpr, fpr, tpr))
tprs[-1][0] = 0.0
roc_auc = auc(fpr, tpr)
aucs.append(roc_auc)
plt.plot(fpr, tpr, lw=1, alpha=0.3,
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc)) i += 1
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8) mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
plt.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8) std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.') plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
Best threshold:
- we know the slope of the accuracy line: it's 1
- the best classifier for this slope is the 6th one
- threshold value θ
- so we take the score obtained on the 6th record
- and use it as the threshold value θ
- i.e. predict positive if θ⩾0.54
- if we check, we see that indeed we have accuracy = 0.7
学习笔记26— roc曲线(python)的更多相关文章
- python3.4学习笔记(二十六) Python 输出json到文件,让json.dumps输出中文 实例代码
python3.4学习笔记(二十六) Python 输出json到文件,让json.dumps输出中文 实例代码 python的json.dumps方法默认会输出成这种格式"\u535a\u ...
- python3.4学习笔记(二十五) Python 调用mysql redis实例代码
python3.4学习笔记(二十五) Python 调用mysql redis实例代码 #coding: utf-8 __author__ = 'zdz8207' #python2.7 import ...
- python3.4学习笔记(二十四) Python pycharm window安装redis MySQL-python相关方法
python3.4学习笔记(二十四) Python pycharm window安装redis MySQL-python相关方法window安装redis,下载Redis的压缩包https://git ...
- python3.4学习笔记(二十二) python 在字符串里面插入指定分割符,将list中的字符转为数字
python3.4学习笔记(二十二) python 在字符串里面插入指定分割符,将list中的字符转为数字在字符串里面插入指定分割符的方法,先把字符串变成list然后用join方法变成字符串str=' ...
- openresty 学习笔记番外篇:python的一些扩展库
openresty 学习笔记番外篇:python的一些扩展库 要写一个可以使用的python程序还需要比如日志输出,读取配置文件,作为守护进程运行等 读取配置文件 使用自带的ConfigParser模 ...
- openresty 学习笔记番外篇:python访问RabbitMQ消息队列
openresty 学习笔记番外篇:python访问RabbitMQ消息队列 python使用pika扩展库操作RabbitMQ的流程梳理. 客户端连接到消息队列服务器,打开一个channel. 客户 ...
- python学习笔记(三)---python关键字及其用法
转载出处:https://www.cnblogs.com/ECJTUACM-873284962/p/7576959.html 前言 最近在学习Java Sockst的时候遇到了一些麻烦事,我觉得我很有 ...
- [原创]java WEB学习笔记26:MVC案例完整实践(part 7)---修改的设计和实现
本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...
- python学习笔记:安装boost python库以及使用boost.python库封装
学习是一个累积的过程.在这个过程中,我们不仅要学习新的知识,还需要将以前学到的知识进行回顾总结. 前面讲述了Python使用ctypes直接调用动态库和使用Python的C语言API封装C函数, C+ ...
随机推荐
- Django模糊查询
https://blog.csdn.net/liuweiyuxiang/article/details/71104613 def search(request): searchtype = reque ...
- P2044 [NOI2012]随机数生成器
洛咕原题 正常的矩乘题. 但是,计算过程中会爆long long. 所以,我们要用快速(龟速)乘来解决. 快速乘,也就是把快速幂稍作修改.乘法被分成若干个加法,以时间为代价解决精度问题. #inclu ...
- c#简单案例--单位转换器
经过几天学习,写出了一个简单的winform应用程序,贴出源码,以备不时之需. 软件启动后的界面如下图所示: 如图,该程序由6个label.8个comboBox.8个textBox和4个button组 ...
- iframe跨域问题:Uncaught DOMException: Blocked a frame with origin解决方法
在前后端分离的情况下,前台页面将后台页面加载在预留的iframe中:但是遇到了iframe和主窗口双滚动条的情况,由此引申出来了问题: 只保留单个滚动条,那么就要让iframe的高度自适应,而从主页面 ...
- gdb远程debug A syntax error in expression, near `variable)'.
今天调试有个linux环境的应用时,gdb提示A syntax error in expression, near `variable)'.,最后经查,gdb版本过低(比如7.2)或者源代码不匹配所致 ...
- oracle 12.2 linux/solaris正式发布
oracle 12.2 linux/solaris正式发布,可以从http://www.oracle.com/technetwork/database/enterprise-edition/downl ...
- shell 调试脚本设置
set -x 脚本部分内容 set +x
- ScheduledTheadPool线程池的使用
ScheduledTheadPool线程池的特点在于可以延迟执行任务,也可以周期性执行任务. 创建线程池 ScheduledExecutorService scheduled = Executors. ...
- centos设置中文输入法无效的解决办法
安装 im-chooser: yum install im-chooser 回到当前使用的普通用户,设置 ibus 输入法为默认输入系统: imsettings-switch ibus
- javbus爬虫-老司机你值得拥有
# 起因 有个朋友叫我帮忙写个爬虫,爬取javbus5上面所有的详情页链接,也就是所有的https://www.javbus5.com/SRS-055这种链接, 我一看,嘿呀,这是司机的活儿啊,我绝对 ...