使用MxNet新接口Gluon提供的预训练模型进行微调
1. 导入各种包
from mxnet import gluon
import mxnet as mx
from mxnet.gluon import nn
from mxnet import ndarray as nd
import matplotlib.pyplot as plt
import cv2
from mxnet import image
from mxnet import autograd
2. 导入数据
我使用cifar10这个数据集,使用gluon自带的模块下载到本地并且为了配合后面的网络,我将大小调整到224*224
def transform(data, label):
data = image.imresize(data, 224, 224)
return data.astype('float32'), label.astype('float32')
cifar10_train = gluon.data.vision.CIFAR10(root='./',train=True, transform=transform)
cifar10_test = gluon.data.vision.CIFAR10(root='./',train=False, transform=transform)
batch_size = 64
train_data = gluon.data.DataLoader(cifar10_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(cifar10_test, batch_size, shuffle=False)
3. 加载预训练模型
gluon提供的很多预训练模型,我选择一个简单的模型AlexNet
首先下载AlexNet模型和模型参数
使用下面的代码会获取AlexNet的模型并且加载预训练好的模型参数,但是鉴于网络的原因,我提前下好了
alexnet = mx.gluon.model_zoo.vision.alexnet(pretrained=True)#如果pretrained值为True,则会下载预训练参数,否则是空模型
获取模型并从本地加载参数
alexnet = mx.gluon.model_zoo.vision.alexnet()
alexnet.load_params('alexnet-44335d1f.params',ctx=mx.gpu())
看下AlexNet网络结构,发现分为两部分,features,classifier,而features正好是需要的
print(alexnet)
AlexNet(
(features): HybridSequential(
(0): Conv2D(64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(2): Conv2D(192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(4): Conv2D(384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): Conv2D(256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): Conv2D(256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(8): Flatten
)
(classifier): HybridSequential(
(0): Dense(4096, Activation(relu))
(1): Dropout(p = 0.5)
(2): Dense(4096, Activation(relu))
(3): Dropout(p = 0.5)
(4): Dense(1000, linear)
)
)
4. 组合新的网络
截取想要的features,并且固定参数。这样防止训练的时候把预训练好的参数给搞坏了
featuresnet = alexnet.features
for _, w in featuresnet.collect_params().items():
w.grad_req = 'null'
自己定义后面的网络,因为数据集是10类,就把最后的输出从1000改成了10。
def Classifier():
net = nn.HybridSequential()
net.add(nn.Dense(4096, activation="relu"))
net.add(nn.Dropout(.5))
net.add(nn.Dense(4096, activation="relu"))
net.add(nn.Dropout(.5))
net.add(nn.Dense(10))
return net
接着需要把两部分组合起来,并且对第二部分机进行初始化
net = nn.HybridSequential()
with net.name_scope():
net.add(featuresnet)
net.add(Classifier())
net[1].collect_params().initialize(init=mx.init.Xavier(),ctx=mx.gpu())
net.hybridize()
5. 训练
最后就是训练了,看看效果如何
#定义准确率函数
def accuracy(output, label):
return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(data_iterator, net, ctx=mx.gpu()):
acc = 0.
for data, label in data_iterator:
data = data.transpose([0,3,1,2])
data = data/255
output = net(data.as_in_context(ctx))
acc += accuracy(output, label.as_in_context(ctx))
return acc / len(data_iterator)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(
net.collect_params(), 'sgd', {'learning_rate': 0.01})
for epoch in range(1):
train_loss = 0.
train_acc = 0.
test_acc = 0.
for data, label in train_data:
label = label.as_in_context(mx.gpu())
data = data.transpose([0,3,1,2])
data = data/255
with autograd.record():
output = net(data.as_in_context(mx.gpu()))
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size)
train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, label)
test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss/len(train_data),
train_acc/len(train_data),test_acc))
Epoch 0. Loss: 1.249197, Train acc 0.558764, Test acc 0.696756
使用MxNet新接口Gluon提供的预训练模型进行微调的更多相关文章
- MXNet的新接口Gluon
为什么要开发Gluon的接口 在MXNet中我们可以通过Sybmol模块来定义神经网络,并组通过Module模块提供的一些上层API来简化整个训练过程.那MXNet为什么还要重新开发一套Python的 ...
- MxNet新前端Gluon模型转换到Symbol
1. 导入各种包 from mxnet import gluon from mxnet.gluon import nn import matplotlib.pyplot as plt from mxn ...
- Paddle预训练模型应用工具PaddleHub
Paddle预训练模型应用工具PaddleHub 本文主要介绍如何使用飞桨预训练模型管理工具PaddleHub,快速体验模型以及实现迁移学习.建议使用GPU环境运行相关程序,可以在启动环境时,如下图所 ...
- 预训练模型时代:告别finetune, 拥抱adapter
NLP论文解读 原创•作者 |FLIPPED 研究背景 随着计算算力的不断增加,以transformer为主要架构的预训练模型进入了百花齐放的时代.BERT.RoBERTa等模型的提出为NLP相关问题 ...
- 【转载】最强NLP预训练模型!谷歌BERT横扫11项NLP任务记录
本文介绍了一种新的语言表征模型 BERT--来自 Transformer 的双向编码器表征.与最近的语言表征模型不同,BERT 旨在基于所有层的左.右语境来预训练深度双向表征.BERT 是首个在大批句 ...
- PyTorch-网络的创建,预训练模型的加载
本文是PyTorch使用过程中的的一些总结,有以下内容: 构建网络模型的方法 网络层的遍历 各层参数的遍历 模型的保存与加载 从预训练模型为网络参数赋值 主要涉及到以下函数的使用 add_module ...
- 【翻译】OpenVINO Pre-Trained 预训练模型介绍
OpenVINO 系列软件包预训练模型介绍 本文翻译自 Intel OpenVINO 的 "Overview of OpenVINO Toolkit Pre-Trained Models& ...
- dropzonejs中文翻译手册 DropzoneJS是一个提供文件拖拽上传并且提供图片预览的开源类库.
http://wxb.github.io/dropzonejs.com.zh-CN/dropzonezh-CN/ 由于项目需要,完成一个web的图片拖拽上传,也就顺便学习和了解了一下前端的比较新的技术 ...
- 微信小程序语音识别服务搭建全过程解析(https api开放,支持新接口mp3录音、老接口silk录音)
silk v3(或新录音接口mp3)录音转olami语音识别和语义处理的api服务(ubuntu16.04服务器上实现) 重要的写在前面 重要事项一: 所有相关更新,我优先更新到我个人博客中,其它地方 ...
随机推荐
- 即时作图新工具—ProcessOn【推荐】
www.processon.com 推荐这个在线作图网站:免费登录,云端存储,面向对象,最重要的是提供了丰富模板! 在线软件的人气会越来越高,这是趋势啊~
- Beta阶段项目复审
复审人:王李焕 六指神功:http://www.cnblogs.com/teamworkers/ wt.dll:http://www.cnblogs.com/TeamOf/ 六个核桃:http://w ...
- 团队作业8----第二次项目冲刺(beta阶段)5.22
Day4-05.22 1.每日会议 会议内容: 1.帮助新成员进一步了解项目. 2.陈鑫龙说明昨日任务的完成情况. 3.组长林乔桦安排今日的任务. 讨论照片: 2.任务分配情况: 每个人的工作分配表: ...
- 201521123039《java程序设计》第五周学习总结
1. 本周学习总结 1.1 尝试使用思维导图总结有关多态与接口的知识点. 2. 书面作业 代码阅读:Child压缩包内源代码 1.1 com.parent包中Child.java文件能否编译通过?哪句 ...
- 201521123064 《Java程序设计》第4周学习总结
1. 本章学习总结 1.1 尝试使用思维导图总结有关继承的知识点. 1.2 使用常规方法总结其他上课内容. ① 以上周PTA实验"形状"为例,Circle类和Rectangle类中 ...
- 201521123071《Java程序设计》第1周学习总结
1. 本章学习总结 通过本周的学习,对java的一些语法以及java的发展史有了一定的基础认识,也了解了JDK的安装,以及环境变量定义和配置等知识.还有对码云,Markdown等的使用,大大方便了我们 ...
- 201521123047《Java程序设计》第1周学习总结
1. 本章学习总结 学习到了jdk,jvm,jre之间的关系,下载并安装了jdk,学会设置path变量,初步学会建立简单的java程序,并执行成功.初步学会notepad++,eclipse的操作.学 ...
- 《JAVA程序设计》第11周学习总结
1. 本章学习总结 1.1 以你喜欢的方式(思维导图或其他)归纳总结多线程相关内容. synchronized方法/代码块 wait().notify()用法,生产者消费者例子 lock.condit ...
- 201521123056 《Java程序设计》第10周学习总结
1. 本周学习总结 1.1 以你喜欢的方式(思维导图或其他)归纳总结异常与多线程相关内容. 2. 书面作业 本次PTA作业题集异常.多线程 1. finally 题目4-2 1.1 截图你的提交结果( ...
- Eclipse rap 富客户端开发总结(15) :rap如何使用js
1. 把输入的字符串当 javascript 执行 try { RWT.getResponse().getWriter().println("alert('123');"); } ...