nn.CrossEntropyLoss()这个损失函数和我们普通说的交叉熵还是有些区别。

$x$是模型生成的结果,$class$是数据对应的label

$loss(x,class)=-log(\frac{exp(x[class])}{\sum_j exp(x[j])})=-x[class]+log(\sum_j exp(x[j]))$

nn.CrossEntropyLoss()的使用方式参见如下代码

import torch
import torch.nn as nn # 表示模型的输出output(B,C)格式,B是batch,C是类别
output = torch.randn(2, 3, requires_grad = True) #batch_size设置为2,3分类
# 表示数据的标签label(B)格式,B是batch,其中的数值是位于[0,C-1]
label = torch.empty(2, dtype=torch.long).random_(3) # 0 - 2, 任意选取一个分类
print(output)
'''
tensor([[-1.1313, 0.5944, -1.5735],
[ 1.2037, -1.0548, -0.9253]], requires_grad=True)
'''
print(label)#tensor([0, 2])
loss = nn.CrossEntropyLoss()
#先对每个训练样本求损失,而后再求平均损失
print ('loss :', loss(output, label))#loss : tensor(2.1565, grad_fn=<NllLossBackward>)

pytorch中的nn.CrossEntropyLoss()的更多相关文章

  1. PyTorch 中,nn 与 nn.functional 有什么区别?

    作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...

  2. pytorch中torch.nn构建神经网络的不同层的含义

    主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...

  3. 交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数

    分类问题中,交叉熵函数是比较常用也是比较基础的损失函数,原来就是了解,但一直搞不懂他是怎么来的?为什么交叉熵能够表征真实样本标签和预测概率之间的差值?趁着这次学习把这些概念系统学习了一下. 首先说起交 ...

  4. PyTorch中ReLU的inplace

    0 - inplace 在pytorch中,nn.ReLU(inplace=True)和nn.LeakyReLU(inplace=True)中存在inplace字段.该参数的inplace=True的 ...

  5. PyTorch 中 weight decay 的设置

    先介绍一下 Caffe 和 TensorFlow 中 weight decay 的设置: 在 Caffe 中, SolverParameter.weight_decay 可以作用于所有的可训练参数, ...

  6. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  7. pytorch 中的重要模块化接口nn.Module

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...

  8. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  9. Pytorch中nn.Dropout2d的作用

    Pytorch中nn.Dropout2d的作用 首先,关于Dropout方法,这篇博文有详细的介绍.简单来说, 我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作,这样可以使模型泛化性更 ...

随机推荐

  1. 基于python的种子搜索网站(三)项目部署

    项目部署过程 系统要求:ubuntu 16.04(或以上) 环境搭建和配置,必须严格按照以下步骤来安装部署!如有问题可以咨询(weixin:java2048) 安装部分 安装nginx sudo ap ...

  2. C语言中表达n次方

    C语言中表达n次方可以用pow函数. 函数原型:double pow(double x, double y) 功    能:计算x^y的值 返 回 值:计算结果 举例: double a; a = p ...

  3. spf13-vim 显示neocomplete requires ...th Lua support

    安装spf13-vim的时候下载了许多插件,neocomplete应该是比较重要的一个,毕竟自动补全.但是在使用时却一直有:neocomplete requires ...th Lua support ...

  4. Python基础-day01-6

    算数运算符 计算机,顾名思义就是负责进行 数学计算 并且 存储计算结果 的电子设备 目标 算术运算符的基本使用 01. 算数运算符 算数运算符是 运算符的一种 是完成基本的算术运算使用的符号,用来处理 ...

  5. Spring cloud ——EurekaServer

    Eureka作为服务注册与发现的组件,Eureka2.0已经闭源了,但是本教程还是以Eureka为核心进行展开. 1.三个模块 Spring Cloud Eureka是Spring Cloud Net ...

  6. JS---案例:完整的轮播图---重点!

    案例:完整的轮播图 思路: 分5部分做 1. 获取所有要用的元素 2. 做小按钮,点击移动图标部分 3. 做右边焦点按钮,点击移动图片,小按钮颜色一起跟着变 (克隆了第一图到第六图,用索引liObj. ...

  7. CSS 选择器、字体/文本、背景

    CSS的基本使用 直接写在标签内 <p style="color: red; font-size: 40px;">段落</p> 写在 style 标签内 & ...

  8. 腾讯云推出一站式 DevOps 解决方案 —— CODING DevOps

    在产业互联网的大背景下,如何将人工智能.大数据等前沿技术与实体产业相结合,推动传统企业转型升级,已经成为每一个企业不得不思考的问题.落后的软件研发能力已经拖慢了中国大量企业的数字化转型进程. 为了满足 ...

  9. git 版本检出checkout的方法笔记

    想检出指定版本,比如回退版本,将代码检出到老代码 git checkout 版本号 git reflog git checkout  标签名 1.git log 查看版本信息,复制版本号,执行git ...

  10. 磁盘修复 mount: wrong fs type running e2fsck

    当服务器或PC机器的硬盘在使用一段时间后,会出现无法使用正常进行使用: 1. 当将文件系统挂载到指定的目录的时候,会出现mount 失败,如下图: [root@template ~]# mount / ...