pytorch 中的重要模块化接口nn.Module
torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法
对于自己定义的网络,需要注意以下几点:
1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W
代码示例:
- class LeNet(nn.Module):
- def __init__(self):
- # nn.Module的子类函数必须在构造函数中执行父类的构造函数
- super(LeNet, self).__init__() # 等价与nn.Module.__init__()
- # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
- # 当调用self.conv1(input)的时候,就会调用该类的forward函数
- self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
- self.conv2 = nn.Conv2d(6, 16, (5, 5))
- self.fc1 = nn.Linear(256, 120)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 10)
- def forward(self, x):
- # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
- x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
- # input:(10, 6, 12, 12) output:(10,6,4,4)
- x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
- # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
- x = x.view(x.size()[0], -1)
- # 全连接
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = F.relu(self.fc3(x))
- # 返回值也是一个Variable对象
- return x
- def output_name_and_params(net):
- for name, parameters in net.named_parameters():
- print('name: {}, param: {}'.format(name, parameters))
- if __name__ == '__main__':
- net = LeNet()
- print('net: {}'.format(net))
- params = net.parameters() # generator object
- print('params: {}'.format(params))
- output_name_and_params(net)
- input_image = torch.FloatTensor(10, 1, 28, 28)
- # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
- # 这可以从forward中每一步的执行结果可以看出
- input_image = Variable(input_image)
- output = net(input_image)
- print('output: {}'.format(output))
- print('output.size: {}'.format(output.size()))
pytorch 中的重要模块化接口nn.Module的更多相关文章
- Pytorch中RoI pooling layer的几种实现
Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- PyTorch 中,nn 与 nn.functional 有什么区别?
作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...
- 『PyTorch x TensorFlow』第八弹_基本nn.Module层函数
『TensorFlow』网络操作API_上 『TensorFlow』网络操作API_中 『TensorFlow』网络操作API_下 之前也说过,tf 和 t 的层本质区别就是 tf 的是层函数,调用即 ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
- pytorch中torch.nn构建神经网络的不同层的含义
主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...
- Pytorch中Module,Parameter和Buffer的区别
下文都将torch.nn简写成nn Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类. Buffer: buffer和parameter相对,就是指 ...
- Pytorch中nn.Conv2d的用法
Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...
随机推荐
- JAVA与C#的区别
Java和C#都是编程的语言,它们是两个不同方向的两种语言 相同点: 他们都是面向对象的语言,也就是说,它们都能实现面向对象的思想(封装,继承,多态) 区别: 1.c#中的命名空间是namespace ...
- Effective Java 第三版——83. 明智谨慎地使用延迟初始化
Tips 书中的源代码地址:https://github.com/jbloch/effective-java-3e-source-code 注意,书中的有些代码里方法是基于Java 9 API中的,所 ...
- 【Vegas原创】SQLServer2008防小人利器:审核/审计功能
小人见怪不怪,世界上最可怕的就是会技术的小人,防不胜防! sa密码泄露也就算了,关键是人家也可以前台攻击,直接把你弄的没辙! 在诅咒这种小人的同时,除了加强服务器安全管理,密码策略等,SQL Serv ...
- 关于Java 软件工程师应该知道或掌握的技术栈
鄙人星云,今天突然想写这么一篇需要持续更新的文章,主要目的用于总结当前最流行的技术和工具,方便自己也方便他人. 更新时间:2018-10-23 09:26:19 码农职业路径图 码农入门职业路径图 J ...
- springmvc(五) 数据回显与自定义异常处理器
这章讲解一下springmvc的数据回显和自定义异常处理器的使用,两个都很简单 --WH 一.数据回显技术 Springmvc默认支持对pojo类型的数据回显,默认不支持简单类型的数据回显 1.1.什 ...
- [lvs]lvs的三种模式
回顾了下lvs的三种模式的调度机制 1.lvs的dr模式中的arp的抑制,eth用自己口arp回应. 2.keepalive是否直接操作rs? 不直接操作, 只操作dr(配lvs) 3.tunnel模 ...
- Visual Studio编辑器“智能提示(IntelliSense)”异常的解决方案
1.删除工程中的 .suo 文件. 2.重启vs
- vs code 设置问题
现已取消 .vue 文件与 HTML 的默认关联,需要手动配置.vue 文件里不能使用div + Tab 键快速生成 html 代码 "emmet.syntaxProfiles" ...
- C语言 · 猜算式
题目:猜算式 看下面的算式: □□ x □□ = □□ x □□□ 它表示:两个两位数相乘等于一个两位数乘以一个三位数. 如果没有限定条件,这样的例子很多. 但目前的限定是:这9个方块,表示1~9的9 ...
- javaEE中config.properties文件乱码解决办法
http://jingyan.baidu.com/article/ed2a5d1f3381d709f6be17f8.html ————————————————————————————————————— ...