目录

构建一个简单的模型

序贯(Sequential)模型

网络层的构造

模型训练和参数评价

模型训练

模型的训练

tf.data的数据集

模型评估和预测

基本模型的建立

网络层模型

模型子类函数构建

回调函数Callbacks

模型保存和载入

网络参数保存Weights only

配置参数保存Configuration only

完整模型保存


目前keras API 已经整合到 tensorflow最新版本1.9.0 中,在tensorflow中通过tf.keras就可以调用keras。

import tensorflow as tf
from tensorflow import keras

官方教程为:https://tensorflow.google.cn/guide/keras

tf.keras可以调用所有的keras编译代码,但是有两个限制:

  1. 版本问题,需要通过tf.keras.version确认版本。
  2. 模型保存问题,tf.keras默认使用 checkpoint format格式,而keras模型的保存格式HDF5需要借用函数save_format='h5'

构建一个简单的模型

序贯(Sequential)模型

序贯模型就是是多个网络层的线性堆叠,比如多层感知机,BP神经网络。

tf.keras构建一个简单的全连通网络(即多层感知器)代码如下:

#建立序贯模型
model = keras.Sequential()
#添加全连接层,节点数为64,激活函数为relu函数,dense表示标准的一维全连接层
model.add(keras.layers.Dense(64, activation='relu'))
#添加全连接层,节点数为64,激活函数为relu函数
model.add(keras.layers.Dense(64, activation='relu'))
#添加输出层,输出节点数为10
model.add(keras.layers.Dense(10, activation='softmax'))

其中激活函数详细信息见keras官方文档http://keras-cn.readthedocs.io/en/latest/other/activations/

网络层的构造

通常在tf.keras中,网络层的构造参数主要有以下几个:

  1. 激活函数activation function,默认是没有激活函数的。
  2. 参数初始化,默认通过正态分布初始化(Glorot uniform)
  3. 参数正则化,包括权值初始化和偏置的初始化。
#参数调整
#建立一个sigmoid层
layers.Dense(64, activation='sigmoid')
#或者
layers.Dense(64, activation=tf.sigmoid) #权重L1正则化
layers.Dense(64, kernel_regularizer=keras.regularizers.l1(0.01))
#偏置L2正则化
layers.Dense(64, bias_regularizer=keras.regularizers.l2(0.01)) #权重正交矩阵的随机数初始化
layers.Dense(64, kernel_initializer='orthogonal')
#偏置常数初始化
layers.Dense(64, bias_initializer=keras.initializers.constant(2.0))

模型训练和参数评价

模型训练

模型建立后,通过compile模块确定模型的训练参数(tf.keras.Model.compile)

tf.keras.Model.compile有三个主要参数:

  1. 优化器optimizer:通过tf.train模块调用优化器,可用的优化器类型见:http://keras-cn.readthedocs.io/en/latest/other/optimizers/
  2. 损失函数loss:通过tf.keras.losses模块调用损失函数,可用的损失函数类型见:http://keras-cn.readthedocs.io/en/latest/other/objectives/
  3. 模型评估方法metrics:通过tf.keras.metrics调用评估参数,可用的模型评估方法见:http://keras-cn.readthedocs.io/en/latest/other/metrics/

具体例子如下:

# 配置均方误差回归模型
model.compile(optimizer=tf.train.AdamOptimizer(0.01),
loss='mse', # 均方差
metrics=['mae']) # 平均绝对误差 # 配置分类模型
model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
loss=keras.losses.categorical_crossentropy, #多类的对数损失
metrics=[keras.metrics.categorical_accuracy]) #多分类问题,所有预测值上的平均正确率

模型的训练

对于小数据集,使用numpy数组,通过tf.keras.Model.fit模块来训练和评估模型。

import numpy as np
#输入数据(1000,32)
data = np.random.random((1000, 32))
#输入标签(1000,10)
labels = np.random.random((1000, 10))
#模型训练
model.fit(data, labels, epochs=10, batch_size=32)

tf.keras.Model.fit模块有三个重要的参数:

  1. 训练轮数epochs:epochs指的就是训练过程中数据将被训练多少轮,一个epoch指的是当一个完整的数据集通过了神经网络一次并且返回了一次。
  2. 批训练大小batch_size:基本上现在的梯度下降都是基于mini-batch的,即将一个完整数据分为batch_size个批次进行训练。详见http://keras-cn.readthedocs.io/en/latest/for_beginners/concepts/#epochs
  3. 验证集validation_data:通常一个模型训练,评估要有训练集,验证集和测试集。验证集就是模型调参时用来评估模型的数据集。

tf.data的数据集

对于大型数据集,常常通过tf.data模块来调用数据,详见https://tensorflow.google.cn/guide/datasets

# 数据实例化
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.repeat() #模型训练,steps_per_epoch表示每次训练的数据大小类似与batch_size
model.fit(dataset, epochs=10, steps_per_epoch=30)

模型评估和预测

通过 tf.keras.Model.evaluate 和tf.keras.Model.predict可以实现模型的评估和预测。

model.evaluate(x, y, batch_size=32)
model.evaluate(dataset, steps=30) model.predict(x, batch_size=32)
model.predict(dataset, steps=30)

基本模型的建立

网络层模型

通过f.keras.Sequential 可以实现各种的复杂模型,如:

  1. 多输入模型;
  2. 多输出模型;
  3. 参数共享层模型;
  4. 残差网络模型。

具体例子如下:

#输入参数
inputs = keras.Input(shape=(32,)) #网络层的构建
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(64, activation='relu')(x)
#预测
predictions = keras.layers.Dense(10, activation='softmax')(x) #模型实例化
model = keras.Model(inputs=inputs, outputs=predictions) #模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy']) #模型训练
model.fit(data, labels, batch_size=32, epochs=5)

模型子类函数构建

通常通过tf.keras.Model构建模型结构, __init__方法初始化模型,call方法进行参数传递。如下所示:

class MyModel(keras.Model):
#模型结构确定
def __init__(self, num_classes=10):
super(MyModel, self).__init__(name='my_model')
self.num_classes = num_classes
#网络层的定义
self.dense_1 = keras.layers.Dense(32, activation='relu')
self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid')
#参数调用
def call(self, inputs):
#前向传播过程确定
x = self.dense_1(inputs)
return self.dense_2(x) def compute_output_shape(self, input_shape):
#输出参数确定
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.num_classes
return tf.TensorShape(shape) #模型初始化
model = MyModel(num_classes=10) #模型构建
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy']) #模型训练
model.fit(data, labels, batch_size=32, epochs=5)

回调函数Callbacks

回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型fit()中,即可在给定的训练阶段调用该函数集中的函数。详见:http://keras-cn.readthedocs.io/en/latest/other/callbacks/。主要回调函数有:

  1. tf.keras.callbacks.ModelCheckpoint:模型保存
  2. tf.keras.callbacks.LearningRateScheduler:学习率调整
  3. tf.keras.callbacks.EarlyStopping:中断训练
  4. tf.keras.callbacks.TensorBoard:tensorboard的使用

模型保存和载入

tf.keras有两种模型保存方式

网络参数保存Weights only

#模型保存为tensorflow默认格式
model.save_weights('./my_model') #载入模型
model.load_weights('my_model') #模型保存为keras默认格式,包含其他优化参数
model.save_weights('my_model.h5', save_format='h5') #载入模型
model.load_weights('my_model.h5')

配置参数保存Configuration only

保存一个没有模型参数只有配置参数的模型, Keras支持 JSON和YAML序列化格式:

# 模型保存
json_string = model.to_json()
yaml_string = model.to_yaml()
#模型载入
fresh_model = keras.models.from_json(json_string)
fresh_model = keras.models.from_yaml(yaml_string)

完整模型保存

将原来模型所用信息进行保存:

#模型建立
model = keras.Sequential([
keras.layers.Dense(10, activation='softmax', input_shape=(32,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, targets, batch_size=32, epochs=5) #保存为keras格式文件
model.save('my_model.h5') # 模型载入
model = keras.models.load_model('my_model.h5')

[深度学习] tf.keras入门1-基本函数介绍的更多相关文章

  1. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  2. [深度学习] tf.keras入门4-过拟合和欠拟合

    过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...

  3. [深度学习] tf.keras入门5-模型保存和载入

    目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...

  4. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. AspNetCore中 使用 Grpc 简单Demo

    为什么要用Grpc 跨语言进行,调用服务,获取跨服务器调用等 目前我的需要使用 我的抓取端是go 写的 查询端用 Net6 写的 导致很多时候 我需要把一些临时数据写入到 Redis 在两个服务器进行 ...

  2. 【MySQL】01_运算符、函数

    运算符 运算符是保留字或主要用于 SQL 语句的 WHERE 子句 中的字符,用于执行操作,例如:比较和算术运算. 这些运算符用于指定 SQL 语句中的条件,并用作语句中多个条件的连词. 常见运算符有 ...

  3. 前端监控系列4 | SDK 体积与性能优化实践

    背景 字节各类业务拥有众多用户群,作为字节前端性能监控 SDK,自身若存在性能问题,则会影响到数以亿计的真实用户的体验.所以此类 SDK 自身的性能在设计之初,就必须达到一个非常极致的水准. 与此同时 ...

  4. 【原创】在RT1050 LittleVgl GUI中嵌入中文输入法框架

    时隔一年多终于又冒泡了,哎,随着工作越来越忙,自己踏实坐下来写点东西真是越来越费劲,这篇文章也是准备了好久好久才打算发表出来(不瞒大家,东西做完好久了,文章憋了一年了,当真"高产" ...

  5. How to install the Package Controller

    How to install the Package Controller? https://packagecontrol.io/installation INSTALLATION Use one o ...

  6. Azure DevOps Server 入门实践与安装部署

    一,引言 最近一段时间,公司希望在自己的服务器上安装本地版的 Azure DevOps Service(Azure DevOps Server),用于项目内的测试,学习.本着学习的目的,我也就开始学习 ...

  7. Mysql InnoDB多版本并发控制MVCC

    参考书籍<mysql是怎样运行的> 系列文章目录和关于我 一丶为什么需要事务隔离级别 mysql是一个客户端/服务断软件,对于同一个服务器来说,可以有多个客户端进行连接,每一个客户端进行连 ...

  8. C++ 中指针常量、指向常量的指针、引用类型的常量

    命题1. 在C++ 中 const T a 与 T const a 是一样的, 表示a是一个T类型的常量. 测试: 一. 形参定义为引用类型的常量 在函数传参时,形参若定义为 const T& ...

  9. C#程序自启动

    在窗体加载事件里面加入下述代码: //设置开机自启动 RegistryKey registryKey = Registry.CurrentUser.OpenSubKey ("SOFTWARE ...

  10. .NET刷算法

    BFS模板-宽度优先搜索(Breadth First Search) 1.模板 /// <summary> /// BFS遍历 /// </summary> /// <p ...