用sklearn的DecisionTreeClassifer训练模型,然后用roc_auc_score计算模型的auc。代码如下

clf = DecisionTreeClassifier(criterion='gini', max_depth=6, min_samples_split=10, min_samples_leaf=2)
clf.fit(X_train, y_train)
y_pred = clf.predict_proba(X_test)
roc_auc = roc_auc_score(y_test, y_pred)

报错信息如下

/Users/wgg/anaconda/lib/python2.7/site-packages/sklearn/metrics/ranking.pyc in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
297 check_consistent_length(y_true, y_score)
298 y_true = column_or_1d(y_true)
--> 299 y_score = column_or_1d(y_score)
300 assert_all_finite(y_true)
301 assert_all_finite(y_score) /Users/wgg/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.pyc in column_or_1d(y, warn)
560 return np.ravel(y)
561
--> 562 raise ValueError("bad input shape {0}".format(shape))
563
564 ValueError: bad input shape (900, 2)

目测是你的y_pred出了问题,你的y_pred是(900, 2)的array,也就是有两列。

因为predict_proba返回的是两列。predict_proba的用法参考这里

简而言之,你上面的代码改成这样就可以了。

y_pred = clf.predict_proba(X_test)[:, 1]
roc_auc = roc_auc_score(y_test, y_pred)

原文:http://sofasofa.io/forum_main_post.php?postid=1001678

sklearn里计算roc_auc_score,报错ValueError: bad input shape的更多相关文章

  1. 标记编码报错ValueError: bad input shape ()

    <Python机器学习经典实例>2.9小节中,想自己动手实践汽车特征评估质量,所以需要对数据进行预处理,其中代码有把字符串标记编码为对应的数字,如下代码 input_data = ['vh ...

  2. keras 报错 ValueError: Tensor conversion requested dtype int32 for Tensor with dtype float32: 'Tensor("embedding_1/random_uniform:0", shape=(5001, 128), dtype=float32)'

    在服务器上训练并保存模型,复制到本地之后load_model()报错: ValueError: Tensor conversion requested dtype int32 for Tensor w ...

  3. matplotlib.pyplot import报错: ValueError: _getfullpathname: embedded null character in path

    Environment: Windows 10, Anaconda 3.6 matplotlib 2.0 import matplotlib.pyplot 报错: ValueError: _getfu ...

  4. 安装 r 里的 igraph 报错

    转载来源:http://genek.tv/article/40 1186 0 0 安装 r 里的 igraph 报错: foreign-graphml.c: In function ‘igraph_w ...

  5. dbfread报错ValueError错误解决方法

    问题 我在用dbfread处理.dbf数据的时候出现了报错 ValueError("could not convert string to float: b'.'",) 然后查找. ...

  6. moviepy音视频剪辑VideoClip类fl_image方法image_func报错ValueError: assignment destination is read-only解决办法

    ☞ ░ 前往老猿Python博文目录 ░ moviepy音视频剪辑模块的视频剪辑基类VideoClip的fl_image方法用于进行对剪辑帧数据进行变换. 调用语法:fl_image(self, im ...

  7. Linux部署Django:报错 nohup: ignoring input and appending output to ‘nohup.out’

    一.部署 Django 到远程 Linux 服务器 利用 xshell 通过 ssh 连接到 Linux服务器,常规的启动命令是 python3 manage.py runserver 但是,关闭 x ...

  8. tensorflow-TFRecord报错ValueError: Protocol message Feature has no "feature" field.

    编写代码用TFRecord数据结构存储数据集信息是报错:ValueError: Protocol message Feature has no "feature" field.或和 ...

  9. datetime.strptime格式转换报错ValueError

    今天遇到一个报错:ValueError: time data '2018-10-10(Wednesday) AM0:50' does not match format '%Y-%m-%d(%A) %p ...

随机推荐

  1. DB2新建编目及删除编目

    场景:在添加一个新数据库的连接时,需要先建立此数据库的编目信息 新建: 1.获取数据库IP.端口.数据库名称 2.打开DB2客户端的“DB2命令窗口” 3.按以下命令执行 db2 catalog tc ...

  2. Windows10 临时将线程绑定至指定CPU的方法

    本文首发:https://www.somata.work/2019/WindowsThreadBind.html 将线程绑定至指定CPU,这个应该时很多管理员需要了解认知的操作了吧,这样可以在一定程度 ...

  3. centos7 修改内核文件 网卡名称为标准名称eth0

    在开机安装系统之前按TAB键后输入标记信息后安装系统就可以变成标准网卡接口eth0 或eth1

  4. python - django (session)

    # """ # Session # 是存在服务端的键值对 # Session 必须依赖Cookie 存储Session: · 在服务器生成随机字符串 · 生成一个和上面随 ...

  5. vue router.beforeEach(),详解

    outer.beforeEach()一般用来做一些进入页面的限制. 比如没有登录, 就不能进入某些页面,只有登录了之后才有权限查看某些页面...说白了就是路由拦截.第一步 规定进入路由需不需要权限 @ ...

  6. nginx动静分离配置

    动静分离: 所谓动静分离指的是当访问静态资源时,路由到一台静态资源服务器,当访问是非静态资源时,路由到另外一台服务器 静态资源配置: 如配置如下location 表示url为  /static/*.x ...

  7. javascript权威指南第17章 错误异常处理

    function TestTryCatch(){ try { } catch (error) { //error 类型如下 Error EvalError RangeError ReferenceEr ...

  8. 010——C#选择文件路径

    (一)具体教程查看:011——C#创建ECXEL文件(附教程) (二)代码:foldPath 就是获取到的文件路径 private void button1_Click(object sender, ...

  9. bootstrap富文本编辑

    先把设定富文本框架 <div class="form-group"> <label class="col-sm-2 control-label" ...

  10. K-D Tree学习笔记

    用途 做各种二维三维四维偏序等等. 代替空间巨大的树套树. 数据较弱的时候水分. 思想 我们发现平衡树这种东西功能强大,然而只能做一维上的询问修改,显得美中不足. 于是我们尝试用平衡树的这种二叉树结构 ...