# 前面的是定义好的模型结构
- # 前面的代码是模型的定义代码
- saver = tf.train.Saver() # 生成saver
- with tf.Session() as sess:
- sess.run(init) # 模型的初始化
- #
- # 模型的训练代码,当模型训练完毕后,下面就可以对模型进行保存了
- #
- saver.save(sess, "model/linear") # 当路径不存在时,会自动创建路径
- saver = tf.train.Saver()
- with tf.Session() as sess:
- # 可以对模型进行初始化,也可以不进行模型的初始化,因为后面的加载会覆盖之前的
- # 初始化操作
- sess.run(init)
- saver.restore(sess, "model/linear")
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- import os
- train_x = np.linspace(-5, 3, 50)
- train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
- plt.plot(train_x, train_y, 'r.')
- plt.grid(True)
- plt.show()
- X = tf.placeholder(dtype=tf.float32)
- Y = tf.placeholder(dtype=tf.float32)
- w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
- b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
- z = tf.multiply(X, w) + b
- cost = tf.reduce_mean(tf.square(Y - z))
- learning_rate = 0.01
- optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
- init = tf.global_variables_initializer()
- training_epochs = 20
- display_step = 2
- saver = tf.train.Saver()
- if __name__ == '__main__':
- with tf.Session() as sess:
- sess.run(init)
- if os.path.exists("model/"):
- saver.restore(sess, "model/linear")
- w_, b_ = sess.run([w, b])
- print(" Finished ")
- print("W: ", w_, " b: ", b_)
- plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
- plt.grid(True)
- plt.show()
- else:
- loss_list = []
- for epoch in range(training_epochs):
- for (x, y) in zip(train_x, train_y):
- sess.run(optimizer, feed_dict={X: x, Y: y})
- if epoch % display_step == 0:
- loss = sess.run(cost, feed_dict={X: x, Y: y})
- loss_list.append(loss)
- print('Iter: ', epoch, ' Loss: ', loss)
- w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
- saver.save(sess, "model/linear")
- print(" Finished ")
- print("W: ", w_, " b: ", b_, " loss: ", loss)
- plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
- plt.grid(True)
- plt.show()
- from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
- modeldir = 'model/'
- print_tensors_in_checkpoint_file(modeldir + 'linear.cpkt', None, True)
- saver = tf.train.Saver({'weight_':w, 'bias_':b})
- # saver = tf.train.Saver([w, b])
