Tutorial on GoogleNet based image classification 

2018-06-26 15:50:29

本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

1. GoogleNet 的结构:

  1. class Inception(nn.Module):
  2. def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
  3. super(Inception, self).__init__()
  4. # 1x1 conv branch
  5. self.b1 = nn.Sequential(
  6. nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
  7. nn.BatchNorm2d(kernel_1_x),
  8. nn.ReLU(True),
  9. )
  10.  
  11. # 1x1 conv -> 3x3 conv branch
  12. self.b2 = nn.Sequential(
  13. nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
  14. nn.BatchNorm2d(kernel_3_in),
  15. nn.ReLU(True),
  16. nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
  17. nn.BatchNorm2d(kernel_3_x),
  18. nn.ReLU(True),
  19. )
  20.  
  21. # 1x1 conv -> 5x5 conv branch
  22. self.b3 = nn.Sequential(
  23. nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
  24. nn.BatchNorm2d(kernel_5_in),
  25. nn.ReLU(True),
  26. nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
  27. nn.BatchNorm2d(kernel_5_x),
  28. nn.ReLU(True),
  29. nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
  30. nn.BatchNorm2d(kernel_5_x),
  31. nn.ReLU(True),
  32. )
  33.  
  34. # 3x3 pool -> 1x1 conv branch
  35. self.b4 = nn.Sequential(
  36. nn.MaxPool2d(3, stride=1, padding=1),
  37. nn.Conv2d(in_planes, pool_planes, kernel_size=1),
  38. nn.BatchNorm2d(pool_planes),
  39. nn.ReLU(True),
  40. )
  41.  
  42. def forward(self, x):
  43. y1 = self.b1(x)
  44. y2 = self.b2(x)
  45. y3 = self.b3(x)
  46. y4 = self.b4(x)
  47. return torch.cat([y1,y2,y3,y4], 1)
  1. class GoogLeNet(nn.Module):
  2. def __init__(self):
  3. super(GoogLeNet, self).__init__()
  4. self.pre_layers = nn.Sequential(
  5. nn.Conv2d(3, 192, kernel_size=3, padding=1),
  6. nn.BatchNorm2d(192),
  7. nn.ReLU(True),
  8. )
  9.  
  10. self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
  11. self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
  12.  
  13. self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
  14.  
  15. self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
  16. self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
  17. self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
  18. self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
  19. self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
  20.  
  21. self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
  22. self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
  23.  
  24. self.avgpool = nn.AvgPool2d(8, stride=1)
  25. self.linear = nn.Linear(1024, 10)
  26.  
  27. def forward(self, x):
  28. x = self.pre_layers(x)
  29. x = self.a3(x)
  30. x = self.b3(x)
  31. x = self.max_pool(x)
  32. x = self.a4(x)
  33. x = self.b4(x)
  34. x = self.c4(x)
  35. x = self.d4(x)
  36. x = self.e4(x)
  37. x = self.max_pool(x)
  38. x = self.a5(x)
  39. x = self.b5(x)
  40. x = self.avgpool(x)
  41. x = x.view(x.size(0), -1)
  42. x = self.linear(x)
  43. return x

2. 保存和加载模型:

  1. # 保存和加载整个模型
  2. torch.save(model_object, 'model.pkl')
  3. model = torch.load('model.pkl')
  4.  
  5. # 仅保存和加载模型参数(推荐使用)
  6. torch.save(model_object.state_dict(), 'params.pkl')
  7. model_object.load_state_dict(torch.load('params.pkl'))

Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models的更多相关文章

  1. A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)

    A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON  ...

  2. 图像分类之特征学习ECCV-2010 Tutorial: Feature Learning for Image Classification

    ECCV-2010 Tutorial: Feature Learning for Image Classification Organizers Kai Yu (NEC Laboratories Am ...

  3. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature【枚举二分答案】

    https://codeforces.com/contest/1241/problem/C You are an environmental activist at heart but the rea ...

  4. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature

    链接: https://codeforces.com/contest/1241/problem/C 题意: You are an environmental activist at heart but ...

  5. How to Build Android Applications Based on FFmpeg by An Example

    This is a follow up post of the previous blog How to Build FFmpeg for Android.  You can read the pre ...

  6. 解读(GoogLeNet)Going deeper with convolutions

    (GoogLeNet)Going deeper with convolutions Inception结构 目前最直接提升DNN效果的方法是increasing their size,这里的size包 ...

  7. [论文阅读]Going deeper with convolutions(GoogLeNet)

    本文采用的GoogLenet网络(代号Inception)在2014年ImageNet大规模视觉识别挑战赛取得了最好的结果,该网络总共22层. Motivation and High Level Co ...

  8. Node.js NPM Tutorial: Create, Publish, Extend & Manage

    A module in Node.js is a logical encapsulation of code in a single unit. It's always a good programm ...

  9. Plant Leaves Classification植物叶子分类:基于孪生网络的小样本学习方法

    目录 Abstract Introduction PROPOSED CNN STRUCTURE INITIAL CNN ANALYSIS EXPERIMENTAL STRUCTURE AND ALGO ...

随机推荐

  1. html5-垂直定位

    *{    padding: 0px;    margin: 0px; }#div2{    background: green;    padding: 15px;    width: 200px; ...

  2. sql server 将两列的值合并到另一列

    select top 100 t2.FullName, * from Subject,(select id, isnull(first_name,'') +isnull(middle_name,'') ...

  3. Java多线程-----volatile关键字详解

       volatile原理     Java语言提供了一种稍弱的同步机制,即volatile变量,用来确保将变量的更新操作通知到其他线程.当把变量声明为volatile类型后, 编译器与运行时都会注意 ...

  4. 仿照admin实现一个自定义的增删改查的组件

    1.首先,创建三个项目,app01,app02,stark,在settings里边记得配置.然后举例:在app01的model里边写表,用的db.sqlite3,所以数据库不用再settings里边配 ...

  5. 苹果企业版签名分发相关问题,蒲公英签名,fir.im分发,安装ipa设置信任

    苹果企业版签名分发相关问题,蒲公英签名,fir.im分发,安装ipa设置信任蒲公英 - 高效安全的内测应用发布.管理平台https://www.pgyer.com/app/signature分发版 2 ...

  6. Ajax 知识

    Ajax 为什么要有ajax技术?    传统的web应用,一个简单的操作就要加载整个页面.浪费资源. Ajax  即“Asynchronous Javascript And XML”(异步JavaS ...

  7. java加载配置文件信息

    #基金数据存放根目录fund_save_root_path=E:/fundCrawling #龙虎榜数据存放根目录long_hu_root_path=E:/longHuCrawling #巨潮数据存放 ...

  8. JS中常见原生DOM操作API

    摘自:https://blog.csdn.net/hj7jay/article/details/53389522 几种对象 Node Node是一个接口,中文叫节点,很多类型的DOM元素都是继承于它, ...

  9. String小案例(**)、包装类型和普通数据类型的转换(拆装箱)

    ###String用法: package StringTest; /**功能: * 判断Java文件名是否正确,判断邮箱格式是否正确 * 其中:合法的文件名应该以.java结尾 * 合法的邮箱名至少包 ...

  10. Prometheus监控学习笔记之教程推荐

    最近学习K8S和基于容器的监控,发现了如下的教程质量不错,记录下来以备参考 1. K8S最佳实战(包括了K8S的Prometheus监控和EFK日志搜集) https://jimmysong.io/k ...