[深度学习] tf.keras入门4-过拟合和欠拟合
过拟合和欠拟合
简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。
官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit
过拟合和欠拟合
以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:
baseline模型结构为:10000-16-16-1
smaller_model模型结构为:10000-4-4-1
bigger_model模型结构为:10000-512-512-1
造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。
解决方法
正则化
正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。
正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975
tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。
l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
dropout
Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。
具体见: https://yq.aliyun.com/articles/68901
dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dropout(0.3), #百分之30的神经元失效
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dropout(0.7), #百分之70的神经元失效
keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy'])
dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
总结
常用防止过拟合的方法有:
- 增加数据量
- 减少网络结构参数
- 正则化
- dropout
- 数据扩增data-augmentation
- 批标准化
[深度学习] tf.keras入门4-过拟合和欠拟合的更多相关文章
- [深度学习] tf.keras入门3-回归
目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...
- [深度学习] tf.keras入门5-模型保存和载入
目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...
- [深度学习] tf.keras入门2-分类
目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...
- [深度学习] tf.keras入门1-基本函数介绍
目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...
- 深度学习:Keras入门(一)之基础篇
1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...
- 深度学习:Keras入门(一)之基础篇【转】
本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...
- 深度学习:Keras入门(一)之基础篇(转)
转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)
说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】
本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...
随机推荐
- POJ3280 Cheapest Palindrome (区间DP)
dp[i][j]表示将字符串子区间[i,j]转化为回文字符串的最小成本. 1 #include<cstdio> 2 #include<algorithm> 3 #include ...
- FluentValidation 验证(二):WebApi 中使用 注入服务
比如你要验证用户的时候判断一下这个用户名称在数据库是否已经存在了,这时候FluentValidation 就需要注入查询数据库 只需要注入一下就可以了 public class Login3Reque ...
- TomCat之安装
TomCat 之安装(伪分布式版本) 本次安装是使用的伪分布式的安装(即一台机器安装两个tomcat) 1.通过scp导入tomcat安装包 2.解压缩成俩个文件 3.修改第一个tomcat的配置文件 ...
- SpringBoot 2.5.5整合轻量级的分布式日志标记追踪神器TLog
TLog能解决什么痛点 随着微服务盛行,很多公司都把系统按照业务边界拆成了很多微服务,在排错查日志的时候.因为业务链路贯穿着很多微服务节点,导致定位某个请求的日志以及上下游业务的日志会变得有些困难. ...
- react.js+easyui 做一个简单的商品表
效果图: import React from 'react'; import { Form, FormField, Layout,DataList,LayoutPanel,Panel, Lab ...
- 『现学现忘』Git基础 — 37、标签tag(二)
目录 5.共享标签 6.删除标签 7.修改标签指定提交的代码 8.标签在.git目录中的位置 9.本文中所使用到的命令 提示:接上一篇文章内容. 5.共享标签 默认情况下,git push 命令并不会 ...
- golang开发一个简单的grpc
0.1.索引 https://waterflow.link/articles/1665674508275 1.什么是grpc 在 gRPC 中,客户端应用程序可以直接调用不同机器上的服务器应用程序上的 ...
- 【单元测试】Junit 4(二)--eclipse配置Junit+Junit基础注解
1.0 前言 前面我们介绍了白盒测试方法,后面我们来介绍一下Junit 4,使用的是eclipse(用IDEA的小伙伴可以撤了) 1.1 配置Junit 4 1.1.1 安装包 我们需要三个jar ...
- vue3 页面跳转
需要引入 useRouter import {useRouter} from "vue-router"; 然后声明对象 代码如下 export default { setup() ...
- uniapp/微信小程序 项目day03
一.商品列表 1.1 获取数据 首先能够进入商品列表的途径 传的数据有 了解了这个之后就可以开始了,先创建分支 创建编译模式,并分配初试数据 这个时候就可以获取数据了 需要的数据 所以在发起请求之前需 ...