tensorflow add_to_collection用法
训练代码:
# coding: utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import argparse
def dense_to_one_hot(input_data, class_num):
data_num = input_data.shape[0]
index_offset = np.arange(data_num) * class_num
labels_one_hot = np.zeros((data_num, class_num))
labels_one_hot.flat[index_offset + input_data.ravel()] = 1
return labels_one_hot
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True)
parser.add_argument('--model_path', type=str, required=True)
args = parser.parse_args()
return args
p = build_parser()
origin = np.genfromtxt(p.data_path, delimiter=',')
data = origin[:, 0:2]
labels = origin[:, 2]
learning_rate = 0.001
training_epochs = 5000
display_step = 1
n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class])
W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
b = tf.Variable(tf.zeros([n_class]), name="b")
scores = tf.nn.xw_plus_b(x, W, b, name='scores')
pred_proba = tf.nn.softmax(scores, name="pred_proba")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
saver = tf.train.Saver()
tf.add_to_collection('pred_proba', pred_proba)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
if epoch % 100 == 0:
print(c)
saver.save(sess, p.model_path)
print("Optimization Finished!")
推理代码:
# coding: utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import argparse
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=True)
args = parser.parse_args()
return args
p = build_parser()
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(p.model_path + ".meta")
new_saver.restore(sess, p.model_path)
pred_proba = tf.get_collection('pred_proba')[0]
graph = tf.get_default_graph()
input_x = graph.get_operation_by_name('input').outputs[0]
r = sess.run(pred_proba, feed_dict={input_x: np.array([[0.6211,5]])})
print(r)
print(0 if r[0][0] > r[0][1] else 1)
参考资料
tensorflow add_to_collection用法的更多相关文章
- Tensorflow Summary用法
本文转载自:https://www.cnblogs.com/lyc-seu/p/8647792.html Tensorflow Summary用法 tensorboard 作为一款可视化神器,是学习t ...
- 第一节,TensorFlow基本用法
一 TensorFlow安装 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tsnsor(张量)意味着N维数组,Flow(流)意味着基 ...
- tensorflow SavedModelBuilder用法
训练代码: # coding: utf-8 from __future__ import print_function from __future__ import division import t ...
- tensorflow基本用法个人笔记
综述 TensorFlow程序分为构建阶段和执行阶段.通过构建一个图.执行这个图来得到结果. 构建图 创建源op,源op不需要任何输入,例如常量constant,源op的输出被传递给其他op做 ...
- Tensorflow学习笔记——Summary用法
tensorboard 作为一款可视化神器,可以说是学习tensorflow时模型训练以及参数可视化的法宝. 而在训练过程中,主要用到了tf.summary()的各类方法,能够保存训练过程以及参数分布 ...
- (转)TensorFlow 入门
TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...
- 统计学习方法:罗杰斯特回归及Tensorflow入门
作者:桂. 时间:2017-04-21 21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...
- 芝麻HTTP:TensorFlow基础入门
本篇内容基于 Python3 TensorFlow 1.4 版本. 本节内容 本节通过最简单的示例 -- 平面拟合来说明 TensorFlow 的基本用法. 构造数据 TensorFlow 的引入方式 ...
- tensorflow 学习日志
Windows安装anaconda 和 TensorFlow anaconda : https://zhuanlan.zhihu.com/p/25198543 anaconda 使用与说 ...
随机推荐
- 用大写字母输入 Linux 命令,实现以 sudo 用户权限运行
我们知道,一些 Linux 命令是要通过 sudo 权限才能运行的,这需要我们每次使用这些命令时在前面加一个 sudo ,十分繁琐.今天给大家介绍一个好用的工具 SUDO ,它只需要我们用大写字母键入 ...
- Ubuntu清空回收站
ubuntu 回收站的具体位置:$HOME/.local/share/Trash/ 执行如下命令清空回收站: sudo rm -fr $HOME/.local/share/Trash/files/ 如 ...
- Python3.7.1学习(七)mysql中pymysql模块详解(一)
pymysql是纯用Python操作MySQL的模块,其使用方法和MySQLdb几乎相同.此次介绍mysql以及在python中如何用pymysql操作数据库, 以及在mysql中存储过程, 触发器以 ...
- 《JAVA 程序员面试宝典(第四版)》之传递与引用篇
废话开场白 这个周末突然很想创业,为什么呢?原因很简单,我周围的同学或者说玩的比较好的朋友都发达了,前一个月和一个两年前还睡在一张床上的朋友,他现在已经在深圳买房买车了,没错是在深圳买 ...
- Day01第一天 Python基础一
变量 就是将一些运算的中间结果暂时存在内存中,以便后续代码的调用. >命名规则: 1,只能以字母,数字,下划线自由组合,且,不能以数字开头.2,不能是 Python 中的关键字.3,要具有可 ...
- java基础开发环境安装(全)
一.jdk安装(可以根据自己习惯选择合适安装路径) 1.jdk1.8下载地址:https://pan.baidu.com/s/1O9JQlFJ9cpkGCQL35cm_7g 提取码:pe2g 2.jd ...
- 使用 Rsync 从 Windows 同步数据到 Linux
为什么要使用 rsync 从 Windows 到 linux 进行同步? 我们经常会面临这种的情况,项目使用 Windows 开发,最终部署在 Linux 上,但有时想要进行测试.维护.迭代版本时操作 ...
- jQuery简单面试题
干货 | jQuery经典面试题及答案精选 面试题来啦! 毫无疑问,JavaScript是一门如此有用,但总是被低估的一门语言. 在 jQuery 粉墨登场之前,我们曾经会写出冗长的JavaScrip ...
- 20191010-3 alpha week 1/2 Scrum立会报告+燃尽图 01
此作业要求参见https://edu.cnblogs.com/campus/nenu/2019fall/homework/8746 一.小组情况 组长:迟俊文 组员:宋晓丽 梁梦瑶 韩昊 刘信鹏 队名 ...
- 在idea中使用git
在idea中使用git 1. 在idea中配置git 安装好IntelliJ IDEA后,如果Git安装在默认路径下,那么idea会自动找到git的位置,如果更改了Git的安装位置则需要手动配置下 ...