pytorch中的nn.CrossEntropyLoss()
nn.CrossEntropyLoss()这个损失函数和我们普通说的交叉熵还是有些区别。
$x$是模型生成的结果,$class$是数据对应的label
$loss(x,class)=-log(\frac{exp(x[class])}{\sum_j exp(x[j])})=-x[class]+log(\sum_j exp(x[j]))$
nn.CrossEntropyLoss()的使用方式参见如下代码
import torch
import torch.nn as nn # 表示模型的输出output(B,C)格式,B是batch,C是类别
output = torch.randn(2, 3, requires_grad = True) #batch_size设置为2,3分类
# 表示数据的标签label(B)格式,B是batch,其中的数值是位于[0,C-1]
label = torch.empty(2, dtype=torch.long).random_(3) # 0 - 2, 任意选取一个分类
print(output)
'''
tensor([[-1.1313, 0.5944, -1.5735],
[ 1.2037, -1.0548, -0.9253]], requires_grad=True)
'''
print(label)#tensor([0, 2])
loss = nn.CrossEntropyLoss()
#先对每个训练样本求损失,而后再求平均损失
print ('loss :', loss(output, label))#loss : tensor(2.1565, grad_fn=<NllLossBackward>)
pytorch中的nn.CrossEntropyLoss()的更多相关文章
- PyTorch 中,nn 与 nn.functional 有什么区别?
作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...
- pytorch中torch.nn构建神经网络的不同层的含义
主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...
- 交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数
分类问题中,交叉熵函数是比较常用也是比较基础的损失函数,原来就是了解,但一直搞不懂他是怎么来的?为什么交叉熵能够表征真实样本标签和预测概率之间的差值?趁着这次学习把这些概念系统学习了一下. 首先说起交 ...
- PyTorch中ReLU的inplace
0 - inplace 在pytorch中,nn.ReLU(inplace=True)和nn.LeakyReLU(inplace=True)中存在inplace字段.该参数的inplace=True的 ...
- PyTorch 中 weight decay 的设置
先介绍一下 Caffe 和 TensorFlow 中 weight decay 的设置: 在 Caffe 中, SolverParameter.weight_decay 可以作用于所有的可训练参数, ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- pytorch 中的重要模块化接口nn.Module
torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- Pytorch中nn.Dropout2d的作用
Pytorch中nn.Dropout2d的作用 首先,关于Dropout方法,这篇博文有详细的介绍.简单来说, 我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作,这样可以使模型泛化性更 ...
随机推荐
- 基于python的种子搜索网站(三)项目部署
项目部署过程 系统要求:ubuntu 16.04(或以上) 环境搭建和配置,必须严格按照以下步骤来安装部署!如有问题可以咨询(weixin:java2048) 安装部分 安装nginx sudo ap ...
- C语言中表达n次方
C语言中表达n次方可以用pow函数. 函数原型:double pow(double x, double y) 功 能:计算x^y的值 返 回 值:计算结果 举例: double a; a = p ...
- spf13-vim 显示neocomplete requires ...th Lua support
安装spf13-vim的时候下载了许多插件,neocomplete应该是比较重要的一个,毕竟自动补全.但是在使用时却一直有:neocomplete requires ...th Lua support ...
- Python基础-day01-6
算数运算符 计算机,顾名思义就是负责进行 数学计算 并且 存储计算结果 的电子设备 目标 算术运算符的基本使用 01. 算数运算符 算数运算符是 运算符的一种 是完成基本的算术运算使用的符号,用来处理 ...
- Spring cloud ——EurekaServer
Eureka作为服务注册与发现的组件,Eureka2.0已经闭源了,但是本教程还是以Eureka为核心进行展开. 1.三个模块 Spring Cloud Eureka是Spring Cloud Net ...
- JS---案例:完整的轮播图---重点!
案例:完整的轮播图 思路: 分5部分做 1. 获取所有要用的元素 2. 做小按钮,点击移动图标部分 3. 做右边焦点按钮,点击移动图片,小按钮颜色一起跟着变 (克隆了第一图到第六图,用索引liObj. ...
- CSS 选择器、字体/文本、背景
CSS的基本使用 直接写在标签内 <p style="color: red; font-size: 40px;">段落</p> 写在 style 标签内 & ...
- 腾讯云推出一站式 DevOps 解决方案 —— CODING DevOps
在产业互联网的大背景下,如何将人工智能.大数据等前沿技术与实体产业相结合,推动传统企业转型升级,已经成为每一个企业不得不思考的问题.落后的软件研发能力已经拖慢了中国大量企业的数字化转型进程. 为了满足 ...
- git 版本检出checkout的方法笔记
想检出指定版本,比如回退版本,将代码检出到老代码 git checkout 版本号 git reflog git checkout 标签名 1.git log 查看版本信息,复制版本号,执行git ...
- 磁盘修复 mount: wrong fs type running e2fsck
当服务器或PC机器的硬盘在使用一段时间后,会出现无法使用正常进行使用: 1. 当将文件系统挂载到指定的目录的时候,会出现mount 失败,如下图: [root@template ~]# mount / ...