目录

构建一个简单的模型

序贯(Sequential)模型

网络层的构造

模型训练和参数评价

模型训练

模型的训练

tf.data的数据集

模型评估和预测

基本模型的建立

网络层模型

模型子类函数构建

回调函数Callbacks

模型保存和载入

网络参数保存Weights only

配置参数保存Configuration only

完整模型保存


目前keras API 已经整合到 tensorflow最新版本1.9.0 中,在tensorflow中通过tf.keras就可以调用keras。

import tensorflow as tf
from tensorflow import keras

官方教程为:https://tensorflow.google.cn/guide/keras

tf.keras可以调用所有的keras编译代码,但是有两个限制:

  1. 版本问题,需要通过tf.keras.version确认版本。
  2. 模型保存问题,tf.keras默认使用 checkpoint format格式,而keras模型的保存格式HDF5需要借用函数save_format='h5'

构建一个简单的模型

序贯(Sequential)模型

序贯模型就是是多个网络层的线性堆叠,比如多层感知机,BP神经网络。

tf.keras构建一个简单的全连通网络(即多层感知器)代码如下:

#建立序贯模型
model = keras.Sequential()
#添加全连接层,节点数为64,激活函数为relu函数,dense表示标准的一维全连接层
model.add(keras.layers.Dense(64, activation='relu'))
#添加全连接层,节点数为64,激活函数为relu函数
model.add(keras.layers.Dense(64, activation='relu'))
#添加输出层,输出节点数为10
model.add(keras.layers.Dense(10, activation='softmax'))

其中激活函数详细信息见keras官方文档http://keras-cn.readthedocs.io/en/latest/other/activations/

网络层的构造

通常在tf.keras中,网络层的构造参数主要有以下几个:

  1. 激活函数activation function,默认是没有激活函数的。
  2. 参数初始化,默认通过正态分布初始化(Glorot uniform)
  3. 参数正则化,包括权值初始化和偏置的初始化。
#参数调整
#建立一个sigmoid层
layers.Dense(64, activation='sigmoid')
#或者
layers.Dense(64, activation=tf.sigmoid) #权重L1正则化
layers.Dense(64, kernel_regularizer=keras.regularizers.l1(0.01))
#偏置L2正则化
layers.Dense(64, bias_regularizer=keras.regularizers.l2(0.01)) #权重正交矩阵的随机数初始化
layers.Dense(64, kernel_initializer='orthogonal')
#偏置常数初始化
layers.Dense(64, bias_initializer=keras.initializers.constant(2.0))

模型训练和参数评价

模型训练

模型建立后,通过compile模块确定模型的训练参数(tf.keras.Model.compile)

tf.keras.Model.compile有三个主要参数:

  1. 优化器optimizer:通过tf.train模块调用优化器,可用的优化器类型见:http://keras-cn.readthedocs.io/en/latest/other/optimizers/
  2. 损失函数loss:通过tf.keras.losses模块调用损失函数,可用的损失函数类型见:http://keras-cn.readthedocs.io/en/latest/other/objectives/
  3. 模型评估方法metrics:通过tf.keras.metrics调用评估参数,可用的模型评估方法见:http://keras-cn.readthedocs.io/en/latest/other/metrics/

具体例子如下:

# 配置均方误差回归模型
model.compile(optimizer=tf.train.AdamOptimizer(0.01),
loss='mse', # 均方差
metrics=['mae']) # 平均绝对误差 # 配置分类模型
model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
loss=keras.losses.categorical_crossentropy, #多类的对数损失
metrics=[keras.metrics.categorical_accuracy]) #多分类问题,所有预测值上的平均正确率

模型的训练

对于小数据集,使用numpy数组,通过tf.keras.Model.fit模块来训练和评估模型。

import numpy as np
#输入数据(1000,32)
data = np.random.random((1000, 32))
#输入标签(1000,10)
labels = np.random.random((1000, 10))
#模型训练
model.fit(data, labels, epochs=10, batch_size=32)

tf.keras.Model.fit模块有三个重要的参数:

  1. 训练轮数epochs:epochs指的就是训练过程中数据将被训练多少轮,一个epoch指的是当一个完整的数据集通过了神经网络一次并且返回了一次。
  2. 批训练大小batch_size:基本上现在的梯度下降都是基于mini-batch的,即将一个完整数据分为batch_size个批次进行训练。详见http://keras-cn.readthedocs.io/en/latest/for_beginners/concepts/#epochs
  3. 验证集validation_data:通常一个模型训练,评估要有训练集,验证集和测试集。验证集就是模型调参时用来评估模型的数据集。

tf.data的数据集

对于大型数据集,常常通过tf.data模块来调用数据,详见https://tensorflow.google.cn/guide/datasets

# 数据实例化
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.repeat() #模型训练,steps_per_epoch表示每次训练的数据大小类似与batch_size
model.fit(dataset, epochs=10, steps_per_epoch=30)

模型评估和预测

通过 tf.keras.Model.evaluate 和tf.keras.Model.predict可以实现模型的评估和预测。

model.evaluate(x, y, batch_size=32)
model.evaluate(dataset, steps=30) model.predict(x, batch_size=32)
model.predict(dataset, steps=30)

基本模型的建立

网络层模型

通过f.keras.Sequential 可以实现各种的复杂模型,如:

  1. 多输入模型;
  2. 多输出模型;
  3. 参数共享层模型;
  4. 残差网络模型。

具体例子如下:

#输入参数
inputs = keras.Input(shape=(32,)) #网络层的构建
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(64, activation='relu')(x)
#预测
predictions = keras.layers.Dense(10, activation='softmax')(x) #模型实例化
model = keras.Model(inputs=inputs, outputs=predictions) #模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy']) #模型训练
model.fit(data, labels, batch_size=32, epochs=5)

模型子类函数构建

通常通过tf.keras.Model构建模型结构, __init__方法初始化模型,call方法进行参数传递。如下所示:

class MyModel(keras.Model):
#模型结构确定
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
#网络层的定义
self.dense_1 = keras.layers.Dense(32, activation='relu')
self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid')
#参数调用
def call(self, inputs):
#前向传播过程确定
x = self.dense_1(inputs)
return self.dense_2(x) def compute_output_shape(self, input_shape):
#输出参数确定
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.num_classes
return tf.TensorShape(shape) #模型初始化
model = MyModel(num_classes=10) #模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy']) #模型训练
model.fit(data, labels, batch_size=32, epochs=5)

回调函数Callbacks

回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型fit()中,即可在给定的训练阶段调用该函数集中的函数。详见:http://keras-cn.readthedocs.io/en/latest/other/callbacks/。主要回调函数有:

  1. tf.keras.callbacks.ModelCheckpoint:模型保存
  2. tf.keras.callbacks.LearningRateScheduler:学习率调整
  3. tf.keras.callbacks.EarlyStopping:中断训练
  4. tf.keras.callbacks.TensorBoard:tensorboard的使用

模型保存和载入

tf.keras有两种模型保存方式

网络参数保存Weights only

#模型保存为tensorflow默认格式
model.save_weights('./my_model') #载入模型
model.load_weights('my_model') #模型保存为keras默认格式,包含其他优化参数
model.save_weights('my_model.h5', save_format='h5') #载入模型
model.load_weights('my_model.h5')

配置参数保存Configuration only

保存一个没有模型参数只有配置参数的模型, Keras支持 JSON和YAML序列化格式:

# 模型保存
json_string = model.to_json()
yaml_string = model.to_yaml()
#模型载入
fresh_model = keras.models.from_json(json_string)
fresh_model = keras.models.from_yaml(yaml_string)

完整模型保存

将原来模型所用信息进行保存:

#模型建立
model = keras.Sequential([
keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, targets, batch_size=32, epochs=5) #保存为keras格式文件
model.save('my_model.h5') # 模型载入
model = keras.models.load_model('my_model.h5')

[深度学习] tf.keras入门1-基本函数介绍的更多相关文章

  1. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  2. [深度学习] tf.keras入门4-过拟合和欠拟合

    过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...

  3. [深度学习] tf.keras入门5-模型保存和载入

    目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...

  4. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. oracle 12C 《服务器、客户端安装》

    oracle 12C <服务器.客户端安装> 1.下载database和client database和client下载地址:http://www.oracle.com/technetwo ...

  2. MybatisPlus生成主键策略方法

    MybatisPlus生成主键策略方法 全局id生成策略[因为是全局id所以不推荐] SpringBoot集成Mybatis-Plus 在yaml配置文件中添加MP配置 mybatis-plus: g ...

  3. java中list集合怎么判断是否为空

    首先看下面代码 @RequestMapping("/getCatlist") public String getCatlist(HttpSession session,HttpSe ...

  4. Vue学习之--------组件在Vue脚手架中的使用(代码实现)(2022/7/24)

    文章目录 1.第一步编写组件 1.1 编写一个 展示学校的组件 1.2 定义一个展示学生的信息组件 2.第二步引入组件 3.制作一个容器 4.使用Vue接管 容器 5.实际效果 6.友情提示: 7.项 ...

  5. Spring Cloud 整合 nacos 实现动态配置中心

    上一篇文章讲解了Spring Cloud 整合 nacos 实现服务注册与发现,nacos除了有服务注册与发现的功能,还有提供动态配置服务的功能.本文主要讲解Spring Cloud 整合nacos实 ...

  6. 最长不下降子序列(线段树优化dp)

    最长不下降子序列 题目大意: 给定一个长度为 N 的整数序列:A\(_{1}\),A\(_{2}\),⋅⋅⋅,A\(_{N}\). 现在你有一次机会,将其中连续的 K 个数修改成任意一个相同值. 请你 ...

  7. 【云原生 · Kubernetes】Kubernetes容器云平台部署与运维

    [题目1]Deployment管理 在master节点/root目录下编写yaml文件nginx-deployment.yaml,具体要求如下: (1)Deployment名称:nginx-deplo ...

  8. CPU TLB原理 [转载好文]

    首先,我们知道MMU的作用是把虚拟地址转换成物理地址.虚拟地址和物理地址的映射关系存储在页表中,而现在页表又是分级的.64位系统常见的配置是4级页表,就以4级页表为例说明.分别是PGD.PUD.PMD ...

  9. 在Java Web中setContentType与setCharacterEncoding中设置字符编码格式的区别

    在Java Web中setContentType与setCharacterEncoding中设置字符编码格式的区别 通用解释 setCharacterEncoding只是设置字符的编码方式 setCo ...

  10. win7使用onedrive右键托盘图标中文不显示问题

    前言 win7 用的 onedrive不能在微软官网下载,用不了,所以需要下载 win7可以使用的版本. onedrive_for_win7.exe 解决问题 重启电脑解决 其他 我看贴吧说是文本放大 ...