AlexNet

上图是论文的网络的结构图,包括5个卷积层和3个全连接层,作者还特别强调,depth的重要性,少一层结果就会变差,所以这种超参数的调节可真是不简单.

激活函数

首先讨论的是激活函数,作者选择的不是\(f(x)=\mathrm{tanh}(x)=(1+e^{-x})^{-1}\),而是ReLUs ( Rectified Linear Units)——\(f(x)=\max (0, x)\), 当然,作者考虑的问题是比赛的那个数据集,其网络的收敛速度为:



接下来,作者讨论了标准化的问题,说ReLUs是不需要进行这一步的,论文中的那句话我感觉理解的怪怪的:

ReLUs have the desirable property that they do not require input normalization to prevent them fromsaturating.

饱和?

作者说,也可以对ReLUs进行扩展,使得其更有泛化性,把多个核进行标准化处理:



\(i\)表示核的顺序,\(a_{x,y}^i\)则是其值, 说实话,这部分也没怎么弄懂.

然后是关于池化层的部分,一般的池化层的核是不用重叠的,作者这部分也考虑进去了.

防止过拟合

为了防止过拟合,作者提出了他的几点经验.

增加数据

这个数据不是简单的多找点数据,而是通过一些变换使得数据增加.

比如对图片进行旋转,以及PCA提主成分,改变score等.

Dropout

多个模型,进行综合评价是防止过拟合的好方法,但是训练网络不易,dropout, 即让隐层的神经元以一定的概率输出为0来,所以每一次训练,网络的结构实际上都是不一样的,但是整个网络是共享参数的,所以可以一次性训练多个模型?

细节

batch size: 128

momentum: 0.9

weight decay: 0.0005

一般的随机梯度下降好像是没有weight decay这一部分的,但是作者说,实验中这个的选择还是蛮有效的.

代码

"""
epochs: 50
lr: 0.001
batch_size = 128
在训练集上的正确率达到了97%,
在测试集上的正确率为83%.
"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os class AlexNet(nn.Module): def __init__(self, output_size=10):
super(AlexNet, self).__init__()
self.conv1 = nn.Sequential( # 3 x 227 x 227
nn.Conv2d(3, 96, 11, 4, 0), # 3通道 输出96通道 卷积核为11 x 11 滑动为4 不补零
nn.BatchNorm2d(96),
nn.ReLU()
)
self.conv2 = nn.Sequential( # 96 x 55 x 55
nn.Conv2d(48, 128, 5, 1, 2),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.conv3 = nn.Sequential( # 256 x 27 x 27
nn.Conv2d(256, 192, 3, 1, 1),
nn.BatchNorm2d(192),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.conv4 = nn.Sequential( # 384 x 13 x 13
nn.Conv2d(192, 192, 3, 1, 1),
nn.BatchNorm2d(192),
nn.ReLU()
)
self.conv5 = nn.Sequential( # 384 x 13 x 13
nn.Conv2d(192, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.dense = nn.Sequential(
nn.Linear(9216, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, output_size)
) def forward(self, input):
x = self.conv1(input)
x1, x2 = x[:, :48, :, :], x[:, 48:, :, :] # 拆分
x1 = self.conv2(x1)
x2 = self.conv2(x2)
x = torch.cat((x1, x2), 1) # 合并
x1 = self.conv3(x)
x2 = self.conv3(x)
x1 = self.conv4(x1)
x2 = self.conv4(x2)
x1 = self.conv5(x1)
x2 = self.conv5(x2)
x = torch.cat((x1, x2), 1)
x = x.view(-1, 9216)
output = self.dense(x)
return output class Train: def __init__(self, lr=0.001, momentum=0.9, weight_decay=0.0005):
self.net = AlexNet()
self.criterion = nn.CrossEntropyLoss()
self.opti = torch.optim.SGD(self.net.parameters(),
lr=lr, momentum=momentum,
weight_decay=weight_decay)
self.generate_path() def gpu(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let'us use %d GPUs" % torch.cuda.device_count())
self.net = nn.DataParallel(self.net)
self.net = self.net.to(self.device) def generate_path(self):
"""
生成保存数据的路径
:return:
"""
try:
os.makedirs('./paras')
os.makedirs('./logs')
os.makedirs('./images')
except FileExistsError as e:
pass
name = self.net.__class__.__name__
paras = os.listdir('./paras')
self.para_path = "./paras/{0}{1}.pt".format(
name,
len(paras)
)
logs = os.listdir('./logs')
self.log_path = "./logs/{0}{1}.txt".format(
name,
len(logs)
) def log(self, strings):
"""
运行日志
:param strings:
:return:
"""
# a 往后添加内容
with open(self.log_path, 'a', encoding='utf8') as f:
f.write(strings) def save(self):
"""
保存网络参数
:return:
"""
torch.save(self.net.state_dict(), self.para_path) def derease_lr(self, multi=10):
"""
降低学习率
:param multi:
:return:
"""
self.opti.param_groups()[0]['lr'] /= multi def train(self, trainloder, epochs=50):
data_size = len(trainloder) * trainloder.batch_size
for epoch in range(epochs):
running_loss = 0.
acc_count = 0.
if (epoch + 1) % 10 is 0:
self.derease_lr()
self.log(
"learning rate change!!!\n"
)
for i, data in enumerate(trainloder):
imgs, labels = data
imgs = imgs.to(self.device)
labels = labels.to(self.device)
out = self.net(imgs)
loss = self.criterion(out, labels)
_, pre = torch.max(out, 1) #判断是否判断正确
acc_count += (pre == labels).sum().item() #加总对的个数 self.opti.zero_grad()
loss.backward()
self.opti.step() running_loss += loss.data if (i+1) % 10 is 0:
strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
epoch, i, running_loss * 50
)
self.log(strings)
running_loss = 0.
self.log(
"Accuracy of the network on %d train images: %d %%\n" %(
data_size, acc_count / data_size * 100
)
)
self.save()
class Test:

    def __init__(self, classes, path=0):
self.net = AlexNet()
self.classes = classes
self.load(path) def load(self, path=0):
if isinstance(path, int):
name = self.net.__class__.__name__
path = "./paras/{0}{1}.pt".format(
name, path
)
#加载参数, map_location 因为是用GPU训练的, 保存的是是GPU的模型
#如果需要在cpu的情况下测试, 选择map_location="cpu".
self.net.load_state_dict(torch.load(path, map_location="cpu"))
self.net.eval() def showimgs(self, imgs, labels):
n = imgs.size(0)
pres = self.__call__(imgs)
n = max(n, 7)
fig, axs = plt.subplots(n)
for i, ax in enumerate(axs):
img = imgs[i].numpy().transpose((1, 2, 0))
img = img / 2 + 0.5
label = self.classes[labels[i]]
pre = self.classes[pres[i]]
ax.set_title("{0}|{1}".format(
label, pre
))
ax.plot(img)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.tight_layout()
plt.show() def acc_test(self, testloader):
data_size = len(testloader) * testloader.batch_size
acc_count = 0.
for (imgs, labels) in testloader:
pre = self.__call__(imgs)
acc_count += (pre == labels).sum().item()
return acc_count / data_size def __call__(self, imgs):
out = self.net(imgs)
_, pre = torch.max(out, 1)
return pre

AlexNet: ImageNet Classification with Deep Convolutional Neural Networks的更多相关文章

  1. AlexNet论文翻译-ImageNet Classification with Deep Convolutional Neural Networks

    ImageNet Classification with Deep Convolutional Neural Networks 深度卷积神经网络的ImageNet分类 Alex Krizhevsky ...

  2. 《ImageNet Classification with Deep Convolutional Neural Networks》 剖析

    <ImageNet Classification with Deep Convolutional Neural Networks> 剖析 CNN 领域的经典之作, 作者训练了一个面向数量为 ...

  3. ImageNet Classification with Deep Convolutional Neural Networks(译文)转载

    ImageNet Classification with Deep Convolutional Neural Networks Alex Krizhevsky, Ilya Sutskever, Geo ...

  4. 中文版 ImageNet Classification with Deep Convolutional Neural Networks

    ImageNet Classification with Deep Convolutional Neural Networks 摘要 我们训练了一个大型深度卷积神经网络来将ImageNet LSVRC ...

  5. Understanding the Effective Receptive Field in Deep Convolutional Neural Networks

    Understanding the Effective Receptive Field in Deep Convolutional Neural Networks 理解深度卷积神经网络中的有效感受野 ...

  6. Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019

    CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...

  7. Image Scaling using Deep Convolutional Neural Networks

    Image Scaling using Deep Convolutional Neural Networks This past summer I interned at Flipboard in P ...

  8. 深度卷积神经网络用于图像缩放Image Scaling using Deep Convolutional Neural Networks

    This past summer I interned at Flipboard in Palo Alto, California. I worked on machine learning base ...

  9. [论文阅读] ImageNet Classification with Deep Convolutional Neural Networks(传说中的AlexNet)

    这篇文章使用的AlexNet网络,在2012年的ImageNet(ILSVRC-2012)竞赛中获得第一名,top-5的测试误差为15.3%,相比于第二名26.2%的误差降低了不少. 本文的创新点: ...

随机推荐

  1. act.四级

    act的词源是do, 干着或干了的事情也可以叫act.action: doing sth; act: n. action, v. do; activity: busy, energetic, or占据 ...

  2. Hadoop fs.copyToLocalFile错误

    fs.copyToLocalFile(new Path("/study1/1.txt"), new Path("C:/Users/Administrator/Deskto ...

  3. flink-----实时项目---day06-------1. 获取窗口迟到的数据 2.双流join(inner join和left join(有点小问题)) 3 订单Join案例(订单数据接入到kafka,订单数据的join实现,订单数据和迟到数据join的实现)

    1. 获取窗口迟到的数据 主要流程就是给迟到的数据打上标签,然后使用相应窗口流的实例调用sideOutputLateData(lateDataTag),从而获得窗口迟到的数据,进而进行相关的计算,具体 ...

  4. 商业爬虫学习笔记day6

    一. 正则解析数据 解析百度新闻中每个新闻的title,url,检查每个新闻的源码可知道,其title和url都位于<a></a>标签中,因为里面参数的具体形式不一样,同一个正 ...

  5. 【swift】长按事件绑定,平移滑动事件+坐标获取

    为何把这两个事件归类在一起? 我后来才明白,iOS有一个手势事件(UiGestureRecognizer) 事件里有7个功能,不过我只试过前两个,也就是标题的这两个(长按.平移滑动) UILongPr ...

  6. webservice--cxf和spring结合

    服务端: 实体: package entity; import java.util.Date; /*** 实体 */ public class Pojo { //温度 private String d ...

  7. Java 8实现BASE64编解码

    Java一直缺少BASE64编码 API,以至于通常在项目开发中会选用第三方的API实现.但是,Java 8实现了BASE64编解码API,它包含到java.util包.下面我会对Java 8的BAS ...

  8. springboot热部署与监控

    一.热部署 添加依赖+Ctrl+F9 <dependency> <groupId>org.springframework.boot</groupId> <ar ...

  9. Bash shell(六)-管道命令

    就如同前面所说的, bash 命令执行的时候有输出的数据会出现! 那么如果这群数据必需要经过几道手续之后才能得到我们所想要的格式,应该如何来设定? 这就牵涉到管线命令的问题了 (pipe) ,管线命令 ...

  10. 【Spring Framework】Spring注解设置Bean的初始化、销毁方法的方式

    bean的生命周期:创建---初始化---销毁. Spring中声明的Bean的初始化和销毁方法有3种方式: @Bean的注解的initMethod.DestroyMethod属性 bean实现Ini ...