导入Keras函数模型

假设使用Keras的函数API开始定义一个简单的MLP:

  1. from keras.models import Model
  2. from keras.layers import Dense, Input
  3.  
  4. inputs = Input(shape=(100,))
  5. x = Dense(64, activation='relu')(inputs)
  6. predictions = Dense(10, activation='softmax')(x)
  7. model = Model(inputs=inputs, outputs=predictions)
  8. model.compile(loss='categorical_crossentropy',optimizer='sgd', metrics=['accuracy'])

在Keras,有几种保存模型的方法。可以将整个模型(模型定义、权重和训练配置)存储为HDF5文件,仅存储模型配置(作为JSON或YAML文件)或仅存储权重(作为HDF5文件):

  1. model.save('full_model.h5') # save everything in HDF5 format
  2.  
  3. model_json = model.to_json() # save just the config. replace with "to_yaml" for YAML serialization
  4. with open("model_config.json", "w") as f:
  5. f.write(model_json)
  6.  
  7. model.save_weights('model_weights.h5') # save just the weights.
如果你决定保存完整的模型,那么将能够访问模型的训练配置,否则将不访问。因此,如果想在导入之后在DL4J中进一步训练模型,请记住这一点,并使用model.save(...)来持久化模型。

载加Keras模型

将完整模型加载回DL4J(假设它在类路径上):

  1. String fullModel = new ClassPathResource("full_model.h5").getFile().getPath();
  2. ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel);

万一没有编译Keras模型,它就不会有一个训练配置。在这种情况下,需要显式地告诉模型导入忽略训练配置,方法是将enforceTrainingConfig标志设置为false,如下所示:

  1. ComputationGraph model = KerasModelImport.importKerasModelAndWeights(fullModel, false);

若要仅从JSON加载模型配置,请按如下使用KerasModelImport

  1. String modelJson = new ClassPathResource("model_config.json").getFile().getPath();
  2. ComputationGraphConfiguration modelConfig = KerasModelImport.importKerasModelConfiguration(modelJson)

如果另外还想加载模型权重与配置,那么需要做:

  1. String modelWeights = new ClassPathResource("model_weights.h5").getFile().getPath();
  2. MultiLayerNetwork network = KerasModelImport.importKerasModelAndWeights(modelJson, modelWeights)
在后面两种情况下,将不读取训练配置。

KerasModel

Github:KerasModel.java - 从Keras(函数API)模型或序列模型配置构建计算图

KerasModel(建议)

  1. public KerasModel(KerasModelBuilder modelBuilder)
  2. throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException
  1. // 函数API模型的构建器模式构造器
  2. 参数 modelBuilder 构建器对象
  3. 抛出 IOException IO 异常
  4. 抛出 InvalidKerasConfigurationException 无效的 Keras 配置
  5. 抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置

getComputationGraphConfiguration(不推荐)

  1. public ComputationGraphConfiguration getComputationGraphConfiguration()
  2. throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
  1. // 来自模型配置(JSON或YAML)、训练配置(JSON)、权重和“训练模式”布尔指示符的(函数 API)模型的构造器。当内置在训练模式时,某些不支持的配置(例如,未知的正则化器)将抛出异常。当强制TrainingConfig= false时,这些将生成警告,但将被忽略。
  2. 参数 modelJson 模型配置JSON 字符串
  3. 参数 modelYaml 模型配置 YAML 字符串
  4. 参数 enforceTrainingConfig 是否实施训练相关配置
  5. 抛出 IOException IO 异常
  6. 抛出 InvalidKerasConfigurationException 无效的 Keras 配置
  7. 抛出 UnsupportedKerasConfigurationException 不支持的 Keras 配置

getComputationGraph

  1. public ComputationGraph getComputationGraph()
  2. throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
  1. // 从这个Keras模型配置构建计算图并导入权重
  2. 返回 ComputationGraph

getComputationGraph

  1. public ComputationGraph getComputationGraph(boolean importWeights)
  2. throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
  1. // 从这个Keras模型配置构建计算图并(可选的)导入权重。
  2. 参数 importWeights 是否导入权重
  3. 返回 ComputationGraph

DL4J中文文档/Keras模型导入/函数模型的更多相关文章

  1. ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档]

    ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档] 简介 简单地说就是该有的都有了,但是总体跑起来效果还不好. 还在开发中,它工作的效果还不好.但是你可以直 ...

  2. 【Chromium中文文档】进程模型

    进程模型 转载请注明出处:https://ahangchen.gitbooks.io/chromium_doc_zh/content/zh//General_Architecture/Process_ ...

  3. Keras官方中文文档:Keras安装和配置指南(Windows)

    这里需要说明一下,笔者不建议在Windows环境下进行深度学习的研究,一方面是因为Windows所对应的框架搭建的依赖过多,社区设定不完全:另一方面,Linux系统下对显卡支持.内存释放以及存储空间调 ...

  4. Django 1.10中文文档-第一个应用Part2-模型和管理站点

    本教程继续Part1.我们将设置数据库,创建您的第一个模型,并快速介绍Django的自动生成的管理网站. 数据库设置 现在,编辑mysite/settings.py.它是一个用模块级别变量表示Djan ...

  5. Django 1.10中文文档-执行查询

    Django 1.10中文文档: https://github.com/jhao104/django-chinese-doc 只要创建好 数据模型, Django 会自动为生成一套数据库抽象的API, ...

  6. Django 1.10中文文档-第一个应用Part5-测试

    本教程上接教程Part4. 前面已经建立一个网页投票应用,现在将为它创建一些自动化测试. 自动化测试简介 什么是自动化测试 测试是检查你的代码是否正常运行的行为.测试也分为不同的级别.有些测试可能是用 ...

  7. Apache Spark 2.2.0 中文文档

    Apache Spark 2.2.0 中文文档 - 快速入门 | ApacheCN Geekhoo 关注 2017.09.20 13:55* 字数 2062 阅读 13评论 0喜欢 1 快速入门 使用 ...

  8. 一、neo4j中文文档-入门指南

    目录 neo4j中文文档-入门指南 Neo4j v4.4 neo4j **Cypher ** 开始使用 Neo4j 1. 安装 Neo4j 2. 文档 图数据库概念 1. 示例图 2.节点 3. 节点 ...

  9. Knockout中文开发指南(完整版API中文文档) 目录索引

    a, .tree li > span { padding: 4pt; border-radius: 4px; } .tree li a { color:#46cfb0; text-decorat ...

随机推荐

  1. Educational Codeforces Round 53 E. Segment Sum(数位DP)

    Educational Codeforces Round 53 E. Segment Sum 题意: 问[L,R]区间内有多少个数满足:其由不超过k种数字构成. 思路: 数位DP裸题,也比较好想.由于 ...

  2. Linux 上的Tomcat配置输入域名直接访问项目

    申请的域名备案通过了,域名是在阿里云上面的买的,一块钱,当初买服务器是买着来玩玩的. 既然申请的域名已经备案通过了,也配置了域名解析 ,服务器上也装了Tomcat,部署了web项目,下面来配置下通过域 ...

  3. Python关于File学习过程

    一.首先,认识下文件 文本文件和二进制文件的差异和区别 进行个总结: 计算机内的文件广义上来说,只有二进制文件 狭义上来讲分为两大类:二进制文件和文本文件. 先说数据的产生(即写操作) 文本文件的所有 ...

  4. mongodb 安装配置及简单使用

    步骤一: 下载网址:https://www.mongodb.com/download-center/community 根据自己的环境下载 步骤二: 安装过程只需要默认即可,需要注意的是连接工具“mo ...

  5. Centos7 - mysql 5.5.62 tar.gz 方式安装

    安装准备 Mariadb 去除 由于CentOS7自带的是 Mariadb, 所以先来删除他吧... 1. 查找版本 # rpm -qa|grep mariadb 执行命令后会出现类似 MariaDB ...

  6. manage.py migrate 报错

    第一个提示,setting里面的 STATICFILES_DIRS = (  os.path.join(BASE_DIR,'static')) 第二行的后面加','解决,这样可以被识别是tuple. ...

  7. delphi设置鼠标图形

    //Screen.Cursor := crHourGlass;//忙 //Screen.Cursor := crDefault;//不忙时

  8. java+web文件的上传和下载代码

    一般10M以下的文件上传通过设置Web.Config,再用VS自带的FileUpload控件就可以了,但是如果要上传100M甚至1G的文件就不能这样上传了.我这里分享一下我自己开发的一套大文件上传控件 ...

  9. 连接服务器VNC

    1,启动vnc vncserver 2,提示输入密码 3,Would you like to enter a view-only password (y/n)?  选择n 4,会生成一个端口号 5, ...

  10. Pytorch-拼接与拆分

    引言 本篇介绍tensor的拼接与拆分. 拼接与拆分 cat stack split chunk cat numpy中使用concat,在pytorch中使用更加简写的 cat 完成一个拼接 两个向量 ...