用Tensorflow搭建神经网络的一般步骤
用Tensorflow搭建神经网络的一般步骤如下:
① 导入模块
② 创建模型变量和占位符
③ 建立模型
④ 定义loss函数
⑤ 定义优化器(optimizer), 使 loss 达到最小
⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)
⑦ 训练模型
⑧ 检验模型
⑨ 使用模型预测数据
⑩ 保存模型
⑪ 使用Tensorboard的可视化功能
下面以一个简单的线性回归问题为例:
首先是训练模型的代码: train_model.py
# ① 导入模块
import tensorflow as tf # ② 创建模型的变量和占位符
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="input_x")
y = tf.placeholder(tf.float32, name="input_y") # ③建立模型
linear_model = W*x + b
# 如果是矩阵相乘,可以写成:
# linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 # ④ 定义loss函数
loss = tf.reduce_sum(tf.square(linear_model - y)) # ⑤ 定义优化器(optimizer), 使 loss 达到最小
learning_rate=0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss) # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) # ⑦ 训练模型
# 假设模型是y=2x+1
x_train = [1, 2, 3, 4]
y_train = [3, 5, 7, 9] init = tf.global_variables_initializer() # 添加用于初始化变量的节点
sess = tf.Session()
sess.run(init) # 运行初始化操作
for step in range(1000):
sess.run(train, {x: x_train, y: y_train}) '''
第⑦步和第⑩步可以合并为:
for step in xrange(1000000):
sess.run(train, {x: x_train, y: y_train})
if step % 1000 == 0:
saver.save(sess, 'my-model', global_step=step)
''' # ⑧ 检验模型
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
'''
W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
''' # ⑨ 使用模型预测数据
x_predict = [-1, 0, 1, 2]
predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
# 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
print("result:", predicted_values)
'''
result: [-1.0000062 0.99999553 2.99999714 4.99999905]
''' # ⑩ 保存模型
tf.add_to_collection("predict_network", linear_model)
saver = tf.train.Saver()
saver_path=saver.save(sess, "save/model.ckpt") # ⑪ 使用Tensorboard的可视化功能
# 定义保存日志的路径
path = "log" # 也可写成: path = "./log"
writer=tf.summary.FileWriter(path, sess.graph) sess.close()
然后是载入模型的代码: restore_model.py
import tensorflow as tf with tf.Session() as sess:
new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess,"save/model.ckpt")
# print(tf.get_collection("predict_network"))
restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 graph=tf.get_default_graph()
restored_x=graph.get_operation_by_name("input_x").outputs[0] predict_data = [-2, 3, 4]
predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]
用Tensorflow搭建神经网络的一般步骤的更多相关文章
- (转)一文学会用 Tensorflow 搭建神经网络
一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day ...
- 一文学会用 Tensorflow 搭建神经网络
http://www.jianshu.com/p/e112012a4b2d 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码 ...
- Tensorflow 搭建神经网络及tensorboard可视化
1. session对话控制 matrix1 = tf.constant([[3,3]]) matrix2 = tf.constant([[2],[2]]) product = tf.matmul(m ...
- kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)
一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...
- Tensorflow搭建神经网络及使用Tensorboard进行可视化
创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...
- tensorflow搭建神经网络
最简单的神经网络 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt date = np.linspa ...
- tensorflow搭建神经网络基本流程
定义添加神经层的函数 1.训练的数据2.定义节点准备接收数据3.定义神经层:隐藏层和预测层4.定义 loss 表达式5.选择 optimizer 使 loss 达到最小 然后对所有变量进行初始化,通过 ...
- 基于tensorflow搭建一个神经网络
一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以 ...
- Tensorflow学习:(二)搭建神经网络
一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络 2.搭建神经网络结构,从输入到输出 3.大量特征数据喂给 NN,迭代优化 NN 参数 4.使 ...
随机推荐
- StreamReader 和 StreamWriter 简单调用
/* ######### ############ ############# ## ########### ### ###### ##### ### ####### #### ### ####### ...
- 论文笔记:Auto-ReID: Searching for a Part-aware ConvNet for Person Re-Identification
Auto-ReID: Searching for a Part-aware ConvNet for Person Re-Identification 2019-03-26 15:27:10 Paper ...
- C、C++中的static和extern关键字
1.首先,关于声明和定义的区别 这种写法(函数原型后加;号表示结束的写法)只能叫函数声明而不能叫函数定义,只有带函数体的声明才叫定义,比如下面 只有分配存储空间的变量声明才叫变量定义,其实函数也是一样 ...
- ES6解构过程添加一个默认值和赋值一个新的值
const info = { name: 'xiaobe', } const { name: nickName = '未知' } = info; 其中nickName是解构过程中新声明的一个变量,并且 ...
- Valotile关键字详解
在了解valotile关键字之前.我们先来了解其他相关概念. 1.1 java内存模型: 不同的平台,内存模型是不一样的,我们可以把内存模型理解为在特定操作协议下,对特定的内存或高速缓存进行读写访问 ...
- “妄”眼欲穿-CSS之flex布局和边框阴影
妄:狂妄: 不会的东西只有怀着一颗狂妄的心,假装能把它看穿吧. 作为一个什么都不会的小白,为了学习(zb),特别在拿来主义之后写一些对于某些css布局的总结,进一步加深对知识的记忆.知识是人类的共同财 ...
- (简单)华为P9plus VIE-AL00的usb调试模式在哪里开启的经验
每次我们使用pc接通安卓手机的时候,如果手机没有开启Usb调试模式,pc则没能成功检测到我们的手机,有时,我们使用的一些功能比较强的的app比如之前我们使用的一个app引号精灵,老版本就需要打开Usb ...
- (转)如何在maven的pom.xml中添加本地jar包
转载自: https://www.cnblogs.com/lixuwu/p/5855031.html 1 maven本地仓库认识 maven本地仓库中的jar目录一般分为三层:图中的1 2 3分别如下 ...
- 在线批量将gps经纬度坐标转换为百度经纬度坐标
1.首先打开百度api示例页面: 在浏览器地址栏中输入:http://developer.baidu.com/map/jsdemo.htm#a5_3 2.修改代码 如下图,将需要批量转换的坐标,按规则 ...
- Ubuntu - apt -commands
1. install sudo apt install [软件名] sudo apt-get install [软件名]Tab补全,可以使用sudo apt upgrade 升级apt, 也可以通过s ...