简单了解pytorch的forward
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torch class Net(nn.Module): # 需要继承这个类
def __init__(self):
super(Net, self).__init__()
# 建立了两个卷积层,self.conv1, self.conv2,注意,这些层都是不包含激活函数的
self.conv1 = nn.Conv2d(1, 6, 5) # 1 input image channel, 6 output channels, 5x5 square convolution kernel
self.conv2 = nn.Conv2d(6, 16, 5)
# 三个全连接层
self.fc1 = nn.Linear(16 * 5 * 5, 120) # an affine operation: y = Wx + b
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x): # 注意,2D卷积层的输入data维数是 batchsize*channel*height*width
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv2(x)), 2) # If the size is a square you can only specify a single number
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x) print(x)
print('y=--------')
return x def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features net = Net()
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr = 0.01)
num_iteations = 20
input = Variable(torch.randn(2, 1, 32, 32))
print('input=',input)
#target = Variable(torch.Tensor([5],dtype=torch.long))
target = Variable(torch.LongTensor([5,7]))
# in your training loop:
for i in range(num_iteations):
optimizer.zero_grad() # zero the gradient buffers,如果不归0的话,gradients会累加 output = net(input) # 这里就体现出来动态建图了,你还可以传入其他的参数来改变网络的结构
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
loss.backward() # 得到grad,i.e.给Variable.grad赋值
optimizer.step() # Does the update,i.e. Variable.data -= learning_rate*Variable.grad
这里是给出的一个代码。
init只是规定了conv的输入通道数量、输出通道数量和卷积核尺寸。
然后在神经网络中,充当卷积层的是forward部分。
input = Variable(torch.randn(2, 1, 32, 32)) #batchsize,channel,height,width
target = Variable(torch.LongTensor([5,7])) #我希望两个神经网络,第一个等于5,第二个等于7.当然随便两个数。(不代表5*7维矩阵呀)
简单了解pytorch的forward的更多相关文章
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- 超简单!pytorch入门教程(四):准备图片数据集
在训练神经网络之前,我们必须有数据,作为资深伸手党,必须知道以下几个数据提供源: 一.CIFAR-10 CIFAR-10图片样本截图 CIFAR-10是多伦多大学提供的图片数据库,图片分辨率压缩至32 ...
- 超简单!pytorch入门教程(三):构造一个小型CNN
torch.nn只接受mini-batch的输入,也就是说我们输入的时候是必须是好几张图片同时输入. 例如:nn. Conv2d 允许输入4维的Tensor:n个样本 x n个色彩频道 x 高度 x ...
- pytorch 调用forward 的具体流程
forward方法的具体流程: 以一个Module为例:1. 调用module的call方法2. module的call里面调用module的forward方法3. forward里面如果碰到Modu ...
- 超简单!pytorch入门教程(一):Tensor
http://www.jianshu.com/p/5ae644748f21 二.pytorch的基石--Tensor张量 其实标量,向量,矩阵它们三个也是张量,标量是零维的张量,向量是一维的张量,矩阵 ...
- 超简单!pytorch入门教程(二):Autograd
一.autograd自动微分 autograd是专门为了BP算法设计的,所以这autograd只对输出值为标量的有用,因为损失函数的输出是一个标量.如果y是一个向量,那么backward()函数就会失 ...
- PyTorch之前向传播函数自动调用forward
参考:1. pytorch学习笔记(九):PyTorch结构介绍 2.pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解 3.Pytorch入门 ...
- 机器翻译注意力机制及其PyTorch实现
前面阐述注意力理论知识,后面简单描述PyTorch利用注意力实现机器翻译 Effective Approaches to Attention-based Neural Machine Translat ...
- pytorch中检测分割模型中图像预处理探究
Object Detection and Classification using R-CNNs 目标检测:数据增强(Numpy+Pytorch) - 主要探究检测分割模型数据增强操作有哪些? - 检 ...
随机推荐
- Java面向对象 第2节 Scanner 类和格式化输出printf
§Scanner 类 java.util.Scanner 是 Java5 的新特征,我们可以通过 Scanner 类来获取用户的输入. 1.创建 Scanner 对象的基本语法:Scanner s = ...
- 各CF-based tracker中output_sigma_factor取值
现有的各CF-Based tracker中理想高斯响应中output_sigma_factor的取值情况 默认output_sigma = target_sz*output_sigma_factor; ...
- Linux之文件(目录)默认权限、特殊权限与隐藏权限
文件默认权限 从Linux之用户组.文件权限详解了解到文件与目录的基本权限管理,文件在创建时如果不指定具体的权限,那么系统会给它分配一个默认的权限,这个默认权限就是umask. vbird@Ubunt ...
- 学习vue容易忽视的细节
1.对于自定义标签名(组件名称),Vue.js 不强制要求遵循 W3C 规则 (小写,并且包含一个短杠),尽管遵循这个规则比较好.HTML 特性是不区分大小写的.所以,当使用的不是字符串模板,came ...
- 如何系统学习知识图谱-15年+IT老兵的经验分享
一.前言 就IT而言,胖子哥算是老兵,可以去猝死的年纪,按照IT江湖猿龄的规矩,也算是到了耳顺之年:而就人工智能而言,胖子哥还是新人,很老的新人,深度学习.语音识别.人脸识别,知识图谱,逐个的学习了一 ...
- 【ELK】之Centos6.9_x64安装elasticsearch6.2.1
1.下载elasticsearch6.2.1 wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-6.2.1 ...
- 围绕Buganizer的产品流程
做技术的一定知道缺陷跟踪系统(bug系统),更不用说做测试的了,不过普遍都认为这系统是用来记录bug的,其实在google内部,这套系统是产品/项目围绕的核心.Google Buganizer扩展了类 ...
- 前端-JavaScript1-5——JavaScript之变量的类型
5.1 概述 基本类型5种 number 数字类型 string 字符串类型 undefined undefined类型,变量未定义时的值,这个值自 ...
- [转]jvm调优-命令大全(jps jstat jmap jhat jstack jinfo)
运用jvm自带的命令可以方便的在生产监控和打印堆栈的日志信息帮忙我们来定位问题!虽然jvm调优成熟的工具已经有很多:jconsole.大名鼎鼎的VisualVM,IBM的Memory Analyzer ...
- 关于分布式uuid的一点设想
在一次公开课上,听别人讲过全局分布式uuid的设计,听过twitter的snowflake的设计.也听过,如果使用单独的计数器服务,不可能每次都保存当前计数器到文本,自己想到应该可以每隔一些数,例如1 ...