Smiling & Weeping

                  ---- 祝你想我

                    在平静的湖面

                   不止在失控的雪山之前

说明:Inception Module

1. 卷积核超参数选择困难,自动找到卷积的最佳组合

2. 1x1卷积核,不同通道的信息融合。使用1x1卷积核可以调节通道数量,可以显著降低计算量

3. Inception Module由四个分支组成,要分清哪些是在init里定义的,那些是在forward里调用的。4个分支在dim=1(channels)上进行concatenate

  1 import torch
2 import torch.nn as nn
3 from torchvision import transforms
4 from torchvision import datasets
5 from torch.utils.data import DataLoader
6 import torch.nn.functional as F
7 import torch.optim as optim
8 import matplotlib.pyplot as plt
9
10 batch_size = 64
11 # 归一化,均值和方差
12 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
13
14 train_dataset = datasets.MNIST(root='../dataset/mnist', train=True, download=True, transform=transform)
15 train_loader = DataLoader(train_dataset, shuffle=True,batch_size=batch_size)
16 test_dataset = datasets.MNIST(root='../dataset/mnist', train=False, download=True, transform=transform)
17 test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size)
18
19 # design model using class
20 class InceptionA(nn.Module):
21 def __init__(self, in_channels):
22 super(InceptionA, self).__init__()
23 self.branch1x1 = nn.Conv2d(in_channels, 16, kernel_size=1)
24
25 self.branch5x5_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
26 self.branch5x5_2 = nn.Conv2d(16, 24, kernel_size=5, padding=2)
27
28 self.branch3x3_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
29 self.branch3x3_2 = nn.Conv2d(16, 24, kernel_size=3, padding=1)
30 self.branch3x3_3 = nn.Conv2d(24, 24, kernel_size=3, padding=1)
31
32 self.branch_pool = nn.Conv2d(in_channels, 24, kernel_size=1)
33
34 def forward(self, x):
35 branch1x1 = self.branch1x1(x)
36
37 branch5x5 = self.branch5x5_1(x)
38 branch5x5 = self.branch5x5_2(branch5x5)
39
40 branch3x3 = self.branch3x3_1(x)
41 branch3x3 = self.branch3x3_2(branch3x3)
42 branch3x3 = self.branch3x3_3(branch3x3)
43
44 branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
45 branch_pool = self.branch_pool(branch_pool)
46
47 outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
48 return torch.cat(outputs, dim=1)
49
50 class Net(nn.Module):
51 def __init__(self):
52 super(Net, self).__init__()
53 self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
54 self.conv2 = nn.Conv2d(88, 20, kernel_size=5) # 88=24*3+16
55
56 self.incep1 = InceptionA(in_channels=10) # conv1 中的10对应
57 self.incep2 = InceptionA(in_channels=20) # conv2 中的20对应
58
59 self.mp = nn.MaxPool2d(2)
60 self.fc = nn.Linear(1408, 10)
61
62 def forward(self, x):
63 in_size = x.size(0)
64 x = F.relu(self.mp(self.conv1(x)))
65 x = self.incep1(x)
66 x = F.relu(self.mp(self.conv2(x)))
67 x = self.incep2(x)
68 x = x.view(in_size, -1)
69 x = self.fc(x)
70
71 return x
72
73 model = Net()
74 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75 model.to(device)
76
77 # 定义优化器 和 损失
78 criterion = torch.nn.CrossEntropyLoss()
79 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
80 # print(model.parameters())
81
82 def train(epoch):
83 run_loss = 0.0
84 for batch_idx, data in enumerate(train_loader, 0):
85 inputs, target = data
86 inputs, target = inputs.to(device), target.to(device)
87 optimizer.zero_grad()
88
89 outputs = model(inputs)
90 loss = criterion(outputs, target)
91 loss.backward()
92 optimizer.step()
93
94 run_loss += loss.item()
95 if batch_idx%300 == 299:
96 print('[%d %5d] loss: %.3f' % (epoch+1, batch_idx+1, run_loss/300))
97 run_loss = 0.0
98
99 def test():
100 correct = 0
101 total = 0
102 with torch.no_grad():
103 for data in test_loader:
104 images, labels = data
105 images, labels = images.to(device), labels.to(device)
106 outputs = model(images)
107 _, prediction = torch.max(outputs.data, dim=1)
108 total += labels.size(0)
109 correct += (prediction == labels).sum().item()
110 print('accuracy on test set: %d %%' % (100*correct/total))
111 return correct/total
112
113 epoch_list = []
114 acc_list = []
115 for epoch in range(10):
116 train(epoch)
117 acc = test()
118 epoch_list.append(epoch)
119 acc_list.append(acc)
120
121 plt.plot(epoch_list, acc_list)
122 plt.ylabel('accuracy')
123 plt.xlabel('epoch')
124 plt.show()
125
126 class DatasetSubmissionMNIST(torch.utils.data.Dataset):
127 def __init__(self, file_path, transform=None):
128 self.data = pd.read_csv(file_path)
129 self.transform = transform
130
131 def __len__(self):
132 return len(self.data)
133
134 def __getitem__(self, index):
135 image = self.data.iloc[index].values.astype(np.uint8).reshape((28, 28, 1))
136
137
138 if self.transform is not None:
139 image = self.transform(image)
140
141 return image
142
143 transform = transforms.Compose([
144 transforms.ToPILImage(),
145 transforms.ToTensor(),
146 transforms.Normalize(mean=(0.5,), std=(0.5,))
147 ])
148
149 submissionset = DatasetSubmissionMNIST('/kaggle/input/digit-recognizer/test.csv', transform=transform)
150 submissionloader = torch.utils.data.DataLoader(submissionset, batch_size=batch_size, shuffle=False)
151
152 submission = [['ImageId', 'Label']]
153
154 with torch.no_grad():
155 model.eval()
156 image_id = 1
157
158 for images in submissionloader:
159 images = images.cuda()
160 log_ps = model(images)
161 ps = torch.exp(log_ps)
162 top_p, top_class = ps.topk(1, dim=1)
163
164 for prediction in top_class:
165 submission.append([image_id, prediction.item()])
166 image_id += 1
167
168 print(len(submission) - 1)
169 import csv
170
171 with open('submission.csv', 'w') as submissionFile:
172 writer = csv.writer(submissionFile)
173 writer.writerows(submission)
174
175 print('Submission Complete!')
176 # submission.to_csv('/kaggle/working/submission.csv', index=False)

文章到此结束,我们下次再见

-- 大地的芳香来自每一株野草的献祭

CNN --Inception Module的更多相关文章

  1. Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models

    Tutorial on GoogleNet based image classification  2018-06-26 15:50:29 本文旨在通过案例来学习 GoogleNet 及其 Incep ...

  2. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

  3. 【转】CNN卷积神经网络_ GoogLeNet 之 Inception(V1-V4)

    http://blog.csdn.net/diamonjoy_zone/article/details/70576775 参考: 1. Inception[V1]: Going Deeper with ...

  4. CNN卷积神经网络_深度残差网络 ResNet——解决神经网络过深反而引起误差增加的根本问题,Highway NetWork 则允许保留一定比例的原始输入 x。(这种思想在inception模型也有,例如卷积是concat并行,而不是串行)这样前面一层的信息,有一定比例可以不经过矩阵乘法和非线性变换,直接传输到下一层,仿佛一条信息高速公路,因此得名Highway Network

    from:https://blog.csdn.net/diamonjoy_zone/article/details/70904212 环境:Win8.1 TensorFlow1.0.1 软件:Anac ...

  5. AI:IPPR的数学表示-CNN结构进化(Alex、ZF、Inception、Res、InceptionRes)

    前言: 文章:CNN的结构分析-------:  文章:历年ImageNet冠军模型网络结构解析-------: 文章:GoogleLeNet系列解读-------: 文章:DNN结构演进Histor ...

  6. 经典分类CNN模型系列其五:Inception v2与Inception v3

    经典分类CNN模型系列其五:Inception v2与Inception v3 介绍 Inception v2与Inception v3被作者放在了一篇paper里面,因此我们也作为一篇blog来对其 ...

  7. 【机器学习】彻底搞懂CNN

    之前通过各种博客视频学习CNN,总是对参数啊原理啊什么的懵懵懂懂..这次上课终于弄明白了,O(∩_∩)O~ 上世纪科学家们发现了几个视觉神经特点,视神经具有局部感受眼,一整张图的识别由多个局部识别点构 ...

  8. AndrewNG Deep learning课程笔记 - CNN

    参考, An Intuitive Explanation of Convolutional Neural Networks http://www.hackcv.com/index.php/archiv ...

  9. 图像分类(三)GoogLenet Inception_v3:Rethinking the Inception Architecture for Computer Vision

    Inception V3网络(注意,不是module了,而是network,包含多种Inception modules)主要是在V2基础上进行的改进,特点如下: 将滤波器尺寸(Filter Size) ...

  10. 详解卷积神经网络(CNN)

    详解卷积神经网络(CNN) 详解卷积神经网络CNN 概揽 Layers used to build ConvNets 卷积层Convolutional layer 池化层Pooling Layer 全 ...

随机推荐

  1. 7张图揭晓RocketMQ存储设计的精髓

    ​简介: RocketMQ 作为一款基于磁盘存储的中间件,具有无限积压能力,并提供高吞吐.低延迟的服务能力,其最核心的部分必然是它优雅的存储设计. 存储概述 RocketMQ 存储的文件主要包括 Co ...

  2. 使用友盟+的APM服务实现对移动端APP的性能监控

    ​简介: 对于信息系统服务,一般我们的重点监控对象都是核心的后端服务,通常会采用一些主流的APM(Application Performance Management)框架进行监控.告警.分析.那么对 ...

  3. 快速界定故障:Socket Tracer网络监控实践

    ​ 简介: Socket Tracer定位是传输层(Socket&TCP)的指标采集工具,通过补齐网络监控的这部分盲区,来达到快速界定网络问题的目标. ​ 作者 | 四忌 来源 | 阿里技术公 ...

  4. Serverless 工程实践 | 细数 Serverless 的配套服务

    ​简介: 上文说到云计算的十余年发展让整个互联网行业发生了翻天覆地的变化,Serverless 作为云计算的产物,或者说是云计算在某个时代的表现,被很多人认为是真正意义上的云计算,关于"Se ...

  5. 举例useContext性能低下的样例,同时推荐用什么方法改进

    在React中,useContext 是一种非常方便的全局状态管理工具,它可以让我们在组件之间共享状态,而不需要通过层层传递 props.然而,当我们在一个大型的 React 应用中过度使用 useC ...

  6. uniapp+vue3聊天室|uni-app+vite4+uv-ui跨端仿微信app聊天语音/朋友圈

    原创研发uniapp+vue3+pinia2跨三端仿微信app聊天模板Uniapp-Wechat. uni-vue3-wchat基于uni-app+vue3+pinia2+uni-ui+uv-ui等技 ...

  7. 【git】建立分支

    1.git clone现有的项目 git clone git@github.com:zhangshengdong/Zflask.git 2.建立关联 git remote add origin git ...

  8. 解决Host key verification failed.(亲测有效)

    哈喽哇,今天在访问远程服务器的时候,出现了一个小问题. 原因:之前ssh联系过服务器,重置服务器后,再次连接服务器,就会出这个问题. 一.发现问题 问题如下图代码: $ ssh root@108.61 ...

  9. 02.go-admin IDE配置配置命令启动方式讲解笔记

    目录 go-admin版本 视频地址 一.代码地址 二.在线文档 三.首次配置需要初始化数据库资源信息(已初始化过数据库的,跳过此步) 配置数据库迁移 五.配置启动项目,用goland IDE进行启动 ...

  10. 那什么是URL、URI、URN?

    URI Uniform Resource Identifier 统一资源标识符 URL Uniform Resource Locator 统一资源定位符 URN Uniform Resource Na ...