googleNet网络结构

输入网络: 由4个分支网络构成

第一分支: 由1x1的卷积构成

第二分支: 由1x1的卷积,3x3的卷积构成

第三分支: 由1x1的卷积, 5x5的卷积构成

第四分支: 由3x3的最大值池化, 1x1的卷积构成

  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4.  
  5. class BasicConv2d(nn.Module):
  6. def __init__(self, in_channels, out_channels, **kwargs):
  7. super(BasicConv2d, self).__init__()
  8. self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) # 构造卷积层
  9. self.bn = nn.BatchNorm2d(out_channels, eps=0.001) # 构造标准化
  10.  
  11. def forward(self, x):
  12. x = self.conv(x) # 进行卷积操作
  13. x = self.bn(x) # 进行标准化操作
  14. x = F.relu(x) # 进行激活层操作
  15.  
  16. return x
  17.  
  18. class Inception(nn.Module):
  19. def __init__(self, in_channels, pool_features):
  20. super(Inception, self).__init__()
  21. self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) # 1x1的卷积操作
  22.  
  23. self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) # 进行卷积操作
  24. self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
  25.  
  26. self.branch3x3db1_1 = BasicConv2d(in_channels, 64, kernel_size=1)
  27. self.branch3x3db1_2 = BasicConv2d(64, 96, kernel_size=3)
  28. self.branch3x3db1_3 = BasicConv2d(96, 96, kernel_size=3)
  29.  
  30. self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
  31.  
  32. def forward(self, x):
  33. branch1x1 = self.branch1x1(x)
  34.  
  35. branch5x5 = self.branch5x5_1(x)
  36. branch5x5 = self.branch5x5_2(branch5x5)
  37.  
  38. branch3x3db1_1 = self.branch3x3db1_1(x)
  39. branch3x3db1_2 = self.branch3x3db1_2(branch3x3db1_1)
  40. branch3x3db1_3 = self.branch3x3db1_3(branch3x3db1_2)
  41.  
  42. branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
  43. branch_pool = self.branch_pool(branch_pool)
  44. # 进行卷积的叠加操作
  45. outputs = [branch1x1, branch5x5, branch3x3db1_3, branch_pool]
  46. outputs = torch.cat(outputs, dim=1)
  47.  
  48. return outputs

pytorch-googleNet的更多相关文章

  1. GoogLeNet网络的Pytorch实现

    1.文章原文地址 Going deeper with convolutions 2.文章摘要 我们提出了一种代号为Inception的深度卷积神经网络,它在ILSVRC2014的分类和检测任务上都取得 ...

  2. 从头学pytorch(十八):GoogLeNet

    GoogLeNet GoogLeNet和vgg分别是2014的ImageNet挑战赛的冠亚军.GoogLeNet则做了更加大胆的网络结构尝试,虽然深度只有22层,但大小却比AlexNet和VGG小很多 ...

  3. Pytorch1.0入门实战二:LeNet、AleNet、VGG、GoogLeNet、ResNet模型详解

    LeNet 1998年,LeCun提出了第一个真正的卷积神经网络,也是整个神经网络的开山之作,称为LeNet,现在主要指的是LeNet5或LeNet-5,如图1.1所示.它的主要特征是将卷积层和下采样 ...

  4. 深度学习框架PyTorch一书的学习-第六章-实战指南

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...

  5. Keras vs. PyTorch in Transfer Learning

    We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...

  6. 经典的卷积神经网络及其Pytorch代码实现

    1.LeNet LeNet是指LeNet-5,它是第一个成功应用于数字识别的卷积神经网络.在MNIST数据集上,可以达到99.2%的准确率.LeNet-5模型总共有7层,包括两个卷积层,两个池化层,两 ...

  7. pytorch基础学习(一)

    在炼丹师的路上越走越远,开始入手pytorch框架的学习,越炼越熟吧... 1. 张量的创建和操作 创建为初始化矩阵,并初始化 a = torch.empty(, ) #创建一个5*3的未初始化矩阵 ...

  8. [深度学习] pytorch学习笔记(1)(数据类型、基础使用、自动求导、矩阵操作、维度变换、广播、拼接拆分、基本运算、范数、argmax、矩阵比较、where、gather)

    一.Pytorch安装 安装cuda和cudnn,例如cuda10,cudnn7.5 官网下载torch:https://pytorch.org/ 选择下载相应版本的torch 和torchvisio ...

  9. 目标检测Object Detection概述(Tensorflow&Pytorch实现)

    1999:SIFT 2001:Cascades 2003:Bag of Words 2005:HOG 2006:SPM/SURF/Region Covariance 2007:PASCAL VOC 2 ...

  10. Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

    近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...

随机推荐

  1. CentOS如何安装MySQL8.0、创建用户并授权的详细步骤

    # 安装相关软件 yum install -y gcc gcc-c++ openssl openssl-devel ncurses ncurses-devel make cmake # 获取MySQL ...

  2. python初始化定长列表

    >>> lst = ['x' for n in range(5)] >>> print(lst) ['x', 'x', 'x', 'x', 'x'] >> ...

  3. 处理器拦截器(HandlerInterceptor)详解(转)

    简介 SpringWebMVC的处理器拦截器,类似于Servlet开发中的过滤器Filter,用于处理器进行预处理和后处理. 应用场景 1.日志记录,可以记录请求信息的日志,以便进行信息监控.信息统计 ...

  4. 【Day3】1.正则表达式

    1.正则表达式 2.案例 关闭贪婪模式

  5. shell数组处理

    linux shell在编程方面比windows 批处理强大太多,无论是在循环.运算.已经数据类型方面都是不能比较的. 下面是个人在使用时候,对它在数组方面一些操作进行的总结.   1.数组定义   ...

  6. List去重比较

    import java.util.ArrayList; import java.util.HashSet; import java.util.LinkedHashSet; import java.ut ...

  7. Mac上的redis安装与jedis入门

    Redis 是一个开源(BSD许可)的,内存中的数据结构存储系统,它可以用作数据库.缓存和消息中间件 安装与配置 (1) https://redis.io/download下载redis stable ...

  8. linux下devel软件包作用

    devel 包主要是供开发用,至少包括以下2个东西: 头文件 链接库 有的还含有开发文档或演示代码. 以 glib 和 glib-devel 为例: 如果你安装基于 glib 开发的程序,只需要安装 ...

  9. Jmeter设置集合点(并发测试)

    什么是集合点? 让所有请求在不满足条件的时候处于等待状态. 如何实现? 使用jmeter中的同步计时器Synchronizing Timer来实现 集合点的位置 因为集合点是在取样器sampler(例 ...

  10. 计数dp做题笔记

    YCJS 3924 饼干 Description 给定一个长度为\(n\)的序列,序列每个元素的取值为\([1,x]\),现在给定\(q\)个区间,求在所有取值方案中,区间最小值的最大值的期望是多少? ...