用Tensorflow搭建神经网络的一般步骤如下:

① 导入模块

② 创建模型变量和占位符

③ 建立模型

④ 定义loss函数

⑤ 定义优化器(optimizer), 使 loss 达到最小

⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)

⑦ 训练模型

⑧ 检验模型

⑨ 使用模型预测数据

⑩ 保存模型

⑪ 使用Tensorboard的可视化功能

下面以一个简单的线性回归问题为例:

首先是训练模型的代码: train_model.py

 # ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()

然后是载入模型的代码: restore_model.py

 import tensorflow as tf

 with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]

用Tensorflow搭建神经网络的一般步骤的更多相关文章

  1. (转)一文学会用 Tensorflow 搭建神经网络

    一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...

  2. 一文学会用 Tensorflow 搭建神经网络

    http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...

  3. Tensorflow 搭建神经网络及tensorboard可视化

    1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...

  4. kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)

    一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...

  5. Tensorflow搭建神经网络及使用Tensorboard进行可视化

    创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...

  6. tensorflow搭建神经网络

    最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...

  7. tensorflow搭建神经网络基本流程

    定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...

  8. 基于tensorflow搭建一个神经网络

    一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...

  9. Tensorflow学习:(二)搭建神经网络

    一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络       2.搭建神经网络结构,从输入到输出       3.大量特征数据喂给 NN,迭代优化 NN 参数       4.使 ...

随机推荐

  1. StreamReader 和 StreamWriter 简单调用

    /* ######### ############ ############# ## ########### ### ###### ##### ### ####### #### ### ####### ...

  2. 论文笔记:Auto-ReID: Searching for a Part-aware ConvNet for Person Re-Identification

    Auto-ReID: Searching for a Part-aware ConvNet for Person Re-Identification 2019-03-26 15:27:10 Paper ...

  3. C、C++中的static和extern关键字

    1.首先,关于声明和定义的区别 这种写法(函数原型后加;号表示结束的写法)只能叫函数声明而不能叫函数定义,只有带函数体的声明才叫定义,比如下面 只有分配存储空间的变量声明才叫变量定义,其实函数也是一样 ...

  4. ES6解构过程添加一个默认值和赋值一个新的值

    const info = { name: 'xiaobe', } const { name: nickName = '未知' } = info; 其中nickName是解构过程中新声明的一个变量,并且 ...

  5. Valotile关键字详解

    在了解valotile关键字之前.我们先来了解其他相关概念. 1.1  java内存模型: 不同的平台,内存模型是不一样的,我们可以把内存模型理解为在特定操作协议下,对特定的内存或高速缓存进行读写访问 ...

  6. “妄”眼欲穿-CSS之flex布局和边框阴影

    妄:狂妄: 不会的东西只有怀着一颗狂妄的心,假装能把它看穿吧. 作为一个什么都不会的小白,为了学习(zb),特别在拿来主义之后写一些对于某些css布局的总结,进一步加深对知识的记忆.知识是人类的共同财 ...

  7. (简单)华为P9plus VIE-AL00的usb调试模式在哪里开启的经验

    每次我们使用pc接通安卓手机的时候,如果手机没有开启Usb调试模式,pc则没能成功检测到我们的手机,有时,我们使用的一些功能比较强的的app比如之前我们使用的一个app引号精灵,老版本就需要打开Usb ...

  8. (转)如何在maven的pom.xml中添加本地jar包

    转载自: https://www.cnblogs.com/lixuwu/p/5855031.html 1 maven本地仓库认识 maven本地仓库中的jar目录一般分为三层:图中的1 2 3分别如下 ...

  9. 在线批量将gps经纬度坐标转换为百度经纬度坐标

    1.首先打开百度api示例页面: 在浏览器地址栏中输入:http://developer.baidu.com/map/jsdemo.htm#a5_3 2.修改代码 如下图,将需要批量转换的坐标,按规则 ...

  10. Ubuntu - apt -commands

    1. install sudo apt install [软件名] sudo apt-get install [软件名]Tab补全,可以使用sudo apt upgrade 升级apt, 也可以通过s ...