pytorch 中的重要模块化接口nn.Module
torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法
对于自己定义的网络,需要注意以下几点:
1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W
代码示例:
class LeNet(nn.Module):
def __init__(self):
# nn.Module的子类函数必须在构造函数中执行父类的构造函数
super(LeNet, self).__init__() # 等价与nn.Module.__init__() # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
# 当调用self.conv1(input)的时候,就会调用该类的forward函数
self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
self.conv2 = nn.Conv2d(6, 16, (5, 5))
self.fc1 = nn.Linear(256, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
# F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# input:(10, 6, 12, 12) output:(10,6,4,4)
x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
# 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
x = x.view(x.size()[0], -1)
# 全连接
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x)) # 返回值也是一个Variable对象
return x def output_name_and_params(net):
for name, parameters in net.named_parameters():
print('name: {}, param: {}'.format(name, parameters)) if __name__ == '__main__':
net = LeNet()
print('net: {}'.format(net))
params = net.parameters() # generator object
print('params: {}'.format(params))
output_name_and_params(net) input_image = torch.FloatTensor(10, 1, 28, 28) # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
# 这可以从forward中每一步的执行结果可以看出
input_image = Variable(input_image) output = net(input_image)
print('output: {}'.format(output))
print('output.size: {}'.format(output.size()))
pytorch 中的重要模块化接口nn.Module的更多相关文章
- Pytorch中RoI pooling layer的几种实现
Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- PyTorch 中,nn 与 nn.functional 有什么区别?
作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...
- 『PyTorch x TensorFlow』第八弹_基本nn.Module层函数
『TensorFlow』网络操作API_上 『TensorFlow』网络操作API_中 『TensorFlow』网络操作API_下 之前也说过,tf 和 t 的层本质区别就是 tf 的是层函数,调用即 ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
- pytorch中torch.nn构建神经网络的不同层的含义
主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...
- Pytorch中Module,Parameter和Buffer的区别
下文都将torch.nn简写成nn Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类. Buffer: buffer和parameter相对,就是指 ...
- Pytorch中nn.Conv2d的用法
Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...
随机推荐
- 关于Android开发中Arm、X86和Mips(草稿)
一.架构 1.Arm架构 是一个32位精简指令集(RISC)处理器架构,其广泛地使用在许多嵌入式系统设计. 2.X86架构 是一个intel通用计算机系列的标准编号缩写,也标识一套通用的计算机指令集合 ...
- kafka注册异常
问题描述: kafka注册异常,提示brokers id已经被注册过 -- ::,] FATAL [Kafka Server ], Fatal error during KafkaServer sta ...
- MATLAB 统计元素出现的次数
可以使用 hist 函数: A = [1 2 8 8 1 8 2 1 8 2 1]; count = hist(A,unique(A)) count的结果与unique(A)对应.
- 搭建web之 服务器鉴权失败,请确认服务器已启用密码鉴权并且账号密码正确?
实例化时,登录过程中出现 服务器鉴权失败! 这是由于密码错误所致! 第一种情况:原始随机密码 第一种情况,你没有修改密码,则可以直接查找原始密码: 过程详见官网 使用密码登录的前提条件 密码: 若用户 ...
- exp导出数据时丢表
友军发来消息,说使用exp导出某个schema的数据的时候,发现有些表没有导出来.因为一直没有使用exp的习惯,就使用exp\expdp再次导出一次,分析二者的日志,发现exp的确有些表没有导出. 问 ...
- R8500 MPv2 版本 刷梅林改版固件
由于R8500折腾起来比较繁琐.并且国内的koolshare上已经有人释出梅林改版移植的固件,主要是***更方便了,所以把R8500刷成了梅林固件,这是我第一次用上梅林固件. 刷机整个过程参考了下面的 ...
- vim:放弃hjkl
vim放弃使用hjkl,可以加快文本的编辑速度,不信,看我摘录的文章:http://vimcasts.org/blog/2013/02/habit-breaking-habit-making/ Wor ...
- vivado和modelsim联合调试仿真
vivado和modelsim联合调试仿真 0赞 发表于 2017/5/10 19:10:59 阅读(881) 评论(0) 使用vivado和modelsim联合调试仿真时,在破解完modelsim后 ...
- vs code 快捷键中英文对照
常用 General 按 Press 功能 Function Ctrl + Shift + P,F1 显示命令面板 Show Command Palette Ctrl + P 快速打开 Quick O ...
- Python定期删除文件、整理文件夹
1.根据传入的参数,文件所在目录,匹配文件的正则表达式,过期天数进行删除,这些可写在配置文件del_file.conf. del_file3.py #!/usr/bin/env python # en ...