Tutorial on GoogleNet based image classification 

2018-06-26 15:50:29

本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

1. GoogleNet 的结构:

 class Inception(nn.Module):
def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
super(Inception, self).__init__()
# 1x1 conv branch
self.b1 = nn.Sequential(
nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
nn.BatchNorm2d(kernel_1_x),
nn.ReLU(True),
) # 1x1 conv -> 3x3 conv branch
self.b2 = nn.Sequential(
nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
nn.BatchNorm2d(kernel_3_in),
nn.ReLU(True),
nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_3_x),
nn.ReLU(True),
) # 1x1 conv -> 5x5 conv branch
self.b3 = nn.Sequential(
nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
nn.BatchNorm2d(kernel_5_in),
nn.ReLU(True),
nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
nn.BatchNorm2d(kernel_5_x),
nn.ReLU(True),
) # 3x3 pool -> 1x1 conv branch
self.b4 = nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
nn.BatchNorm2d(pool_planes),
nn.ReLU(True),
) def forward(self, x):
y1 = self.b1(x)
y2 = self.b2(x)
y3 = self.b3(x)
y4 = self.b4(x)
return torch.cat([y1,y2,y3,y4], 1)
class GoogLeNet(nn.Module):
def __init__(self):
super(GoogLeNet, self).__init__()
self.pre_layers = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=3, padding=1),
nn.BatchNorm2d(192),
nn.ReLU(True),
) self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) self.avgpool = nn.AvgPool2d(8, stride=1)
self.linear = nn.Linear(1024, 10) def forward(self, x):
x = self.pre_layers(x)
x = self.a3(x)
x = self.b3(x)
x = self.max_pool(x)
x = self.a4(x)
x = self.b4(x)
x = self.c4(x)
x = self.d4(x)
x = self.e4(x)
x = self.max_pool(x)
x = self.a5(x)
x = self.b5(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x

2. 保存和加载模型:

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl') # 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models的更多相关文章

  1. A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)

    A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON  ...

  2. 图像分类之特征学习ECCV-2010 Tutorial: Feature Learning for Image Classification

    ECCV-2010 Tutorial: Feature Learning for Image Classification Organizers Kai Yu (NEC Laboratories Am ...

  3. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature【枚举二分答案】

    https://codeforces.com/contest/1241/problem/C You are an environmental activist at heart but the rea ...

  4. Codeforces Round #591 (Div. 2, based on Technocup 2020 Elimination Round 1) C. Save the Nature

    链接: https://codeforces.com/contest/1241/problem/C 题意: You are an environmental activist at heart but ...

  5. How to Build Android Applications Based on FFmpeg by An Example

    This is a follow up post of the previous blog How to Build FFmpeg for Android.  You can read the pre ...

  6. 解读(GoogLeNet)Going deeper with convolutions

    (GoogLeNet)Going deeper with convolutions Inception结构 目前最直接提升DNN效果的方法是increasing their size,这里的size包 ...

  7. [论文阅读]Going deeper with convolutions(GoogLeNet)

    本文采用的GoogLenet网络(代号Inception)在2014年ImageNet大规模视觉识别挑战赛取得了最好的结果,该网络总共22层. Motivation and High Level Co ...

  8. Node.js NPM Tutorial: Create, Publish, Extend & Manage

    A module in Node.js is a logical encapsulation of code in a single unit. It's always a good programm ...

  9. Plant Leaves Classification植物叶子分类:基于孪生网络的小样本学习方法

    目录 Abstract Introduction PROPOSED CNN STRUCTURE INITIAL CNN ANALYSIS EXPERIMENTAL STRUCTURE AND ALGO ...

随机推荐

  1. STL容器之vector

    [1]模板类vector 模板类vector可理解为广义数组.广义数组,即与类型无关的数组,具有与数组相同的所有操作. 那么,你或许要问:既然C++语言本身已提供了一个序列式容器array,为什么还要 ...

  2. Linux基础命令---arping

    arping arping指令用于发送arp请求到一个相邻的主机,在指定网卡上发送arp请求指定地址,源地址使用-s指定.该指令可以直径ping MAC地址,找出哪些地址被哪些电脑使用了. 此命令的适 ...

  3. 51Nod 1212 无向图最小生成树 (路径压缩)

    N个点M条边的无向连通图,每条边有一个权值,求该图的最小生成树.   Input 第1行:2个数N,M中间用空格分隔,N为点的数量,M为边的数量.(2 <= N <= 1000, 1 &l ...

  4. 使用Fiddler测试WebApi接口

    Fiddler是好用的WebApi调试工具之一,它能记录所有客户端和服务器的http和https请求,允许你监视,设置断点,甚至修改输入输出数据,Fiddler 是以代理web服务器的形式工作的,使用 ...

  5. 清明 DAY2

    数论 数论是研究整数性质的东西 也就是 lim   π(x)=x/ ln x (x->无穷) 证明: ∵ p|ab ∴ ab有因子p 设 a=p1k1p2k2......prkr      b= ...

  6. APIClound 弹出层 Frame

    JS api.openFrame({ name: 'showPic', url: './showPic.html', rect: { // x: api.pageParam.marginBottom, ...

  7. eclipse 的版本及下载地址

    eclipse 的各个版本号: 版本号 代号 代号名 发布日期 Eclipse 3.1 IO 木卫一,伊奥 2005 Eclipse 3.2 Callisto 木卫四,卡里斯托 2006 Eclips ...

  8. maven 项目 查询部分关心的字段

    <?xml version="1.0" encoding="UTF-8" ?> <!DOCTYPE configuration PUBLIC ...

  9. Android Camera2 预览,拍照,人脸检测并实时展现

    https://www.jianshu.com/p/5414ba2b5508 背景     最近需要做一个人脸检测并实时预览的功能.就是边检测人脸,边在预览界面上框出来.     当然本人并不是专门做 ...

  10. PHP 变量类型的强制转换 & 创建空对象

    PHP 在变量定义中不需要(或不支持)明示的类型定义:变量类型是根据使用该变量的上下文所决定的. 也就是说,如果把一个字符串值赋给变量 var,var 就成了一个字符串.如果又把一个整型值赋给 var ...