Fine-Tuning微调原理

如何在只有60000张图片的Fashion-MNIST训练数据集中训练模型。ImageNet,这是学术界使用最广泛的大型图像数据集,它拥有1000多万幅图像和1000多个类别的对象。然而,我们经常处理的数据集的大小通常比第一个大,但比第二个小。

假设我们想在图像中识别不同种类的椅子,然后将购买链接推给用户。一种可行的方法是先找到一百张常见的椅子,每把椅子取一千张不同角度的图像,然后在采集到的图像数据集上训练分类模型。虽然这个数据集可能比时尚MNIST大,但是示例的数量仍然不到ImageNet的十分之一。这可能导致适用于ImageNet的复杂模型在此数据集上过度拟合。同时,由于数据量有限,最终训练出的模型精度可能达不到实际要求。

为了解决上述问题,一个显而易见的解决办法就是收集更多的数据。然而,收集和标记数据会消耗大量的时间和金钱。例如,为了收集ImageNet的数据集,研究人员花费了数百万美元的研究经费。尽管近年来,数据采集成本大幅下降,但成本仍然不容忽视。

另一种解决方案是应用转移学习将从源数据集学习的知识迁移到目标数据集。例如,虽然ImageNet中的图像大多与椅子无关,但是在这个数据集上训练的模型可以提取更一般的图像特征,这些特征可以帮助识别边缘、纹理、形状和对象组成。这些相似的特征对于识别椅子同样有效。

在本节中,我们将介绍迁移学习中的一种常用技术:微调。如图13.2.1所示,微调包括以下四个步骤:

在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型。

建立一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数,输出层除外。我们假设这些模型参数包含从源数据集学习到的知识,这些知识将同样适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关,因此不在目标模型中使用。

将输出大小为目标数据集类别数的输出层添加到目标模型中,并随机初始化该层的模型参数。

在目标数据集上训练目标模型,例如椅子数据集。我们将从头开始训练输出层,同时根据源模型的参数对所有剩余层的参数进行微调。

Fig. 1.  Fine tuning.

1. Hot Dog Recognition

我们将使用一个具体的例子来练习:热狗识别。我们将基于一个小的数据集,对在ImageNet数据集上训练的ResNet模型进行微调。这个小数据集包含数千张图像,其中一些包含热狗。我们将使用通过微调获得的模型来识别图像是否包含热狗。

首先,导入实验所需的软件包和模块。Gluon的model_zoo package提供了一个通用的预训练模型。如果你想获得更多的计算机视觉的预先训练模型,你可以使用GluonCV工具箱。

%matplotlib inline

from d2l import mxnet as d2l

from mxnet import gluon, init, np, npx

from mxnet.gluon import nn

import os

npx.set_np()

1.1. Obtaining the Dataset

我们使用的热狗数据集来自在线图像,包含1400个热狗的正面图片和其他食物的相同数量的负面图片。1000个各种课程的图像用于训练,其余的用于测试。

我们首先下载压缩数据集,得到两个文件夹hotdog/train和hotdog/test。这两个文件夹都有hotdog和not hotdog类别子文件夹,每个子文件夹都有相应的图像文件。

#@save

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL+'hotdog.zip',

'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')

Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

我们创建两个ImageFolderDataset实例,分别读取训练数据集和测试数据集中的所有图像文件。

train_imgs = gluon.data.vision.ImageFolderDataset(

os.path.join(data_dir, 'train'))

test_imgs = gluon.data.vision.ImageFolderDataset(

os.path.join(data_dir, 'test'))

前8个正面示例和最后8个负面图像如下所示。如您所见,图像的大小和纵横比各不相同。

hotdogs = [train_imgs[i][0] for i in range(8)]

not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]

d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在训练过程中,我们首先从图像中裁剪出一个大小和纵横比随机的随机区域,然后将该区域缩放到一个高度和宽度为224像素的输入。在测试过程中,我们将图像的高度和宽度缩放到256像素,然后裁剪高宽为224像素的中心区域作为输入。此外,我们规范化三个RGB(红色、绿色和蓝色)颜色通道的值。从每个值中减去信道所有值的平均值,然后将结果除以信道所有值的标准差,以产生输出。

# We specify the mean and variance of the three RGB channels to normalize the

# image channel

normalize = gluon.data.vision.transforms.Normalize(

[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = gluon.data.vision.transforms.Compose([

gluon.data.vision.transforms.RandomResizedCrop(224),

gluon.data.vision.transforms.RandomFlipLeftRight(),

gluon.data.vision.transforms.ToTensor(),

normalize])

test_augs = gluon.data.vision.transforms.Compose([

gluon.data.vision.transforms.Resize(256),

gluon.data.vision.transforms.CenterCrop(224),

gluon.data.vision.transforms.ToTensor(),

normalize])

1.2. Defining and Initializing the Model

我们使用ResNet-18作为源模型,ResNet-18是在ImageNet数据集上预先训练的。这里,我们指定pretrained=True以自动下载和加载预先训练的模型参数。第一次使用时,需要从互联网上下载模型参数。

pretrained_net = gluon.model_zoo.vision.resnet18_v2(pretrained=True)

预先训练的源模型实例包含两个成员变量:features和output。前者包含模型的所有层,输出层除外,后者是模型的输出层。这一划分的主要目的是促进除输出层之外的所有层的模型参数的微调。源模型的成员变量输出如下所示。作为一个完全连接的层,它将ResNet最终的全局平均池层输出转换为ImageNet数据集上的1000个类输出。

pretrained_net.output

Dense(512 -> 1000, linear)

然后构建一个新的神经网络作为目标模型。它的定义方式与预先训练的源模型相同,但最终输出数量等于目标数据集中的类别数。在下面的代码中,目标模型实例finetune_net的成员变量特征中的模型参数初始化为源模型对应层的模型参数。由于特征中的模型参数是通过对ImageNet数据集的预训练得到的,所以它是足够好的。因此,我们通常只需要使用较小的学习速率来“微调”这些参数。相比之下,成员变量输出中的模型参数是随机初始化的,通常需要更大的学习速率才能从头开始学习。假设训练实例中的学习率为 η,学习率为10η,更新成员变量输出中的模型参数。

finetune_net = gluon.model_zoo.vision.resnet18_v2(classes=2)

finetune_net.features = pretrained_net.features

finetune_net.output.initialize(init.Xavier())

# The model parameters in output will be updated using a learning rate ten

# times greater

finetune_net.output.collect_params().setattr('lr_mult', 10)

1.3. Fine Tuning the Model

我们首先定义了一个训练函数train_fine_tuning,它使用了微调,因此可以多次调用它。

def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5):

train_iter = gluon.data.DataLoader(

train_imgs.transform_first(train_augs), batch_size, shuffle=True)

test_iter = gluon.data.DataLoader(

test_imgs.transform_first(test_augs), batch_size)

ctx = d2l.try_all_gpus()

net.collect_params().reset_ctx(ctx)

net.hybridize()

loss = gluon.loss.SoftmaxCrossEntropyLoss()

trainer = gluon.Trainer(net.collect_params(), 'sgd', {

'learning_rate': learning_rate, 'wd': 0.001})

d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, ctx)

我们将训练器实例中的学习率设置为一个较小的值,如0.01,以便对预训练中获得的模型参数进行微调。基于前面的设置,我们将使用10倍以上的学习率从头开始训练目标模型的输出层参数。

train_fine_tuning(finetune_net, 0.01)

loss 0.518, train acc 0.890, test acc 0.927

634.3 examples/sec on [gpu(0), gpu(1)]

为了进行比较,我们定义了一个相同的模型,但将其所有模型参数初始化为随机值。由于整个模型需要从头开始训练,所以我们可以使用更大的学习率。

scratch_net = gluon.model_zoo.vision.resnet18_v2(classes=2)

scratch_net.initialize(init=init.Xavier())

train_fine_tuning(scratch_net, 0.1)

loss 0.371, train acc 0.839, test acc 0.784

706.5 examples/sec on [gpu(0), gpu(1)]

正如您所看到的,由于参数的初始值更好,微调后的模型往往在同一时代获得更高的精度。

2. Summary

  • Transfer learning migrates the knowledge learned from the source dataset to the target dataset. Fine tuning is a common technique for transfer learning.
  • The target model replicates all model designs and their parameters on the source model, except the output layer, and fine-tunes these parameters based on the target dataset. In contrast, the output layer of the target model needs to be trained from scratch.
  • Generally, fine tuning parameters use a smaller learning rate, while training the output layer from scratch can use a larger learning rate.

Fine-Tuning微调原理的更多相关文章

  1. L23模型微调fine tuning

    resnet185352 链接:https://pan.baidu.com/s/1EZs9XVUjUf1MzaKYbJlcSA 提取码:axd1 9.2 微调 在前面的一些章节中,我们介绍了如何在只有 ...

  2. (原)caffe中fine tuning及使用snapshot时的sh命令

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/5946041.html 参考网址: http://caffe.berkeleyvision.org/tu ...

  3. Fine Tuning

    (转载自:WikiPedia) Fine tuning is a process to take a network model that has already been trained for a ...

  4. DL开源框架Caffe | 模型微调 (finetune)的场景、问题、技巧以及解决方案

    转自:http://blog.csdn.net/u010402786/article/details/70141261 前言 什么是模型的微调?   使用别人训练好的网络模型进行训练,前提是必须和别人 ...

  5. FineTuning机制的分析

    FineTuning机制的分析 为什么用FineTuning 使用别人训练好的网络模型进行训练,前提是必须和别人用同一个网络,因为参数是根据网络而来的.当然最后一层是可以修改的,因为我们的数据可能并没 ...

  6. [转载]关于Pretrain、Fine-tuning

    [转载]关于Pretrain.Fine-tuning 这两种tricks的意思其实就是字面意思,pre-train(预训练)和fine -tuning(微调) 来源:https://blog.csdn ...

  7. 【原创】TextCNN原理详解(一)

    ​ 最近一直在研究textCNN算法,准备写一个系列,每周更新一篇,大致包括以下内容: TextCNN基本原理和优劣势 TextCNN代码详解(附Github链接) TextCNN模型实践迭代经验总结 ...

  8. (原)torch中微调某层参数

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...

  9. TorchVision Faster R-CNN 微调,实战 Kaggle 小麦检测

    本文将利用 TorchVision Faster R-CNN 预训练模型,于 Kaggle: 全球小麦检测 上实践迁移学习中的一种常用技术:微调(fine tuning). 本文相关的 Kaggle ...

随机推荐

  1. Android Hook框架adbi的分析(2)--- inline Hook的实现

    本文博客地址:http://blog.csdn.net/qq1084283172/article/details/74452308 一. Android Hook框架adbi源码中inline Hoo ...

  2. SpringBoot JPA + 分页 + 单元测试SpringBoot JPA条件查询

    application.properties 新增数据库链接必须的参数 spring.jpa.properties.hibernate.hbm2ddl.auto=update 表示会自动更新表结构,所 ...

  3. vmware vpshere 安装完的必备工作

    1:例如:vCenter计算机地址为:192.168.0.200, 访问地址:https://192.168.0.200,安装证书: 参考教程:https://blog.csdn.net/cooljs ...

  4. 中文NER的那些事儿2. 多任务,对抗迁移学习详解&代码实现

    第一章我们简单了解了NER任务和基线模型Bert-Bilstm-CRF基线模型详解&代码实现,这一章按解决问题的方法来划分,我们聊聊多任务学习,和对抗迁移学习是如何优化实体识别中边界模糊,垂直 ...

  5. C++基础——文件逐行读取与字符匹配

    技术背景 用惯了python,对其他语言就比较的生疏.但是python很多时候在性能上比较受局限,这里尝试通过C++来实现一个文件IO的功能,看看是否能够比python的表现更好一些.关于python ...

  6. 技能Get·解决MSSQL Where查询中文数据存在但查不出来

    阅文时长 | 0.33分钟 字数统计 | 294.4字符 主要内容 | 1.引言&背景 2.声明与参考资料 『技能Get·解决MSSQL Where查询中文数据存在但查不出来』 编写人 | S ...

  7. java中基本数据类型、包装类及字符串之间的相互转换

    基本数据类型:不支持面向对象的编程机制(没有属性和方法),即不支持面向对象,之所以提供8中基本数据类型,是为了方便常规数据的处理. 包装类:通过包装类可以将基本数据类型的值包装为引用数据类型的对象,使 ...

  8. [bug] Mysql 对实体 "characterEncoding" 的引用必须以 ';' 分隔符结尾。

    参考 https://blog.csdn.net/cherrycheng_/article/details/51251441?

  9. [DB] 大数据集群安装

    学习要点 体系架构.原理 多做练习.试验 装虚拟机 网络模式:仅主机模式 software selection:development tools, GUI network & host na ...

  10. Docker------Linux安装Docker

    1.添加yum源 yum install epel-release –y yum clean all yum list 2.安装并运行Docker yum install docker-io –y s ...