1. 准备预训练好的模型

  • TensorFlow 预训练好的模型被保存为以下四个文件

  • data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如下所示,为简单起见只保留了一个模型。
  1. model_checkpoint_path: "/home/senius/python/c_python/test/model-40"
  2. all_model_checkpoint_paths: "/home/senius/python/c_python/test/model-40"

2. 导入模型图、参数值和相关变量

  1. import tensorflow as tf
  2. import numpy as np
  3. sess = tf.Session()
  4. X = None # input
  5. yhat = None # output
  6. def load_model():
  7. """
  8. Loading the pre-trained model and parameters.
  9. """
  10. global X, yhat
  11. modelpath = r'/home/senius/python/c_python/test/'
  12. saver = tf.train.import_meta_graph(modelpath + 'model-40.meta')
  13. saver.restore(sess, tf.train.latest_checkpoint(modelpath))
  14. graph = tf.get_default_graph()
  15. X = graph.get_tensor_by_name("X:0")
  16. yhat = graph.get_tensor_by_name("tanh:0")
  17. print('Successfully load the pre-trained model!')
  • 通过 saver.restore 我们可以得到预训练的所有参数值,然后再通过 graph.get_tensor_by_name 得到模型的输入张量和我们想要的输出张量。

3. 运行前向传播过程得到预测值

  1. def predict(txtdata):
  2. """
  3. Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3).
  4. Test a single example.
  5. Arg:
  6. txtdata: Array in C.
  7. Returns:
  8. Three coordinates of a face normal.
  9. """
  10. global X, yhat
  11. data = np.array(txtdata)
  12. data = data.reshape(-1, 41, 41, 41, 3)
  13. output = sess.run(yhat, feed_dict={X: data}) # (-1, 3)
  14. output = output.reshape(-1, 1)
  15. ret = output.tolist()
  16. return ret
  • 通过 feed_dict 喂入测试数据,然后 run 输出的张量我们就可以得到预测值。

4. 测试

  1. load_model()
  2. testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32)
  3. testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3)
  4. testdata = testdata[0:2, ...] # the first two examples
  5. txtdata = testdata.tolist()
  6. output = predict(txtdata)
  7. print(output)
  8. # [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523],
  9. # [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]]
  • 本例输入是一个三维网格模型处理后的 [41, 41, 41, 3] 的数据,输出一个表面法向量坐标 (x, y, z)。

获取更多精彩,请关注「seniusen」!

TensorFlow 调用预训练好的模型—— Python 实现的更多相关文章

  1. tensorflow 使用预训练好的模型的一部分参数

    vars = tf.global_variables() net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name ...

  2. 学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)

    视觉问题引入深度神经网络后,针对端对端的训练和预测网络,可以看是特征的表达和任务的决策问题(分类,回归等).当我们自己的训练数据量过小时,往往借助牛人已经预训练好的网络进行特征的提取,然后在后面加上自 ...

  3. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...

  4. 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现

    现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...

  5. TensorFlow 同时调用多个预训练好的模型

    在某些任务中,我们需要针对不同的情况训练多个不同的神经网络模型,这时候,在测试阶段,我们就需要调用多个预训练好的模型分别来进行预测. 调用单个预训练好的模型请点击此处 弄明白了如何调用单个模型,其实调 ...

  6. 【猫狗数据集】使用预训练的resnet18模型

    数据集下载地址: 链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw提取码:2xq4 创建数据集:https://www.cnblogs.com/xi ...

  7. ubuntu16.04 使用tensorflow object detection训练自己的模型

    一.构建自己的数据集 1.格式必须为jpg.jpeg或png. 2.在models/research/object_detection文件夹下创建images文件夹,在images文件夹下创建trai ...

  8. 深度学习tensorflow实战笔记 用预训练好的VGG-16模型提取图像特征

    1.首先就要下载模型结构 首先要做的就是下载训练好的模型结构和预训练好的模型,结构地址是:点击打开链接 模型结构如下: 文件test_vgg16.py可以用于提取特征.其中vgg16.npy是需要单独 ...

  9. Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning

    转载自:https://blog.csdn.net/huachao1001/article/details/78501928 使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我 ...

随机推荐

  1. Mybatis自动生成的BO对象继承公共父类(BO中过滤掉公共属性)

    使用mybatis的代码生成工具:mybatis-generator,如果自动生成的BO都有公共的属性,则可以指定这些BO继承父类(父类中定义公共属性) 1.定义父类 注意:属性public,不要使用 ...

  2. Vue教程:简介(一)

    前言 用了这么久的vue了,但是一直没有时间写个系列文章,现在抽一定时间总结下vue的知识点. 首先,Vue 不支持 IE8 及以下版本,因为 Vue 使用了 IE8 无法模拟的 ECMAScript ...

  3. bitmap和drawable的相互转化以及setImageResource(),setImageDrawable(),setImageBitmap()

    从本地获取drawable图片:getResources().getDrawable(R.drawable.**) 获取bitmap:Bitmap b=BitmapFactory().decodeRe ...

  4. excel导入到java/导出到excel

    package com.test.order.config; import com.test.order.domain.HavalDO; import org.apache.poi.ss.usermo ...

  5. 【学时总结&模板时间】◆学时·10 & 模板·3◆ AC自动机

    ◇学时·10 & 模板·3◇ AC自动机 跟着高中上课……讲AC自动机的扩展运用.然而连KMP.trie字典树都不怎么会用的我一脸懵逼<(_ _)> 花一上午自学了一下AC自动机 ...

  6. 使用第三方工具连接docker数据库

    一.背景 ​ 为了把测试环境迁移至docker上,我在centos7上安装了docker,具体安装方法可参考<CentOS7下安装docker>本文不再论述.有些同学可能会有疑问,为什么要 ...

  7. 关于python的GIL

    转自依云在知乎上的回答,链接为https://www.zhihu.com/question/27245271/answer/462975593 侵删. python的多线程,其实不是真的多线程,它会通 ...

  8. React中的全选反选问题

    全选反选问题 1.在state里维护一个数组,例如showArr:[] 2.绑定点击事件的时候将当前这个当选按钮的index加进来 <span className='arrow' onClick ...

  9. Mysqldump自定义导出n条记录

    很多时候DBA需要导出部分记录至开发.测试环境,因数据量需求较小,如果原库的记录多,且表数量也多,在用mysqldump命令导出时可以添加一个where参数,自定义导出n条记录,而不必全量导出. 示例 ...

  10. CacheManager源码分析

    计算rdd的某个分区是从RDD的iterator()方法开始的,我们从这个方法进入 然后我们进入getOrCompute()方法中看看是如何进行读取数据或计算的 getOrElseUpdate()方方 ...