用Keras定义网络模型有两种方式,

之前我们介绍了Sequential顺序模型,今天我们来接触一下 Keras 的函数式API模型

函数式API:全连接网络

  1. from keras.layers import Input, Dense
  2. from keras.models import Model
  3.  
  4. # 这部分返回一个张量
  5. inputs = Input(shape=(784,))
  6.  
  7. # 层的实例是可调用的,它以张量为参数,并且返回一个张量
  8. x = Dense(64, activation='relu')(inputs)
  9. x = Dense(64, activation='relu')(x)
  10. predictions = Dense(10, activation='softmax')(x)
  11.  
  12. # 这部分创建了一个包含输入层和三个全连接层的模型
  13. model = Model(inputs=inputs, outputs=predictions)
  14. model.compile(optimizer='rmsprop',
  15. loss='categorical_crossentropy',
  16. metrics=['accuracy'])
  17. model.fit(data, labelsbatch_size=32, epochs=5) # 开始训练

多输入多输出模型

主要负责用函数式API来实现它

主要输入接收新闻标题本身,即一个整数序列(每个证书编码一个词),这些整数在1到10000之间(10000个词的词汇表),且序列长度为100个词

  1. from keras.layers import Input, Embedding, LSTM, Dense
  2. from keras.models import Model
  3.  
  4. # 标题输入:接收一个含有 100 个整数的序列,每个整数在 1 到 10000 之间。
  5. # 注意我们可以通过传递一个 "name" 参数来命名任何层。
  6. main_input = Input(shape=(100,), dtype='int32', name='main_input')
  7.  
  8. # Embedding 层将输入序列编码为一个稠密向量的序列,
  9. # 每个向量维度为 512。
  10. x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)
  11.  
  12. # LSTM 层把向量序列转换成单个向量,
  13. # 它包含整个序列的上下文信息
  14. lstm_out = LSTM(32)(x)

在这里,我们插入辅助损失,即使在模型主损失很高的情况下,LSTM层和Embedding层都能被平稳地训练。

  1. auxiliary_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)

此时,我们将辅助输入数据与 LSTM 层的输出连接起来,输入到模型中:

  1. auxiliary_input = Input(shape=(5,), name='aux_input')
  2. x = keras.layers.concatenate([lstm_out, auxiliary_input])
  3.  
  4. # 堆叠多个全连接网络层
  5. x = Dense(64, activation='relu')(x)
  6. x = Dense(64, activation='relu')(x)
  7. x = Dense(64, activation='relu')(x)
  8.  
  9. # 最后添加主要的逻辑回归层
  10. main_output = Dense(1, activation='sigmoid', name='main_output')(x)

然后定义一个具有两个输入和两个输出的模型:

  1. model = Model(inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output])

现在编译模型,并给辅助损失分配一个 0.2 的权重。如果要为不同的输出指定不同的 loss_weights 或 loss,可以使用列表或字典。 在这里,我们给 loss 参数传递单个损失函数,这个损失将用于所有的输出。

  1. model.compile(optimizer='rmsprop', loss='binary_crossentropy',
  2. loss_weights=[1., 0.2])

我们可以通过输入数组和目标数组列表来训练模型:

  1. model.fit([headline_data, additional_data], [labels, labels],
  2. epochs=50, batch_size=32)

由于输入和输出均被命名了(在定义时传递了一个 name 参数),我们也可以通过以下方式编译模型:

  1. model.compile(optimizer='rmsprop',
  2. loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},
  3. loss_weights={'main_output': 1., 'aux_output': 0.2})
  4.  
  5. # 然后使用以下方式训练:
  6. model.fit({'main_input': headline_data, 'aux_input': additional_data},
  7. {'main_output': labels, 'aux_output': labels},
  8. epochs=50, batch_size=32)

共享网络层

函数API的另一个用途是使用共享网络层的模型。

比如我们想建立一个模型来分辨两条推文是否来自同一个人,实现这个目标的方法是:将两条推文编码层两个向量,连接向量,然后添加逻辑回归层;这将输出推文来自通一个作者的概率。模型将接受一对对正负表示的推特数据。

太难了,我理解不了。以后这条博客慢慢更新。

Keras函数式 API的更多相关文章

  1. keras函数式编程(多任务学习,共享网络层)

    https://keras.io/zh/ https://keras.io/zh/getting-started/functional-api-guide/ https://github.com/ke ...

  2. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  3. 【小白学PyTorch】21 Keras的API详解(下)池化、Normalization层

    文章来自微信公众号:[机器学习炼丹术].作者WX:cyx645016617. 参考目录: 目录 1 池化层 1.1 最大池化层 1.2 平均池化层 1.3 全局最大池化层 1.4 全局平均池化层 2 ...

  4. TensorFlow 1.4利用Keras+Estimator API进行训练和预测

    Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中 ...

  5. 【小白学PyTorch】21 Keras的API详解(上)卷积、激活、初始化、正则

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...

  6. 小白如何学习PyTorch】25 Keras的API详解(下)缓存激活,内存输出,并发解决

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...

  7. Keras高层API之Metrics

    在tf.keras中,metrics其实就是起到了一个测量表的作用,即测量损失或者模型精度的变化.metrics的使用分为以下四步: step1:Build a meter acc_meter = m ...

  8. 深度学习框架: Keras官方中文版文档正式发布

    今年 1 月 12 日,Keras 作者 François Chollet‏ 在推特上表示因为中文读者的广泛关注,他已经在 GitHub 上展开了一个 Keras 中文文档项目.而昨日,Françoi ...

  9. 3.keras实现-->高级的深度学习最佳实践

    一.不用Sequential模型的解决方案:keras函数式API 1.多输入模型 简单的问答模型 输入:问题 + 文本片段 输出:回答(一个词) from keras.models import M ...

随机推荐

  1. dynamic load javascript file.

    $.ajax({ url : ("js/public/" + window.localStorage.getItem("lang") + ".js&q ...

  2. js中对象的一些特性,JSON,scroll家族

    一.js中对象的一些特性 对象的动态特性 1.当对象有这个属性时,会对属性的值重写 2.当对象没有这个属性时,会为对象创建一个新属性,并赋值 获得对象的属性的方式 为元素设置DOM0级事件 二.JSO ...

  3. MVP框架模式

    一.基本概念 MVP是Model-View-Presenter的简称,即模型-视图-表现层的缩写.MVP是由MVC模式进化而来的,MVP改进了MVC中的控制器过于臃肿的问题.与MVC一样,MVP将应用 ...

  4. 分享6款优秀的 AR/VR 开源库

    今天,为大家推荐几款优秀的 AR/VR 开源库,希望能对大家有所帮助~ 1.AR.js AR.js 是一款应用于 Web 的高效增强现实(AR)库,基于 three.js + jsartoolkit5 ...

  5. SQL语言的增删改查

    select(查), update(改), delete(删), insert into(增)   select * from table_name 获取表中所有字段 select id, name, ...

  6. UVA-673 Parentheses Balance(栈)

    题目大意:给一个括号串,看是否匹配. 题目分析:一开始用区间DP写的,超时了... 注意:空串合法. 代码如下: # include<iostream> # include<cstd ...

  7. Page.TryUpdateModel 方法

    使用来自值提供程序的值更新指定的模型实例. 使用来自值提供程序的值更新指定的模型实例. 命名空间:   System.Web.UI程序集:  System.Web(System.Web.dll 中) ...

  8. 让FireFox支持window.event属性

    场景描述: 在用户行为采集的过程中,需要侦听window下的event对象,根据事件类型做相应的过滤处理,但在firefox下window.event是未定义的: 问题分析: 要想获取event属性共 ...

  9. Java——基本语法

    body, table{font-family: 微软雅黑; font-size: 10pt} table{border-collapse: collapse; border: solid gray; ...

  10. (转)Web系统大规模并发——电商秒杀与抢购

    电商的秒杀和抢购,对我们来说,都不是一个陌生的东西.然而,从技术的角度来说,这对于Web系统是一个巨大的考验.当一个Web系统,在一秒钟内收到数以万计甚至更多请求时,系统的优化和稳定至关重要.这次我们 ...