卷积神经网络CNN实战:MINST手写数字识别——网络定义
本文基于python-pytorch框架,编写CNN网络,并采用CNN手写数字数据集训练、测试网络。
网络的构建
以LeNet-5 网络为例
类定义
首先先了解一下网络的最基本框架
- 一般而言,首先创建一个类
class
,创建时,继承nn.Module
父类,注意,在该类的构造函数中__init__
中,显示的调用其父类的构造函数super(...).__init__()
- 网络的结构,例如卷积层、线性层等,一般在其构造函数中定义。
- 对于一些不带参数的网络结构,也可以在forward方法中直接调用,而不定义,但不推荐。
- 每一个网络类必须显示的定义
forward
方法,编写程序时需要在该函数中编写运算,实现对输入张量(tensor)的运算,并最后给予返回值;通常forward
函数的返回值也为张量(tensor)。
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
''' Definition of Network Structure '''
def forward(self, x):
# x = f(x)
return x
卷积模块/特征提取器
self.features = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2), # 28x28x1 -> 28x28x6
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5, stride=1), # 14x14x6 -> 10x10x16
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
卷积层
nn.Conv2d
是 PyTorch 中用于定义二维卷积层的一个模块。其参数配置的含义如下:
in_channels (1): 输入图像的通道数。对于单通道的灰度图像,值为 1;对于 RGB 图像,则为 3。
out_channels (6): 卷积层输出的通道数,即卷积核的数量。每个卷积核会产生一个输出通道。这里设置为 6,意味着该卷积层会生成 6 个特征图(feature maps)。
kernel_size (5): 卷积核的尺寸。这里的
5
表示使用 5x5 的卷积核。这是一个正方形的卷积核大小,但也可以设置成不相等的高度和宽度,比如(5, 3)
。stride (1): 卷积操作的步幅。步幅决定了卷积核在输入图像上滑动的速度。步幅为 1 表示卷积核每次滑动一个像素。
padding (2): 输入图像的边界填充。填充用于控制输出特征图的空间尺寸,通常用于保持特征图的尺寸不变或减少尺寸。这里的
2
表示在每一边填充 2 个像素。
综上,nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
的设置表示从单通道的输入图像中提取 6 个特征图,每个特征图通过一个 5x5 大小的卷积核生成,卷积操作的步幅为 1,并且在输入图像的每一边填充 2 个像素以控制输出尺寸。
激活函数
nn.Tanh()
是 PyTorch 中的一个激活函数层,它实现了双曲正切函数(tanh)。激活函数在神经网络中用于引入非线性特性,从而使网络能够学习更复杂的模式和特征。
双曲正切函数的数学形式是:
\]
该函数的输出范围是 ([-1, 1]),具有以下特点:
非线性:tanh 是一个非线性函数,使得神经网络能够处理更复杂的任务。
输出范围:输出值范围在 -1 到 1 之间,零中心化,使得均值接近零,这可以帮助加快训练过程和提高模型的收敛速度。
梯度:tanh 函数的导数在输入接近 0 时最大(为 1),而在输入较大或较小时梯度逐渐变小,这意味着它在大输入值时可能会遇到梯度消失的问题。
在 PyTorch 中,nn.Tanh()
可以作为模型的一个层来应用于网络的前向传播中:
nn.MaxPool2d
是 PyTorch 中的一个池化层,用于对二维输入数据进行最大池化操作。池化操作常用于卷积神经网络(CNN)中,以减少特征图的空间尺寸,减少计算量和过拟合的风险,并提高模型的鲁棒性。
池化层
nn.MaxPool2d(kernel_size=2, stride=2)
的参数含义如下:
kernel_size (2): 池化窗口的尺寸。这里的
2
表示池化窗口是 2x2 的正方形区域。池化操作将在每个 2x2 的区域内选取最大值。stride (2): 池化窗口在输入特征图上滑动的步幅。步幅为 2 表示池化窗口每次滑动 2 个像素。这意味着池化操作会将特征图的空间尺寸缩小为原来的一半。
具体功能:
- 最大池化:在池化窗口(2x2 区域)内选择最大值,并用该最大值替代整个窗口区域中的所有值。这样,池化层能够保留特征图中的重要信息,同时减少空间尺寸。
作用:
降维:通过池化操作减少特征图的空间尺寸(宽度和高度),从而减少计算量和内存消耗。
提高鲁棒性:池化操作可以使网络对位置的微小变化更具鲁棒性,因为它只保留局部区域的最大值。
防止过拟合:通过减少特征图的尺寸,可以降低模型的复杂性,从而减少过拟合的风险。
解释:
对于一个 4x4 的输入特征图,使用 2x2 的池化窗口和步幅为 2,池化操作会将特征图缩小为 2x2 的尺寸,每个池化窗口选择区域内的最大值。例如,在 2x2 的池化窗口内,[[1, 2], [5, 6]]
会变成 6
,依此类推。
总之,nn.MaxPool2d
是卷积神经网络中常用的池化层,用于减少特征图的尺寸和计算复杂度,同时保留重要的特征信息。
线性层/分类器
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), # 全连接层1
nn.Tanh(),
nn.Linear(120, 84), # 全连接层2
nn.Tanh(),
nn.Linear(84, 10) # 输出层
)
展平函数
nn.Flatten()
是 PyTorch 中用于将多维张量展平成一维张量的模块。这个操作通常在卷积层(Convolutional Layers)和线性层(Linear Layers)之间使用,以便将卷积层输出的多维特征图转换成适合于线性层处理的一维特征向量。
nn.Flatten()
的主要作用是将输入张量从多维转换为一维。例如,对于形状为 (N, C, H, W)
的输入张量,使用 nn.Flatten()
后,输出的张量将变为形状为 (N, C * H * W)
的一维张量。
为什么在卷积层和线性层之间使用 nn.Flatten()
卷积层的输出是多维的:卷积层生成的输出通常是一个四维张量,表示批量的特征图,其中包含多个通道的二维特征图。为了将这些特征图传递到线性层,必须将其展平成一维张量,因为线性层要求输入为一维特征向量。
线性层的输入是一维的:线性层(也称全连接层)只能处理一维的输入数据。通过
nn.Flatten()
,可以将卷积层的多维输出展平为一维,从而可以将其作为线性层的输入。连接卷积和线性层:卷积层通常用于提取特征,而线性层则用于对这些特征进行分类或回归等任务。在这些任务中,线性层处理的是扁平化的特征向量,因此需要将卷积层的输出展平。
总结
nn.Flatten()
在卷积神经网络(CNN)的前向传播中充当了重要的角色,它将卷积层的多维特征图展平为线性层所需的一维特征向量。这使得卷积层提取的复杂特征可以被线性层进一步处理,从而完成分类、回归等任务。
nn.Linear 的基本概念
nn.Linear
是 PyTorch 中的一个模块,用于实现线性变换,也称为全连接层(Fully Connected Layer,FC Layer)。它将输入的特征通过一个线性变换映射到输出特征。
参数解释
输入特征的数量 (
in_features
):- 在
nn.Linear(16 * 5 * 5, 120)
中,第一个参数16 * 5 * 5
表示输入特征的数量。 16
通常表示通道数,5 * 5
是特征图的高度和宽度。这里的计算表示输入的特征总数为16 * 5 * 5 = 400
。- 卷积层的输出经过
Flatten
层处理后,变为一维向量。
- 在
输出特征的数量 (
out_features
):- 第二个参数
120
表示输出特征的数量。在这个例子中,模型将输出一个长度为 120 的一维向量。
- 第二个参数
线性层的功能
线性层的功能可以用数学公式表示为:
\]
- 其中 \(y\) 是输出,\(A\) 是权重矩阵,\(x\) 是输入,\(b\) 是偏差项。
在使用 nn.Linear(16 * 5 * 5, 120)
时,PyTorch 会自动创建一个形状为 (120, 400)
的权重矩阵和一个形状为 (120,)
的偏差向量。权重和偏差会在训练过程中进行学习和优化。
总结
nn.Linear(16 * 5 * 5, 120)
用于定义一个线性层,它将输入的 400 维特征向量映射到 120 维输出向量。- 此层通常用于将卷积层提取的特征连接到分类器或其他层,以形成完整的神经网络架构。
代码汇总
import torch.nn as nn
# LeNet-5
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2), # 28x28x1 -> 28x28x6
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2), # 池化层
nn.Conv2d(6, 16, kernel_size=5, stride=1), # 14x14x6 -> 10x10x16
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2) # 池化层
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), # 全连接层1
nn.Tanh(),
nn.Linear(120, 84), # 全连接层2
nn.Tanh(),
nn.Linear(84, 10) # 输出层
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
卷积神经网络CNN实战:MINST手写数字识别——网络定义的更多相关文章
- 卷积神经网络应用于tensorflow手写数字识别(第三版)
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...
- MINST手写数字识别(三)—— 使用antirectifier替换ReLU激活函数
这是一个来自官网的示例:https://github.com/keras-team/keras/blob/master/examples/antirectifier.py 与之前的MINST手写数字识 ...
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- MINST手写数字识别(一)—— 全连接网络
这是一个简单快速入门教程——用Keras搭建神经网络实现手写数字识别,它大部分基于Keras的源代码示例 minst_mlp.py. 1.安装依赖库 首先,你需要安装最近版本的Python,再加上一些 ...
- MINST手写数字识别(二)—— 卷积神经网络(CNN)
今天我们的主角是keras,其简洁性和易用性简直出乎David 9我的预期.大家都知道keras是在TensorFlow上又包装了一层,向简洁易用的深度学习又迈出了坚实的一步. 所以,今天就来带大家写 ...
- 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)
主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
随机推荐
- .NET 日志系统-3 结构化日志和集中日志服务
.NET 日志系统-3 结构化日志和集中日志服务 系列文章 认识.NET 日志系统 https://www.cnblogs.com/ZYPLJ/p/17663487.html .NET 认识日志系统- ...
- python基础-元组tuple( )
元组的定义和操作 元组的特性: 元素数量 支持多个 元素类型 任意 下标索引 支持 重复元素 支持 可修改性 不支持 数据有序 是 使用场景 不可修改.可重复的 一批数据记录场景 # 定义元组 ...
- 基于Bootstrap Blazor开源的.NET通用后台权限管理系统
前言 今天大姚给大家分享一个基于Bootstrap Blazor开源的.NET通用后台权限管理系统,后台管理页面兼容所有主流浏览器,完全响应式布局(支持电脑.平板.手机等所有主流设备),可切换至 Bl ...
- github中的子模块(git submodule)
git中支持引用另外一个开源库,并且可以指定依赖的分支或者提交记录号. 比如fltk-rs 库的fltk-sys模块依赖了库 cfltk 并指明了依赖的提交是 8a56507 甚至可以嵌套,毕竟库自身 ...
- Python 潮流周刊#59:Polars 1.0 发布了,PyCon US 2024 演讲视频也发布了(摘要)
本周刊由 Python猫 出品,精心筛选国内外的 250+ 信息源,为你挑选最值得分享的文章.教程.开源项目.软件工具.播客和视频.热门话题等内容.愿景:帮助所有读者精进 Python 技术,并增长职 ...
- 树莓派4B-GPIO控制舵机转动
树莓派4B-GPIO控制舵机转动 硬件需求: 树莓派 舵机 杜邦线 舵机 什么是舵机? 舵机(servomotor)是一种简化版本的伺服电机,是位置伺服的驱动器,能够通过输入PWM信号控制旋转角度,具 ...
- 你使用过 Vuex 吗?
Vuex 是一个专为 Vue.js 应用程序开发的状态(全局数据)管理模式.每一个 Vuex 应用的核心就是 store(仓库)."store" 基本上就是一个容器,它包含着你的应 ...
- 动手学深度学习——CNN应用demo
CNN应用demo CNN实现简单的手写数字识别 import torch import torch.nn.functional as F from torchvision import datase ...
- springboot项目分层
springboot项目分层 一般的项目模块中都有DAO.Entity.Service.Controller层. Entity层:实体层 数据库在项目中的类 Entity层是实体层,也就是所谓的mod ...
- Unity 2023/Unity 6编辑器文字模糊的解决方案
这是从2023.1开始就有的问题了.本质原因是Unity不知道哪个天才决定的在编辑器文字上使用了SDF渲染. 2023.1因为缺乏选项导致几乎不可用:2023.2加了一个锐度选项:后来在论坛里被众人喷 ...