深度学习 (DeepLearning) 基础 [2]---神经网络常用的损失函数


Introduce

在上一篇“深度学习 (DeepLearning) 基础 [1]---监督学习和无监督学习”中我们介绍了监督学习和无监督学习相关概念。本文主要介绍神经网络常用的损失函数。

以下均为个人学习笔记,若有错误望指出。


神经网络常用的损失函数

pytorch损失函数封装在torch.nn中。

损失函数反映了模型预测输出与真实值的区别,模型训练的过程即让损失函数不断减小,最终得到可以拟合预测训练样本的模型。

note:由于PyTorch神经网络模型训练过程中每次传入一个mini-batch的数据,因此pytorch内置损失函数的计算出来的结果如果没有指定reduction参数,则默认对mini-batch取平均

以下对几个常用的损失函数以及其应用场景做一个简单总结。(以下损失函数的公式均代表单个min-batch的损失,且假设x为神经网络的预测输出,y为样本的真实值,xi为一个mini-batch中第i个样本的预测输出,yi同理,n为一个批量mini-batch的大小

  • nn.L1Loss(L1损失,也称平均绝对误差MAE):计算模型输出x与目标y之间差的绝对值。常用于回归任务。

\[loss(x,y) = {1\over n}\sum|x_i-y_i|
\]

'''代码示例'''
loss_func = torch.nn.L1Loss(reduction='mean')
'''note:
reduction=None 啥也不干
reduction='mean' 返回loss和的平均值
reduction='mean' 返回loss的和。
不指定即默认mean。
'''
  • nn.MSELoss(L2损失,也称均方误差MSE):计算模型输出x与目标y之间差的平方的均值,均方差。常用于回归任务。

\[loss(x,y) = {1\over n}\sum(x_i-y_i)^2
\]

'''代码示例'''
loss_func = torch.nn.MSELoss(reduction='mean')
# note: reduction同上。
  • nn.BCELoss(二进制交叉熵损失):计算模型输出x与目标y之间的交叉熵。(我对于交叉熵的理解,交叉熵为相对熵(即KL散度,用来衡量两个分布的差异程度)中的一项,最小化两个分布的差异,即最小化相对熵,由相对熵公式,由于真实分布是确定的,那么最小化相对熵就是最小化交叉熵,而最小化交叉熵的目标就是寻找一个预测分布尽可能逼近真实分布,这和我们模型的训练目标是一致的,即让模型预测逼近样本真实值,参考链接)常用于二分类任务。

\[loss(x,y) = {1\over n}\sum-w_i[y_i*logx_i + (1-y_i)*log(1-x_i)]
\]

'''代码示例'''
loss_func = torch.nn.BCELoss(weight=None, reduction='mean')
# note:
# weight为长度为n的tensor,用来指定一个batch中各样本占有的权重,如公式中的wi,不指定默认为各样本权重均为1。
# reduction同上。 # 用的时候需要在该层前面加上 Sigmoid 函数。
  • nn.NLLLoss(负对数似然损失):将神经网络输出的隶属各个类的概率向量x与对应真实标签向量(个人理解应该是one-hot向量吧)相差再相加,最后再取负。如果不取负的话,应该是loss值越大预测标签越接近真实标签,取负的话反过来,越小则越接近真实标签,符合loss函数的定义。常用于多分类任务。 以下公式假设节点xi属于第j类,x[j]为预测的x属于第j类的概率,且w[j]为第j类的权重

\[loss(x,class) = {1\over n}\sum -w[j]*x[j]
\]

'''代码示例'''
loss_func = torch.nn.NLLLoss(weight=None, reduction='mean')
# note:
# weight同上,如公式中的w代表各个类在损失中占有的权重,即类的重要程度,若不赋予权重w,则各类同等重要,上述公式中的w[class]去掉。
# reduction同上。
  • nn.CrossEntropyLoss (交叉熵损失):如上述二进制交叉熵所示,随着预测的概率分布越来越接近实际标签,交叉熵会逐渐减小。pytorch将nn.LogSoftmax()和nn.NLLLoss()组合到nn.CrossEntropyLoss(),即调用nn.CrossEntropyLoss() 底层会调用上述两个函数,可以理解为 CrossEntropyLoss = LogSoftmax + NLLLoss。因此一般多分类任务都常用交叉熵损失。 以下label_i代表节点xi的真实标签,c为总的标签数。

\[loss(x,class) = {1 \over n}\sum-w[label_i]log{exp(x_i[label_i])\over \sum_{j=1}^cexp(x[j])} = {1 \over n}\sum w[label_i](-x_i[label_i]+log(\sum_{j=1}^c)exp(x[j]))
\]

'''代码示例'''
loss_func = torch.nn.CrossEntropyLoss(weight=None,reduction='mean') # note:
# weight同nn.NLLLoss。
# reduction同上。

本文参考-1

本文参考-2

Pytorch_第六篇_深度学习 (DeepLearning) 基础 [2]---神经网络常用的损失函数的更多相关文章

  1. Pytorch_第七篇_深度学习 (DeepLearning) 基础 [3]---梯度下降

    深度学习 (DeepLearning) 基础 [3]---梯度下降法 Introduce 在上一篇"深度学习 (DeepLearning) 基础 [2]---神经网络常用的损失函数" ...

  2. Pytorch_第八篇_深度学习 (DeepLearning) 基础 [4]---欠拟合、过拟合与正则化

    深度学习 (DeepLearning) 基础 [4]---欠拟合.过拟合与正则化 Introduce 在上一篇"深度学习 (DeepLearning) 基础 [3]---梯度下降法" ...

  3. Pytorch_第五篇_深度学习 (DeepLearning) 基础 [1]---监督学习与无监督学习

    深度学习 (DeepLearning) 基础 [1]---监督学习与无监督学习 Introduce 学习了Pytorch基础之后,在利用Pytorch搭建各种神经网络模型解决问题之前,我们需要了解深度 ...

  4. Pytorch_第十篇_卷积神经网络(CNN)概述

    卷积神经网络(CNN)概述 Introduce 卷积神经网络(convolutional neural networks),简称CNN.卷积神经网络相比于人工神经网络而言更适合于图像识别.语音识别等任 ...

  5. Coursera深度学习(DeepLearning.ai)编程题&笔记

    因为是Jupyter Notebook的形式,所以不方便在博客中展示,具体可在我的github上查看. 第一章 Neural Network & DeepLearning week2 Logi ...

  6. (zhuan) 126 篇殿堂级深度学习论文分类整理 从入门到应用

    126 篇殿堂级深度学习论文分类整理 从入门到应用 | 干货 雷锋网 作者: 三川 2017-03-02 18:40:00 查看源网址 阅读数:66 如果你有非常大的决心从事深度学习,又不想在这一行打 ...

  7. 2020年深度学习DeepLearning技术实战班

    深度学习DeepLearning核心技术实战2020年01月03日-06日 北京一.深度学习基础和基本思想二.深度学习基本框架结构 1,Tensorflow2,Caffe3,PyTorch4,MXNe ...

  8. 深度学习DeepLearning核心技术理论与实践

    深度学习DeepLearning核心技术开发与应用时间地点:2019年11月01日-04日(北京) 联系人杨老师  电话(同微信)17777853361

  9. 2020年12月18号--21号 人工智能(深度学习DeepLearning)python、TensorFlow技术实战

    深度学习DeepLearning(Python)实战培训班 时间地点: 2020 年 12 月 18 日-2020 年 12 月 21日 (第一天报到 授课三天:提前环境部署 电脑测试) 一.培训方式 ...

随机推荐

  1. day18 装饰器(下)+迭代器+生成器

    目录 一.有参装饰器 1 前提 2 如何使用有参装饰器 3 有参装饰器模板 4 修正装饰器 二.迭代器 1 什么是迭代器 2 为什么要有迭代器 3 如何用迭代器 3.1 可迭代对象 3.2 可迭代对象 ...

  2. 方正璞华Java面试总结(武汉)

    方正璞华Java面试总结(武汉) 现在社会急缺复合型人才,计算机与日语的结合,具备这两种能力的人不愁工作,最后他们大多到的也是日企,甚至到日本去工作.至今为止接触的日企有光庭.方正璞华.先锋·商泰.英 ...

  3. java IO流 (一) File类的使用

    1.File类的理解* 1. File类的一个对象,代表一个文件或一个文件目录(俗称:文件夹)* 2. File类声明在java.io包下* 3. File类中涉及到关于文件或文件目录的创建.删除.重 ...

  4. 一张PDF了解JDK10 GC调优秘籍-附PDF下载

    目录 简介 Java参数类型 Large Pages JIT调优 总结 简介 今天我们讲讲JDK10中的JVM GC调优参数,JDK10中JVM的参数总共有1957个,其中正式的参数有658个. 其实 ...

  5. vue + echart 实现中国地图 和 省市地图(可切换省份)

    一.中国地图 1.先导入echarts,然后再main.js里引入echarts // 引入echartsimport echarts from 'echarts'Vue.prototype.$ech ...

  6. layui弹窗里面 session过期 后跳转到登录页面

    1.在登录页面添加 <script> $(function () { if (top != window) { layer.msg("登录失效", {icon: 5}) ...

  7. 题解 CF1372C

    题目 传送门 题意 给你一个 \(1\) 到 \(n\) 的排列. 定义特殊交换为:选择一段区间\([l,r]\) ,使得此段区间上的数交换后都不在原来的位置. 问最少多少次可以将此排列变成升序的. ...

  8. NameBeta - 多家比价以节省咱的域名注册成本

    共收录 1584 种顶级域名,汇集互联网上 29 家知名域名注册商,每日更新价格信息 有的域名还可以查出到期时间点我前往官网 NameSilo1美元优惠码:whatz

  9. 亚马逊如何使用二次验证码/虚拟MFA/两步验证/谷歌验证器?

    一般点账户名——设置——安全设置中开通虚拟MFA两步验证 具体步骤见链接  亚马逊如何使用二次验证码/虚拟MFA/两步验证/谷歌验证器? 二次验证码小程序于谷歌身份验证器APP的优势 1.无需下载ap ...

  10. vue------反响代理

    //测试项目 https://i.cnblogs.com/Files.aspx