深度学习模型的调优,首先需要对各方面进行评估,主要包括定义函数、模型在训练集和测试集拟合效果、交叉验证、激活函数和优化算法的选择等。

那如何对我们自己的模型进行判断呢?——通过模型训练跑代码,我们可以分别从训练集和测试集上看到这个模型造成的损失大小(loss),还有它的精确率(accuracy)。

目录

  前言

  1、定义模型函数

  2、交叉验证(Cross-validation)

  3、优化算法

  4、激活函数(activation)

  5、dropout

  6、early stopping

  模型训练实战案例

前言

  1. 最初可从分析数据集的特征,选择适合的函数以及优化器。
  2. 其次,模型在训练集上的效果。至少模型在训练集上得到理想的效果,模型优化方向:激活函数(activation function)、学习率(learning rate)。
  3. 再者,模型在训练集上得到好的效果之后,对测试集上的拟合程度更为重要,因为我们训练之后目的就是去预测并达到好的结果。模型优化方向:正则化(regularization)、丢弃参数(dropout)、提前停止训练(early stopping)。

以下介绍具体的方法和过程,图文结合,一目了然,这样方便学习记忆!

1、定义模型函数

衡量方法通常用到真实f与预测f*的方差(variance)和偏差(bias),最后对数据集的拟合程度可分为4种情况。如下:

bias/variance low high
low 模型过于复杂
high 模型过于简单 ×

这好比数据模型预测对准靶心后的偏移和分散程度,期望达到的效果就是(low,low),模型的偏差相当于与目标的偏离程度,而方差就是数据之间的分散程度。

定义一个模型函数后结果会遇到(low,high)和(high,low),显然,如果是(high,high)说明函数完全与模型不匹配,可重新考虑其他函数模型。那出现其他情况如何评估这个函数呢?

  1. 小偏差,大方差:所谓模型过拟合,在训练数据上得到好的效果,而在测试集上效果并不满意。优化方向:增加数据量;数据正则化处理。增加数据可以提高模型的鲁棒性,不被特殊数据影响整个模型的偏向;正则化是另一种方法,为了减小variance,但直接影响到bias,需要权衡。
  2. 大偏差,小方差:模型过于简单,在训练上没有得到好的效果。优化:增加模型参数(特征),更好去拟合数据。

2、交叉验证(Cross-validation)

在训练一个模型时候,通常会将数据分为:训练集,测试集,开放集(小型训练集)。这里交叉验证是在训练集上进行的,选择最优模型。

假设一个划分为:

training-set testing-set
90% 10%

这里进行十折交叉验证,每一轮训练去9份数据作为训练集,1份作为测试集,并且每一轮的训练集与测试集对换,实现了所有数据都作为样本训练,所得到的模型避免了过拟合与低拟

合的问题。

K折交叉验证类似。

优点:

(1)每一轮训练中几乎所有的样本数据用于训练模型,这样最接近原始样本的分布,评估所得的结果比较可靠。

(2)训练模型过程中较少随机特殊因素会影响实验数据,鲁棒性更好。

3、优化算法

模型优化算法选择根据要训练的模型以及数据,选择合适的算法,常用优化算法有:Gradient descent,Stochastic gradient descent,Adagrad,Adam,RMSprop。

前两种算法原理好理解,这里给出后面三种的算法原理如图,具体就不介绍了,写代码时候通常直接指定算法就行,算法原理理解就OK啦。

Adagrad

RMSprop

Adam

4、激活函数(activation)

选择合适的激活函数对模型训练也有很大的影响,需要适应训练任务,比如任务是二分类、多分类或许更新参数梯度问题等。常用的有:sigmoid、tanh、relu、softmax

sigmoid

功能特点:平滑函数,连续可导,适合二分类,存在梯度消失问题。

tanh

功能特点:与sigmoid相同的缺点,存在梯度消失,梯度下降的速度变慢。一般用在二分问题输出层,不在隐藏层中使用。

relu

功能特点:ReLU在神经网络中使用最广泛的激活函数。根据图像x负半轴的特点,relu不会同时激活所有的神经元,如果输入值是负的,ReLU函数会转换为0,而神经元不被激活。这意味着,在一段时间内,只有少量的神经元被激活,神经网络的这种稀疏性使其变得高效且易于计算。

softmax

功能特点:又称归一化指数函数。它是二分类函数sigmoid在多分类上的推广,目的是将多分类的结果以概率的形式展现出来。softmax通常在分类器的输出层使用。

在模型训练任务中激活函数通过指定选择,我们理解原理后帮助选择正确的函数,具体无需自己实现。

5、dropout

它解决深度神经网络的过拟合(overfitting)梯度消失(gradient vanishing)问题,简单理解就是,当模型在训练集上得到较好的效果,而在测试集效果并不乐观,此时使用dropout对训练时候的参数进行优化调整(减少训练时候的参数),在学习过程中通过将隐含层的部分权重或输出随机归零,降低节点间的相互依赖性,使得模型在测试集上得到较好的结果的一种方法。相当于运用新的神经网络结构训练模型,在模型训练时候可在每一层指定设计。

6、early stopping

在训练时候观察模型在training-set和testing-set上的损失(loss),我们想要得到的模型是在测试时候损失误差更小,训练时候在训练集上可能不是最好的效果,所以需要提前停止保证了模型预测得到较好的结果。理解如图:

 模型训练实战案例

(X_train, Y_train), (X_test, Y_test) = load_data()
model = keras.Sequential()
model.add(Dense(input_dim=28 * 28, units=690,
activation='relu')) # tanh activation:Sigmoid、tanh、ReLU、LeakyReLU、pReLU、ELU、maxout、softmax
model.add(Dense(units=690, activation='relu'))
model.add(Dense(units=690, activation='relu')) # tanh model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics=['accuracy']) # loss:mse,categorical_crossentropy,optimizer: rmsprop 或 adagrad、SGD(此处推荐)
model.fit(X_train, Y_train, batch_size=100, epochs=20)
result = model.evaluate(X_train, Y_train, batch_size=10000)
print('Train ACC:', result[1])
result = model.evaluate(X_test, Y_test, batch_size=10000)
print('Test ACC:', result[1])

案例中深度学习模型调优总结:

前三层sigmoid,输出层softmax,损失函数categorical_crossentropy,优化器SGD,训练集上达到86%
前三层换为relu,隐藏层加10层,其他不变,训练集上达到99.97%,测试集上95.9%,不加10层,Train-acc:100%,Test-acc:95.54%
优化器adam,训练速度提高,99.98%和96.7%
上面模型调优的理论方法,理解之后进行实战演练,可参考上一篇文章《mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)》,自行根据理论方法对模型进行调优,体验一下简单深度学习模型训练和模型调优的过程!

我的博客园:https://www.cnblogs.com/chenzhenhong/p/13437132.html

深度学习模型调优方法(Deep Learning学习记录)的更多相关文章

  1. 深度学习笔记之关于总结、展望、参考文献和Deep Learning学习资源(五)

    不多说,直接上干货! 十.总结与展望 1)Deep learning总结 深度学习是关于自动学习要建模的数据的潜在(隐含)分布的多层(复杂)表达的算法.换句话来说,深度学习算法自动的提取分类需要的低层 ...

  2. 【deep learning学习笔记】注释yusugomori的DA代码 --- dA.h

    DA就是“Denoising Autoencoders”的缩写.继续给yusugomori做注释,边注释边学习.看了一些DA的材料,基本上都在前面“转载”了.学习中间总有个疑问:DA和RBM到底啥区别 ...

  3. (转)分布式深度学习系统构建 简介 Distributed Deep Learning

    HOME ABOUT CONTACT SUBSCRIBE VIA RSS   DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...

  4. 英特尔深度学习框架BigDL——a distributed deep learning library for Apache Spark

    BigDL: Distributed Deep Learning on Apache Spark What is BigDL? BigDL is a distributed deep learning ...

  5. [置顶] Deep Learning 学习笔记

    一.文章来由 好久没写原创博客了,一直处于学习新知识的阶段.来新加坡也有一个星期,搞定签证.入学等杂事之后,今天上午与导师确定了接下来的研究任务,我平时基本也是把博客当作联机版的云笔记~~如果有写的不 ...

  6. 【deep learning学习笔记】Recommending music on Spotify with deep learning

    主要内容: Spotify是个类似酷我音乐的音乐站点.做个性化音乐推荐和音乐消费.作者利用deep learning结合协同过滤来做音乐推荐. 详细内容: 1. 协同过滤 基本原理:某两个用户听的歌曲 ...

  7. paper 149:Deep Learning 学习笔记(一)

     1. 直接上手篇 台湾李宏毅教授写的,<1天搞懂深度学习> slideshare的链接: http://www.slideshare.net/tw_dsconf/ss-62245351? ...

  8. 【deep learning学习笔记】注释yusugomori的RBM代码 --- 头文件

    百度了半天yusugomori,也不知道他是谁.不过这位老兄写了deep learning的代码,包括RBM.逻辑回归.DBN.autoencoder等,实现语言包括c.c++.java.python ...

  9. Neural Networks and Deep Learning学习笔记ch1 - 神经网络

    近期開始看一些深度学习的资料.想学习一下深度学习的基础知识.找到了一个比較好的tutorial,Neural Networks and Deep Learning,认真看完了之后觉得收获还是非常多的. ...

随机推荐

  1. Java多线程编程基础知识汇总

    多线程简介 多任务   现代操作系统(Windows.Linux.MacOS)都可以执行多任务,多任务就是同时运行多个任务.例如在我们的计算机上,一般都同时跑着多个程序,例如浏览器,视频播放器,音乐播 ...

  2. Python 之 Restful API设计规范

    理解RESTful架构 Restful API设计指南 理解RESTful架构 越来越多的人开始意识到,网站即软件,而且是一种新型的软件. 这种"互联网软件"采用客户端/服务器模式 ...

  3. hihoCoder 1037 数字三角形 最详细的解题报告

    题目来源:hihoCoder 1037 数字三角形 解题思路:请好好看看 提示一.提示二.提示三 具体算法(java版,可以直接AC) import java.util.Scanner; public ...

  4. POJ 1057 File Mapping 最详细的解题报告

    题目来源:POJ 1057 File Mapping 题目大意:像我的电脑那样显示文件夹和文件信息,其中在同一级目录内,文件夹排在文件的前面并且文件夹的顺序不变,同一级目录中文件按字母序排列.文件以‘ ...

  5. mysql中DDL库和表的管理

    #DDL /* 数据定义语言 库和表的管理 一.库的管理 创建.修改.删除 二.表的管理 创建.修改.删除 创建:create 修改:alter 删除:drop */ #一.库的管理 #1.库的创建 ...

  6. Elasticsearch从入门到放弃:再聊搜索

    在前文中我们曾经聊过搜索文档的方法,Elasticsearch 一般适用于读多写少的场景,因此我们需要更多的关注读操作. Elasticsearch 提供的 Search API 可以分为 URI S ...

  7. oracle终止用户会话

    1.创建两个测试用户进行实验 执行命令如下: create user test1 identified by 1; create user test2 identified by 1; grant d ...

  8. 010.Nginx正反代理

    一 Nginx代理 1.1 Nginx代理概述 nginx是一款自由的.开源的.高性能的HTTP服务器和反向代理服务器.同时也是一个IMAP.POP3.SMTP代理服务器.nginx可以作为一个HTT ...

  9. Linux内核功能介绍及如何使用保护您的网页安全

    在本文中,我们快速浏览了Linux内核的许可流程,并向您展示了如何使用它们来保护您的网页或应用安全 传统上,Linux内核通过以下两类来区分其进程: 特权进程:这些进程使用户可以绕过所有内核权限检查. ...

  10. java8之Stream流处理

    简介 Stream 流处理,首先要澄清的是 java8 中的 Stream 与 I/O 流 InputStream 和 OutputStream 是完全不同的概念. Stream 机制是针对集合迭代器 ...