Pytorch入门中 —— 搭建网络模型
本节内容参照小土堆的pytorch
入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络。学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快、更可靠。讲解的内容主要是pytorch
核心包中TORCH.NN
中的内容(nn
是Neural Netwark
的缩写)。
通常,我们定义的神经网络模型会继承torch.nn.Module
类,该类为我们定义好了神经网络骨架。
卷积层
对于图像处理来说,我们通常使用二维卷积,即使用torch.nn.Conv2d
类:
创建该类时,我们通常只需要传入以下几个参数,其他不常用参数入门时可以不做了解,使用默认值即可,以后需要时再查询文档:
in_channels (int):输入数据的通道数,图片通常为3
out_channels (int):输出数据的通道数,也是卷积核的个数
kernel_size (int or tuple):卷积核大小,传入int表示正方形,传入tuple代表高和宽
stride (int or tuple, optional):卷积操作的步长,传入int代表横向和纵向步长相同,默认为1
padding (int, tuple or str, optional):填充厚度,传入int代表上下左右四个边填充厚度相同,默认为0,即不填充
padding_mode (string, optional):填充模式,默认为'zeros',即0填充
卷积操作后输出的张量的高和宽计算公式如下:
其中input
和output
中的N
代表BatchSize
,C
代表通道数,他们不影响H
和W
的计算。在保持dilation
为默认值1
的情况下,计算公式可简化为如下:
\]
\]
池化层
常用的二维最大池化定义在torch.nn.MaxPool2d
类中:
创建该类时,我们通常只需要传入以下几个参数,其他不常用参数入门时可以不做了解,使用默认值即可,以后需要时再查询文档:
kernel_size:池化操作时的窗口大小
stride:池化操作时的步长,默认为kernel_size
padding:每个边的填充厚度(0填充)
池化操作后输出的张量的高和宽计算公式与卷积操作后的计算公式相同。
非线性激活
常见的ReLU
激活定义在torch.nn.ReLU
类中:
参数inplace
代表是否将ouput
直接修改在input
中。
线性层
线性层的定义在torch.nn.Linear
类中:
创建线性层使用的参数如下:
in_features:输入特征大小
out_features:输出特征大小
bias:是否添加偏置,默认为True
模型搭建示例
下图是一个CIFAR10
数据集上的分类模型,下面将根据图片进行模型代码的编写。
1.由于CIFAR10
数据集中图片为3*32*32
,所以图中模型的输入为3通道,高宽都为32的张量。
2.使用 5*5
的卷积核进行卷积操作,得到通道数为32
,高和宽为32
的张量。因此我们可以推出该卷积层的参数如下:
in_channels = 3
out_channels = 32
kernel_size = 5
stride = 1
padding = 2
注:将 Hin = 32,Hout = 32 以及kernal_size[0] = 5
三个参数带入:
\]
有:
\]
发现stride[0] = 1
和padding[0] = 2
可以使得等式成立。同理可以得到stride[1] = 1
和padding[1] = 2
。
3.使用2*2
的核进行最大池化操作,得到通道数为32
,高和宽为16
的张量。可以推出该池化层的参数如下:
kernel_size = 2
stride = 2
padding = 0
注:stride
和padding
推导方式与2中相同。
4.使用 5*5
的卷积核进行卷积操作,得到通道数为32
,高和宽为16
的张量。因此我们可以推出该卷积层的参数如下:
in_channels = 32
out_channels = 32
kernel_size = 5
stride = 1
padding = 2
5.使用2*2
的核进行最大池化操作,得到通道数为32
,高和宽为8
的张量。可以推出该池化层的参数如下:
kernel_size = 2
stride = 2
padding = 0
6.使用 5*5
的卷积核进行卷积操作,得到通道数为64
,高和宽为8
的张量。因此我们可以推出该卷积层的参数如下:
in_channels = 32
out_channels = 64
kernel_size = 5
stride = 1
padding = 2
7.使用2*2
的核进行最大池化操作,得到通道数为64
,高和宽为4
的张量。可以推出该池化层的参数如下:
kernel_size = 2
stride = 2
padding = 0
8.将64*4*4
的张量进行展平操作得到长为1024
的向量。
9.将长为1024
的向量进行线性变换得到长为64
的向量(隐藏层),可以推出该线性层的参数如下:
in_features:1024
out_features:64
10.将长为64
的向量进行线性变换得到长为10
的向量,可以推出该线性层的参数如下:
in_features:64
out_features:10
因此,模型代码如下:
from torch import nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 5, padding=2)
self.max_pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 32, 5, padding=2)
self.max_pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
self.max_pool3 = nn.MaxPool2d(2)
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(1024, 64)
self.linear2 = nn.Linear(64, 10)
# 必须覆盖该方法,该方法会在实例像函数一样调用时被调用,后面会有示例
def forward(self, x):
x = self.conv1(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = self.max_pool2(x)
x = self.conv3(x)
x = self.max_pool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
sequential
使用torch.nn.sequential
可以简化模型的搭建代码,他是一个顺序存放Module
的容器。当sequential
执行时,会按照Module
在构造函数中的先后顺序依次调用,前面Module
的输出会作为后面Module
的输入。
使用sequential
,上一节的代码可以简化为:
from torch import nn
class MyModel(nn.Module):
def __init__(self):
super.__init__(MyModel, self)
self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
损失函数、反向传播以及优化器
上面两节我们已经将CIFAR10
的分类模型搭建好,但还需要进行训练后才能用来预测分类。训练模型时,会用损失函数来衡量模型的好坏,并利用反向传播来求梯度,然后利用优化器对模型参数进行梯度下降,多次循环往复以训练出最优的模型。
模型训练代码如下:
import torch
from torch.optim import SGD
import torchvision
from torch.utils.data import DataLoader
from cifar10_model import MyModel
from torch import nn
from torch.utils import tensorboard
def train():
# 获取 cifar10 数据集
root = "./dataset"
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_cifar10 = torchvision.datasets.CIFAR10(root=root, train=True,
transform=transform,
download=True)
# 创建dataloader
train_dataloader = DataLoader(dataset=train_cifar10, batch_size=64,
shuffle=True,
num_workers=16)
# 创建模型
model = MyModel()
# 创建交叉熵损失函数
loss = nn.CrossEntropyLoss()
# 创建优化器,传入需要更新的参数,以及学习率
optim = SGD(model.parameters(), lr=0.01)
# 创建 SummaryWriter
writer = tensorboard.SummaryWriter("logs")
# 写入模型图,随机生成一个输入
writer.add_graph(model, torch.randn(64, 3, 32, 32))
for epoch in range(20):
loss_temp = 0.0
for batch_num, batch_data in enumerate(train_dataloader):
images, targets = batch_data
# 像调用方法一样调用实例
outputs = model(images)
loss_res = loss(outputs, targets)
loss_temp = loss_res
# 清空前一次计算的梯度
optim.zero_grad()
# 反向传播求梯度
loss_res.backward()
# 更新参数
optim.step()
# 记录每个epoch之后的loss
writer.add_scalar("Loss/train", loss_temp, epoch)
writer.close()
if __name__ == "__main__":
train()
模型图如下:
损失函数随训练周期的下降情况如下:
Pytorch入门中 —— 搭建网络模型的更多相关文章
- Pytorch入门随手记
Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ...
- Pytorch入门下 —— 其他
本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- 架构师入门:搭建双注册中心的高可用Eureka架构(基于项目实战)
本文的案例是基于 架构师入门:搭建基本的Eureka架构(从项目里抽取) 改写的. 在上文里,我们演示Eureka客户端调用服务的整个流程,在这部分里我们将在架构上有所改进.大家可以想象下,在上文里案 ...
- Spring Cloud 入门教程 - 搭建配置中心服务
简介 Spring Cloud 提供了一个部署微服务的平台,包括了微服务中常见的组件:配置中心服务, API网关,断路器,服务注册与发现,分布式追溯,OAuth2,消费者驱动合约等.我们不必先知道每个 ...
- ArcGIS API for Silverlight/ 开发入门 环境搭建
Silverlight/ 开发入门 环境搭建1 Silverlight SDK下载ArcGIS API for Microsoft Silverlight/WPF ,需要注册一个ESRI Gloab ...
- Spring MVC+Spring+Mybatis+MySQL(IDEA)入门框架搭建
目录 Spring MVC+Spring+Mybatis+MySQL(IDEA)入门框架搭建 0.项目准备 1.数据持久层Mybatis+MySQL 1.1 MySQL数据准备 1.2 Mybatis ...
- 基于flask的轻量级webapi开发入门-从搭建到部署
基于flask的轻量级webapi开发入门-从搭建到部署 注:本文的代码开发工作均是在python3.7环境下完成的. 关键词:python flask tornado webapi 在python虚 ...
随机推荐
- 浏览器 Proxy SwitchyOmega 插件设置代理访问内网服务器
使用Proxy SwitchyOmega 插件通过代理 直接访问到内网网站 一.使用场景 如下图所示,如果在电脑的网络设置中开启代理,每次更换代理就需要进入这里设置改变代理.且我们可能回需求到两个网页 ...
- [loj3256]火灾
将问题差分,即求$\sum_{i=1}^{r}S_{i}(t)-\sum_{i=1}^{l-1}S_{i}(t)$,由于两者类似,不妨考虑前者 构造矩阵$A_{i,j}=S_{j}(i)-S_{j}( ...
- [noi1779]D
先离散,然后将黑的看成1,白的看成-1,对整个序列差分,所有区间建为$(l,r+1)$的无向边,并标上-1和1,每一个点的前缀和即为该点的值 考虑什么情况下能够使得所有点都是0:当且仅当每一个点的度数 ...
- 消息抽象层设计和实现-OSS.DataFlow
前面已经介绍了消息生产消费中间类库(OSS.DataFlow)的简单使用,这篇主要介绍内部的设计实现.主要内容包含: 1. 消息生产消费的抽象设计. 2. 具体使用示例 一. 消息生产消费的抽象设计. ...
- k8s-Pod污点与容忍
目录 Pod污点与容忍 大白话先解释一下污点与容忍 为什么要用污点和容忍? 官方解释 Taints参数 标记污点 容忍污点 取消所有节点污点 Pod污点与容忍 大白话先解释一下污点与容忍 污点:被打上 ...
- 【Spring】(1)-- 概述
Spring框架 -- 概述 2019-07-07 22:40:42 by冲冲 1. Spring的概念 ① Spring框架的关键词:开源框架.轻量级框架.JavaEE/J2EE开发框架.企业级 ...
- UI自动化测试:App的WebView页面中,当搜索栏无搜索按钮时处理方法
一.遇到的问题 在做移动端的UI自动化测试时,经常会遇到上图所示的搜索框,这里有个麻烦就是搜索框没有"搜索"按钮,UI自动化测试时不能确认搜索. 要解决这个问题,我们可以通过 dr ...
- gantt甘特图可拖拽、编辑(vue、react都可用 highcharts)
前言 Excel功能强大,应用广泛.随着web应用的兴起和完善,用户的要求也越来越高.很多Excel的功能都搬到了sass里面.恨不得给他们做个Excel出来...程序员太难了... 去年我遇到了 ...
- vip视频解析保存
无广告通用:https://vip.52jiexi.top/?url= 腾讯直解 无广告解析:https://jx.lfeifei.cn/?url= 无广告解析:https://api.steak51 ...
- banner.txt
Spring Boot Version: ${spring-boot.version} __----~~~~~~~~~~~------___ . . ~~//====...... __--~ ~~ - ...