之前一直不清楚Top1和Top5是什么,其实搞清楚了很简单,就是两种衡量指标,其中,Top1就是普通的Accuracy,Top5比Top1衡量标准更“严格”,

具体来讲,比如一共需要分10类,每次分类器的输出结果都是10个相加为1的概率值,Top1就是这十个值中最大的那个概率值对应的分类恰好正确的频率,而Top5则是在十个概率值中从大到小排序出前五个,然后看看这前五个分类中是否存在那个正确分类,再计算频率。Pytorch实现如下:

def evaluteTop1(model, loader):
model.eval() correct = 0
total = len(loader.dataset) for x,y in loader:
x,y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
#correct += torch.eq(pred, y).sum().item()
return correct / total def evaluteTop5(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x,y = x.to(device),y.to(device)
with torch.no_grad():
logits = model(x)
maxk = max((1,5))
        y_resize = y.view(-1,1)
_, pred = logits.topk(maxk, 1, True, True)
correct += torch.eq(pred, y_resize).sum().float().item()
return correct / total

注意:y_resize = y.view(-1,1)是非常关键的一步,在correct的运算中,关键就是要pred和y_resize维度匹配,而原来的y是[128],128是batch大小;

pred的维度则是[128,10],假设这里是CIFAR10十分类;因此必须把y转化成[128,1]这种维度,但是不能直接是y.view(128,1),因为遍历整个数据集的时候,

最后一个batch大小并不是128,所以view()里面第一个size就设为-1未知,而确保第二个size是1就行

topk函数的具体用法参见https://blog.csdn.net/u014264373/article/details/86525621

Pytorch实现Top1准确率和Top5准确率的更多相关文章

  1. 【猫狗数据集】使用top1和top5准确率衡量模型

    数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...

  2. 深度学习基础系列(二)| 常见的Top-1和Top-5有什么区别?

    在深度学习过程中,会经常看见各成熟网络模型在ImageNet上的Top-1准确率和Top-5准确率的介绍,如下图所示: 那Top-1 Accuracy和Top-5 Accuracy是指什么呢?区别在哪 ...

  3. 基础网络之EfficientNet

    摘要: 一般情况下,我们都会根据当前的硬件资源来设计相应的卷积神经网络,如果资源升级,可以将模型结构放大以获取更好精度.我们系统地研究模型缩放并验证网络深度,宽度和分辨率之间的平衡以得到更好的性能表现 ...

  4. [开发技巧]·TopN指标计算方法

    [开发技巧]·TopN指标计算方法 ​ 1.概念介绍 在图片分类的中经常可以看到Top-1,Top-5等TopN准确率(或者时错误率). 那这个TopN是什么意思呢?首先Top-1准确率最好理解,就是 ...

  5. 我的Keras使用总结(4)——Application中五款预训练模型学习及其应用

    本节主要学习Keras的应用模块 Application提供的带有预训练权重的模型,这些模型可以用来进行预测,特征提取和 finetune,上一篇文章我们使用了VGG16进行特征提取和微调,下面尝试一 ...

  6. pytorch识别CIFAR10:训练ResNet-34(准确率80%)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com CNN的层数越多,能够提取到的特征越丰富,但是简单地增加卷积层数,训练时会导致梯度弥散或梯度爆炸. 何 ...

  7. 混淆矩阵、准确率、精确率/查准率、召回率/查全率、F1值、ROC曲线的AUC值

    准确率.精确率(查准率).召回率(查全率).F1值.ROC曲线的AUC值,都可以作为评价一个机器学习模型好坏的指标(evaluation metrics),而这些评价指标直接或间接都与混淆矩阵有关,前 ...

  8. 深度学习(PYTORCH)-3.sphereface-pytorch.lfw_eval.py详解

    pytorch版本sphereface的原作者地址:https://github.com/clcarwin/sphereface_pytorch 由于接触深度学习不久,所以花了较长时间来阅读源码,以下 ...

  9. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

随机推荐

  1. SAP的春天回来么?

    作为一个财务出身的码农,经常会关注在财务和编程的交叉领域,新兴的细分领域有:德勤的财务机器人,RPA机器人,FINTECH等等. 但是非要说一个便是sap.如果呈把用友成立之年算作sap元年,1988 ...

  2. hiho #1474 拆字游戏(dfs,记录状态)

    #1474 : 拆字游戏 时间限制:10000ms 单点时限:1000ms 内存限制:256MB 描述 小Kui喜欢把别人的名字拆开来,比如“螺”就可以拆成“虫田糸”,小Kui的语文学的不是很好,于是 ...

  3. eclipse svn同步过滤掉某些不需要同步的文件

    注:这里说的svn是eclipse里svn插件 默认情况下,我们在点击svn同步时,总是会把一些不需要的目录和文件也给同步了,这样我觉得很晃眼睛,所以在这里说下怎么去去掉不想同步的文件 1.默认同步下 ...

  4. python-加密算法

    #!/usr/bin/python3 # coding:utf-8 # Auther:AlphaPanda # Description: 使用hashlib模块的md5和sha系列加密算法对字符串进行 ...

  5. vs 2019 调试web项目 浏览器

  6. Eclipse 开发环境修改及MAVEN配置

    Eclipse集成Maven配置 默认为 修改为所用版本 选择maven软件所在目录 勾选 默认连接仓库为 修改为

  7. ASP.NET通过反射生成sql语句

    最近对接一个接口,需要通过xml序列化成实体后添加额外信息后批量插入数据库,需要手动拼sql.因为涉及多张表,拼凑很麻烦而且容易出错,所以写了两个工具方法来生成sql,先写到博客里面,以便以后不时之需 ...

  8. Spring配置文件beans标签报错问题解决

    因为有很多配置是复制过来的,附带的很多注释的格式会导致报错,所以可以要试试把注释去掉,只有配置文件的话可能就不会报错了.

  9. JavaWeb_(MVC)管理员后台商品查询demo

    MVC分层实现管理员后台商品查询 MVC层即model view controller Model(模型):模型代表着核心的业务逻辑和数据(不要理解成Model只是实体类) View(视图):视图应该 ...

  10. 当 LAST_INSERT_ID() 带有参数时# 清空重来

    [root@yejr.me]> truncate table t; # 插入1条新记录[root@yejr.me]> insert into t select 0,rand()*1024; ...