『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

# Author : Hellcat
# Time : 2018/2/11 import torch as t
import torch.nn as nn
import torch.nn.functional as F class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10) def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x if __name__ == "__main__":
net = LeNet() # #########训练网络#########
from torch import optim
# 初始化Loss函数 & 优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for epoch in range(2):
running_loss = 0.0
for step, data in enumerate(trainloader, 0): # step为训练次数, trainloader包含batch的数据和标签
inputs, labels = data
inputs, labels = t.autograd.Variable(inputs), t.autograd.Variable(labels) # 梯度清零
optimizer.zero_grad() # forward
outputs = net(inputs)
# backward
loss = loss_fn(outputs, labels)
loss.backward()
# update
optimizer.step() running_loss += loss.data[0]
if step % 2000 == 1999:
print("[{0:d}, {1:5d}] loss: {2:3f}".format(epoch+1, step+1, running_loss/2000))
running_loss = 0.
print("Finished Training")

这是使用LeNet分类cifar_10的例子,数据处理部分由于不是重点,没有列上来,主要是对使用torch分类有一个直观理解,

初始化网络

初始化Loss函数 & 优化器

进入step循环:

  梯度清零

  向前传播

  计算本次Loss

  向后传播

  更新参数

由于pytorch的网络是class,所以在不考虑持久化的情况下,后续处理都不是太难,值得一提的是预测函数,我们直接net(Variable(test_data))即可,输出是概率分布的Variable,我们只要调用:

_, predict = t.max(test_out, 1)

即可,这是因为当指定了dim时,torch.max会融合max和argmax的功能,

>> a = torch.randn(4, 4)
>> a
    
0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]
    
>>> torch.max(a, 1)
(
1.2513
0.9288
1.0695
0.7426
[torch.FloatTensor of size 4]
,
2
0
0
0
[torch.LongTensor of size 4]
)

其他torch的高级功能没有使用到,本篇的目的是对于torch神经网络基本的使用有个理解。

『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下的更多相关文章

  1. 『MXNet』第四弹_Gluon自定义层

    一.不含参数层 通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里, from mxnet import nd, gluon from ...

  2. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

  3. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  4. 『TensorFlow』第七弹_保存&载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

  5. 关于『进击的Markdown』:第四弹

    关于『进击的Markdown』:第四弹 建议缩放90%食用 美人鱼(Mermaid)悄悄的来,又悄悄的走,挥一挥匕首,不留一个活口 又是漫漫画图路... 女士们先生们,大家好!  我们要接受Markd ...

  6. 关于『HTML』:第三弹

    关于『HTML』:第三弹 建议缩放90%食用 盼望着, 盼望着, 第三弹来了, HTML基础系列完结了!! 一切都像刚睡醒的样子(包括我), 欣欣然张开了眼(我没有) 敬请期待Markdown语法系列 ...

  7. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  8. 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法

    在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...

  9. 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数

    一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...

随机推荐

  1. Java SE 基础知识(二)

    1. 类由两大部分构成:属性和方法.属性一般用名词来表示,方法一般用动词来表示. 2. 如果一个java源文件中定义了多个类,那么这些类中最多只能有一个类是public的,可以都不是public的. ...

  2. DBMS_OUTPUT.PUT_LINE()方法的简单介绍

    1.最基本的DBMS_OUTPUT.PUT_LINE()方法. 随便在什么地方,只要是BEGIN和END之间,就可以使用DBMS_OUTPUT.PUT_LINE(output);然而这会有一个问题,就 ...

  3. curl 7.52.1 for Windows

    curl是利用URL语法在命令行方式下工作的开源文件传输工具.它被广泛应用在Unix.多种Linux发行版中,并且有DOS和Win32.Win64下的移植版本. 这个工具对于在运维.持续集成和批处理场 ...

  4. Github使用教程(二)------ Github客户端使用方法

    在上一篇教程中,我们简单介绍了Github网站的各个部分,相信大家对Github网站也有了一个初步的了解(/(ㄒoㄒ)/~~可是还是不会用怎么办),不要着急,我们今天先讲解一下Github for w ...

  5. 03:requests与BeautifulSoup结合爬取网页数据应用

    1.1 爬虫相关模块命令回顾 1.requests模块 1. pip install requests 2. response = requests.get('http://www.baidu.com ...

  6. 2018-2019-1 20189218《Linux内核原理与分析》第五周作业

    系统调用的三层机制 用户态.内核态和中断 用户态.较低的执行级别,只能访问一部分内存,只能执行一部分指令. 内核态.高级执行级别,可以访问任意物理内存,可以执行特权指令. 中断.系统从用户态进入内核态 ...

  7. Java继承相关知识总结

    Java继承的理解 一.概念: 一个新类从已有的类那里获得其已有的属性和方法,这种现象叫类的继承 这个新类称为子类,或派生类,已有的那个类叫做父类,或基类 继承的好处:代码得到极大的重用.形成一种类的 ...

  8. Python3基础 str capitalize 返回新字符串,第一个字母大写

             Python : 3.7.0          OS : Ubuntu 18.04.1 LTS         IDE : PyCharm 2018.2.4       Conda ...

  9. 【转载】浅谈JavaScript,let和var定义变量的区别

    了解JS与ES5与ES6区别 JS语言 JavaScript一种动态类型.弱类型.基于原型的客户端脚本语言,用来给HTML网页增加动态功能. 动态: 在运行时确定数据类型.变量使用之前不需要类型声明, ...

  10. UVa 10201 Adventures in Moving - Part IV

    https://vjudge.net/problem/UVA-10201 题意: 给出到达终点的距离和每个加油站的距离和油费,初始油箱里有100升油,计算到达终点时油箱内剩100升油所需的最少花费. ...