当我们把训练好的tensorflow训练图拿来进行预测时,会有多个训练时生成的节点,这些节点是不必要的,我们需要在预测的时候进行删除。

下面以bert的图为例,进行优化

    def optimize_graph(self, checkpoint_file, model_config):
import json
tf = self.import_tf()
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) init_checkpoint = checkpoint_file with tf.gfile.GFile(model_config, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f)) input_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_type_ids') import contextlib
jit_scope = contextlib.suppress with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=False) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) # get output tensor
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
reader = tf.train.NewCheckpointReader(init_checkpoint)
output_weights = reader.get_tensor('output_weights')
output_bias = reader.get_tensor('output_bias')
output_layers = model.get_pooled_output()
pooled = tf.nn.softmax(tf.nn.bias_add(tf.matmul(output_layers, output_weights, transpose_b=True),
output_bias))
pooled = tf.identity(pooled, 'final_encodes') output_tensors = [pooled]
tmp_g = tf.get_default_graph().as_graph_def() # write graph to file
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
dtypes = [n.dtype for n in input_tensors]
tmp_g = optimize_for_inference(
tmp_g,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False) import tempfile
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=r'optimize').name
with tf.gfile.GFile(tmp_file, 'wb') as f:
f.write(tmp_g.SerializeToString()) return tmp_file

返回一个gfile类型的文件,我们可以像原来导入模型文件时,恢复图,不过这个图是优化过的。

tensorflow 优化图的更多相关文章

  1. TensorFlow的图切割模块——Graph Partitioner

    背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 在经过TensorFlow的Placer策略模块调整之后,下一步就是根据Pla ...

  2. 现代英特尔® 架构上的 TensorFlow* 优化——正如去年参加Intel AI会议一样,Intel自己提供了对接自己AI CPU优化版本的Tensorflow,下载链接见后,同时可以基于谷歌官方的tf版本直接编译生成安装包

    现代英特尔® 架构上的 TensorFlow* 优化 转自:https://software.intel.com/zh-cn/articles/tensorflow-optimizations-on- ...

  3. TensorFlow从0到1之TensorFlow优化器(13)

    高中数学学过,函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系数.本节将介绍如何使 ...

  4. TensorFlow优化器及用法

    TensorFlow优化器及用法 函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系 ...

  5. TensorFlow优化器浅析

    本文基于tensorflow-v1.15分支,简单分析下TensorFlow中的优化器. optimizer = tf.train.GradientDescentOptimizer(learning_ ...

  6. tensorflow优化器-【老鱼学tensorflow】

    tensorflow中的优化器主要是各种求解方程的方法,我们知道求解非线性方程有各种方法,比如二分法.牛顿法.割线法等,类似的,tensorflow中的优化器也只是在求解方程时的各种方法. 比较常用的 ...

  7. tensorflow:图(Graph)的核心数据结构与通用函数(Utility function)

    Tensorflow一些常用基本概念与函数(2) 1. 图(Graph)的核心数据结构 tf.Graph.__init__:建立一个空图: tf.Graph.as_default():一个将某图设置为 ...

  8. Tensorflow 优化学习

    # coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data pr ...

  9. tensorflow 框架图

随机推荐

  1. weat!!团队

    摘要: 团队名称:weat!! 团队成员:刘波 崔和杰 简介: 刘波:性别男,爱好:动漫,徒步旅行.在组内负责程序编写这一部分. 优点:认真负责,不懂就会去问. 崔和杰:性别男,爱好:篮球.在组内负责 ...

  2. Android Studio自定义组合控件

    在Android的开发中,为了能够服用代码,会把有一定共有特点的控件组合在一起定义成一个自定义组合控件. 本文就详细讲述这一过程.虽然这样的View的组合有一个粒度的问题.粒度太大了无法复用,粒度太小 ...

  3. mac终端的命令都失效的解决方法

    step1. 在terminal里面输入: export PATH="/usr/bin:/bin:/usr/sbin:/sbin:/usr/local/bin:/usr/X11/bin&qu ...

  4. java基础-day27

    第04天 java基础加强 今日内容介绍 u Xml的综合案例 u 注解 u 类的加载 u 动态代理 第1章   注解 1.1  注解概述 l  什么是注解:Annotation注解,是一种代码级别的 ...

  5. kafka参数

    转载地址http://debugo.com/kafka-params/ ############################# System ########################### ...

  6. C++ 中的异常机制分析

    C++异常机制概述 异常处理是C++的一项语言机制,用于在程序中处理异常事件.异常事件在C++中表示为异常对象.异常事件发生时,程序使用throw关键字抛出异常表达式,抛出点称为异常出现点,由操作系统 ...

  7. openresty + lua 4、openresty kafka

    kafka 官网: https://kafka.apache.org/quickstart zookeeper 官网:https://zookeeper.apache.org/ kafka 运行需要 ...

  8. JSON知识介绍

    JSON资料整理   目录 1.什么是json 2.json语法规则 3.json基础结构 4.json基础示例 5.JSON和XML比较 6. .NET操作JSON 原始方式 通用方式 内置方式 契 ...

  9. cxgrid动态多表头

    function TForm15.CreateBand(View: TcxGridDBBandedTableView;  BandCaption, ParentBandCaption: String) ...

  10. Android-AndroidStudio加载工程方式-gradle文件夹

    例如:在其他地方,其他工作人员哪里的OpenGateDemo工程是OK的, 然后Copy到李四的电脑上运行是报错,其实所有的错误都和gradle有关: 第一步,李四电脑运行OpenGateDemo工程 ...