tensorflow 优化图
当我们把训练好的tensorflow训练图拿来进行预测时,会有多个训练时生成的节点,这些节点是不必要的,我们需要在预测的时候进行删除。
下面以bert的图为例,进行优化
def optimize_graph(self, checkpoint_file, model_config):
import json
tf = self.import_tf()
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) init_checkpoint = checkpoint_file with tf.gfile.GFile(model_config, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f)) input_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_type_ids') import contextlib
jit_scope = contextlib.suppress with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=False) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) # get output tensor
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
reader = tf.train.NewCheckpointReader(init_checkpoint)
output_weights = reader.get_tensor('output_weights')
output_bias = reader.get_tensor('output_bias')
output_layers = model.get_pooled_output()
pooled = tf.nn.softmax(tf.nn.bias_add(tf.matmul(output_layers, output_weights, transpose_b=True),
output_bias))
pooled = tf.identity(pooled, 'final_encodes') output_tensors = [pooled]
tmp_g = tf.get_default_graph().as_graph_def() # write graph to file
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
dtypes = [n.dtype for n in input_tensors]
tmp_g = optimize_for_inference(
tmp_g,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False) import tempfile
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=r'optimize').name
with tf.gfile.GFile(tmp_file, 'wb') as f:
f.write(tmp_g.SerializeToString()) return tmp_file
返回一个gfile类型的文件,我们可以像原来导入模型文件时,恢复图,不过这个图是优化过的。
tensorflow 优化图的更多相关文章
- TensorFlow的图切割模块——Graph Partitioner
背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 在经过TensorFlow的Placer策略模块调整之后,下一步就是根据Pla ...
- 现代英特尔® 架构上的 TensorFlow* 优化——正如去年参加Intel AI会议一样,Intel自己提供了对接自己AI CPU优化版本的Tensorflow,下载链接见后,同时可以基于谷歌官方的tf版本直接编译生成安装包
现代英特尔® 架构上的 TensorFlow* 优化 转自:https://software.intel.com/zh-cn/articles/tensorflow-optimizations-on- ...
- TensorFlow从0到1之TensorFlow优化器(13)
高中数学学过,函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系数.本节将介绍如何使 ...
- TensorFlow优化器及用法
TensorFlow优化器及用法 函数在一阶导数为零的地方达到其最大值和最小值.梯度下降算法基于相同的原理,即调整系数(权重和偏置)使损失函数的梯度下降. 在回归中,使用梯度下降来优化损失函数并获得系 ...
- TensorFlow优化器浅析
本文基于tensorflow-v1.15分支,简单分析下TensorFlow中的优化器. optimizer = tf.train.GradientDescentOptimizer(learning_ ...
- tensorflow优化器-【老鱼学tensorflow】
tensorflow中的优化器主要是各种求解方程的方法,我们知道求解非线性方程有各种方法,比如二分法.牛顿法.割线法等,类似的,tensorflow中的优化器也只是在求解方程时的各种方法. 比较常用的 ...
- tensorflow:图(Graph)的核心数据结构与通用函数(Utility function)
Tensorflow一些常用基本概念与函数(2) 1. 图(Graph)的核心数据结构 tf.Graph.__init__:建立一个空图: tf.Graph.as_default():一个将某图设置为 ...
- Tensorflow 优化学习
# coding: utf-8 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data pr ...
- tensorflow 框架图
随机推荐
- uwsgi_read_timeout超时处理
最近发现一服务器一个奇怪的现象: Django的视图函数在浏览器一个请求的情况下,竟然做了两个请求的函数处理.不可思议,找了几天也不知道为什么, 只发现只要用uwsgi_read_timeout之后, ...
- 堆操作,malloc
PS:堆空间缺省值都是cd,栈空间缺省值都是cc 内存有四区:栈.全局(静态).常量.除此以外的空间暂时不能随意使用,但是通过malloc函数申请就可以使用了. 利用malloc申请一个int变量,注 ...
- 转载:Java项目读取配置文件时,FileNotFoundException 系统找不到指定的文件,System.getProperty("user.dir")的理解
唉,读取个文件,也就是在项目里面去获得配置文件的目录,然后,变成文件,有事没事,总是出个 FileNotFoundException 系统找不到指定的文件,气死人啦. 还有就是:System.get ...
- jQuery插件初级练习3
<!DOCTYPE html><html> <head> <meta charset="UTF-8"> <title>& ...
- Android-Java-死锁
死锁:程序不往下执行了,程序又没有结束,就一直卡在哪里: 在使用synchronized的时候要避免死锁,synchronized嵌套就可能会引发死锁,需要严格的检查代码,排除死锁发生的可能: 特意演 ...
- ASP.NET自定义错误页并返回正确的500、404状态码
在项目中,我们常常需要自定义错误页面,但往往返回的状态码都变成了200,对SEO很不友好.我尝试过在百度上寻找解决方案,但找到的资料中说的方法都试过了,发现都是无法返回正确的状态码的. 最后,只好自已 ...
- java多线程面试题整理及答案(2018年)
1) 什么是线程? 线程是操作系统能够进行运算调度的最小单位,它被包含在进程之中,是进程中的实际运作单位.程序员可以通过它进行多处理器编程,你可以使用多线程对 运算密集型任务提速.比如,如果一个线程完 ...
- js中两种for循环的使用
针对两种for循环的使用 1. for in循环的使用环境 可用在字符串.数组.对象中, 需注意:其中遍历对象得到的是每个key 的value值 2. for 变量递加的方式 ...
- PICE(2):JDBCStreaming - gRPC-JDBC Service
在一个akka-cluster环境里,从数据调用的角度上,JDBC数据库与集群中其它节点是脱离的.这是因为JDBC数据库不是分布式的,不具备节点位置透明化特性.所以,JDBC数据库服务器必须通过服务方 ...
- Swift 里 Set(二)概览
类图  Set 是一个结构体,持有另一个结构体_Variant. 最终所有的元素存储在一个叫做__RawSetStorage的类里. 内存布局  结构体分配在栈上,和__RawSetStorage ...