AlexNet: ImageNet Classification with Deep Convolutional Neural Networks

上图是论文的网络的结构图,包括5个卷积层和3个全连接层,作者还特别强调,depth的重要性,少一层结果就会变差,所以这种超参数的调节可真是不简单.
激活函数
首先讨论的是激活函数,作者选择的不是\(f(x)=\mathrm{tanh}(x)=(1+e^{-x})^{-1}\),而是ReLUs ( Rectified Linear Units)——\(f(x)=\max (0, x)\), 当然,作者考虑的问题是比赛的那个数据集,其网络的收敛速度为:

接下来,作者讨论了标准化的问题,说ReLUs是不需要进行这一步的,论文中的那句话我感觉理解的怪怪的:
ReLUs have the desirable property that they do not require input normalization to prevent them fromsaturating.
饱和?
作者说,也可以对ReLUs进行扩展,使得其更有泛化性,把多个核进行标准化处理:

\(i\)表示核的顺序,\(a_{x,y}^i\)则是其值, 说实话,这部分也没怎么弄懂.
然后是关于池化层的部分,一般的池化层的核是不用重叠的,作者这部分也考虑进去了.
防止过拟合
为了防止过拟合,作者提出了他的几点经验.
增加数据
这个数据不是简单的多找点数据,而是通过一些变换使得数据增加.
比如对图片进行旋转,以及PCA提主成分,改变score等.
Dropout
多个模型,进行综合评价是防止过拟合的好方法,但是训练网络不易,dropout, 即让隐层的神经元以一定的概率输出为0来,所以每一次训练,网络的结构实际上都是不一样的,但是整个网络是共享参数的,所以可以一次性训练多个模型?
细节
batch size: 128
momentum: 0.9
weight decay: 0.0005
一般的随机梯度下降好像是没有weight decay这一部分的,但是作者说,实验中这个的选择还是蛮有效的.
代码
"""
epochs: 50
lr: 0.001
batch_size = 128
在训练集上的正确率达到了97%,
在测试集上的正确率为83%.
"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
class AlexNet(nn.Module):
def __init__(self, output_size=10):
super(AlexNet, self).__init__()
self.conv1 = nn.Sequential( # 3 x 227 x 227
nn.Conv2d(3, 96, 11, 4, 0), # 3通道 输出96通道 卷积核为11 x 11 滑动为4 不补零
nn.BatchNorm2d(96),
nn.ReLU()
)
self.conv2 = nn.Sequential( # 96 x 55 x 55
nn.Conv2d(48, 128, 5, 1, 2),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.conv3 = nn.Sequential( # 256 x 27 x 27
nn.Conv2d(256, 192, 3, 1, 1),
nn.BatchNorm2d(192),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.conv4 = nn.Sequential( # 384 x 13 x 13
nn.Conv2d(192, 192, 3, 1, 1),
nn.BatchNorm2d(192),
nn.ReLU()
)
self.conv5 = nn.Sequential( # 384 x 13 x 13
nn.Conv2d(192, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(3, 2)
)
self.dense = nn.Sequential(
nn.Linear(9216, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, output_size)
)
def forward(self, input):
x = self.conv1(input)
x1, x2 = x[:, :48, :, :], x[:, 48:, :, :] # 拆分
x1 = self.conv2(x1)
x2 = self.conv2(x2)
x = torch.cat((x1, x2), 1) # 合并
x1 = self.conv3(x)
x2 = self.conv3(x)
x1 = self.conv4(x1)
x2 = self.conv4(x2)
x1 = self.conv5(x1)
x2 = self.conv5(x2)
x = torch.cat((x1, x2), 1)
x = x.view(-1, 9216)
output = self.dense(x)
return output
class Train:
def __init__(self, lr=0.001, momentum=0.9, weight_decay=0.0005):
self.net = AlexNet()
self.criterion = nn.CrossEntropyLoss()
self.opti = torch.optim.SGD(self.net.parameters(),
lr=lr, momentum=momentum,
weight_decay=weight_decay)
self.generate_path()
def gpu(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let'us use %d GPUs" % torch.cuda.device_count())
self.net = nn.DataParallel(self.net)
self.net = self.net.to(self.device)
def generate_path(self):
"""
生成保存数据的路径
:return:
"""
try:
os.makedirs('./paras')
os.makedirs('./logs')
os.makedirs('./images')
except FileExistsError as e:
pass
name = self.net.__class__.__name__
paras = os.listdir('./paras')
self.para_path = "./paras/{0}{1}.pt".format(
name,
len(paras)
)
logs = os.listdir('./logs')
self.log_path = "./logs/{0}{1}.txt".format(
name,
len(logs)
)
def log(self, strings):
"""
运行日志
:param strings:
:return:
"""
# a 往后添加内容
with open(self.log_path, 'a', encoding='utf8') as f:
f.write(strings)
def save(self):
"""
保存网络参数
:return:
"""
torch.save(self.net.state_dict(), self.para_path)
def derease_lr(self, multi=10):
"""
降低学习率
:param multi:
:return:
"""
self.opti.param_groups()[0]['lr'] /= multi
def train(self, trainloder, epochs=50):
data_size = len(trainloder) * trainloder.batch_size
for epoch in range(epochs):
running_loss = 0.
acc_count = 0.
if (epoch + 1) % 10 is 0:
self.derease_lr()
self.log(
"learning rate change!!!\n"
)
for i, data in enumerate(trainloder):
imgs, labels = data
imgs = imgs.to(self.device)
labels = labels.to(self.device)
out = self.net(imgs)
loss = self.criterion(out, labels)
_, pre = torch.max(out, 1) #判断是否判断正确
acc_count += (pre == labels).sum().item() #加总对的个数
self.opti.zero_grad()
loss.backward()
self.opti.step()
running_loss += loss.data
if (i+1) % 10 is 0:
strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}\n".format(
epoch, i, running_loss * 50
)
self.log(strings)
running_loss = 0.
self.log(
"Accuracy of the network on %d train images: %d %%\n" %(
data_size, acc_count / data_size * 100
)
)
self.save()
class Test:
def __init__(self, classes, path=0):
self.net = AlexNet()
self.classes = classes
self.load(path)
def load(self, path=0):
if isinstance(path, int):
name = self.net.__class__.__name__
path = "./paras/{0}{1}.pt".format(
name, path
)
#加载参数, map_location 因为是用GPU训练的, 保存的是是GPU的模型
#如果需要在cpu的情况下测试, 选择map_location="cpu".
self.net.load_state_dict(torch.load(path, map_location="cpu"))
self.net.eval()
def showimgs(self, imgs, labels):
n = imgs.size(0)
pres = self.__call__(imgs)
n = max(n, 7)
fig, axs = plt.subplots(n)
for i, ax in enumerate(axs):
img = imgs[i].numpy().transpose((1, 2, 0))
img = img / 2 + 0.5
label = self.classes[labels[i]]
pre = self.classes[pres[i]]
ax.set_title("{0}|{1}".format(
label, pre
))
ax.plot(img)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.tight_layout()
plt.show()
def acc_test(self, testloader):
data_size = len(testloader) * testloader.batch_size
acc_count = 0.
for (imgs, labels) in testloader:
pre = self.__call__(imgs)
acc_count += (pre == labels).sum().item()
return acc_count / data_size
def __call__(self, imgs):
out = self.net(imgs)
_, pre = torch.max(out, 1)
return pre
AlexNet: ImageNet Classification with Deep Convolutional Neural Networks的更多相关文章
- AlexNet论文翻译-ImageNet Classification with Deep Convolutional Neural Networks
ImageNet Classification with Deep Convolutional Neural Networks 深度卷积神经网络的ImageNet分类 Alex Krizhevsky ...
- 《ImageNet Classification with Deep Convolutional Neural Networks》 剖析
<ImageNet Classification with Deep Convolutional Neural Networks> 剖析 CNN 领域的经典之作, 作者训练了一个面向数量为 ...
- ImageNet Classification with Deep Convolutional Neural Networks(译文)转载
ImageNet Classification with Deep Convolutional Neural Networks Alex Krizhevsky, Ilya Sutskever, Geo ...
- 中文版 ImageNet Classification with Deep Convolutional Neural Networks
ImageNet Classification with Deep Convolutional Neural Networks 摘要 我们训练了一个大型深度卷积神经网络来将ImageNet LSVRC ...
- Understanding the Effective Receptive Field in Deep Convolutional Neural Networks
Understanding the Effective Receptive Field in Deep Convolutional Neural Networks 理解深度卷积神经网络中的有效感受野 ...
- Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019
CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...
- Image Scaling using Deep Convolutional Neural Networks
Image Scaling using Deep Convolutional Neural Networks This past summer I interned at Flipboard in P ...
- 深度卷积神经网络用于图像缩放Image Scaling using Deep Convolutional Neural Networks
This past summer I interned at Flipboard in Palo Alto, California. I worked on machine learning base ...
- [论文阅读] ImageNet Classification with Deep Convolutional Neural Networks(传说中的AlexNet)
这篇文章使用的AlexNet网络,在2012年的ImageNet(ILSVRC-2012)竞赛中获得第一名,top-5的测试误差为15.3%,相比于第二名26.2%的误差降低了不少. 本文的创新点: ...
随机推荐
- 自然语言式parsing
got NUM(1) Is NUM(1) an expr? Is NUM(1) a term? Is NUM(1) a number? is_term got -(-) -(-) was back i ...
- 15. Linux提取RPM包文件(cpio命令)详解
在讲解如何从 RPM 包中提取文件之前,先来系统学习一下 cpio 命令.cpio 命令用于从归档包中存入和读取文件,换句话说,cpio 命令可以从归档包中提取文件(或目录),也可以将文件(或目录)复 ...
- JavaScript小数、百分数的转换
百分数转化为小数 function toPoint(percent){ var str=percent.replace("%",""); str= str/10 ...
- 关于stm32不常用的中断,如何添加, 比如timer10 timer11等
首先可以从keil中找到 比如找到定时器11的溢出中断,如上图是26 然后,配置定时器11 溢出中断的时候,我就在:下面填上这个变量. 之后要写中断服务函数,也就是发生中断后要跳转到的函数. 需要知道 ...
- Shell学习(五)—— awk命令详解
一.awk简介 awk是一个非常好用的数据处理工具,相对于sed常常作用于一整个行的处理,awk则比较倾向于一行当中分成数个[字段]处理,因此,awk相当适合处理小型的数据数据处理.awk是一种报 ...
- Linux学习 - 文件系统属性chattr权限
change file attributes on 啊linux file system 1 功能 可以防止误操作 2 chattr命令格式 chattr [+-=] [选项] 文件或目录名 + 增加 ...
- Linux基础命令---mget获取ftp文件
mget 使用lftp登录mftp服务器之后,可以使用mget指令从服务器获取文件.mget指令可以使用通配符,而get指令则不可以. 1.语法 mget [-E] [-a] [- ...
- @Order注解使用
注解@Order或者接口Ordered的作用是定义Spring IOC容器中Bean的执行顺序的优先级,而不是定义Bean的加载顺序,Bean的加载顺序不受@Order或Ordered接口的影响: @ ...
- SQL查询:并集、差集、交集
新建两个表进行测试: test_a ID name 1 曹操 2 郭嘉 3 孙权 4 周瑜 test_b ID name 1 刘备 2 关羽 3 张飞 4 孙权 5 周瑜 1.UNION形成并集 UN ...
- 【React】组件书写记录
时钟组件: 组件形式:数字时钟 https://blog.csdn.net/hahahahahahahaha__1/article/details/80688920 Javascript获取时间方法: ...