Pytorch里的CrossEntropyLoss详解
在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax。看得我头大,所以整理本文以备日后查阅。
首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?
下面是对与cross entropy有关的函数做的总结:
torch.nn | torch.nn.functional (F) |
---|---|
CrossEntropyLoss | cross_entropy |
LogSoftmax | log_softmax |
NLLLoss | nll_loss |
下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。
在介绍cross_entropy之前先介绍两个基本函数:
log_softmax
这个很好理解,其实就是log和softmax合并在一起执行。
nll_loss
该函数的全程是negative log likelihood loss,函数表达式为
\]
例如假设\(x=[1,2,3], class=2\),那额\(f(x,class)=-x[2]=-3\)
cross_entropy
交叉熵的计算公式为:
\]
其中\(p\)表示真实值,在这个公式中是one-hot形式;\(q\)是预测值,在这里假设已经是经过softmax后的结果了。
仔细观察可以知道,因为\(p\)的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真实类别)可变形为:
\]
仔细看看,是不是就是等同于log_softmax和nll_loss两个步骤。
所以Pytorch中的F.cross_entropy会自动调用上面介绍的log_softmax和nll_loss来计算交叉熵,其计算方式如下:
\]
代码示例
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randint(5, (3,), dtype=torch.int64)
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()
Pytorch里的CrossEntropyLoss详解的更多相关文章
- pytorch之nn.Conv1d详解
转自:https://blog.csdn.net/sunny_xsc1994/article/details/82969867,感谢分享 pytorch之nn.Conv1d详解
- 全网最全的Windows下Anaconda2 / Anaconda3里Python语言实现定时发送微信消息给好友或群里(图文详解)
不多说,直接上干货! 缘由: (1)最近看到情侣零点送祝福,感觉还是很浪漫的事情,相信有很多人熬夜为了给爱的人送上零点祝福,但是有时等着等着就睡着了或者时间并不是卡的那么准就有点强迫症了,这是也许程序 ...
- 【小白学PyTorch】11 MobileNet详解及PyTorch实现
文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...
- pytorch nn.LSTM()参数详解
输入数据格式:input(seq_len, batch, input_size)h0(num_layers * num_directions, batch, hidden_size)c0(num_la ...
- Pytorch Bi-LSTM + CRF 代码详解
久闻LSTM + CRF的效果强大,最近在看Pytorch官网文档的时候,看到了这段代码,前前后后查了很多资料,终于把代码弄懂了.我希望在后来人看这段代码的时候,直接就看我的博客就能完全弄懂这段代码. ...
- Yii 框架里数据库操作详解-[增加、查询、更新、删除的方法 'AR模式']
public function getMinLimit () { $sql = "..."; $result = yii::app()->db-& ...
- 扩展运算符及其在vuex的辅助函数里的应用详解
一.扩展运算符 <1>为什么扩展运算符会诞生? 因为箭头函数没有arguments,所以才有了扩展运算符 <2>在箭头函数里 ...
- pytorch BiLSTM+CRF代码详解 重点
一. BILSTM + CRF介绍 https://www.jianshu.com/p/97cb3b6db573 1.介绍 基于神经网络的方法,在命名实体识别任务中非常流行和普遍. 如果你不知道Bi- ...
- 【小白学PyTorch】12 SENet详解及PyTorch实现
文章来自微信公众号[机器学习炼丹术].我是炼丹兄,有什么问题都可以来找我交流,近期建立了微信交流群,也在朋友圈抽奖赠书十多本了.我的微信是cyx645016617,欢迎各位朋友. 参考目录: @ 目录 ...
随机推荐
- 回顾servlet生命周期(代码测试),读取初始化参数
servlet生命周期 为简洁,本例使用注解方式来测试,代码部分很简单,只需要新建一个serlet,继承自HttpServlet,重写init,doGet,doPost,destory方法即可,使用注 ...
- redis 初步认识三(设置登录密码)
1.cmd 2.cd C:\Program Files\Redis 3.redis-cli.exe -h 127.0.0.1 -a 123456
- 采用ADM2483磁隔离器让RS485接口更简单更安全
采用ADM2483磁隔离器让RS485接口更简单更安全 摘要:本文介绍RS485的特点及应用,指出了普通RS485接口易损坏的问题,针对存在的问题介绍了以ADM2483为核心的磁隔离解决方案. 关键词 ...
- vc图像合成
本程序下载地址: 上一篇讲述了tiff格式图片拆分成多张图片, 这篇博客讲述如何把多张任意格式的图片合成为一张图片. 图像合成仍然需要借助Cximage图像库,合成函数为Mixfrom, 函数原型为: ...
- [Spark]如何设置使得spark程序不输出 INFO级别的内容
Spark程序在运行的时候,总是输出很多INFO级别内容 查看了网上的一些文章,进行了试验. 发现在 /etc/spark/conf 目录下,有一个 log4j.properties.template ...
- oracle:TNS:监听程序无法分发客户机连接
挂上vpn的时候,PL/SQL连接到oracle的时候,显示ORA-12518:监听程序无法分发客户机连接.如下图: 一.[问题描述] 最近,在系统高峰期的时候,会提示如上的错误,致使无法连接到服务器 ...
- 查看crontab运行状态
cron服务是linux的内置服务,但它不会开机自动启动.可以用以下命令启动和停止服务: /sbin/service crond start/sbin/service crond stop/sbin/ ...
- Django模板语言初识
一.Django框架简介 1.MVC框架 MVC,全名是Model View Controller,是软件工程中的一种软件架构模式,把软件系统分为三个基本部分:模型(Model).视图(View)和控 ...
- JS判断浏览器是否支持触屏事件
var hasTouch=function(){ var touchObj={}; touchObj.isSupportTouch = "ontouchend" in docume ...
- placeholder效果
<!DOCTYPE HTML> <html lang="en-US"> <head> <meta charset="UT ...