PyTorch之前向传播函数自动调用forward
参考:1. pytorch学习笔记(九):PyTorch结构介绍
2.pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
3.Pytorch入门学习(三):Neural Networks
4.forward
神经网络的典型处理如下所示:
1. 定义可学习参数的网络结构(堆叠各层和层的设计);
2. 数据集输入;
3. 对输入进行处理(由定义的网络层进行处理),主要体现在网络的前向传播;
4. 计算loss ,由Loss层计算;
5. 反向传播求梯度;
6. 根据梯度改变参数值,最简单的实现方式(SGD)为:
weight = weight - learning_rate * gradient
下面是利用PyTorch定义深度网络层(Op)示例:
- class FeatureL2Norm(torch.nn.Module):
- def __init__(self):
- super(FeatureL2Norm, self).__init__()
- def forward(self, feature):
- epsilon = 1e-6
- # print(feature.size())
- # print(torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).size())
- norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature)
- return torch.div(feature,norm)
- class FeatureRegression(nn.Module):
- def __init__(self, output_dim=6, use_cuda=True):
- super(FeatureRegression, self).__init__()
- self.conv = nn.Sequential(
- nn.Conv2d(225, 128, kernel_size=7, padding=0),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 64, kernel_size=5, padding=0),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- )
- self.linear = nn.Linear(64 * 5 * 5, output_dim)
- if use_cuda:
- self.conv.cuda()
- self.linear.cuda()
- def forward(self, x):
- x = self.conv(x)
- x = x.view(x.size(0), -1)
- x = self.linear(x)
- return x
由上例代码可以看到,不论是在定义网络结构还是定义网络层的操作(Op),均需要定义forward函数,下面看一下PyTorch官网对PyTorch的forward方法的描述:
那么调用forward方法的具体流程是什么样的呢?具体流程是这样的:
以一个Module为例:
1. 调用module的call方法
2. module的call里面调用module的forward方法
3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下
4. 调用Function的call方法
5. Function的call方法调用了Function的forward方法。
6. Function的forward返回值
7. module的forward返回值
8. 在module的call进行forward_hook操作,然后返回值
上述中“调用module的call方法”是指nn.Module 的__call__方法。定义__call__方法的类可以当作函数调用,具体参考Python的面向对象编程。也就是说,当把定义的网络模型model当作函数调用的时候就自动调用定义的网络模型的forward方法。nn.Module 的__call__方法部分源码如下所示:
- def __call__(self, *input, **kwargs):
- result = self.forward(*input, **kwargs)
- for hook in self._forward_hooks.values():
- #将注册的hook拿出来用
- hook_result = hook(self, input, result)
- ...
- return result
可以看到,当执行model(x)的时候,底层自动调用forward方法计算结果。具体示例如下:
- class Function:
- def __init__(self):
- ...
- def forward(self, inputs):
- ...
- return outputs
- def backward(self, grad_outs):
- ...
- return grad_ins
- def _backward(self, grad_outs):
- hooked_grad_outs = grad_outs
- for hook in hook_in_outputs:
- hooked_grad_outs = hook(hooked_grad_outs)
- grad_ins = self.backward(hooked_grad_outs)
- hooked_grad_ins = grad_ins
- for hook in hooks_in_module:
- hooked_grad_ins = hook(hooked_grad_ins)
- return hooked_grad_ins
model = LeNet()
y = model(x)
如上则调用网络模型定义的forward方法。
PyTorch之前向传播函数自动调用forward的更多相关文章
- pytorch 调用forward 的具体流程
forward方法的具体流程: 以一个Module为例:1. 调用module的call方法2. module的call里面调用module的forward方法3. forward里面如果碰到Modu ...
- 根据判断PC浏览器类型和手机屏幕像素自动调用不同CSS的代码
1.媒体查询方法在 css 里面这样写 -------------------- @media screen and (min-width: 320px) and (max-width: 480px) ...
- paintEvent(QPaintEvent*)是系统自动调用的
qt中函数paintEvent(QPaintEvent*)是被系统自动调用. paintEvent(QPaintEvent *)函数是QWidget类中的虚函数,用于ui的绘制,会在多种情况下被其他函 ...
- QT5.3无法自动调用incomingConnection函数的问题(4.7没有这个问题)
最近将qt4.7的一个工程移到5.3,遇到了几个麻烦事,主要是这个incomingConnection监听后无法自动调用的问题,在4.7上是完全没有问题的,到了5.3就不行,网上也查了下,网友们都是放 ...
- 如果浏览器自动调用quirks模式打开的话
(从已经死了一次又一次终于挂掉的百度空间人工抢救出来的,发表日期 2014-03-21) 则肯定你的html的声明,没有写好. 今天遇到几个,前面莫名其妙的多了个空格(在网页上看源码是多空格,复制到n ...
- PHP中 对象自动调用的方法:__set()、__get()、__tostring()
总结: (1)__get($property_name):获取私有属性$name值时,此对象会自动调用该方法,将属性name值传给参数$property_name,通过这个方法的内部 执行,返回我们传 ...
- C++构造函数的自动调用(调用一个父类的构造函数,有显性调用最好,否则就默认调用无参数的构造函数)——哲学思想:不调用怎么初始化父类的成员数据和VMT?
我总是记不住构造函数的特点,关键还是没有领会那个哲学思想:父类的构造函数一方面要初始化它自己的成员数据,另一方面也要建立它自己的VMT呀!心里默念一百遍:一定调用父类构造函数,一定调用父类构造函数,一 ...
- Object之魔术函数__toString() 直接输出对象引用时自动调用
__toString()是快速获取对象的字符串信息的便捷方式 在直接输出对象引用时自动调用的方法. __toString()的作用 当我们调试程序时,需要知道是否得出正确的数据.比如打印一个对象时,看 ...
- QT5.3无法自动调用incomingConnection函数的问题
最近将qt4.7的一个工程移到5.3,遇到了几个麻烦事,主要是这个incomingConnection监听后无法自动调用的问题,在4.7上是完全没有问题的,到了5.3就不行,网上也查了下,网友们都是放 ...
随机推荐
- locationManager 回调方法不调用问题?
当locationManager都设置好了后开始定位服务后回调方法didUpdateToLocation不调用 [_locationManager setDelegate:self]; [_locat ...
- opencv 图像基本操作
目录:读取图像,获取属性信息,图像ROI,图像通道的拆分和合并 1. 读取图像 像素值返回:直接使用坐标即可获得, 修改像素值:直接通过坐标进行赋值 能用矩阵操作,便用,使用numpy中的array ...
- 错觉-Info:视错觉与UI元素间的可能
ylbtech-错觉-Info:视错觉与UI元素间的可能 1.返回顶部 1. 视觉原理在当下红火的机械视觉中是必不可少的,那在我们日常工作的UI产品设计中又有什么可能性的呢?今天,我从“视错觉”这个角 ...
- pytest 用 @pytest.mark.usefixtures("fixtureName")装饰类,可以让执行每个case前,都执行一遍指定的fixture
conftest.py import pytest import uuid @pytest.fixture() def declass(): print("declass:"+st ...
- var与let循环中经典问题
循环1: 下面代码运行结果是输出10 <script> var a =[]; for(var i = 0;i<10;i++){ a[i] = function(){ consol ...
- li设置多选和取消选择的样式、输入数据类型判断
li设置多选和取消选择的样式: $('li').click(function(){ if($(this).hasClass('active')) {$(this).removeClass('activ ...
- Directx11 教程(1) 基本的windows应用程序框架(1)
原文:Directx11 教程(1) 基本的windows应用程序框架(1) 在vs2010中,建立一个新的win32工程,名字是: myTutorialD3D11, 注意:同时勾选Cr ...
- ESP8266 支持浮点运算吗?
ESP8266 支持浮点运算吗? 可以说支持,也可以说不支持. 说不支持的原因是因为 ESP8266 内部没有 FPU,无法使用硬件计算. 说支持的意思是可以使用软件进行浮点运算,但是会很慢很慢,如果 ...
- oracle如何穿过防火墙连接数据库
这个问题只会在WIN平台出现,UNIX平台会自动解决. 解决方法: 在服务器端的SQLNET.ORA应类似 SQLNET.AUTHENTICATION_SERVICES= (NTS) NAMES.DI ...
- zend studio打开文件提示unsupported character encoding
zend studio打开文件提示unsupported character encoding,是文件的编码方式错误. 有可能是PHP代码中,charset={CHARSET} ,用了变量的形式调用编 ...