本文将会介绍如何利用Keras来实现模型的保存、读取以及加载。

  本文使用的模型为解决IRIS数据集的多分类问题而设计的深度神经网络(DNN)模型,模型的结构示意图如下:

具体的模型参数可以参考文章:Keras入门(一)搭建深度神经网络(DNN)解决多分类问题

模型保存

  Keras使用HDF5文件系统来保存模型。模型保存的方法很容易,只需要使用save()方法即可。

  以Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的DNN模型为例,整个模型的变量为model,我们设置模型共训练10次,在原先的代码中加入Python代码即可保存模型:

  1. # save model
  2. print("Saving model to disk \n")
  3. mp = "E://logs/iris_model.h5"
  4. model.save(mp)

保存的模型文件(iris_model.h5)如下:

模型读取

  保存后的iris_model.h5以HDF5文件系统的形式储存,在我们使用Python读取h5文件里面的数据之前,我们先用HDF5的可视化工具HDFView来查看里面的数据:

  我们感兴趣的是这个模型中的各个神经层之间的连接权重及偏重,也就是上图中的红色部分,model_weights里面包含了各个神经层之间的连接权重及偏重,分别位于dense_1,dense_2,dense_3中。蓝色部分为dense_3/dense_3/kernel:0的数据,即最后输出层的连接权重矩阵。

  有了对模型参数的直观认识,我们要做的下一步工作就是读取各个神经层之间的连接权重及偏重。我们使用Python的h5py这个模块来这个iris_model.h5这个文件。关于h5py的快速入门指南,可以参考文章:h5py快速入门指南

  使用以下Python代码可以读取各个神经层之间的连接权重及偏重数据:

  1. import h5py
  2. # 模型地址
  3. MODEL_PATH = 'E://logs/iris_model.h5'
  4. # 获取每一层的连接权重及偏重
  5. print("读取模型中...")
  6. with h5py.File(MODEL_PATH, 'r') as f:
  7. dense_1 = f['/model_weights/dense_1/dense_1']
  8. dense_1_bias = dense_1['bias:0'][:]
  9. dense_1_kernel = dense_1['kernel:0'][:]
  10. dense_2 = f['/model_weights/dense_2/dense_2']
  11. dense_2_bias = dense_2['bias:0'][:]
  12. dense_2_kernel = dense_2['kernel:0'][:]
  13. dense_3 = f['/model_weights/dense_3/dense_3']
  14. dense_3_bias = dense_3['bias:0'][:]
  15. dense_3_kernel = dense_3['kernel:0'][:]
  16. print("第一层的连接权重矩阵:\n%s\n"%dense_1_kernel)
  17. print("第一层的连接偏重矩阵:\n%s\n"%dense_1_bias)
  18. print("第二层的连接权重矩阵:\n%s\n"%dense_2_kernel)
  19. print("第二层的连接偏重矩阵:\n%s\n"%dense_2_bias)
  20. print("第三层的连接权重矩阵:\n%s\n"%dense_3_kernel)
  21. print("第三层的连接偏重矩阵:\n%s\n"%dense_3_bias)

输出的结果如下:

  1. 读取模型中...
  2. 第一层的连接权重矩阵:
  3. [[ 0.04141677 0.03080632 -0.02768146 0.14334357 0.06242227]
  4. [-0.41209617 -0.77948487 0.5648218 -0.699587 -0.19246106]
  5. [ 0.6856315 0.28241938 -0.91930366 -0.07989818 0.47165248]
  6. [ 0.8655262 0.72175753 0.36529952 -0.53172135 0.26573092]]
  7. 第一层的连接偏重矩阵:
  8. [-0.16441862 -0.02462054 -0.14060321 0. -0.14293939]
  9. 第二层的连接权重矩阵:
  10. [[ 0.39296603 0.01864707 0.12538083 0.07935872 0.27940807 -0.4565802 ]
  11. [-0.34312084 0.6446907 -0.92546445 -0.00538039 0.95466876 -0.32819661]
  12. [-0.7593299 -0.07227057 0.20751365 0.40547106 0.35726753 0.8884158 ]
  13. [-0.48096 0.11294878 -0.29462305 -0.410536 -0.23620337 -0.72703975]
  14. [ 0.7666149 -0.41720924 0.29576775 -0.6328017 0.43118536 0.6589351 ]]
  15. 第二层的连接偏重矩阵:
  16. [-0.1899569 0. -0.09710662 -0.12964155 -0.26443157 0.6050924 ]
  17. 第三层的连接权重矩阵:
  18. [[-0.44450542 0.09977101 0.12196152]
  19. [ 0.14334357 0.18546402 -0.23861367]
  20. [-0.7284191 0.7859063 -0.878823 ]
  21. [ 0.0876545 0.51531947 0.09671918]
  22. [-0.7964963 -0.16435687 0.49531657]
  23. [ 0.8645698 0.4439873 0.24599855]]
  24. 第三层的连接偏重矩阵:
  25. [ 0.39192322 -0.1266532 -0.29631865]

值得注意的是,我们得到的这些矩阵的数据类型都是numpy.ndarray。

  OK,既然我们已经得到了各个神经层之间的连接权重及偏重的数据,那我们能做什么呢?当然是去做一些有趣的事啦,那就是用我们自己的方法来实现新数据的预测向量(softmax函数作用后的向量)。so, really?

  新的输入向量为[6.1, 3.1, 5.1, 1.1],使用以下Python代码即可输出新数据的预测向量:

  1. import h5py
  2. import numpy as np
  3. # 模型地址
  4. MODEL_PATH = 'E://logs/iris_model.h5'
  5. # 获取每一层的连接权重及偏重
  6. print("读取模型中...")
  7. with h5py.File(MODEL_PATH, 'r') as f:
  8. dense_1 = f['/model_weights/dense_1/dense_1']
  9. dense_1_bias = dense_1['bias:0'][:]
  10. dense_1_kernel = dense_1['kernel:0'][:]
  11. dense_2 = f['/model_weights/dense_2/dense_2']
  12. dense_2_bias = dense_2['bias:0'][:]
  13. dense_2_kernel = dense_2['kernel:0'][:]
  14. dense_3 = f['/model_weights/dense_3/dense_3']
  15. dense_3_bias = dense_3['bias:0'][:]
  16. dense_3_kernel = dense_3['kernel:0'][:]
  17. # 模拟每个神经层的计算,得到该层的输出
  18. def layer_output(input, kernel, bias):
  19. return np.dot(input, kernel) + bias
  20. # 实现ReLU函数
  21. relu = np.vectorize(lambda x: x if x >=0 else 0)
  22. # 实现softmax函数
  23. def softmax_func(arr):
  24. exp_arr = np.exp(arr)
  25. arr_sum = np.sum(exp_arr)
  26. softmax_arr = exp_arr/arr_sum
  27. return softmax_arr
  28. # 输入向量
  29. unkown = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
  30. # 第一层的输出
  31. print("模型计算中...")
  32. output_1 = layer_output(unkown, dense_1_kernel, dense_1_bias)
  33. output_1 = relu(output_1)
  34. # 第二层的输出
  35. output_2 = layer_output(output_1, dense_2_kernel, dense_2_bias)
  36. output_2 = relu(output_2)
  37. # 第三层的输出
  38. output_3 = layer_output(output_2, dense_3_kernel, dense_3_bias)
  39. output_3 = softmax_func(output_3)
  40. # 最终的输出的softmax值
  41. np.set_printoptions(precision=4)
  42. print("最终的预测值向量为: %s"%output_3)

其输出的结果如下:

  1. 读取模型中...
  2. 模型计算中...
  3. 最终的预测值向量为: [[0.0242 0.6763 0.2995]]

  额,这个输出的预测值向量会是我们的DNN模型的预测值向量吗?这时候,我们就需要回过头来看看Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的代码了,注意,为了保证数值的可比较性,笔者已经将DNN模型的训练次数改为10次了。让我们来看看原来代码的输出结果吧:

  1. Using model to predict species for features:
  2. [[6.1 3.1 5.1 1.1]]
  3. Predicted softmax vector is:
  4. [[0.0242 0.6763 0.2995]]
  5. Predicted species is:
  6. Iris-versicolor

Yes,两者的预测值向量完全一致!因此,我们用自己的方法也实现了这个DNN模型的预测功能,棒!

模型加载

  当然,在实际的使用中,我们不需要再用自己的方法来实现模型的预测功能,只需使用Keras给我们提供好的模型导入功能(keras.models.load_model())即可。使用以下Python代码即可加载模型

  1. # 模型的加载及使用
  2. from keras.models import load_model
  3. print("Using loaded model to predict...")
  4. load_model = load_model("E://logs/iris_model.h5")
  5. np.set_printoptions(precision=4)
  6. unknown = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
  7. predicted = load_model.predict(unknown)
  8. print("Using model to predict species for features: ")
  9. print(unknown)
  10. print("\nPredicted softmax vector is: ")
  11. print(predicted)
  12. species_dict = {v: k for k, v in Class_dict.items()}
  13. print("\nPredicted species is: ")
  14. print(species_dict[np.argmax(predicted)])

输出结果如下:

  1. Using loaded model to predict...
  2. Using model to predict species for features:
  3. [[6.1 3.1 5.1 1.1]]
  4. Predicted softmax vector is:
  5. [[0.0242 0.6763 0.2995]]
  6. Predicted species is:
  7. Iris-versicolor

总结

  本文主要介绍如何利用Keras来实现模型的保存、读取以及加载。

  本文将不再给出完整的Python代码,如需完整的代码,请参考Github地址:https://github.com/percent4/Keras_4_multiclass.

注意:本人现已开通微信公众号: Python爬虫与算法(微信号为:easy_web_scrape), 欢迎大家关注哦~~

Keras入门(二)模型的保存、读取及加载的更多相关文章

  1. keras模型的保存与重新加载

    # 模型保存JSON文件 model_json = model.to_json() with open('model.json', 'w') as file: file.write(model_jso ...

  2. pyspider 示例二 升级完整版绕过懒加载,直接读取图片

    pyspider 示例二 升级完整版绕过懒加载,直接读取图片,见[升级写法处] #!/usr/bin/env python # -*- encoding: utf-8 -*- # Created on ...

  3. esri-leaflet入门教程(5)- 动态要素加载

    esri-leaflet入门教程(5)- 动态要素加载 by 李远祥 在上一章节中已经说明了esr-leaflet是如何加载ArcGIS Server提供的各种服务,这些都是服务本身来决定的,API脚 ...

  4. Linux内核启动代码分析二之开发板相关驱动程序加载分析

    Linux内核启动代码分析二之开发板相关驱动程序加载分析 1 从linux开始启动的函数start_kernel开始分析,该函数位于linux-2.6.22/init/main.c  start_ke ...

  5. DB数据源之SpringBoot+MyBatis踏坑过程(二)手工配置数据源与加载Mapper.xml扫描

    DB数据源之SpringBoot+MyBatis踏坑过程(二)手工配置数据源与加载Mapper.xml扫描 liuyuhang原创,未经允许进制转载  吐槽之后应该有所改了,该方式可以作为一种过渡方式 ...

  6. spark SQL (四)数据源 Data Source----Parquet 文件的读取与加载

    spark SQL Parquet 文件的读取与加载 是由许多其他数据处理系统支持的柱状格式.Spark SQL支持阅读和编写自动保留原始数据模式的Parquet文件.在编写Parquet文件时,出于 ...

  7. 基于 Koa平台Node.js开发的KoaHub.js的控制器,模型,帮助方法自动加载

    koahub-loader koahub-loader是基于 Koa平台Node.js开发的KoaHub.js的koahub-loader控制器,模型,帮助方法自动加载 koahub loader I ...

  8. Unity3d-WWW实现图片资源显示以及保存和本地加载

    本文固定连接:http://blog.csdn.net/u013108312/article/details/52712844 WWW实现图片资源显示以及保存和本地加载 using UnityEngi ...

  9. tensorflow 模型保存后的加载路径问题

    import tensorflow as tf #保存模型 saver = tf.train.Saver() saver.save(sess, "e://code//python//test ...

随机推荐

  1. pycharm License server激活

    2018-11-15 pycharm License server激活有效:https://idea.ouyanglol.com/

  2. Appium日志乱码终结指北

    缘起 最近Android,IOS自动化多开群控都搞好了,但是Appium中的log 显示中文乱码问题像个苍蝇一样,看着感觉特别难受,挥之不去,抚之不平.论坛搜索了一下,很多帖子都反映过这个问题,但是都 ...

  3. 【笔记】css基于box的一行时垂直方向居中,多行平均居中,多出部分还省略号代替

    题目很长,其实他就是这样的: 看标题,一行的时候是这样的,在行中间 标题文字多的时候是这样的,变成2行,超出部分用省略号: 但是为了更好的兼容性,没有使用flex,使用的是box布局. 核心代码就是这 ...

  4. Educational Codeforces Round 25

    A 题意:给你一个01的字符串,0是个分界点,0把这个字符串分成(0的个数+1)个部分,分别求出这几部分1的个数.例如110011101 输出2031,100输出100,1001输出101 代码: # ...

  5. VSCode插件开发全攻略(一)概览

    文章索引 VSCode插件开发全攻略(一)概览 VSCode插件开发全攻略(二)HelloWord VSCode插件开发全攻略(三)package.json详解 VSCode插件开发全攻略(四)命令. ...

  6. Android 视频播放器 (三):使用NBPlayer播放直播视频

    一.前言 在 Android 音视频开发学习思路 中,我们不断的学习和了解音视频相关的知识,随着知识点不断的学习,我们现在应该做的事情,就是将知识点不断的串联起来.这样才能得到更深层次的领悟.通过整理 ...

  7. Kali学习笔记14:SMB扫描、SMTP扫描

    SMB(Server Message Block)协议,服务消息块协议. 最开始是用于微软的一种消息传输协议,因为颇受欢迎,现在已经成为跨平台的一种消息传输协议. 同时也是微软历史上出现安全问题最多的 ...

  8. Kali学习笔记5:被动信息收集工具集

    1.Shodan SHODAN搜索引擎不像百度谷歌等,它们爬取的是网页,而SHODAN搜索的是设备. 物联网使用过程中,通常容易出现安全问题,如果被黑客用SHODAN搜索到,后果不堪设想. 网站:ht ...

  9. 浅谈static关键字的四种用法

    1.修饰成员变量 在一个person类中,一个成员变量例如 String name,当new2个person()对象时候,这2个对象在堆的位置是不同的,给name赋值张三.李四,这两个对象的name是 ...

  10. vim常用命令行备忘总结

    一 窗口切换 1 :sp    水平切换当前窗口 2 :vsp 垂直切换当前窗口 3 :clo 关闭活动窗口 4 : on 只保留活动窗口 5 : ctrl + w  在窗口间循环切换  ctrl + ...