用Tensorflow搭建神经网络的一般步骤如下:

① 导入模块

② 创建模型变量和占位符

③ 建立模型

④ 定义loss函数

⑤ 定义优化器(optimizer), 使 loss 达到最小

⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)

⑦ 训练模型

⑧ 检验模型

⑨ 使用模型预测数据

⑩ 保存模型

⑪ 使用Tensorboard的可视化功能

下面以一个简单的线性回归问题为例:

首先是训练模型的代码: train_model.py

 # ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()

然后是载入模型的代码: restore_model.py

 import tensorflow as tf

 with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]

用Tensorflow搭建神经网络的一般步骤的更多相关文章

  1. (转)一文学会用 Tensorflow 搭建神经网络

    一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...

  2. 一文学会用 Tensorflow 搭建神经网络

    http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...

  3. Tensorflow 搭建神经网络及tensorboard可视化

    1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...

  4. kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)

    一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...

  5. Tensorflow搭建神经网络及使用Tensorboard进行可视化

    创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...

  6. tensorflow搭建神经网络

    最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...

  7. tensorflow搭建神经网络基本流程

    定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...

  8. 基于tensorflow搭建一个神经网络

    一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...

  9. Tensorflow学习:(二)搭建神经网络

    一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络       2.搭建神经网络结构,从输入到输出       3.大量特征数据喂给 NN,迭代优化 NN 参数       4.使 ...

随机推荐

  1. WinForm 设置窗体启动位置在活动屏幕右下角

    WinForm 设置窗体启动位置在活动屏幕右下角 在多屏幕环境下, 默认使用鼠标所在的屏幕 1. 设置窗体的 StartPosition 为 FormStartPosition.Manual. 2. ...

  2. [math]本博客已经支持书写数学公式

    本博客已经支持mathjax格式公式 使用方法 使用方法单美元符号加单行公式. 使用方法双美元符号加多行公式. 展示 单行公式:\(x^2+2x+1=0\) 多行公式:\[x=\frac{{-b}\p ...

  3. 自动化pip安装

    其实正确安装python3.6后,在安装目录里就有pip.exe文件,只不过用的时候,要进入pip的安装目录下进行安装numpy等. 如进入这个目录, D:\Program Files\Python\ ...

  4. eclipse报错:Multiple annotations found at this line: - String cannot be resolved to a type解决方法实测

    Multiple annotations found at this line:- String cannot be resolved to a type- The method getContext ...

  5. 跨域获取后台日期-ASP

    最近所有的计划都被打乱,生活节奏也有些控制不住,所以在自己还算清醒的时候,把之前一个小功能写下来,对其它人也有些帮助. 需求前景:需要用AJAX跨域获取后台服务器日期. 1.分析需求: 在这个需求中, ...

  6. selenium+Headless Chrome实现不弹出浏览器自动化登录

    目前由于phantomjs已经不维护了,而新版的Chrome(59+)推出了Headless模式,对爬虫来说尤其是定时任务的爬虫截屏之类的是一大好事. 不过按照网络上的一些方法来写的话,会报下面的错误 ...

  7. js 延时等待

    //延时器,2秒后执行函数 function test(){ alert("aaaa"); } setTimeout(function () { test(); }, ); //或 ...

  8. centos7 基础命令

    一: linux基础 (1) 查看服务器的IP信息 ip add showifconfig (2) 操作网卡命令(重启网络和启用网卡) systemctl restart networksystemc ...

  9. react项目中实现悬浮(hover)在按钮上时在旁边显示提示

    <i className={classNames({ 'device-icon': true, 'camera-icon': true, 'camera-icon-hover-show-intr ...

  10. Windows Server 2012 R2 英文版安装中文语言包教程

    Windows Server 是云操作系统的主要组成部分. 有了 Windows Server,再加上云操作系统内的开发者技术,您就可以构建现代业务应用程序. 现代业务应用程序通常涵盖内部部署资源和公 ...