训练代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division import tensorflow as tf
import numpy as np
import argparse def dense_to_one_hot(input_data, class_num):
data_num = input_data.shape[0]
index_offset = np.arange(data_num) * class_num
labels_one_hot = np.zeros((data_num, class_num))
labels_one_hot.flat[index_offset + input_data.ravel()] = 1
return labels_one_hot def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_dir', type=str, required=True)
args = parser.parse_args()
return args p = build_parser()
origin = np.genfromtxt(p.data_path, delimiter=',') data = origin[:, 0:2]
labels = origin[:, 2] learning_rate = 0.001
training_epochs = 5000
display_step = 1 n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class]) W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
b = tf.Variable(tf.zeros([n_class]), name="b") scores = tf.nn.xw_plus_b(x, W, b, name='scores')
pred_proba = tf.nn.softmax(scores, name="pred_proba") cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) saver = tf.train.Saver()
tf.add_to_collection('pred_proba', pred_proba)
init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
if epoch % 100 == 0:
print(c)
builder = tf.saved_model.builder.SavedModelBuilder(p.model_dir)
inputs = {'input': tf.saved_model.utils.build_tensor_info(x)}
outputs = {'pred_proba': tf.saved_model.utils.build_tensor_info(pred_proba)}
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature': signature})
builder.save()

推理代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division import tensorflow as tf
import numpy as np
import argparse def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, required=True)
args = parser.parse_args()
return args p = build_parser() with tf.Session() as sess:
signature_key = 'test_signature'
input_key = 'input'
output_key = 'pred_proba' meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], p.model_dir)
signature = meta_graph_def.signature_def
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
r = sess.run(y, feed_dict={x: np.array([[0.6211, 5]])})
print(r)
print(0 if r[0][0] > r[0][1] else 1)

参考资料

TensorFlow saved_model 模块

tensorflow SavedModelBuilder用法的更多相关文章

  1. Tensorflow Summary用法

    本文转载自:https://www.cnblogs.com/lyc-seu/p/8647792.html Tensorflow Summary用法 tensorboard 作为一款可视化神器,是学习t ...

  2. 第一节,TensorFlow基本用法

    一 TensorFlow安装 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tsnsor(张量)意味着N维数组,Flow(流)意味着基 ...

  3. tensorflow add_to_collection用法

    训练代码: # coding: utf-8 from __future__ import print_function from __future__ import division import t ...

  4. tensorflow基本用法个人笔记

    综述   TensorFlow程序分为构建阶段和执行阶段.通过构建一个图.执行这个图来得到结果. 构建图   创建源op,源op不需要任何输入,例如常量constant,源op的输出被传递给其他op做 ...

  5. Tensorflow学习笔记——Summary用法

    tensorboard 作为一款可视化神器,可以说是学习tensorflow时模型训练以及参数可视化的法宝. 而在训练过程中,主要用到了tf.summary()的各类方法,能够保存训练过程以及参数分布 ...

  6. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

  7. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

  8. 芝麻HTTP:TensorFlow基础入门

    本篇内容基于 Python3 TensorFlow 1.4 版本. 本节内容 本节通过最简单的示例 -- 平面拟合来说明 TensorFlow 的基本用法. 构造数据 TensorFlow 的引入方式 ...

  9. tensorflow 学习日志

    Windows安装anaconda 和 TensorFlow anaconda : https://zhuanlan.zhihu.com/p/25198543        anaconda 使用与说 ...

随机推荐

  1. SqlServer2005 查询 第五讲 top

    今天我们来说sql命令中得参数top top top[ 最前面若干个记录,专属于SqlServer2005的语法,不可移植到其他库.oracle中是用rownum<6来实现输出前5行记录.] 下 ...

  2. windows 激活工具链接

    链接:https://pan.baidu.com/s/1FphGFZhhLp01akGTDWjW2A  密码:f9t7

  3. MySql——创建数据表,查询数据,排序查询数据

    参考资料:<Mysql必知必会> 创建数据表 在学习前首先创建数据表和插入数据.如何安装mysql可以看看上个博客https://www.cnblogs.com/lbhym/p/11675 ...

  4. 微信小程序(mpvue) wx.openSetting 无法调起设置页面

    在开发过程有个需要保存图片/视频到设备相册的业务,so easy~   巴啦啦撸下来了完整功能, wx.saveVideoToPhotosAlbum 会自动调起用户授权,美滋滋~~   btu.... ...

  5. 发送大数据时,PDU的问题?

    昨天发现通过 Ice发送请求传递一个大块数据时,当请求的体积大于1.2M后,直接抛出异常Connection Lost,对方peer或是断开了.通过防火墙配置排查,以及对同一网络同一机器的php服务p ...

  6. Java的Arrays类 基本用法

    初识Java的Arrays类 Arrays类包括很多用于操作数组的静态方法(例如排序和搜索),且静态方法可以通过类名Arrays直接调用.用之前需要导入Arrays类: import java.uti ...

  7. 小白学习python第一天,Pycharm破解与用法(持续更新)

    目录 Pycharm安装与破解及汉化 Pycharm安装 Pycharm破解 Pycharm汉化 Pycharm使用 添加作者.时间等信息 补充 @ Pycharm安装与破解及汉化 本人最近开始找到了 ...

  8. Flask入门学习——自定义一个url转换器

          我们知道,flask的url规则是可以添加变量部分的,这个参数变量是写在尖括号里的,比如:/item/<id>/,如果需要指出参数的类型要符合<converter:vai ...

  9. word is too tall: try to use less letters, smaller font or bigger background 报错 java程序 验证码不显示

    验证码不现实问题爆发在测试站,还好只是个测试站,有时间让我慢慢研究此问题. 具体的情况是这样的: 下午三点多,突然测试人员跟我说,测试站后台的验证码不现实了,也就无法登陆了 通过询问,是中午吃饭前还是 ...

  10. Java之Retry重试机制详解

    应用中需要实现一个功能: 需要将数据上传到远程存储服务,同时在返回处理成功情况下做其他操作.这个功能不复杂,分为两个步骤:第一步调用远程的Rest服务上传数据后对返回的结果进行处理:第二步拿到第一步结 ...