一、BCELoss 二分类损失函数

输入维度为(n, ), 输出维度为(n, )

如果说要预测二分类值为1的概率,则建议用该函数!

输入比如是3维,则每一个应该是在0——1区间内(随意通常配合sigmoid函数使用),举例如下:

import torch
import torch.nn as nn

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3,requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward() input,target,output 返回值:
(tensor([-0.8728, 0.3632, -0.0547], requires_grad=True),
tensor([1., 0., 0.]),
tensor(0.9264, grad_fn=<BinaryCrossEntropyBackward>)) m(input)结果为:
tensor([0.2947, 0.5898, 0.4863]) 计算output = (1 * ln 0.2947+(1-1)*ln(1-0.2947) + 0*ln0.5898 + (1-0)*ln(1-0.5898) + 0*ln0.4863 + (1-0)*ln(1-0.4863)) / 3 = 0.9264

二、nn.CrossEntropyLoss 交叉熵损失函数

输入维度(batch_size, feature_dim)

输出维度  (batch_size, 1)

X_input = torch.tensor[ [2.8883, 0.1760, 1.0774],

          [1.1216, -0.0562, 0.0660],

          [-1.3939, -0.0967, 0.5853]]

y_target = torch.tensor([1,2,0])

loss_func = nn.CrossEntropyLoss()

loss = loss_func(X_input, y_target)

计算流程:第一,x先softmax再log,得到x_hat  第二,y转0-1编码[1,2,0] 转[[0,1,0], [0,0,1], [1,0,0]] 再与x_hat相乘,取负取平均值

思考问题:多标签的分类任务中,怎么使用损失函数呢,是拆分是多个二分类问题呢,还是不用拆分直接用BCE呢(https://blog.csdn.net/rosefun96/article/details/88058708,参考:BCE 可以应用到多标签的分类任务中)?有什么区别呢?

pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)的更多相关文章

  1. pytorch中文文档-torch.nn常用函数-待添加-明天继续

    https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...

  2. [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList

    1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...

  3. pytorch中文文档-torch.nn.init常用函数-待添加

    参考:https://pytorch.org/docs/stable/nn.html torch.nn.init.constant_(tensor, val) 使用参数val的值填满输入tensor ...

  4. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  5. Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别

    在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...

  6. 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

    从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...

  7. pytorch 损失函数

    pytorch损失函数: http://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medi ...

  8. Multi label 多标签分类问题(Pytorch,TensorFlow,Caffe)

    适用场景:一个输入对应多个label,或输入类别间不互斥 调用函数: 1. Pytorch使用torch.nn.BCEloss 2. Tensorflow使用tf.losses.sigmoid_cro ...

  9. Pytorch的默认初始化分布 nn.Embedding.weight初始化分布

    一.nn.Embedding.weight初始化分布 nn.Embedding.weight随机初始化方式是标准正态分布  ,即均值$\mu=0$,方差$\sigma=1$的正态分布. 论据1——查看 ...

随机推荐

  1. 解决无法访问 Github

    可以正常使用Google,但无法打开Github. 查阅了一些资料,发现需要在hosts文件中添加映射. 在hosts文件中加入两行 140.82.113.4 github.com 140.82.11 ...

  2. hi-nginx-java并发性能一窥

    欲知hi-nginx-java的并发性能,用jmeter进行测试便知一二. 设定用户数为100000,循环次数为100,ramp-up perio为2: 请求地址为http://localhost/t ...

  3. spark-submit提交python脚本过程记录

    最近刚学习spark,用spark-submit命令提交一个python脚本,一开始老报错,所以打算好好整理一下用spark-submit命令提交python脚本的过程.先看一下spark-submi ...

  4. day05-类型转换和变量

    1.类型转换概念 java是强类型语言,所以有些运算的时候,需要用到类型转换 类型转换原则:低-->高,byte,short,char-->int-->long-->float ...

  5. 三:redis启动后的基础知识

    Redis启动后的杂项基础知识 1.单进进程 单进程模型来处理客户端的请求.对读写等事件的响应是通过对epoll函数的包装来做到的.Redis的实际处理速度完全依靠主进程的执行效率       Epo ...

  6. IAR_STM32_BootLoader

    1.STM32 Bootloader与APP IROM中可以分成两个区域,起始代码运行地址为0x08000000,这是基本固定的,可以将IROM的0x08000000 ~ 0x08002000这8KB ...

  7. 4.Spring Boot web开发

    1.创建一个web模块 (1).创建SpringBoot应用,选中我们需要的模块: (2).SpringBoot已经默认将这些场景配置好了,只需要在配置文件中指定少量配置就可以运行起来 (3).自己编 ...

  8. 02、MyBatis XML 全局配置文件

    MyBatis-全局配置文件 在MyBatis中全局配置文件有着重要的地位,里面有9类行为信息;如果我们要想将MyBatis运用的熟练,配置全局配置文件是必不可少的步骤,所以我们一定要啃下这一块硬骨头 ...

  9. ASP.NET Core管道详解[2]: HttpContext本质论

    ASP.NET Core请求处理管道由一个服务器和一组有序排列的中间件构成,所有中间件针对请求的处理都在通过HttpContext对象表示的上下文中进行.由于应用程序总是利用服务器来完成对请求的接收和 ...

  10. sqlilab less1-less10

    less-1 参数被单引号包裹,加单引号,闭合后绕过 less-2 参数没有被包裹,直接带入查询,不需要闭合 less-3 参数被 ('$id') 包裹,需要将他闭合 less-4 参数被小括号和双引 ...