参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/functional/#_1

class torch.nn.Softmax(input, dim)

或:

torch.nn.functional.softmax(input, dim)

对n维输入张量运用Softmax函数,将张量的每个元素缩放到(0,1)区间且和为1。Softmax函数定义如下:

参数:

  dim:指明维度,dim=0表示按列计算;dim=1表示按行计算。默认dim的方法已经弃用了,最好声明dim,否则会警告:

UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

shape:

  • 输入:(N, L)
  • 输出:(N, L)

返回结果是一个与输入维度dim相同的张量,每个元素的取值范围在(0,1)区间。

例子:

import torch

from torch import nn
from torch import autograd m = nn.Softmax()
input = autograd.Variable(torch.randn(, ))
print(input)
print(m(input))

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
tensor([[ 0.2854, 0.1708, 0.4308],
[-0.1983, 2.0705, 0.1549]])
test.py:: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
print(m(input))
tensor([[0.3281, 0.2926, 0.3794],
[0.0827, 0.7996, 0.1177]])

可见默认按行计算,即dim=1

更明显的例子:

import torch

import torch.nn.functional as F

x= torch.Tensor( [ [,,,],[,,,],[,,,]])

y1= F.softmax(x, dim = ) #对每一列进行softmax
print(y1) y2 = F.softmax(x,dim =) #对每一行进行softmax
print(y2) x1 = torch.Tensor([,,,])
print(x1) y3 = F.softmax(x1,dim=) #一维时使用dim=,使用dim=1报错
print(y3)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py
tensor([[0.3333, 0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333, 0.3333],
[0.3333, 0.3333, 0.3333, 0.3333]])
tensor([[0.0321, 0.0871, 0.2369, 0.6439],
[0.0321, 0.0871, 0.2369, 0.6439],
[0.0321, 0.0871, 0.2369, 0.6439]])
tensor([., ., ., .])
tensor([0.0321, 0.0871, 0.2369, 0.6439])

因为列的值相同,所以按列计算时每一个所占的比重都是0.3333;行都是[1,2,3,4],所以按行计算,比重结果都为[0.0321, 0.0871, 0.2369, 0.6439]

一维使用dim=1报错:

RuntimeError: Dimension out of range (expected to be in range of [-, ], but got )

torch.nn.functional中softmax的作用及其参数说明的更多相关文章

  1. 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

    从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...

  2. PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx

    PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx 在写 PyTorch 代码时,我们会发现一些功能重复的操作,比如卷积.激活.池化等操作.这些操作分别可 ...

  3. [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList

    1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...

  4. Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别

    在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...

  5. pytorch torch.nn.functional实现插值和上采样

    interpolate torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', ali ...

  6. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  7. Pytorch中pad函数toch.nn.functional.pad()的用法

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: 1 2 ...

  8. torch.nn 的本质

    torch.nn 的本质 PyTorch 提供了各种优雅设计的 modules 和类 torch.nn,torch.optim,Dataset 和 DataLoader 来帮助你创建并训练神经网络.为 ...

  9. 到底什么是TORCH.NN?

    该教程是在notebook上运行的,而不是脚本,下载notebook文件. PyTorch提供了设计优雅的模块和类:torch.nn, torch.optim, Dataset, DataLoader ...

随机推荐

  1. bat文件传递参数

    %*是表示命令行传过来的参数,%1表示第一个参数,%2表示第二个参数,以此类推.如执行C:/>hello.bat hello world, %1取出来就是hello %2取出来就是world h ...

  2. Linux配置防火墙端口 8080端口

    1.查看防火墙状态,哪些端口开放了 /etc/init.d/iptables status 2.配置防火墙 vi /etc/sysconfig/iptables   ################# ...

  3. SAP WM 有无保存WM Level历史库存的Table?

    SAP WM 有无保存WM Level历史库存的Table? 前日下班回家的路上,收到一个前客户内部顾问同行发过来的微信,问我在SAP系统里哪个表是用来存储WM Level历史库存的. 这个问题问住了 ...

  4. Android星球效果实现

    在项目中看着这个旋转效果挺炫的,就抽取出来做个记录.主要是使用CarrouselLayout 稍微修改 CarrouselLayout代码Demo下载z地址:GitHub https://github ...

  5. leetcode-83.删除排序链表中的重复元素

    leetcode-83.删除排序链表中的重复元素 Points 链表 题意 给定一个排序链表,删除所有重复的元素,使得每个元素只出现一次. 示例 1: 输入: 1->1->2 输出: 1- ...

  6. OSWatcher使用过程中小问题解决方法

    本文介绍一下在使用OSWatcher过程当中遇到的两个问题的解决方法.如有更好的方法,敬请留言. 1:OSWatcher在配置文件里面设置了参数OSW_COMPRESSION为gzip后,OSWatc ...

  7. [20181130]如何猜测那些值存在hash冲突.txt

    [20181130]如何猜测那些值存在hash冲突.txt --//今年6月份开始kerrycode的1个帖子提到子查询结果缓存在哈希表中情况:--//链接:http://www.cnblogs.co ...

  8. 前后端分离djangorestframework——解析渲染组件

    解析器 解析器的作用就是服务端接收客户端传过来的数据,把数据解析成自己想要的数据类型的过程,本质就是对请求体中的数据进行解析 Accept是告诉对方我能解析什么样的数据,通常也可以表示我想要什么样的数 ...

  9. 关于SqlServer数据表操作

    --修改表字段长度alter table Tbl_Count_User_Ref ALTER COLUMN CountName nvarchar(500);新增字段alter table 表名 add ...

  10. mysql 数据库 命令行的操作——对表和字段的操作

    一.对表的操作 1.查看所有表 show tables: 2.创建表 create table 表名(字段1 类型1 约束1 ,字段2 类型2 约束2): 3.修改表的名字 rename table ...