keras建模的3种方式——序列模型、函数模型、子类模型
1 前言
keras是Google公司于2016年发布的以tensorflow为后端的用于深度学习网络训练的高阶API,因接口设计非常人性化,深受程序员的喜爱。
keras建模有3种实现方式——序列模型、函数模型、子类模型。本文以MNIST手写数字为例,用3种建模方式实现。
关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类
笔者工作空间如下:
2 序列模型
sequential.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.models import load_model
from keras.layers import Dense
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#序列模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
model=Sequential()
model.add(Dense(64,input_dim=784,activation='relu'))
model.add(Dense(128,activation='relu'))
model.add(Dense(10,activation='softmax'))
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('sequential.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('sequential.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])
运行结果
Epoch 98/100
- 1s - loss: 4.8694e-04 - acc: 1.0000 - val_loss: 0.1331 - val_acc: 0.9776
Epoch 99/100
- 1s - loss: 4.7432e-04 - acc: 1.0000 - val_loss: 0.1336 - val_acc: 0.9778
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774
test_loss: 0.13972217990085484 - test_acc: 0.9768999993801117
3 函数模型
fun_model.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.models import load_model
from keras.layers import Input,Dense
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#函数模型
def DNN(train_x,train_y,valid_x,valid_y):
#创建模型
inputs=Input(shape=(784,))
x=Dense(64,activation='relu')(inputs)
x=Dense(128,activation='relu')(x)
output=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=output)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#保存模型
model.save('fun_model.h5')
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
DNN(train_x,train_y,valid_x,valid_y)
model=load_model('fun_model.h5') #下载模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])
4 子类模型
class_model.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import Dense
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#子类模型
class DNN(Model):
def __init__(self):
super(DNN,self).__init__()
#初始化网络结构
self.dense1=Dense(64,input_dim=784,activation='relu')
self.dense2=Dense(128,activation='relu')
self.dense3=Dense(10,activation='softmax')
def call(self,inputs): #回调顺序
x=self.dense1(inputs)
x=self.dense2(x)
x=self.dense3(x)
return x
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
model=DNN()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=2,validation_data=(valid_x,valid_y))
#查看网络结构
model.summary()
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2) #评估模型
print('test_loss:',pre[0],'- test_acc:',pre[1])
5 注意事项
(1)只有序列模型和函数模型能够保存模型,子类模型不能保存模型,即不能调用 model.save()
(2)子类模型中,model.summary() 得放在 model.fit() 之后,否则会报错
ValueError: This model has not yet been built. Build the model first by calling build() or calling fit() with some data. Or specify input_shape or batch_input_shape in the first layer for automatic build.
(3)若想自定义学习率,可以引入优化器对象,如下:
from keras.optimizers import Adam
....
model.compile(optimizer=Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
(4)常用损失函数
mse #均方差(回归)
mae #绝对误差(回归)
binary_crossentropy #二值交叉熵(二分类,逻辑回归)
categorical_crossentropy #交叉熵(多分类)
(5)model.fit( ) 和 model.evaluate( ) 中,属性 verbose 表示打印训练或评估信息是否详细
- 0:不打印进度和结果
- 1:打印进度和结果
Epoch 100/100
55000/55000 [==============================] - 1s 9us/step - loss: 6.0211e-05 - acc: 1.0000 - val_loss: 0.1405 - val_acc: 0.9766
10000/10000 [==============================] - 0s 26us/step
- 2:只打印结果
Epoch 100/100
- 1s - loss: 4.6462e-04 - acc: 1.0000 - val_loss: 0.1343 - val_acc: 0.9774
声明:本文转自keras建模的3种方式——序列模型、函数模型、子类模型
keras建模的3种方式——序列模型、函数模型、子类模型的更多相关文章
- python 零散记录(五) import的几种方式 序列解包 条件和循环 强调getattr内建函数
用import关键字导入模块的几种方式: #python是自解释的,不必多说,代码本身就是人可读的 import xxx from xxx import xxx from xxx import xx1 ...
- 增加收入的 6 种方式(很多公司的模型是:一份时间卖多次。比如网易、腾讯。个人赚取收入的本质是:出售时间)good
个人赚取收入的本质是:出售时间.从这个角度出发,下面的公式可以描述个人收入: 个人收入 = 每天可售时间数量 * 单位时间价格 * 单位时间出售次数 在这个公式里,有三个要素: 每天可出售的时间数量 ...
- Keras框架下的保存模型和加载模型
在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...
- Android-创建启动线程的两种方式
方式一:成为Thread的子类,然后在Thread的子类.start 缺点:存在耦合度(因为线程任务run方法里面的业务逻辑 和 线程启动耦合了) 缺点:Cat extends Thread {} 后 ...
- SpringBoot集成Mybatis实现多表查询的两种方式(基于xml)
下面将在用户和账户进行一对一查询的基础上进行介绍SpringBoot集成Mybatis实现多表查询的基于xml的两种方式. 首先我们先创建两个数据库表,分别是user用户表和account账户表 ...
- 【Keras篇】---Keras初始,两种模型构造方法,利用keras实现手写数字体识别
一.前述 Keras 适合快速体验 ,keras的设计是把大量内部运算都隐藏了,用户始终可以用theano或tensorflow的语句来写扩展功能并和keras结合使用. 二.安装 Pip insta ...
- keras embeding设置初始值的两种方式
随机初始化Embedding from keras.models import Sequential from keras.layers import Embedding import numpy a ...
- Keras中间层输出的两种方式,即特征图可视化
训练好的模型,想要输入中间层的特征图,有两种方式: 1. 通过model.get_layer的方式.创建新的模型,输出为你要的层的名字. 创建模型,debug状态可以看到模型中,base_model/ ...
- Windows10-UWP中设备序列显示不同XAML的三种方式[3]
阅读目录: 概述 DeviceFamily-Type文件夹 DeviceFamily-Type扩展 InitializeComponent重载 结论 概述 Windows10-UWP(Universa ...
- 三种方式实现观察者模式 及 Spring中的事件编程模型
观察者模式可以说是众多设计模式中,最容易理解的设计模式之一了,观察者模式在Spring中也随处可见,面试的时候,面试官可能会问,嘿,你既然读过Spring源码,那你说说Spring中运用的设计模式吧, ...
随机推荐
- css - 伪元素清除浮动
.clearfix:after{ content:""; /*设置内容为空*/ height:0; /*高度为0*/ line-height:0; /*行高为0*/ display ...
- Oracle12c On 银河麒麟v10SP3 的安装过程
Oracle12c On 银河麒麟的安装过程 学习官网资料 下载最新版的preinstall文件 https://yum.oracle.com/repo/OracleLinux/OL8/appstre ...
- [转帖]没 K8s 用不了 Chaos Mesh?试试 Chaosd
https://cn.pingcap.com/blog/cannot-use-chaosmesh-without-k8s-then-try-chaosd Chaosd 是什么? 相信大家对 Chaos ...
- charles如何抓取https请求
我们都知道charles下载安装后只能抓取http请求,要想抓取https请求需要下载安装证书 下面介绍pc端和移动端的配置方法 一.pc端(win) 1.打开charles,点击help>SS ...
- Skia 编译及踩坑实践
本文要点 •了解并入门 Skia.OpenGL 和 Vulkan •了解 Skia 在后端渲染上的坑点 前言 Skia 是什么 Skia 是一个开源 2D 图形库,提供可跨各种硬件和软件平台工作的通用 ...
- 【AIGC】只要10秒,AI生成IP海报,解放双手!!!
看完这篇文章,你将学会以下价值连城的内容 1.云端部署(配置不行的小伙伴看)+ 云端模型放置位置 2.本地部署(配置达标的小伙伴看) 3.运用SD训练IP的流程和技巧(LoRA篇) 4.运用SD稳定生 ...
- 一种读取亿级doris数据库的方法
工作中,常常需要将线上doris同步至集市.读取doris数据同读取常规mysql基本相同.如果数据行小于千万,比较简单的方式直接单节点连接.读取和存储.Python示例如下: def get_dat ...
- 解决node与npm版本不一致,出现npm WARN npm npm does not support Node.js v15.14.0
出现node与npm版本不一致 今天我升级了node之后,出现的了如下信息 npm WARN npm You should probably upgrade to a newer version of ...
- Vue基础系统文章07---webpack安装和配置与打包
1.当前web开发困境 a.文件依赖关系错综复杂 b.静态资源请求效率低 c.模块化支持不友好 d.浏览器对高级js兼容性低 例如:模块代码实现隔行换色 1)在新建空白文件夹中运行:npm init ...
- openim支持十万超级大群
钉钉:根据相关监管要求,新建普通群人数上限调整为500人,不支持群人数扩容. 企业微信:内部群聊人数最多支持2000人,群个数无上限.全员群人数最多支持10000人.企业微信用户创建的外部群人数最多支 ...