我们先介绍下pytorch中的cnn网络

学过深度卷积网络的应该都非常熟悉这张demo图(LeNet):

先不管怎么训练,我们必须先构建出一个CNN网络,很快我们写了一段关于这个LeNet的代码,并进行注释:

  1. # coding=utf-8
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6.  
  7. class Net(nn.Module):
  8. # 定义Net的初始化函数,这个函数定义了该神经网络的基本结构
  9. def __init__(self):
  10. super(Net, self).__init__() # 复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数
  11. self.conv1 = nn.Conv2d(1, 6, 5) # 定义conv1函数的是图像卷积函数:输入为图像(1个频道,即灰度图),输出为 6张特征图, 卷积核为5x5正方形
  12. self.conv2 = nn.Conv2d(6, 16, 5) # 定义conv2函数的是图像卷积函数:输入为6张特征图,输出为16张特征图, 卷积核为5x5正方形
  13. self.fc1 = nn.Linear(16 * 5 * 5, 120) # 定义fc1(fullconnect)全连接函数1为线性函数:y = Wx + b,并将16*5*5个节点连接到120个节点上。
  14. self.fc2 = nn.Linear(120, 84) # 定义fc2(fullconnect)全连接函数2为线性函数:y = Wx + b,并将120个节点连接到84个节点上。
  15. self.fc3 = nn.Linear(84, 10) # 定义fc3(fullconnect)全连接函数3为线性函数:y = Wx + b,并将84个节点连接到10个节点上。
  16.  
  17. # 定义该神经网络的向前传播函数,该函数必须定义,一旦定义成功,向后传播函数也会自动生成(autograd)
  18. def forward(self, x):
  19. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # 输入x经过卷积conv1之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
  20. x = F.max_pool2d(F.relu(self.conv2(x)), 2) # 输入x经过卷积conv2之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
  21. x = x.view(-1, self.num_flat_features(x)) # view函数将张量x变形成一维的向量形式,总特征数并不改变,为接下来的全连接作准备。
  22. x = F.relu(self.fc1(x)) # 输入x经过全连接1,再经过ReLU激活函数,然后更新x
  23. x = F.relu(self.fc2(x)) # 输入x经过全连接2,再经过ReLU激活函数,然后更新x
  24. x = self.fc3(x) # 输入x经过全连接3,然后更新x
  25. return x
  26.  
  27. # 使用num_flat_features函数计算张量x的总特征量(把每个数字都看出是一个特征,即特征总量),比如x是4*2*2的张量,那么它的特征总量就是16。
  28. def num_flat_features(self, x):
  29. size = x.size()[1:] # 这里为什么要使用[1:],是因为pytorch只接受批输入,也就是说一次性输入好几张图片,那么输入数据张量的维度自然上升到了4维。【1:】让我们把注意力放在后3维上面
  30. num_features = 1
  31. for s in size:
  32. num_features *= s
  33. return num_features
  34.  
  35. net = Net()
  36.  
  37. # 以下代码是为了看一下我们需要训练的参数的数量
  38. print (net)
  39. params = list(net.parameters())
  40.  
  41. k = 0
  42. for i in params:
  43. l = 1
  44. print ("该层的结构:" + str(list(i.size())))
  45. for j in i.size():
  46. l *= j
  47. print ("参数和:" + str(l))
  48. k = k + l
  49.  
  50. print ("总参数和:" + str(k))

注意:torch.nn只接受mini-batch的输入,也就是说我们输入的时候是必须是好几张图片同时输入。

例如:nn. Conv2d 允许输入4维的Tensor:n个样本 x n个色彩频道 x 高度 x 宽度。

这段代码运行 (运行于pytorch0.4版本) 效果如下:

  1. Net(
  2. (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  3. (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  4. (fc1): Linear(in_features=400, out_features=120, bias=True)
  5. (fc2): Linear(in_features=120, out_features=84, bias=True)
  6. (fc3): Linear(in_features=84, out_features=10, bias=True)
  7. )
  8. 该层的结构:[6, 1, 5, 5]
  9. 参数和:150
  10. 该层的结构:[6]
  11. 参数和:6
  12. 该层的结构:[16, 6, 5, 5]
  13. 参数和:2400
  14. 该层的结构:[16]
  15. 参数和:16
  16. 该层的结构:[120, 400]
  17. 参数和:48000
  18. 该层的结构:[120]
  19. 参数和:120
  20. 该层的结构:[84, 120]
  21. 参数和:10080
  22. 该层的结构:[84]
  23. 参数和:84
  24. 该层的结构:[10, 84]
  25. 参数和:840
  26. 该层的结构:[10]
  27. 参数和:10
  28. 总参数和:61706

参考链接:https://www.jianshu.com/p/cde4a33fa129

pytorch写一个LeNet网络的更多相关文章

  1. 自己动手写一个iOS 网络请求库的三部曲[转]

    代码示例:https://github.com/johnlui/Swift-On-iOS/blob/master/BuildYourHTTPRequestLibrary 开源项目:Pitaya,适合大 ...

  2. 07_利用pytorch的nn工具箱实现LeNet网络

    07_利用pytorch的nn工具箱实现LeNet网络 目录 一.引言 二.定义网络 三.损失函数 四.优化器 五.数据加载和预处理 六.Hub模块简介 七.总结 pytorch完整教程目录:http ...

  3. 网络编程—【自己动手】用C语言写一个基于服务器和客户端(TCP)!

    如果想要自己写一个服务器和客户端,我们需要掌握一定的网络编程技术,个人认为,网络编程中最关键的就是这个东西--socket(套接字). socket(套接字):简单来讲,socket就是用于描述IP地 ...

  4. 1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)

    参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb 模拟一个回归模型,y = X * w ...

  5. 如何使用 libtorch 实现 LeNet 网络?

    如何使用 libtorch 实现 LeNet 网络? LeNet 网络论文地址: http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

  6. 基于LeNet网络的中文验证码识别

    基于LeNet网络的中文验证码识别 由于公司需要进行了中文验证码的图片识别开发,最近一段时间刚忙完上线,好不容易闲下来就继上篇<基于Windows10 x64+visual Studio2013 ...

  7. Pytorch写CNN

    用Pytorch写了两个CNN网络,数据集用的是FashionMNIST.其中CNN_1只有一个卷积层.一个全连接层,CNN_2有两个卷积层.一个全连接层,但训练完之后的准确率两者差不多,且CNN_1 ...

  8. 如何写一个简单的http服务器

    最近几天用C++写了一个简单的HTTP服务器,作为学习网络编程和Linux环境编程的练手项目,这篇文章记录我在写一个HTTP服务器过程中遇到的问题和学习到的知识. 服务器的源代码放在Github. H ...

  9. 一起写一个JSON解析器

    [本篇博文会介绍JSON解析的原理与实现,并一步一步写出来一个简单但实用的JSON解析器,项目地址:SimpleJSON.希望通过这篇博文,能让我们以后与JSON打交道时更加得心应手.由于个人水平有限 ...

随机推荐

  1. java keystore

    JAVA有一个keystore用来存放私钥和证书,该文件是伴随JDK默认存在的,路径默认是/lib/security/cacerts,默认密码是changeit,实际上空密码也可以直接访问 其中cac ...

  2. ubuntu安装php+mysql+apche

    步骤一,安装apache2 ? sudo apt-get install apache2 安装完成. 运行如下命令重启下: ? sudo /etc/init.d/apache2 restart 在浏览 ...

  3. Visio制图之垮职能流程图

    Visio制图中常用的一种就是带有不同职能,不同阶段的流程关系图. 下面是根据实际生产情况制作的一张“软件生产流程关系图”,供参考.

  4. 设置VMware10开机自启动并同时启动虚拟机镜像系统

    首先,进入VMware Workstation的安装目录 C:\Program Files (x86)\VMware\VMware Workstation

  5. tomcat优化,java查看

    java堆空间分为  新生代 ,老年代 , 持久代 各自有各自的垃圾回收算法 eden区:新生的对象存放在这经常被回收 from  .to  存活区 在老年代,回收的频率不是很高 jdk8 就没有持久 ...

  6. php5.6安装redis各个版本地址集合

    igbinary扩展 http://windows.php.net/downloads/pecl/releases/igbinary/2.0.1/ redis扩展 http://windows.php ...

  7. Vijos1983 NOIP2015Day2T3 运输计划 transport LCA

    题目链接Vijos 题目链接UOJ 该博客在博客园的链接 转载一个大佬的题解: 点击这里->大佬题解 下面谈谈我的感悟: 当然写代码也是写的很艰辛: 我力劝C++的同胞们,这题卡常数,Dfs党会 ...

  8. 爬虫2 urllib3 爬取30张百度图片

    import urllib3 import re # 下载百度首页页面的所有图片 # 1. 找到目标数据 # page_url = 'http://image.baidu.com/search/ind ...

  9. ES6 系列之 defineProperty 与 proxy

    ,, ; ; ; } ; }); }; ; }); } });

  10. sqlmap 1.3.2.16 帮助信息

    Usage: sqlmap.py [options] 选项: -h, --help 显示基本帮助信息并退出 -hh 显示高级帮助消息并退出 --version 显示程序的版本号并退出 -v VERBO ...