Keras函数式 API
用Keras定义网络模型有两种方式,
之前我们介绍了Sequential顺序模型,今天我们来接触一下 Keras 的函数式API模型。
函数式API:全连接网络
from keras.layers import Input, Dense
from keras.models import Model # 这部分返回一个张量
inputs = Input(shape=(784,)) # 层的实例是可调用的,它以张量为参数,并且返回一个张量
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x) # 这部分创建了一个包含输入层和三个全连接层的模型
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels,batch_size=32, epochs=5) # 开始训练
多输入多输出模型
主要负责用函数式API来实现它
主要输入接收新闻标题本身,即一个整数序列(每个证书编码一个词),这些整数在1到10000之间(10000个词的词汇表),且序列长度为100个词
from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model # 标题输入:接收一个含有 100 个整数的序列,每个整数在 1 到 10000 之间。
# 注意我们可以通过传递一个 "name" 参数来命名任何层。
main_input = Input(shape=(100,), dtype='int32', name='main_input') # Embedding 层将输入序列编码为一个稠密向量的序列,
# 每个向量维度为 512。
x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input) # LSTM 层把向量序列转换成单个向量,
# 它包含整个序列的上下文信息
lstm_out = LSTM(32)(x)
在这里,我们插入辅助损失,即使在模型主损失很高的情况下,LSTM层和Embedding层都能被平稳地训练。
auxiliary_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)
此时,我们将辅助输入数据与 LSTM 层的输出连接起来,输入到模型中:
auxiliary_input = Input(shape=(5,), name='aux_input')
x = keras.layers.concatenate([lstm_out, auxiliary_input]) # 堆叠多个全连接网络层
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x) # 最后添加主要的逻辑回归层
main_output = Dense(1, activation='sigmoid', name='main_output')(x)
然后定义一个具有两个输入和两个输出的模型:
model = Model(inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output])
现在编译模型,并给辅助损失分配一个 0.2 的权重。如果要为不同的输出指定不同的 loss_weights
或 loss
,可以使用列表或字典。 在这里,我们给 loss
参数传递单个损失函数,这个损失将用于所有的输出。
model.compile(optimizer='rmsprop', loss='binary_crossentropy',
loss_weights=[1., 0.2])
我们可以通过输入数组和目标数组的列表来训练模型:
model.fit([headline_data, additional_data], [labels, labels],
epochs=50, batch_size=32)
由于输入和输出均被命名了(在定义时传递了一个 name 参数),我们也可以通过以下方式编译模型:
model.compile(optimizer='rmsprop',
loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},
loss_weights={'main_output': 1., 'aux_output': 0.2}) # 然后使用以下方式训练:
model.fit({'main_input': headline_data, 'aux_input': additional_data},
{'main_output': labels, 'aux_output': labels},
epochs=50, batch_size=32)
共享网络层
函数API的另一个用途是使用共享网络层的模型。
比如我们想建立一个模型来分辨两条推文是否来自同一个人,实现这个目标的方法是:将两条推文编码层两个向量,连接向量,然后添加逻辑回归层;这将输出推文来自通一个作者的概率。模型将接受一对对正负表示的推特数据。
太难了,我理解不了。以后这条博客慢慢更新。
Keras函数式 API的更多相关文章
- keras函数式编程(多任务学习,共享网络层)
https://keras.io/zh/ https://keras.io/zh/getting-started/functional-api-guide/ https://github.com/ke ...
- 手写数字识别——利用keras高层API快速搭建并优化网络模型
在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...
- 【小白学PyTorch】21 Keras的API详解(下)池化、Normalization层
文章来自微信公众号:[机器学习炼丹术].作者WX:cyx645016617. 参考目录: 目录 1 池化层 1.1 最大池化层 1.2 平均池化层 1.3 全局最大池化层 1.4 全局平均池化层 2 ...
- TensorFlow 1.4利用Keras+Estimator API进行训练和预测
Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中 ...
- 【小白学PyTorch】21 Keras的API详解(上)卷积、激活、初始化、正则
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...
- 小白如何学习PyTorch】25 Keras的API详解(下)缓存激活,内存输出,并发解决
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...
- Keras高层API之Metrics
在tf.keras中,metrics其实就是起到了一个测量表的作用,即测量损失或者模型精度的变化.metrics的使用分为以下四步: step1:Build a meter acc_meter = m ...
- 深度学习框架: Keras官方中文版文档正式发布
今年 1 月 12 日,Keras 作者 François Chollet 在推特上表示因为中文读者的广泛关注,他已经在 GitHub 上展开了一个 Keras 中文文档项目.而昨日,Françoi ...
- 3.keras实现-->高级的深度学习最佳实践
一.不用Sequential模型的解决方案:keras函数式API 1.多输入模型 简单的问答模型 输入:问题 + 文本片段 输出:回答(一个词) from keras.models import M ...
随机推荐
- 51Nod 1509 加长棒(隔板法)
http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1509 思路: 直接去解可行的方法有点麻烦,所以应该用总的方法去减去不可行 ...
- TimerPickerDialog 中 onTimeSet 执行两次的问题
开发android小闹钟的程序时,在添加闹钟时闹钟列表中总是出现两个相同的闹钟. btnAddAlarm.setOnClickListener(new View.OnClickListener() { ...
- 递推-练习1--noi1760 菲波那契数列(2)
递推-练习1--noi1760 菲波那契数列(2) 一.心得 二.题目 1760:菲波那契数列(2) 总时间限制: 1000ms 内存限制: 65536kB 描述 菲波那契数列是指这样的数列: 数 ...
- JAVA异常处理分析高级进界(下)
既然Throwable是异常处理机制的核心,那么,我们就来分析下它的源码来看看它是如何实现的. 进行分析前,我们可以先想想如果让我们实现一个异常处理机制,我们需要它做什么? 1. 发生异常终止程序执行 ...
- hdu4686矩阵快速幂
花了一个多小时终于ac了,有时候真的是需要冷静一下重新打一遍才行. 这题就是 |aod(n)| = |1 ax*bx ax*by ay*bx ...
- UVA-1614 Hell on the Markets(贪心+推理) (有待补充)
题目大意:一个整数序列a,1≤a[i]≤i.问能否通过在一些元素前加上负号,使得整个序列和为0. 题目分析:贪心.贪心策略:每次都先选最大的元素加负号(或保留,不加负号). 贪心依据:对于1≤a[i] ...
- Less开发指南(三)- 代码文件跟踪调试
案例背景:在大型网站中,css样式划分为多个模块文件,如reset.css,layout.css,skin.css等等(颗粒化越小,样式重用率越高),页面需要的时候引入它们即可! 回到less项目中这 ...
- iOS UI-静态单元格与动态单元格
- IOS-网络(小文件下载)
// // ViewController.m // IOS_0131_小文件下载 // // Created by ma c on 16/1/31. // Copyright © 2016年 博文科技 ...
- vue.js-读取/发送数据
<!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...