TensorFlow常用的示例一般都是生成模型和测试模型写在一起,每次更换测试数据都要重新训练,过于麻烦,

以下采用先生成并保存本地模型,然后后续程序调用测试。

示例一:线性回归预测

make.py

import tensorflow as tf
import numpy as np def train_model(): # prepare the data
x_data = np.random.rand(100).astype(np.float32)
print (x_data)
y_data = x_data * 0.1 + 0.2
print (y_data) # define the weights
W = tf.Variable(tf.random_uniform([1], -20.0, 20.0), dtype=tf.float32, name='w')
b = tf.Variable(tf.random_uniform([1], -10.0, 10.0), dtype=tf.float32, name='b')
y = W * x_data + b # define the loss
loss = tf.reduce_mean(tf.square(y - y_data))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss) # save model
saver = tf.train.Saver(max_to_keep=4) with tf.Session() as sess: sess.run(tf.global_variables_initializer())
print ("------------------------------------------------------")
print ("before the train, the W is %6f, the b is %6f" % (sess.run(W), sess.run(b))) for epoch in range(300):
if epoch % 10 == 0:
print ("------------------------------------------------------")
print ("after epoch %d, the loss is %6f" % (epoch, sess.run(loss)))
print ("the W is %f, the b is %f" % (sess.run(W), sess.run(b)))
saver.save(sess, "model/my-model", global_step=epoch)
print ("save the model")
sess.run(train_step)
print ("------------------------------------------------------") train_model()

test.py

import tensorflow as tf
import numpy as np def load_model():
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model/my-model-290.meta')
saver.restore(sess, tf.train.latest_checkpoint("model/"))
print (sess.run('w:0'))
print (sess.run('b:0'))
load_model()

示例二:卷积神经网络

make.py

import tensorflow as tf
import numpy as np
import os
os.mkdir("model1")
def load_data(resultpath): datapath = os.path.join(resultpath, "data10_4.npz")
if os.path.exists(datapath):
data = np.load(datapath)
X, Y = data["X"], data["Y"]
else:
X = np.array(np.arange(30720)).reshape(10, 32, 32, 3)
Y = [0, 0, 1, 1, 2, 2, 3, 3, 2, 0]
X = X.astype('float32')
Y = np.array(Y)
np.savez(datapath, X=X, Y=Y)
print('Saved dataset to dataset.npz.')
print('X_shape:{}\nY_shape:{}'.format(X.shape, Y.shape))
return X, Y def define_model(x): x_image = tf.reshape(x, [-1, 32, 32, 3])
print (x_image.shape) def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial, name="w") def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial, name="b") def conv3d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2d(x):
return tf.nn.max_pool(x, ksize=[1, 3, 3, 1], strides=[1, 3, 3, 1], padding='SAME') with tf.variable_scope("conv1"): # [-1,32,32,3]
weights = weight_variable([3, 3, 3, 32])
biases = bias_variable([32])
conv1 = tf.nn.relu(conv3d(x_image, weights) + biases)
pool1 = max_pool_2d(conv1) # [-1,11,11,32] with tf.variable_scope("conv2"):
weights = weight_variable([3, 3, 32, 64])
biases = bias_variable([64])
conv2 = tf.nn.relu(conv3d(pool1, weights) + biases)
pool2 = max_pool_2d(conv2) # [-1,4,4,64] with tf.variable_scope("fc1"):
weights = weight_variable([4 * 4 * 64, 128]) # [-1,1024]
biases = bias_variable([128])
fc1_flat = tf.reshape(pool2, [-1, 4 * 4 * 64])
fc1 = tf.nn.relu(tf.matmul(fc1_flat, weights) + biases)
fc1_drop = tf.nn.dropout(fc1, 0.5) # [-1,128] with tf.variable_scope("fc2"):
weights = weight_variable([128, 4])
biases = bias_variable([4])
fc2 = tf.matmul(fc1_drop, weights) + biases # [-1,4] return fc2 def train_model(): x = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name="x")
y_ = tf.placeholder('int64', shape=[None], name="y_") initial_learning_rate = 0.001
y_fc2 = define_model(x)
y_label = tf.one_hot(y_, 4, name="y_labels") loss_temp = tf.losses.softmax_cross_entropy(onehot_labels=y_label, logits=y_fc2)
cross_entropy_loss = tf.reduce_mean(loss_temp) train_step = tf.train.AdamOptimizer(learning_rate=initial_learning_rate, beta1=0.9, beta2=0.999,
epsilon=1e-08).minimize(cross_entropy_loss) correct_prediction = tf.equal(tf.argmax(y_fc2, 1), tf.argmax(y_label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # save model
saver = tf.train.Saver(max_to_keep=4)
tf.add_to_collection("predict", y_fc2) with tf.Session() as sess: sess.run(tf.global_variables_initializer())
print ("------------------------------------------------------")
X, Y = load_data("model1/")
X = np.multiply(X, 1.0 / 255.0)
for epoch in range(190): if epoch % 10 == 0:
print ("------------------------------------------------------") train_accuracy = accuracy.eval(feed_dict={x: X, y_: Y})
train_loss = cross_entropy_loss.eval(feed_dict={x: X, y_: Y}) print ("after epoch %d, the loss is %6f" % (epoch, train_loss))
print ("after epoch %d, the acc is %6f" % (epoch, train_accuracy)) saver.save(sess, "model1/my-model", global_step=epoch)
print ("save the model") train_step.run(feed_dict={x: X, y_: Y}) print ("------------------------------------------------------") train_model()

test.py

import tensorflow as tf
import numpy as np
import os def load_model(): # prepare the test data
X = np.array(np.arange(6144, 12288)).reshape(2, 32, 32, 3)
Y = [3, 1]
Y = np.array(Y)
X = X.astype('float32')
X = np.multiply(X, 1.0 / 255.0)
with tf.Session() as sess: # load the meta graph and weights
saver = tf.train.import_meta_graph('model1/my-model-180.meta')
saver.restore(sess, tf.train.latest_checkpoint("model1/")) # get weights
graph = tf.get_default_graph()
fc2_w = graph.get_tensor_by_name("fc2/w:0")
fc2_b = graph.get_tensor_by_name("fc2/b:0") print ("------------------------------------------------------")
print (sess.run(fc2_w))
print ("#######################################")
print (sess.run(fc2_b))
print ("------------------------------------------------------") input_x = graph.get_operation_by_name("x").outputs[0] feed_dict = {"x:0":X, "y_:0":Y}
y = graph.get_tensor_by_name("y_labels:0")
yy = sess.run(y, feed_dict)
print (yy)
print ("the answer is: ", sess.run(tf.argmax(yy, 1)))
print ("------------------------------------------------------") pred_y = tf.get_collection("predict")
pred = sess.run(pred_y, feed_dict)[0]
print (pred, '\n') pred = sess.run(tf.argmax(pred, 1))
print ("the predict is: ", pred)
print ("------------------------------------------------------") load_model()

TensorFlow笔记四:从生成和保存模型 -> 调用使用模型的更多相关文章

  1. go微服务框架kratos学习笔记四(kratos warden-quickstart warden-direct方式client调用)

    目录 go微服务框架kratos学习笔记四(kratos warden-quickstart warden-direct方式client调用) warden direct demo-server gr ...

  2. tensorflow笔记:模型的保存与训练过程可视化

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

  3. (四) tensorflow笔记:常用函数说明

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

  4. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  5. tensorflow笔记之滑动平均模型

    tensorflow使用tf.train.ExponentialMovingAverage实现滑动平均模型,在使用随机梯度下降方法训练神经网络时候,使用这个模型可以增强模型的鲁棒性(robust),可 ...

  6. Keras学习笔记二:保存本地模型和调用本地模型

    使用深度学习模型时当然希望可以保存下训练好的模型,需要的时候直接调用,不再重新训练 一.保存模型到本地 以mnist数据集下的AutoEncoder 去噪为例.添加: file_path=" ...

  7. ThinkPHP 学习笔记 ( 四 ) 数据库操作之关联模型 ( RelationMondel ) 和高级模型 ( AdvModel )

    一.关联模型 ( RelationMondel ) 1.数据查询 ① HAS_ONE 查询 创建两张数据表评论表和文章表: tpk_comment , tpk_article .评论和文章的对应关系为 ...

  8. SpringMVC 学习笔记(四) 处理模型数据

    Spring MVC 提供了下面几种途径输出模型数据: – ModelAndView: 处理方法返回值类型为 ModelAndView时, 方法体就可以通过该对象加入模型数据 – Map及Model: ...

  9. tensorflow笔记:使用tf来实现word2vec

    (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 (四) tensorflow笔 ...

随机推荐

  1. (Mac)centos 6.5安装 JDK+mysql

    为了把自己的网站放到外网,购买了阿里云的centos 6.5服务器,以下是安装 JDK 一.JDK安装: 方法一: 1.创建目录,命令行:(这里可以不加sudo) sudo mkdir /jdk 2. ...

  2. Linux(Centos) 搭建ReviewBoard

    一.官方安装手册 reviewboard 的安装用户手册:猛击这里 二.常用安装步骤 2.1.安装httpd,+ mod_wsgi, fastcgi, or mod_python yum -y int ...

  3. [NOIP2013] 提高组 洛谷P1979 华容道

    题目描述 [问题描述] 小 B 最近迷上了华容道,可是他总是要花很长的时间才能完成一次.于是,他想到用编程来完成华容道:给定一种局面, 华容道是否根本就无法完成,如果能完成, 最少需要多少时间. 小 ...

  4. 洛谷 P 2756 飞行员配对方案问题

    题目背景 第二次世界大战时期.. 题目描述 英国皇家空军从沦陷国征募了大量外籍飞行员.由皇家空军派出的每一架飞机都需要配备在航行技能和语言上能互相配合的2 名飞行员,其中1 名是英国飞行员,另1名是外 ...

  5. 【CF1073B】Vasya and Books(模拟)

    题意:给你一个栈里书的编号,每次能捞出栈顶的一本书,每次询问捞出某本编号的书需要捞几次 n<=2e5 思路: #include<cstdio> #include<cstring ...

  6. 从无序序列中求这个序列排序后邻点间最大差值的O(n)算法

    标题可能比较绕口,简单点说就是给你一个无序数列A={a1,a2,a3……an},如果你把这个序列排序后变成序列B,求序列B中相邻两个元素之间相差数值的最大值. 注意:序列A的元素的大小在[1,2^31 ...

  7. JSTL获取Session的ID与获取文件的真实路径与项目名称

    今天在测试集群配置的时候想到session共享,因此想要获取sessionID,可以通过下面方法: ${pageContext.session.id} 获取文件的真实路径: <%=request ...

  8. MSP430 G2553 寄存器列表与引脚功能

    USCI_B0 USCI_B0 发送缓冲器UCB0TXBUF 06Fh USCI_B0 接收缓冲器UCB0RXBUF 06Eh USCI_B0 状态UCB0STAT 06Dh USCI B0 I2C ...

  9. 《手把手教你学C语言》学习笔记(1)---C语言的特点

    学习C语言的原因,主要是需要使用C语言编程,我用故我学,应该是最主要的原因了. C语言的定位:C语言严格意义上只能算是中级语言,是面向过程编程语言的集大成者,虽然这种语言有很多的问题,但总体而言是瑕不 ...

  10. selenium题

    一.selenium中如何判断元素是否存在? 首先selenium里面是没有这个方法的,判断元素存在需要自己写一个方法了. 元素存在有几种形式,一种是页面有多个元素属性重复的,这种直接操作会报错的:还 ...