AlexNet

上图是论文的网络的结构图,包括5个卷积层和3个全连接层,作者还特别强调,depth的重要性,少一层结果就会变差,所以这种超参数的调节可真是不简单.

激活函数

首先讨论的是激活函数,作者选择的不是\(f(x)=\mathrm{tanh}(x)=(1+e^{-x})^{-1}\),而是ReLUs ( Rectified Linear Units)——\(f(x)=\max (0, x)\), 当然,作者考虑的问题是比赛的那个数据集,其网络的收敛速度为:



接下来,作者讨论了标准化的问题,说ReLUs是不需要进行这一步的,论文中的那句话我感觉理解的怪怪的:

ReLUs have the desirable property that they do not require input normalization to prevent them fromsaturating.

饱和?

作者说,也可以对ReLUs进行扩展,使得其更有泛化性,把多个核进行标准化处理:



\(i\)表示核的顺序,\(a_{x,y}^i\)则是其值, 说实话,这部分也没怎么弄懂.

然后是关于池化层的部分,一般的池化层的核是不用重叠的,作者这部分也考虑进去了.

防止过拟合

为了防止过拟合,作者提出了他的几点经验.

增加数据

这个数据不是简单的多找点数据,而是通过一些变换使得数据增加.

比如对图片进行旋转,以及PCA提主成分,改变score等.

Dropout

多个模型,进行综合评价是防止过拟合的好方法,但是训练网络不易,dropout, 即让隐层的神经元以一定的概率输出为0来,所以每一次训练,网络的结构实际上都是不一样的,但是整个网络是共享参数的,所以可以一次性训练多个模型?

细节

batch size: 128

momentum: 0.9

weight decay: 0.0005

一般的随机梯度下降好像是没有weight decay这一部分的,但是作者说,实验中这个的选择还是蛮有效的.

代码

"""
epochs: 50
lr: 0.001
batch_size = 128
在训练集上的正确率达到了97%,
在测试集上的正确率为83%.
"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os class AlexNet(nn.Module): def __init__(self, output_size=10):
super(AlexNet, self).__init__()
self.conv1 = nn.Sequential( # 3 x 227 x 227
nn.Conv2d(3, 96, 11, 4, 0), # 3通道 输出96通道 卷积核为11 x 11 滑动为4 不补零
nn.BatchNorm2d(96),
nn.ReLU()
)
self.conv2 = nn.Sequential( # 96 x 55 x 55
nn.Conv2d(48, 128, 5, 1, 2),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.conv3 = nn.Sequential( # 256 x 27 x 27
nn.Conv2d(256, 192, 3, 1, 1),
nn.BatchNorm2d(192),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.conv4 = nn.Sequential( # 384 x 13 x 13
nn.Conv2d(192, 192, 3, 1, 1),
nn.BatchNorm2d(192),
nn.ReLU()
)
self.conv5 = nn.Sequential( # 384 x 13 x 13
nn.Conv2d(192, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.dense = nn.Sequential(
nn.Linear(9216, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, output_size)
) def forward(self, input):
x = self.conv1(input)
x1, x2 = x[:, :48, :, :], x[:, 48:, :, :] # 拆分
x1 = self.conv2(x1)
x2 = self.conv2(x2)
x = torch.cat((x1, x2), 1) # 合并
x1 = self.conv3(x)
x2 = self.conv3(x)
x1 = self.conv4(x1)
x2 = self.conv4(x2)
x1 = self.conv5(x1)
x2 = self.conv5(x2)
x = torch.cat((x1, x2), 1)
x = x.view(-1, 9216)
output = self.dense(x)
return output class Train: def __init__(self, lr=0.001, momentum=0.9, weight_decay=0.0005):
self.net = AlexNet()
self.criterion = nn.CrossEntropyLoss()
self.opti = torch.optim.SGD(self.net.parameters(),
lr=lr, momentum=momentum,
weight_decay=weight_decay)
self.generate_path() def gpu(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let'us use %d GPUs" % torch.cuda.device_count())
self.net = nn.DataParallel(self.net)
self.net = self.net.to(self.device) def generate_path(self):
"""
生成保存数据的路径
:return:
"""
try:
os.makedirs('./paras')
os.makedirs('./logs')
os.makedirs('./images')
except FileExistsError as e:
pass
name = self.net.__class__.__name__
paras = os.listdir('./paras')
self.para_path = "./paras/{0}{1}.pt".format(
name,
len(paras)
)
logs = os.listdir('./logs')
self.log_path = "./logs/{0}{1}.txt".format(
name,
len(logs)
) def log(self, strings):
"""
运行日志
:param strings:
:return:
"""
# a 往后添加内容
with open(self.log_path, 'a', encoding='utf8') as f:
f.write(strings) def save(self):
"""
保存网络参数
:return:
"""
torch.save(self.net.state_dict(), self.para_path) def derease_lr(self, multi=10):
"""
降低学习率
:param multi:
:return:
"""
self.opti.param_groups()[0]['lr'] /= multi def train(self, trainloder, epochs=50):
data_size = len(trainloder) * trainloder.batch_size
for epoch in range(epochs):
running_loss = 0.
acc_count = 0.
if (epoch + 1) % 10 is 0:
self.derease_lr()
self.log(
"learning rate change!!!\n"
)
for i, data in enumerate(trainloder):
imgs, labels = data
imgs = imgs.to(self.device)
labels = labels.to(self.device)
out = self.net(imgs)
loss = self.criterion(out, labels)
_, pre = torch.max(out, 1) #判断是否判断正确
acc_count += (pre == labels).sum().item() #加总对的个数 self.opti.zero_grad()
loss.backward()
self.opti.step() running_loss += loss.data if (i+1) % 10 is 0:
strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
epoch, i, running_loss * 50
)
self.log(strings)
running_loss = 0.
self.log(
"Accuracy of the network on %d train images: %d %%\n" %(
data_size, acc_count / data_size * 100
)
)
self.save()
class Test:

    def __init__(self, classes, path=0):
self.net = AlexNet()
self.classes = classes
self.load(path) def load(self, path=0):
if isinstance(path, int):
name = self.net.__class__.__name__
path = "./paras/{0}{1}.pt".format(
name, path
)
#加载参数, map_location 因为是用GPU训练的, 保存的是是GPU的模型
#如果需要在cpu的情况下测试, 选择map_location="cpu".
self.net.load_state_dict(torch.load(path, map_location="cpu"))
self.net.eval() def showimgs(self, imgs, labels):
n = imgs.size(0)
pres = self.__call__(imgs)
n = max(n, 7)
fig, axs = plt.subplots(n)
for i, ax in enumerate(axs):
img = imgs[i].numpy().transpose((1, 2, 0))
img = img / 2 + 0.5
label = self.classes[labels[i]]
pre = self.classes[pres[i]]
ax.set_title("{0}|{1}".format(
label, pre
))
ax.plot(img)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.tight_layout()
plt.show() def acc_test(self, testloader):
data_size = len(testloader) * testloader.batch_size
acc_count = 0.
for (imgs, labels) in testloader:
pre = self.__call__(imgs)
acc_count += (pre == labels).sum().item()
return acc_count / data_size def __call__(self, imgs):
out = self.net(imgs)
_, pre = torch.max(out, 1)
return pre

AlexNet: ImageNet Classification with Deep Convolutional Neural Networks的更多相关文章

  1. AlexNet论文翻译-ImageNet Classification with Deep Convolutional Neural Networks

    ImageNet Classification with Deep Convolutional Neural Networks 深度卷积神经网络的ImageNet分类 Alex Krizhevsky ...

  2. 《ImageNet Classification with Deep Convolutional Neural Networks》 剖析

    <ImageNet Classification with Deep Convolutional Neural Networks> 剖析 CNN 领域的经典之作, 作者训练了一个面向数量为 ...

  3. ImageNet Classification with Deep Convolutional Neural Networks(译文)转载

    ImageNet Classification with Deep Convolutional Neural Networks Alex Krizhevsky, Ilya Sutskever, Geo ...

  4. 中文版 ImageNet Classification with Deep Convolutional Neural Networks

    ImageNet Classification with Deep Convolutional Neural Networks 摘要 我们训练了一个大型深度卷积神经网络来将ImageNet LSVRC ...

  5. Understanding the Effective Receptive Field in Deep Convolutional Neural Networks

    Understanding the Effective Receptive Field in Deep Convolutional Neural Networks 理解深度卷积神经网络中的有效感受野 ...

  6. Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019

    CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...

  7. Image Scaling using Deep Convolutional Neural Networks

    Image Scaling using Deep Convolutional Neural Networks This past summer I interned at Flipboard in P ...

  8. 深度卷积神经网络用于图像缩放Image Scaling using Deep Convolutional Neural Networks

    This past summer I interned at Flipboard in Palo Alto, California. I worked on machine learning base ...

  9. [论文阅读] ImageNet Classification with Deep Convolutional Neural Networks(传说中的AlexNet)

    这篇文章使用的AlexNet网络,在2012年的ImageNet(ILSVRC-2012)竞赛中获得第一名,top-5的测试误差为15.3%,相比于第二名26.2%的误差降低了不少. 本文的创新点: ...

随机推荐

  1. 自然语言式parsing

    got NUM(1) Is NUM(1) an expr? Is NUM(1) a term? Is NUM(1) a number? is_term got -(-) -(-) was back i ...

  2. 15. Linux提取RPM包文件(cpio命令)详解

    在讲解如何从 RPM 包中提取文件之前,先来系统学习一下 cpio 命令.cpio 命令用于从归档包中存入和读取文件,换句话说,cpio 命令可以从归档包中提取文件(或目录),也可以将文件(或目录)复 ...

  3. JavaScript小数、百分数的转换

    百分数转化为小数 function toPoint(percent){ var str=percent.replace("%",""); str= str/10 ...

  4. 关于stm32不常用的中断,如何添加, 比如timer10 timer11等

    首先可以从keil中找到 比如找到定时器11的溢出中断,如上图是26 然后,配置定时器11 溢出中断的时候,我就在:下面填上这个变量. 之后要写中断服务函数,也就是发生中断后要跳转到的函数. 需要知道 ...

  5. Shell学习(五)—— awk命令详解

    一.awk简介   awk是一个非常好用的数据处理工具,相对于sed常常作用于一整个行的处理,awk则比较倾向于一行当中分成数个[字段]处理,因此,awk相当适合处理小型的数据数据处理.awk是一种报 ...

  6. Linux学习 - 文件系统属性chattr权限

    change file attributes on 啊linux file system 1 功能 可以防止误操作 2 chattr命令格式 chattr [+-=] [选项] 文件或目录名 + 增加 ...

  7. Linux基础命令---mget获取ftp文件

    mget 使用lftp登录mftp服务器之后,可以使用mget指令从服务器获取文件.mget指令可以使用通配符,而get指令则不可以.   1.语法       mget [-E]  [-a]  [- ...

  8. @Order注解使用

    注解@Order或者接口Ordered的作用是定义Spring IOC容器中Bean的执行顺序的优先级,而不是定义Bean的加载顺序,Bean的加载顺序不受@Order或Ordered接口的影响: @ ...

  9. SQL查询:并集、差集、交集

    新建两个表进行测试: test_a ID name 1 曹操 2 郭嘉 3 孙权 4 周瑜 test_b ID name 1 刘备 2 关羽 3 张飞 4 孙权 5 周瑜 1.UNION形成并集 UN ...

  10. 【React】组件书写记录

    时钟组件: 组件形式:数字时钟 https://blog.csdn.net/hahahahahahahaha__1/article/details/80688920 Javascript获取时间方法: ...