上一篇文章我们聊的是使用预训练网络中的一种方法,特征提取,今天我们讨论另外一种方法,微调模型,这也是迁移学习的一种方法。

微调模型

为什么需要微调模型?我们猜测和之前的实验,我们有这样的共识,数据量越少,网络的特征节点越多,会越容易导致过拟合,这当然不是我们所希望的,但对于那些预先训练好的模型,还有可能最终无法很好的完成所要做的工作,因此我们还需要对其更改,基于此原因,我们需要做的就是拿来一个训练好的模型,更改其中更加抽象的层,即网络后面的层,然后再采用新的分类器,这样可以比较好的解决上面所提出的过拟合问题了。

进行微调网络的步骤是:

  1. 在已经训练好的网络(基网络)基础上,添加自定义的层;

  2. 冻结基网络并训练新添加的层;

  3. 冻结基网络的一部分层,另一部分可训练;

  4. 联合训练解冻的这些层和添加的部分。

我们上一篇提到的方法就可以完成前两个步骤,接下来我们看如何解决后两个步骤。这里我们还要更明确一下调整的层数如果过多会带来什么问题:随着可变层数的增多,过拟合的风险会随之加大。还要明确调整网络中识别像素和线条的层不如调整识别耳朵的层更有效,因为不论是识别猫还是桌子识别线条的方法层更通用。

完成这项任务所需要写的代码也是很简单的,就是设置模型是可训练的,然后遍历网络的每一层,针对每一层分别设置是否是可训练的,直到 layer_name 层,前面的层都是不可训练的:

conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
if layer.name == 'layer_name':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False

这里是关键部分代码,老规矩,最后将给出全部代码,我们先来看看结果:

需要注意一下这里的数据,在开始的时候不稳定,迅速爬升,因此纵坐标的数据没有那么好,但我们仔细看一下后期的数据,训练精度和验证精度都在百分之九十到百分之百,验证精度一直有一些波动,是网络的一些噪声引起的,我不想去强制让它们那么漂亮了,一是因为训练时间会比较长,而是因为我觉得没有特别大的必要,波动的最高点和最低点都在可接受的范围内,应该把关注点放在更重要的问题上去。

基于本篇文章和上一篇文章,我们做个小结:

  1. 计算机视觉领域中,卷积神经网络的表现非常不错,并且在数据集较小的情况下,表现让人是非常优秀的。

  2. 数据增强是很好的避免过拟合的方法,过拟合产生的主要原因可能是数据量太少或者是参数过多。

  3. 特征提取可以比较好的将现有的神经网络应用于小型数据集,还可以使用微调的方式进行优化。

我们看看代码吧,这里还有一个建议,如果可能尽量使用 GPU 去做网络模型的训练,CPU 在现阶段处理这些问题会有点力不从心,耗时较长,读者也可以考虑减少一些数据量加快速度,但要避免过拟合,请读者心中记住此类问题,在遇到问题的时候是一个方向(当然,笔者是非常惨的,没有好用的 GPU,因此等待数据画图截图是非常痛苦的一件事):

#!/usr/bin/env python3

import os
import time

import matplotlib.pyplot as plt
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator


def cat():
base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

train_datagen = ImageDataGenerator(
rescale=1. / 255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
validation_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary')

# 定义密集连接分类器
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(150, 150, 3))
conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
if layer.name == 'block5_conv1':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))

conv_base.summary()

# 对模型进行配置
model.compile(loss='binary_crossentropy',
optimizer=optimizers.RMSprop(lr=1e-5),
metrics=['acc'])

# 对模型进行训练
history = model.fit_generator(
train_generator,
steps_per_epoch=100,
epochs=100,
validation_data=validation_generator,
validation_steps=50)

# 画图
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.show()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()


if __name__ == "__main__":
time_start = time.time()
cat()
time_end = time.time()
print('Time Used: ', time_end - time_start)

本文首发自公众号:RAIS

AI:拿来主义——预训练网络(二)的更多相关文章

  1. AI:拿来主义——预训练网络(一)

    我们已经训练过几个神经网络了,识别手写数字,房价预测或者是区分猫和狗,那随之而来就有一个问题,这些训练出的网络怎么用,每个问题我都需要重新去训练网络吗?因为程序员都不太喜欢做重复的事情,因此答案肯定是 ...

  2. CNN基础二:使用预训练网络提取图像特征

    上一节中,我们采用了一个自定义的网络结构,从头开始训练猫狗大战分类器,最终在使用图像增强的方式下得到了82%的验证准确率.但是,想要将深度学习应用于小型图像数据集,通常不会贸然采用复杂网络并且从头开始 ...

  3. 学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)

    视觉问题引入深度神经网络后,针对端对端的训练和预测网络,可以看是特征的表达和任务的决策问题(分类,回归等).当我们自己的训练数据量过小时,往往借助牛人已经预训练好的网络进行特征的提取,然后在后面加上自 ...

  4. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

  5. 预训练语言模型的前世今生 - 从Word Embedding到BERT

    预训练语言模型的前世今生 - 从Word Embedding到BERT 本篇文章共 24619 个词,一个字一个字手码的不容易,转载请标明出处:预训练语言模型的前世今生 - 从Word Embeddi ...

  6. pytorch预训练

    Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...

  7. zz从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史

    从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 Bert最近很火,应该是最近最火爆的AI进展,网上的评价很高,那么Bert值得这么高的评价吗?我个人判断是值得.那为什么 ...

  8. 【猫狗数据集】使用预训练的resnet18模型

    数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...

  9. 预训练中Word2vec,ELMO,GPT与BERT对比

    预训练 先在某个任务(训练集A或者B)进行预先训练,即先在这个任务(训练集A或者B)学习网络参数,然后存起来以备后用.当我们在面临第三个任务时,网络可以采取相同的结构,在较浅的几层,网络参数可以直接加 ...

随机推荐

  1. 【Java杂货铺】JVM#Class类结构

    代码编译的结果从本地机器码转为字节码,是储存格式发展的一小步,却是编程语言的一大步.--<深入理解Java虚拟机> 计算机只认识0和1.所以我们写的编程语言只有转义成二进制本地机器码才能让 ...

  2. iOS自定义弹出视图、收音机APP、图片涂鸦、加载刷新、文件缓存等源码

    iOS精选源码 一款优秀的 聆听夜空FM 源码 zhPopupController 简单快捷弹出自定义视图 WHStoryMaker搭建美图(贴纸,涂鸦,文字,滤镜) iOS cell高度自适应 有加 ...

  3. 量化预测质量之分类报告 sklearn.metrics.classification_report

    classification_report的调用为:classification_report(y_true, y_pred, labels=None, target_names=None, samp ...

  4. 83)PHP,配置文件功能

    首选配置文件应该在  我们的应用application目录中,这样针对每一应用,都有自己的配置文件. 我觉得配置文件的名字很有意思,首先是  名字.config.php 格式就是 return arr ...

  5. C++类和对象到底是什么意思?

    C++ 是一门面向对象的编程语言,理解 C++,首先要理解类(Class)和对象(Object)这两个概念. C++ 中的类(Class)可以看做C语言中结构体(Struct)的升级版.结构体是一种构 ...

  6. 微信小程序开发-易源API的调用

    起因:在开发一款旅游类微信小程序时,需要接入大量的景点信息,此时可以选择自己新建数据库导入数据并读取,但是对于我来说,因为只有一个人,数据库还涉及到需要维护方面,选择调用已有API. 过程:首先查阅微 ...

  7. python使用geopandas和shapely处理shp文件

    一.环境搭建 所需库:geopandas (以及前置库)  doc:http://geopandas.org/ shapely(以及前置库)  doc: 二.数据预处理 1.将shp文件进行切片 2. ...

  8. jquery框架概览(一)

    参照jQuery 2.0.3版本(http://files.cnblogs.com/files/snoy/jquery-2.0.3.js")来进行的源码分析 从代码的最外层可以看到是一个II ...

  9. sshd启动故障“Failed to start OpenSSH Server daemon ”解决方法

  10. 性能分析之工具使用——cpu、io 、mem【工具分析】

    nmon nmon 是一种在aix 与各种 Linux 操作系统上广泛使 用的监控与与分析工具,他主要记录以下内容: • cpu 占用率 • 内存使用情况 • 磁盘I/O 速度.传输和读写比率 • 文 ...