在某些任务中,我们需要针对不同的情况训练多个不同的神经网络模型,这时候,在测试阶段,我们就需要调用多个预训练好的模型分别来进行预测。

弄明白了如何调用单个模型,其实调用多个模型也就顺理成章。我们只需要建立多个图,然后每个图导入一个模型,再针对每个图创建一个会话,分别进行预测即可。

  1. import tensorflow as tf
  2. import numpy as np
  3. # 建立两个 graph
  4. g1 = tf.Graph()
  5. g2 = tf.Graph()
  6. # 为每个 graph 建创建一个 session
  7. sess1 = tf.Session(graph=g1)
  8. sess2 = tf.Session(graph=g2)
  9. X_1 = None
  10. tst_1 = None
  11. yhat_1 = None
  12. X_2 = None
  13. tst_2 = None
  14. yhat_2 = None
  15. def load_model(sess):
  16. """
  17. Loading the pre-trained model and parameters.
  18. """
  19. global X_1, tst_1, yhat_1
  20. with sess1.as_default():
  21. with sess1.graph.as_default():
  22. modelpath = r'F:/resnet/model/new0.25-0.35/'
  23. saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
  24. saver.restore(sess1, tf.train.latest_checkpoint(modelpath))
  25. graph = tf.get_default_graph()
  26. X_1 = graph.get_tensor_by_name("X:0")
  27. tst_1 = graph.get_tensor_by_name("tst:0")
  28. yhat_1 = graph.get_tensor_by_name("tanh:0")
  29. print('Successfully load the model_1!')
  30. def load_model_2():
  31. """
  32. Loading the pre-trained model and parameters.
  33. """
  34. global X_2, tst_2, yhat_2
  35. with sess2.as_default():
  36. with sess2.graph.as_default():
  37. modelpath = r'F:/resnet/model/new0.25-0.352/'
  38. saver = tf.train.import_meta_graph(modelpath + 'model-10.meta')
  39. saver.restore(sess2, tf.train.latest_checkpoint(modelpath))
  40. graph = tf.get_default_graph()
  41. X_2 = graph.get_tensor_by_name("X:0")
  42. tst_2 = graph.get_tensor_by_name("tst:0")
  43. yhat_2 = graph.get_tensor_by_name("tanh:0")
  44. print('Successfully load the model_2!')
  45. def test_1(txtdata):
  46. """
  47. Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
  48. Test a single axample.
  49. Arg:
  50. txtdata: Array in C.
  51. Returns:
  52. The normal of a face.
  53. """
  54. global X_1, tst_1, yhat_1
  55. data = np.array(txtdata)
  56. data = data.reshape(-1, 41, 41, 41, 3)
  57. output = sess1.run(yhat_1, feed_dict={X_1: data, tst_1: True}) # (100, 3)
  58. output = output.reshape(-1, 1)
  59. ret = output.tolist()
  60. return ret
  61. def test_2(txtdata):
  62. """
  63. Convert data to Numpy array which has a shape of (-1, 41, 41, 41, 3).
  64. Test a single axample.
  65. Arg:
  66. txtdata: Array in C.
  67. Returns:
  68. The normal of a face.
  69. """
  70. global X_2, tst_2, yhat_2
  71. data = np.array(txtdata)
  72. data = data.reshape(-1, 41, 41, 41, 3)
  73. output = sess2.run(yhat_2, feed_dict={X_2: data, tst_2: True}) # (100, 3)
  74. output = output.reshape(-1, 1)
  75. ret = output.tolist()
  76. return ret

最后,本程序只是为了说明问题,抛砖引玉,代码有很多冗余之处,不要模仿!

获取更多精彩,请关注「seniusen」!

TensorFlow 同时调用多个预训练好的模型的更多相关文章

  1. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...

  2. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...

  3. TensorFlow 调用预训练好的模型—— Python 实现

    1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...

  4. 【猫狗数据集】使用预训练的resnet18模型

    数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...

  5. tensorflow如何正确加载预训练词向量

    使用预训练词向量和随机初始化词向量的差异还是挺大的,现在说一说我使用预训练词向量的流程. 一.构建本语料的词汇表,作为我的基础词汇 二.遍历该词汇表,从预训练词向量中提取出该词对应的词向量 三.初始化 ...

  6. tensorflow 使用预训练好的模型的一部分参数

    vars = tf.global_variables() net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name ...

  7. 深度学习tensorflow实战笔记 用预训练好的VGG-16模型提取图像特征

    1.首先就要下载模型结构 首先要做的就是下载训练好的模型结构和预训练好的模型,结构地址是:点击打开链接 模型结构如下: 文件test_vgg16.py可以用于提取特征.其中vgg16.npy是需要单独 ...

  8. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

  9. 【译】深度双向Transformer预训练【BERT第一作者分享】

    目录 NLP中的预训练 语境表示 语境表示相关研究 存在的问题 BERT的解决方案 任务一:Masked LM 任务二:预测下一句 BERT 输入表示 模型结构--Transformer编码器 Tra ...

随机推荐

  1. Android学习笔记_3_四种布局

    Android布局是应用界面开发的重要一环,在Android中,共有四种布局方式, 分别是:FrameLayout( 帧布局 ).LinearLayout (线性布局).TableLayout(表格布 ...

  2. MR中简单实现自定义的输入输出格式

    import java.io.DataOutput; import java.io.IOException; import java.util.HashMap; import java.util.Ma ...

  3. vue使用v-for循环,动态修改element-ui的el-switch

    在使用element-ui的el-switch中,因为要用v-for循环,一直没有成功,后来仔细查看文档,发现可以这样写 <el-switch v-for="(item, key) i ...

  4. 史上最简单的 SpringCloud 教程 | 第十四篇: 服务注册(consul)

    转载请标明出处: 原文首发于:https://www.fangzhipeng.com/springcloud/2017/07/12/sc14-consul/ 本文出自方志朋的博客 这篇文章主要介绍 s ...

  5. 使用unity3D生成项目(Easy Movie Texture)运行出现的问题

    运行后,首先报的错需要改  -fno-objc-arc 编译后出现的新的错.   需要将   CustomVideoPlayer.mm _lastFrameTimestamp = _curFrameT ...

  6. JS继续学习记录(一)

    JS继续学习记录(一) 总感觉自己的js code写的还算可以,但是又深知好像只知道一些皮毛,所以打算仔细记录一下js晋级学习过程,日日往复 先记录一下自己目前对js的了解吧(20180828) js ...

  7. Java秒杀系统方案优化 高性能高并发实战(已完成)

    1:商品列表 2:商品详情判断是否可以开始秒杀,未开始不显示秒杀按钮显示倒计时,开始显示秒杀按钮,同时会显示验证码输入框以及验证码图片,当点击秒杀按钮的时候会首先判断验证码是否正确,如果正确会返回一个 ...

  8. ETO的公开赛T5《猎杀蓝色空间号》题解

    这道题别看题面这么长,其实题意很简单 就是让你求从起点开始的最长合法区间 合法的要求有两个:兜圈子和直飞 且这两个条件相互独立 (也就是说兜圈子的末尾不会对下面可能出现的直飞造成影响) 举个例子: 1 ...

  9. 解决 Android sdk content loader 0%

    第一次遇到这种情况,真的很头痛,没办法 ,是问题就要解决,在网上找了一些方法,归纳了下来. 方法一(关闭后重启): 遇到Eclipse右下角一直显示“Android sdk content loade ...

  10. Python学习笔记:第2天while循环 运算符 格式化输出 编码

    目录 1. while循环 continue.break和else语句 2. 格式化输出 3. 运算符 3.1 算数运算 3.2 比较运算符 3.3 赋值运算符 3.4 逻辑运算符 3.5 成员运算符 ...