卷积神经网络

在之前的文章里,对28 X 28的图像,我们是通过把它展开为长度为784的一维向量,然后送进全连接层,训练出一个分类模型.这样做主要有两个问题

  1. 图像在同一列邻近的像素在这个向量中可能相距较远。它们构成的模式可能难以被模型识别。
  2. 对于大尺寸的输入图像,使用全连接层容易造成模型过大。假设输入是高和宽均为1000像素的彩色照片(含3个通道)。即使全连接层输出个数仍是256,该层权重参数的形状是\(3,000,000\times 256\),按照参数为float,占用4字节计算,它占用了大约3000000 X 256 X4bytes=3000000kb=3000M=3G的内存或显存。

很显然,通过使用卷积操作可以有效的改善这两个问题.关于卷积操作,池化操作等,参见置顶文章https://www.cnblogs.com/sdu20112013/p/10149529.html

LENET

lenet是比较早期提出来的一个神经网络,其结构如下图所示.

 

LeNet的结构比较简单,就是2次重复的卷积激活池化后面接三个全连接层.卷积层的卷积核用的5 X 5,池化用的窗口大小为2 X 2,步幅为2.

对我们的输入(28 x 28)来说,卷积层得到的输出shape为[batch,16,4,4],在送入全连接层前,要reshape成[batch,16x4x4].可以理解为通过卷积,对没一个样本,我们

都提取出来了16x4x4=256个特征.这些特征用来识别图像里的空间模式,比如线条和物体局部.

全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。

  1. net0 = nn.Sequential(
  2. nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
  3. nn.Sigmoid(),
  4. nn.MaxPool2d(2, 2), # kernel_size, stride
  5. nn.Conv2d(6, 16, 5),
  6. nn.Sigmoid(),
  7. nn.MaxPool2d(2, 2)
  8. )
  9. batch_size=64
  10. X = torch.randn((batch_size,1,28,28))
  11. out=net0(X)
  12. print(out.shape)

输出

  1. torch.Size([64, 16, 4, 4])

这就是上面我们说的"对我们的输入(28 x 28)来说,卷积层得到的输出shape为[batch,16,4,4]"的由来.

模型定义

至此,我们可以给出LeNet的定义:

  1. class LeNet(nn.Module):
  2. def __init__(self):
  3. super(LeNet, self).__init__()
  4. self.conv = nn.Sequential(
  5. nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
  6. nn.Sigmoid(),
  7. nn.MaxPool2d(2, 2), # kernel_size, stride
  8. nn.Conv2d(6, 16, 5),
  9. nn.Sigmoid(),
  10. nn.MaxPool2d(2, 2)
  11. )
  12. self.fc = nn.Sequential(
  13. nn.Linear(16*4*4, 120),
  14. nn.Sigmoid(),
  15. nn.Linear(120, 84),
  16. nn.Sigmoid(),
  17. nn.Linear(84, 10)
  18. )
  19. def forward(self, img):
  20. feature = self.conv(img)
  21. output = self.fc(feature.view(img.shape[0], -1))
  22. return output

forward()中,在输入全连接层之前,要先feature.view(img.shape[0], -1)做一次reshape.

我们用gpu来做训练,所以要把net的参数都存储在显存上:

  1. net = LeNet().cuda()

数据加载

  1. import torch
  2. from torch import nn
  3. import sys
  4. sys.path.append("..")
  5. import learntorch_utils
  6. batch_size,num_workers=64,4
  7. train_iter,test_iter = learntorch_utils.load_data(batch_size,num_workers)

load_data定义于learntorch_utils.py,如下:

  1. def load_data(batch_size,num_workers):
  2. mnist_train = torchvision.datasets.FashionMNIST(root='/home/sc/disk/keepgoing/learn_pytorch/Datasets/FashionMNIST',
  3. train=True, download=True,
  4. transform=transforms.ToTensor())
  5. mnist_test = torchvision.datasets.FashionMNIST(root='/home/sc/disk/keepgoing/learn_pytorch/Datasets/FashionMNIST',
  6. train=False, download=True,
  7. transform=transforms.ToTensor())
  8. train_iter = torch.utils.data.DataLoader(
  9. mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  10. test_iter = torch.utils.data.DataLoader(
  11. mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
  12. return train_iter,test_iter

定义损失函数

l = nn.CrossEntropyLoss()

定义优化器

opt = torch.optim.Adam(net.parameters(),lr=0.01)

定义评估函数

  1. def test():
  2. acc_sum = 0
  3. batch = 0
  4. for X,y in test_iter:
  5. X,y = X.cuda(),y.cuda()
  6. y_hat = net(X)
  7. acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item()
  8. batch += 1
  9. print('acc:%f' % (acc_sum/(batch*batch_size)))

训练

  • 前向传播
  • 计算loss
  • 梯度清空,反向传播
  • 更新参数
  1. num_epochs=5
  2. def train():
  3. for epoch in range(num_epochs):
  4. train_l_sum,batch=0,0
  5. for X,y in train_iter:
  6. X,y = X.cuda(),y.cuda() #把tensor放到显存
  7. y_hat = net(X) #前向传播
  8. loss = l(y_hat,y) #计算loss,nn.CrossEntropyLoss中会有softmax的操作
  9. opt.zero_grad()#梯度清空
  10. loss.backward()#反向传播,求出梯度
  11. opt.step()#根据梯度,更新参数
  12. train_l_sum += loss.item()
  13. batch += 1
  14. print('epoch %d,train_loss %f' % (epoch + 1,train_l_sum/(batch*batch_size)))
  15. test()

输出如下:

  1. epoch 1,train_loss 0.011750
  2. acc:0.799064
  3. epoch 2,train_loss 0.006442
  4. acc:0.855195
  5. epoch 3,train_loss 0.005401
  6. acc:0.857584
  7. epoch 4,train_loss 0.004946
  8. acc:0.874602
  9. epoch 5,train_loss 0.004631
  10. acc:0.874403

从头学pytorch(十四):lenet的更多相关文章

  1. 从头学pytorch(十五):AlexNet

    AlexNet AlexNet是2012年提出的一个模型,并且赢得了ImageNet图像识别挑战赛的冠军.首次证明了由计算机自动学习到的特征可以超越手工设计的特征,对计算机视觉的研究有着极其重要的意义 ...

  2. 从头学pytorch(十九):批量归一化batch normalization

    批量归一化 论文地址:https://arxiv.org/abs/1502.03167 批量归一化基本上是现在模型的标配了. 说实在的,到今天我也没搞明白batch normalize能够使得模型训练 ...

  3. 从头学pytorch(十二):模型保存和加载

    模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...

  4. 从头学pytorch(十六):VGG NET

    VGG AlexNet在Lenet的基础上增加了几个卷积层,改变了卷积核大小,每一层输出通道数目等,并且取得了很好的效果.但是并没有提出一个简单有效的思路. VGG做到了这一点,提出了可以通过重复使⽤ ...

  5. 从头学pytorch(十八):GoogLeNet

    GoogLeNet GoogLeNet和vgg分别是2014的ImageNet挑战赛的冠亚军.GoogLeNet则做了更加大胆的网络结构尝试,虽然深度只有22层,但大小却比AlexNet和VGG小很多 ...

  6. HDU 6467 简单数学题 【递推公式 && O(1)优化乘法】(广东工业大学第十四届程序设计竞赛)

    传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6467 简单数学题 Time Limit: 4000/2000 MS (Java/Others)    M ...

  7. HDU 6464 免费送气球 【权值线段树】(广东工业大学第十四届程序设计竞赛)

    传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6464 免费送气球 Time Limit: 2000/1000 MS (Java/Others)    M ...

  8. HDU 6470 Count 【矩阵快速幂】(广东工业大学第十四届程序设计竞赛 )

    题目传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6470 Count Time Limit: 6000/3000 MS (Java/Others)    ...

  9. HDU 6467.简单数学题-数学题 (“字节跳动-文远知行杯”广东工业大学第十四届程序设计竞赛)

    简单数学题 Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)Total Submi ...

随机推荐

  1. Activiti 工作流入门指南

    概览 如我们的介绍部分所述,Activiti目前分为两大类: Activiti Core Activiti Cloud 如果你想上手Activiti的核心是否遵循了新的运行时API的入门指南:Acti ...

  2. 洛谷P1981 表达式求值 题解 栈/中缀转后缀

    题目链接:https://www.luogu.org/problem/P1981 这道题目就是一道简化的中缀转后缀,因为这里比较简单,只有加号(+)和乘号(*),所以我们只需要开一个存放数值的栈就可以 ...

  3. 不通过DataRow,直接往DataTable中添加新行DataTable.LoadDataRow(object[],bool)

    DataTable dtver = new DataTable();                dtver.Columns.Add("VERSION");            ...

  4. H3C NAT ALG

  5. H3C 在接口上应用ACL

  6. Python--day40线程理论

    1,进程:

  7. HDU 6621"K-th Closest Distance"(二分+主席树)

    传送门 •题意 有 $m$ 次询问,每次询问求 $n$ 个数中, $[L,R]$ 区间距 $p$ 第 $k$ 近的数与 $p$ 差值的绝对值: •题解 二分答案,假设当前二分的答案为 $x$,那么如何 ...

  8. JAVA核心知识点--打包 FatJar 方法小结

    目录 什么是 FatJar 三种打包方法 1. 非遮蔽方法(Unshaded) 2. 遮蔽方法(Shaded) 3. 嵌套方法(Jar of Jars) 小结 参考阅读 原文地址:https://yq ...

  9. SDOI2019热闹又尴尬的聚会

    P5361 [SDOI2019]热闹又尴尬的聚会 出题人用脚造数据系列 只要将\(p\)最大的只求出来,\(q\)直接随便rand就能过 真的是 我们说说怎么求最大的\(p\),这个玩意具有很明显的单 ...

  10. error:cannot load file (code:5555h);bootauto.ini

    最近发现有的网友在使用Ghost XP盘安装系统的时候,选择一键ghost到C盘出现下面的错误: error:cannot load file (code:5555h);bootauto.ini(或b ...