用Tensorflow搭建神经网络的一般步骤
用Tensorflow搭建神经网络的一般步骤如下:
① 导入模块
② 创建模型变量和占位符
③ 建立模型
④ 定义loss函数
⑤ 定义优化器(optimizer), 使 loss 达到最小
⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)
⑦ 训练模型
⑧ 检验模型
⑨ 使用模型预测数据
⑩ 保存模型
⑪ 使用Tensorboard的可视化功能
下面以一个简单的线性回归问题为例:
首先是训练模型的代码: train_model.py
# ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()
然后是载入模型的代码: restore_model.py
import tensorflow as tf with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]
用Tensorflow搭建神经网络的一般步骤的更多相关文章
- (转)一文学会用 Tensorflow 搭建神经网络
一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...
- 一文学会用 Tensorflow 搭建神经网络
http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...
- Tensorflow 搭建神经网络及tensorboard可视化
1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...
- kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)
一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...
- Tensorflow搭建神经网络及使用Tensorboard进行可视化
创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...
- tensorflow搭建神经网络
最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...
- tensorflow搭建神经网络基本流程
定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...
- 基于tensorflow搭建一个神经网络
一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...
- Tensorflow学习:(二)搭建神经网络
一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络 2.搭建神经网络结构,从输入到输出 3.大量特征数据喂给 NN,迭代优化 NN 参数 4.使 ...
随机推荐
- WinForm 设置窗体启动位置在活动屏幕右下角
WinForm 设置窗体启动位置在活动屏幕右下角 在多屏幕环境下, 默认使用鼠标所在的屏幕 1. 设置窗体的 StartPosition 为 FormStartPosition.Manual. 2. ...
- [math]本博客已经支持书写数学公式
本博客已经支持mathjax格式公式 使用方法 使用方法单美元符号加单行公式. 使用方法双美元符号加多行公式. 展示 单行公式:\(x^2+2x+1=0\) 多行公式:\[x=\frac{{-b}\p ...
- 自动化pip安装
其实正确安装python3.6后,在安装目录里就有pip.exe文件,只不过用的时候,要进入pip的安装目录下进行安装numpy等. 如进入这个目录, D:\Program Files\Python\ ...
- eclipse报错:Multiple annotations found at this line: - String cannot be resolved to a type解决方法实测
Multiple annotations found at this line:- String cannot be resolved to a type- The method getContext ...
- 跨域获取后台日期-ASP
最近所有的计划都被打乱,生活节奏也有些控制不住,所以在自己还算清醒的时候,把之前一个小功能写下来,对其它人也有些帮助. 需求前景:需要用AJAX跨域获取后台服务器日期. 1.分析需求: 在这个需求中, ...
- selenium+Headless Chrome实现不弹出浏览器自动化登录
目前由于phantomjs已经不维护了,而新版的Chrome(59+)推出了Headless模式,对爬虫来说尤其是定时任务的爬虫截屏之类的是一大好事. 不过按照网络上的一些方法来写的话,会报下面的错误 ...
- js 延时等待
//延时器,2秒后执行函数 function test(){ alert("aaaa"); } setTimeout(function () { test(); }, ); //或 ...
- centos7 基础命令
一: linux基础 (1) 查看服务器的IP信息 ip add showifconfig (2) 操作网卡命令(重启网络和启用网卡) systemctl restart networksystemc ...
- react项目中实现悬浮(hover)在按钮上时在旁边显示提示
<i className={classNames({ 'device-icon': true, 'camera-icon': true, 'camera-icon-hover-show-intr ...
- Windows Server 2012 R2 英文版安装中文语言包教程
Windows Server 是云操作系统的主要组成部分. 有了 Windows Server,再加上云操作系统内的开发者技术,您就可以构建现代业务应用程序. 现代业务应用程序通常涵盖内部部署资源和公 ...