使用VGG16网络进行迁移学习

使用在ImageNet数据上预训练的VGG16网络模型对猫狗数据集进行分类识别。

1.预训练网络

预训练网络是一个保存好的,已经在大型数据集上训练好的卷积神经网络。

如果这个数据集足够大且通用,那么预训练网络学习到的模型参数可以有效的对图片进行特征提取。即使新问题与原本的数据完全不同,但学习到的特征提取方法依然可以在不同的问题之间进行移植,进而可以在全新的数据集上提取到有效的特征。对这些有效的高级特征进行分类可以大大提高模型分类的准确率。

迁移学习主要适用于已有数据相对较少的情况,如果拥有的数据量足够大,即使不需要迁移学习也能够得到非常高的准确率。

2.如何使用与训练网络

2.1载入图像并创建数据集

首先,读入猫狗数据集中的图片。(实现过程的详细说明在Tensorflow学习笔记No.5中,这里不再赘述)

  1. 1 import tensorflow as tf
  2. 2 import numpy as np
  3. 3 import pandas as pd
  4. 4 import matplotlib.pyplot as plt
  5. 5 %matplotlib inline
  6. 6 import pathlib
  7. 7 import random
  8. 8
  9. 9 data_root = pathlib.Path('../input/cat-and-dog/training_set/training_set')
  10. 10
  11. 11 all_image_path = list(data_root.glob('*/*.jpg'))
  12. 12 random.shuffle(all_image_path)
  13. 13 image_count = len(all_image_path)
  14. 14
  15. 15 label_name = sorted([item.name for item in data_root.glob('*')])
  16. 16 name_to_indx = dict((name, indx) for indx, name in enumerate(label_name))
  17. 17
  18. 18 all_image_path = [str(path) for path in all_image_path]
  19. 19 all_image_label = [name_to_indx[pathlib.Path(p).parent.name] for p in all_image_path]
  20. 20
  21. 21 def load_pregrosess_image(path, label):
  22. 22 image = tf.io.read_file(path)
  23. 23 image = tf.image.decode_jpeg(image, channels = 3)
  24. 24 image = tf.image.resize(image, [256, 256])
  25. 25 image = tf.cast(image, tf.float32)
  26. 26 image = image / 255
  27. 27 return image, label
  28. 28
  29. 29 train_image_ds = tf.data.Dataset.from_tensor_slices((all_image_path, all_image_label))
  30. 30
  31. 31 AUTOTUNE = tf.data.experimental.AUTOTUNE
  32. 32 dataset = train_image_ds.map(load_pregrosess_image, num_parallel_calls = AUTOTUNE)
  33. 33
  34. 34 BATCHSIZE = 16
  35. 35 train_count = int(image_count * 0.8)
  36. 36 test_count = image_count - train_count
  37. 37
  38. 38 train_dataset = dataset.take(train_count)
  39. 39 test_dataset = dataset.skip(train_count)
  40. 40
  41. 41 train_dataset = train_dataset.shuffle(train_count).repeat().batch(BATCHSIZE)
  42. 42 test_dataset = test_dataset.repeat().batch(BATCHSIZE)

2.2加载与训练网络并构建网络模型

与训练的网络由两个部分构成,训练好的卷积基和训练好的分类器。我们需要使用训练好的卷积基来提取特征,并使用自定义的分类器对自己的数据集进行分类识别。

如下图所示:

训练过程中,我们仅仅对自定义的分类器进行训练,而不训练预训练好的卷积基部分。

预训练的卷积基可以非常好的提取图像的某些特征,在训练过程中,由于分类器是一个全新的没有训练过的分类器,在训练初期会产生很大的loss值,由于数据量较少,如果不对预训练的卷积基进行冻结(不更新参数)处理,产生的loss值经梯度传递会对预训练的卷积基造成非常大的影响,且由于可训练数据较少儿难以恢复,所以只对自定义的分类器进行训练,而不训练卷积基。

首先从tf.keras.applications中创建一个预训练VGG16的卷积基。

  1. 1 cov_base = tf.keras.applications.VGG16(weights = 'imagenet', include_top = False)

weight是我们要使用的模型权重,我们使用经imagenet训练过的模型的权重信息进行迁移学习。

include_top是指,是否使用预训练的分类器。在迁移学习过程中我们使用自定义的分类器,所以参数为False。

然后我们对创建好的卷积基进行冻结处理,冻结所有的可训练参数。

  1. 1 cov_base.trainable = False

使用keras.Sequential()创建网络模型。

  1. 1 model = tf.keras.Sequential()
  2. 2 model.add(cov_base)
  3. 3 model.add(tf.keras.layers.GlobalAveragePooling2D())
  4. 4 model.add(tf.keras.layers.Dense(512, activation = 'relu'))
  5. 5 model.add(tf.keras.layers.Dense(1, activation = 'sigmoid'))

在模型中加入卷积基和自定义的分类器。

模型结构如下图所示:

我们得到了一个可训练参数仅为263,169的预训练VGG16网络模型。

2.3使用自定义数据训练分类器

此时模型已经搭建完毕,我们使用之前处理好的数据对它进行训练。

  1. 1 model.compile(optimizer = 'adam',
  2. 2 loss = 'binary_crossentropy',
  3. 3 metrics = ['acc']
  4. 4 )
  5. 5
  6. 6 history = model.fit(train_dataset,
  7. 7 steps_per_epoch = train_count // BATCHSIZE,
  8. 8 epochs = 10,
  9. 9 validation_data = test_dataset,
  10. 10 validation_steps = test_count // BATCHSIZE
  11. 11 )
  12. 12
  13. 13 plt.plot(history.epoch, history.history.get('acc'), label = 'acc')
  14. 14 plt.plot(history.epoch, history.history.get('val_acc'), label = 'acc')

训练结果如下图所示:

模型在训练集和测试机上的正确率均达到了94%左右,而且仅仅经过了10个epoch就达到了这样的效果,足以看出迁移学习在小规模数据上的优势。

3.微调

虽然使用预训练网络可以轻易的达到94%左右的正确率,但是,如果我们还想继续提高这个正确率该怎样进行调整呢?

所谓微调,是冻结卷积基底部的卷积层,共同训练新添加的分类器和卷积基顶部的部分卷积层。

根据卷积神经网络提取特征的原理我们不难发现,越底层的卷积层提取到的图像特征越抽象越细小,而顶层的卷积层提取到的特征更大,更加的接近我们能直接观察到的数据特征,由于我们需要训练的数据和预训练时使用的数据不尽相同,所以越顶层的卷积层提取到的特征与我们所需要的特征差别越大。所以,我们只冻结底部的卷积层,将顶部的卷积层与训练好的分类器共同训练,会得到更好的拟合效果。

只有分类器以及训练好了,才能微调卷积基的顶部卷积层,否则由于训练初期的误差很大,会将卷积层之前学习到的参数破坏掉。

所以我们对卷积基进行解冻,并只对底部的卷积进行冻结。

  1. 1 cov_base.trainable = True
  2. 2 for layers in cov_base.layers[:-3]:
  3. 3 layers.trainable = False

然后将模型继续进行训练。

  1. 1 model.compile(optimizer = tf.keras.optimizers.Adam(lr = 0.0001),
  2. 2 loss = 'binary_crossentropy',
  3. 3 metrics = ['acc']
  4. 4 )
  5. 5
  6. 6 history = model.fit(train_dataset,
  7. 7 steps_per_epoch = train_count // BATCHSIZE,
  8. 8 epochs = 20,
  9. 9 initial_epoch = 10,
  10. 10 validation_data = test_dataset,
  11. 11 validation_steps = test_count // BATCHSIZE
  12. 12 )
  13. 13
  14. 14 plt.plot(history.epoch, history.history.get('acc'), label = 'acc')
  15. 15 plt.plot(history.epoch, history.history.get('val_acc'), label = 'acc')

注意将学习率调小,以便尽可能的达到loss的极小值点。

得到的结果如下图所示:

模型再训练集上达到了近乎100%的准确率,在测试集上也达到了96%左右准确率,微调的效果还是较为明显的。

那么关于迁移学习的介绍到这里就结束了o(* ̄▽ ̄*)o,后续会更新更多内容。

Tensorflow学习笔记No.8的更多相关文章

  1. Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

    简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...

  2. Tensorflow学习笔记2019.01.22

    tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...

  3. Tensorflow学习笔记2019.01.03

    tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...

  4. TensorFlow学习笔记之--[compute_gradients和apply_gradients原理浅析]

    I optimizer.minimize(loss, var_list) 我们都知道,TensorFlow为我们提供了丰富的优化函数,例如GradientDescentOptimizer.这个方法会自 ...

  5. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  6. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  7. tensorflow学习笔记(4)-学习率

    tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...

  8. tensorflow学习笔记(3)前置数学知识

    tensorflow学习笔记(3)前置数学知识 首先是神经元的模型 接下来是激励函数 神经网络的复杂度计算 层数:隐藏层+输出层 总参数=总的w+b 下图为2层 如下图 w为3*4+4个   b为4* ...

  9. tensorflow学习笔记(2)-反向传播

    tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...

  10. tensorflow学习笔记(1)-基本语法和前向传播

    tensorflow学习笔记(1) (1)tf中的图 图中就是一个计算图,一个计算过程.                                       图中的constant是个常量 计 ...

随机推荐

  1. RXJAVA之概述

    RXjava是一个异步和基于事件的程序库.RXjava的核心理念是编程风格的的变化,从传统的命令式程序改变到函数响应式编程. RXjava的基本概念: Observable:发射源,即对象产生的地方. ...

  2. golang 协程学习

    协程数据传递问题 func TestGoroutineData(t *testing.T) { var wg sync.WaitGroup wg.Add(1) i := 0 go func(j int ...

  3. HTML+CSS实现大盒子在小盒子的展示范围内进行滚动展示

    HTML+CSS实现大盒子在小盒子的展示范围内进行滚动展示 1.效果展示: 2.主要代码:样式: overflow:auto; 3.如果想要消除对应的滚动条: .out::-webkit-scroll ...

  4. OSI和TCP/IP参考模型

    分层思想: 分层模型是一种开发网络协议的设计方法. 把节点之间的通讯这个复杂的问题,分成了若干个简单的小问题逐一解决. 把网络相邻节点之间通过接口进行通信,下层为上层提供服务.当网络发生故障,很容易确 ...

  5. myisamchk是用来做什么的?MyISAM Static和MyISAM Dynamic有什么区别?

    myisamchk是用来做什么的? 它用来压缩MyISAM[歌1] 表,这减少了磁盘或内存使用. MyISAM Static和MyISAM Dynamic有什么区别? 在MyISAM Static上的 ...

  6. SCI-HUB打不开了?附SCIHUB最新下载方式

    写在前面: 今天给大家推荐一个文献下载工具包:飞鸟科研助手 www.flybird.cc输入flybird.cc同样可以访问,存书签不失联!强调下:flybird.cc 读研之前,在一家NGS生殖应用 ...

  7. Django_项目开始

    如何初始Django运行环境? 1. 安装python 2. 创建Django项目专用的虚拟环境 http://www.cnblogs.com/2bjiujiu/p/7365876.html 3.进入 ...

  8. Python numpy总结(3)——常用函数用法

    1,np.ceil(x, y) 限制元素范围,进一法,即向上取整. x 表示输入的数据  y float类型 表示每个元素的上限. a = np.array([-1.7, -1.5, -0.2, 0. ...

  9. http协议和chrome浏览器

    http协议和Chrome抓包工具 什么是http和https协议: HTTP协议:全称是HyperText Transfer Protocol,中文意思是超文本传输协议,是一种发布和接收HTML页面 ...

  10. 关于继承、封装、多态、抽象和接口(Java)

    1.继承:    通过扩展一个已有的类,并继承该类的属性和行为,来创建一个新的类.已有的称为父类,新的类称为子类(父类派生子类,子类继承父类). (1)继承的优点:①代码的可重用性: ②父类的属性的方 ...