# -*- coding: utf-8 -*-
"""
Created on Mon Sep 10 11:21:27 2018 @author: zhen
"""
from sklearn.datasets import fetch_mldata
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_recall_curve
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import RandomForestClassifier mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData学习资源库/人工智能开发/分类评估/资料/test_data_home') x, y = mnist['data'], mnist['target']
some_digit = x[36000] #获取第36000行数据 some_digit_image = some_digit.reshape(28, 28) plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,
interpolation='nearest', vmin=0, vmax=1)
plt.axis('off')
plt.show() x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000) x_train, y_train = x_train[shuffle_index], y_train[shuffle_index] y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5) sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4)
sgd_clf.fit(x_train, y_train_5) result = sgd_clf.predict([some_digit]) print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy'))
print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision'))
print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall')) sgd_clf.fit(x_train, y_train_5) y_scores = sgd_clf.decision_function([some_digit]) threshold = 0
y_some_digit_pred = (y_scores > threshold) threshold = 200000
y_some_digit_pred = (y_scores > threshold) # cv 数据集划分的个数
y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function') precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores) def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], 'b--',label='Precision')
plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')
plt.xlabel("Threshold")
plt.legend(loc='upper left')
plt.ylim([0, 1])
plt.show() def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label='roc')
plt.plot([0, 1], [0, 1], 'k--', label='mid')
plt.legend(loc='lower right')
# plt.axes([0, 1, 0, 1]) : 前两个参数表示坐标原点的位置,后两个表示x,y轴的长度
plt.xlabel('fpr')
plt.ylabel('tpr')
plt.show() plot_precision_recall_vs_threshold(precisions, recalls, thresholds) fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
plot_roc_curve(fpr, tpr) print(roc_auc_score(y_train_5, y_scores)) forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba')
y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
plt.plot(fpr, tpr, 'b:', label='SGD')
plt.plot(fpr_forest, tpr_forest, label='Random Forest')
plt.legend(loc='lower right')
plt.show() print(roc_auc_score(y_train_5, y_scores_forest))

          

总结:正向准确率和召回率在整体上成反比,可知在使用相同数据集,相同验证方式的情况下,随机森林要优于随机梯度下降!

评估指标【交叉验证&ROC曲线】的更多相关文章

  1. 【分类模型评判指标 二】ROC曲线与AUC面积

    转自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80499031 略有改动,仅供个人学习使用 简介 ROC曲线与AUC面积均是用来 ...

  2. 【机器学习】--模型评估指标之混淆矩阵,ROC曲线和AUC面积

    一.前述 怎么样对训练出来的模型进行评估是有一定指标的,本文就相关指标做一个总结. 二.具体 1.混淆矩阵 混淆矩阵如图:  第一个参数true,false是指预测的正确性.  第二个参数true,p ...

  3. 评价指标的局限性、ROC曲线、余弦距离、A/B测试、模型评估的方法、超参数调优、过拟合与欠拟合

    1.评价指标的局限性 问题1 准确性的局限性 准确率是分类问题中最简单也是最直观的评价指标,但存在明显的缺陷.比如,当负样本占99%时,分类器把所有样本都预测为负样本也可以获得99%的准确率.所以,当 ...

  4. 评估指标:ROC,AUC,Precision、Recall、F1-score

    一.ROC,AUC ROC(Receiver Operating Characteristic)曲线和AUC常被用来评价一个二值分类器(binary classifier)的优劣 . ROC曲线一般的 ...

  5. 召回率、AUC、ROC模型评估指标精要

    混淆矩阵 精准率/查准率,presicion 预测为正的样本中实际为正的概率 召回率/查全率,recall 实际为正的样本中被预测为正的概率 TPR F1分数,同时考虑查准率和查全率,二者达到平衡,= ...

  6. PR曲线,ROC曲线,AUC指标等,Accuracy vs Precision

    作为机器学习重要的评价指标,标题中的三个内容,在下面读书笔记里面都有讲: http://www.cnblogs.com/charlesblc/p/6188562.html 但是讲的不细,不太懂.今天又 ...

  7. 从TP、FP、TN、FN到ROC曲线、miss rate、行人检测评估

    从TP.FP.TN.FN到ROC曲线.miss rate.行人检测评估 想要在行人检测的evaluation阶段要计算miss rate,就要从True Positive Rate讲起:miss ra ...

  8. [机器学习] 性能评估指标(精确率、召回率、ROC、AUC)

    混淆矩阵 介绍这些概念之前先来介绍一个概念:混淆矩阵(confusion matrix).对于 k 元分类,其实它就是一个k x k的表格,用来记录分类器的预测结果.对于常见的二元分类,它的混淆矩阵是 ...

  9. 机器学习 - 案例 - 样本不均衡数据分析 - 信用卡诈骗 ( 标准化处理, 数据不均处理, 交叉验证, 评估, Recall值, 混淆矩阵, 阈值 )

    案例背景 银行评判用户的信用考量规避信用卡诈骗 ▒ 数据 数据共有 31 个特征, 为了安全起见数据已经向了模糊化处理无法读出真实信息目标 其中数据中的 class 特征标识为是否正常用户 (0 代表 ...

随机推荐

  1. C#语言介绍

    C#(读作“See Sharp”)是一种简单易用的新式编程语言,不仅面向对象,还类型安全. C# 源于 C 语言系列,C.C++.Java 和 JavaScript 程序员很快就可以上手使用. C# ...

  2. Eclipse4JavaEE安装SpringBoot

    第一步:下载SpringBoot SpringBoot官网下载链接 第二步:在Eclipse里进行安装 打开Eclipse,菜单栏Help ->Install New Software,进入下图 ...

  3. 设计模式 | 工厂方法模式(factory method)

    定义: 定义一个用于创建对象的接口,让子类决定实例化哪一个类.工厂方法使一个类的实例化延迟到其子类. 结构:(书中图,侵删) 一个工厂的抽象接口 若干个具体的工厂类 一个需要创建对象的抽象接口 若干个 ...

  4. Linux新手随手笔记1.8

    配置网卡服务 将网卡的配置文件,保存成模板,叫做会话. nmcli命令查看网卡信息.nmcli是一款基于命令行的网络配置工具 只有一个网卡信息,下面我们再添加一个. 公司:静态IP地址 家庭:DHCP ...

  5. PHP全栈学习笔记2

    php概述 什么是php,PHP语言的优势,PHP5的新特性,PHP的发展趋势,PHP的应用领域. PHP是超文本预处理器,是一种服务器端,跨平台,HTML嵌入式的脚本语言,具有c语言,Java语言, ...

  6. 3. VIM 系列 - 遇见你的第一个插件

    目录 1. 插件管理利器 vim-plug 1.1 安装插件管理器 1.2 配置插件管理器 1.3 安装插件 1.4 更新插件 1.5 回滚插件 1.6 卸载插件 1. 插件管理利器 vim-plug ...

  7. PHP内核之旅-4.可变长度的字符串

    PHP 内核之旅系列 PHP内核之旅-1.生命周期 PHP内核之旅-2.SAPI中的Cli PHP内核之旅-3.变量 PHP内核之旅-4.字符串 PHP内核之旅-5.强大的数组 PHP内核之旅-6.垃 ...

  8. springboot+aop切点记录请求和响应信息

    本篇主要分享的是springboot中结合aop方式来记录请求参数和响应的数据信息:这里主要讲解两种切入点方式,一种方法切入,一种注解切入:首先创建个springboot测试工程并通过maven添加如 ...

  9. javaScript设计模式之面向对象编程(object-oriented programming,OOP)(一)

    面试的时候,总会被问到,你对javascript面向对象的理解? 面向对象编程(object-oriented programming,OOP)是一种程序设计范型.它讲对象作为程序的设计基本单元,讲程 ...

  10. MySQL视图简介与操作

    1.准备工作 在MySQL数据库中创建两张表balance(余额表)和customer(客户表)并插入数据. create table customer( id int(10) primary key ...