Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别
在写代码时发现我们在定义Model时,有两种定义方法:
torch.nn.Conv2d()和torch.nn.functional.conv2d()
那么这两种方法到底有什么区别呢,我们通过下述代码看出差别,先拿torch.nn.Conv2d
- torch.nn.Conv2d
class Conv2d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros'):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode) def forward(self, input):
return self.conv2d_forward(input, self.weight)
- torch.nn.functional.conv2d
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
groups=1):
if input is not None and input.dim() != 4:
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()))
f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
_pair(0), groups, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,torch.backends.cudnn.enabled)
return f(input, weight, bias)
对比上述代码我们可以发现,torch.nn.Conv2d是一个类,而torch.nn.functiona.conv2d()是一个函数,并且torch.nn.Conv2d中的forward()函数
是由torch.nn.functiona.conv2d()实现的(在Module类中有一个__call__实现了forward的调用)所以他们在功能的使用上并没有什么区别,但是我
们有了一个疑问,为什么要有着两个功能一样的方法呢?
其实主要的原因在乎我们构建计算图的时候,有些操作不需要进行体现在计算图中的,例如ReLu层,池化层。但是像卷积层、全连接层还是
需要体现在计算图中的。如果所有的层我们都用torch.nn.functional来定义,那么我们需要将卷积层和全连接层中的weights、bias全部手动写入
计算图中去,这样是非常不方便的。如果我们全部使用类的方式来构建计算图,这样即使是非常简单的操作都需要构建类,这样是写代码的效率是
非常低的。所以我们将卷积层、全连接层使用类的方式来进行定义,将池化和激活操作使用函数的方式进行使用,这样使我们更方便的构建计算图。
Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别的更多相关文章
- Pytorch本人疑问(2)model.train()和model.eval()的区别
我们在训练时如果使用了BN层和Dropout层,我们需要对model进行标识: model.train():在训练时使用BN层和Dropout层,对模型进行更改. model.eval():在评价时将 ...
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系
从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...
- pytorch记录:seq2seq例子看看这torch怎么玩的
先看看简单例子: import torch import torch.autograd as autograd import torch.nn as nn import torch.nn.functi ...
- PyTorch - torch.eq、torch.ne、torch.gt、torch.lt、torch.ge、torch.le
PyTorch - torch.eq.torch.ne.torch.gt.torch.lt.torch.ge.torch.le 参考:https://flyfish.blog.csdn.net/art ...
- 【Pytorch】关于torch.matmul和torch.bmm的输出tensor数值不一致问题
发现 对于torch.matmul和torch.bmm,都能实现对于batch的矩阵乘法: a = torch.rand((2,3,10))b = torch.rand((2,2,10))### ma ...
- PyTorch 中,nn 与 nn.functional 有什么区别?
作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...
- pytorch(11)模型创建步骤与nn.Module
模型创建与nn.Module 网络模型创建步骤 nn.Module graph LR 模型 --> 模型创建 模型创建 --> 构建网络层 构建网络层 --> id[卷积层,池化层, ...
随机推荐
- 【资源分享】Gmod-Expression2 - 自定义像素画生成
*作者:BUI* 可自定义制作属于你的像素画(默认为Sans) 第77行的COLOR可编辑你想要的颜色(RGB值) 1,2,3,4分别代表第77行所定义的颜色(0代表不显示) 视频地址:传送链接 @n ...
- touch命令修改时间
实例[rhel7]: [root@localhost test]# stat 1.txt 文件:"1.txt" 大小:0 块:0 IO 块:4096 普通空文件设备:fd00h/6 ...
- Layui自定义模块的使用方式
为什么要自定义模块呢?好处很多.比如可以大量重用代码...... 根据layui官方的文档说明.首先第一步是要确定你要扩展的模块名称 现在做的是登录功能.因此扩展模块名叫 login 使用layui ...
- UVA 11464 偶数矩阵(递推 | 进制)
题目链接:https://vjudge.net/problem/UVA-11464 一道比较好的题目. 思路如下: 如果我们枚举每一个数字“变”还是“不变”,那么需要枚举$2^{255}$种情况,很显 ...
- JS-内置对象和方法
1.Array数组对象unshift( ) 数组开头增加功能:给数组开头增加一个或多个 参数:一个或多个 返回值:数组的长度 原数组发生改变 shift( ) 数组开头删除一项功能 ...
- 03hive_DDL数据定义
一. DDL数据定义 创建数据库 1)create database db_hive; 2)避免要创建的数据库已经存在错误,增加 if not exists 判断. create database i ...
- python学习之matplotlib绘制动图(FuncAnimation()参数)
1.函数FuncAnimation(fig,func,frames,init_func,interval,blit)是绘制动图的主要函数,其参数如下: a.fig 绘制动图的画布名称 b.func自定 ...
- 在linux下安装java(centos和ubuntu)
在本地测试环境安装插件,发现还得用到java,虽说是个程序员,可是没用过java啊,哎,但是插件得用啊,怎么办啊?自己装呗 一.自己的系统:CentOS 7 1.查看CentOS自带JDK是否已安装. ...
- Leet Code 9.回文数
判断一个整数是否是回文数. 题解 普通解法:将整数转为字符串,然后对字符串做判断. ///简单粗暴,看看就行 class Solution { public boolean isPalindrome( ...
- 通过WMI获取网卡MAC地址、硬盘序列号、主板序列号、CPU ID、BIOS序列号
转载:https://www.cnblogs.com/tlduck/p/5132738.html #define _WIN32_DCOM #include<iostream> #inclu ...