《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

Basic Convolution Neural Network

1、全连接网络

线性层串行—全连接网络

每一个输入和输出都有权重--全连接层

全连接网络在处理图像时,直接将每一行像素拼接成向量,丧失了图像的空间结构

2、CNN结构

CNN在处理图像时,保留了图像的空间结构信息

卷积神经网络:卷积运算(特征提取)à转换成向量à全连接网络(分类)

3、卷积过程

1×28×28是C(channle)×W(width)×H(Hight),就是通道数×图像宽度×图像高度

 

①单通道卷积(矩阵数乘)

②三通道卷积

③N通道卷积

每一个卷积核的通道数量 = 输入的通道数量

卷积核的个数 = 输出的通道数量

 4、下采样(subsampling)---Max Pooling

下采样的目的是减少特征图像的数据量,降低运算需求。在下采样过程中,通道数(Channel)保持不变,图像的宽度和高度发生改变

5、全连接层

先将原先多维的卷积结果通过全连接层转为一维的向量,再通过多层全连接层将原向量转变为可供输出的向量。

卷积和下采样都是在特征提取

全连接层才是分类

6、CNN

①卷积操作

Pytorch输入数据必须是小批量数据,设置batch_size

需要确定的值:输入的通道(in_channels)、输出的通道(out_channels)、卷积核的大小(kernel_size:3x3)

②Padding,向外填充

③Stride—步长

有效降低图像的宽度和高度

④下采样:Max Pooling Layer

默认Stride=2

⑤整体结构

⑥用CPU或GPU进行模型的训练和测试

torch.device

完整代码

  1. import torch
  2. from torchvision import transforms
  3. from torchvision import datasets
  4. from torch.utils.data import DataLoader
  5. import torch.nn.functional as F
  6. import torch.optim as optim
  7.  
  8. # prepare dataset
  9.  
  10. batch_size = 64
  11. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  12.  
  13. train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
  14. train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
  15. test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
  16. test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
  17.  
  18. # design model using class
  19.  
  20. class Net(torch.nn.Module):
  21. def __init__(self):
  22. super(Net, self).__init__()
  23. self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
  24. self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
  25. self.pooling = torch.nn.MaxPool2d(2)
  26. self.fc = torch.nn.Linear(320, 10)
  27.  
  28. def forward(self, x):
  29. # flatten data from (n,1,28,28) to (n, 784)
  30. batch_size = x.size(0)
  31. x = F.relu(self.pooling(self.conv1(x)))
  32. x = F.relu(self.pooling(self.conv2(x)))
  33. x = x.view(batch_size, -1) # -1 此处自动算出的是320
  34. x = self.fc(x)
  35.  
  36. return x
  37.  
  38. model = Net()
  39. ## Device—选择是用GPU还是用CPU训练
  40. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  41. model.to(device)
  42.  
  43. # construct loss and optimizer
  44. criterion = torch.nn.CrossEntropyLoss()
  45. optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
  46.  
  47. # training cycle forward, backward, update
  48.  
  49. def train(epoch):
  50. running_loss = 0.0
  51. for batch_idx, data in enumerate(train_loader, 0):
  52. inputs, target = data
  53. inputs, target = inputs.to(device), target.to(device)
  54. optimizer.zero_grad()
  55.  
  56. outputs = model(inputs)
  57. loss = criterion(outputs, target)
  58. loss.backward()
  59. optimizer.step()
  60.  
  61. running_loss += loss.item()
  62. if batch_idx % 300 == 299:
  63. print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
  64. running_loss = 0.0
  65.  
  66. def test():
  67. correct = 0
  68. total = 0
  69. with torch.no_grad():
  70. for data in test_loader:
  71. images, labels = data
  72. images, labels = images.to(device), labels.to(device)
  73. outputs = model(images)
  74. _, predicted = torch.max(outputs.data, dim=1)
  75. total += labels.size(0)
  76. correct += (predicted == labels).sum().item()
  77. print('accuracy on test set: %d %% ' % (100*correct/total))
  78.  
  79. if __name__ == '__main__':
  80. for epoch in range(10):
  81. train(epoch)
  82. test()

运行结果

Pytorch实战学习(六):基础CNN的更多相关文章

  1. Java学习 (六)基础篇 类型转换

    类型转换 由于Java是强类型语言,所以要进行有些运算的时候,需要用到类型转换 字节大小(容量)-> 低--------------------------------------------- ...

  2. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

  3. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  4. Spring实战第六章学习笔记————渲染Web视图

    Spring实战第六章学习笔记----渲染Web视图 理解视图解析 在之前所编写的控制器方法都没有直接产生浏览器所需的HTML.这些方法只是将一些数据传入到模型中然后再将模型传递给一个用来渲染的视图. ...

  5. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

  6. 参考《深度学习之PyTorch实战计算机视觉》PDF

    计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...

  7. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  8. Docker虚拟化实战学习——基础篇(转)

    Docker虚拟化实战学习——基础篇 2018年05月26日 02:17:24 北纬34度停留 阅读数:773更多 个人分类: Docker   Docker虚拟化实战和企业案例演练 深入剖析虚拟化技 ...

  9. Pytorch_第六篇_深度学习 (DeepLearning) 基础 [2]---神经网络常用的损失函数

    深度学习 (DeepLearning) 基础 [2]---神经网络常用的损失函数 Introduce 在上一篇"深度学习 (DeepLearning) 基础 [1]---监督学习和无监督学习 ...

  10. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

随机推荐

  1. 2022.2.1最新版本的IDEA

          一.下载破解工具.激活码 激活工具下载链接:https://note.youdao.com/s/1ANz2F3o   6G5NXCPJZB-eyJsaWNlbnNlSWQiOiI2RzVO ...

  2. javaEE(单元测试、反射、动态代理、xml)

    单元测试 最小的功能单元编写测试代码,java针对方法,检查方法的正确性 JUnit单元测试框架 @Test注解 public class A { @Test public void a(){ ... ...

  3. Zstack迁移实战记录1

    https://blog.csdn.net/weixin_43767046/article/details/113748775 这段时间除了那个重度烤机测试(上面链接),还在做另一件事,想再做一个服务 ...

  4. CF1625D.Binary Spiders

    \(\text{Problem}\) 大概就是给出 \(n\) 个数和 \(m\),要从中选最多的数使得两两异或值大于等于 \(m\) 输出方案 \(\text{Solution}\) 一开始的想法很 ...

  5. [COCI2015-2016#2] VUDU

    题目传送门 思路 这是一种简单的树状数组解法. 我们设偏移值表示 \(a_i\) 与目标平均数 \(p\) 的差值,显然,一个区间若能满足条件,需要满足此区间的偏移值之和 \(\ge 0\). 看到区 ...

  6. 梅毒感染者能否应用TNF抑制剂

    对于伴发的未经控制的任何严重感染,都不适合使用TNF抑制剂.在1998年国际上首个TNF抑制剂获批治疗类风湿关节炎(RA)以来,这就是广大临床医生和风湿性疾病患者的共识.在临床实践中,需要权衡药物的利 ...

  7. git添加多账户(附带tortoiseGit多账号使用)

    近期想在公司电脑上开发自己项目,但是电脑上已经配置过一个gitlab账户了,现在想要把自己的git账户也加进来,方便代码控制. 因为git用的比较少,还不太熟悉,都是网上找资料,边看边学边做,如有不对 ...

  8. vue浏览器全屏 非全屏切换 按esc退出全屏

    methods: { // 全屏 点击按钮        allping(){        this.status = true;        this.enterFullscreen();    ...

  9. PostgreSQL 实现快速删除一个用户

    一.具体方法 一般情况下直接执行 drop role xxx; 就可以把这个用户删除.但是很多时候会因为用户有依赖而报错. 二.权限依赖 postgres=# create role test wit ...

  10. 【Java-01-1】java基础-基本语法(1)(基本输入输出,计算)

    1.基本输出语句 /* * java * 多行注释 */ //java单行注释 public class _01_HelloWorld { public static void main(String ...