TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤
一、TensorFlow实战Google深度学习框架学习
1、步骤:
1、定义神经网络的结构和前向传播的输出结果。
2、定义损失函数以及选择反向传播优化的算法。
3、生成会话(session)并且在训练数据上反复运行反向传播优化算法。
2、代码:
来源:https://blog.csdn.net/longji/article/details/69472310
import tensorflow as tf
from numpy.random import RandomState # 1. 定义神经网络的参数,输入和输出节点
batch_size = 8
w1= tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2= tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_= tf.placeholder(tf.float32, shape=(None, 1), name='y-input') # 2. 定义前向传播过程,损失函数及反向传播算法
a = tf.matmul(x, w1)
y = tf.matmul(a, w2)
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy) # 3. 生成模拟数据集
rdm = RandomState(1)
X = rdm.rand(128,2)
Y = [[int(x1+x2 < 1)] for (x1, x2) in X] # 4. 创建一个会话来运行TensorFlow程序
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op) # 输出目前(未经训练)的参数取值。
print("w1:", sess.run(w1))
print("w2:", sess.run(w2))
print("\n") # 训练模型。
STEPS = 5000
for i in range(STEPS):
start = (i * batch_size) % 128
end = (i * batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy)) # 输出训练后的参数取值。
print("\n")
print("w1:", sess.run(w1))
print("w2:", sess.run(w2)) ''' 输出形式
w1: [[-0.81131822 1.48459876 0.06532937]
[-2.4427042 0.0992484 0.59122431]]
w2: [[-0.81131822]
[ 1.48459876]
[ 0.06532937]] After 0 training step(s), cross entropy on all data is 0.0674925
After 1000 training step(s), cross entropy on all data is 0.0163385
After 2000 training step(s), cross entropy on all data is 0.00907547
After 3000 training step(s), cross entropy on all data is 0.00714436
After 4000 training step(s), cross entropy on all data is 0.00578471 w1: [[-1.9618274 2.58235407 1.68203783]
[-3.46817183 1.06982327 2.11789012]]
w2: [[-1.82471502]
[ 2.68546653]
[ 1.41819513]]
'''
二、莫烦大大的神经网络训练步骤:
1、def add_layer()
添加神经网络层:
import tensorflow as tf
#输入、输入大小、输出大小、激活函数 def add_layer( inputs, in_size, out_size ,activation_function=None) : #weight初始化时生成一个随机变量矩阵比0矩阵效果要好 Weights = tf.Variable( tf.random_normal ( [in_size, out_size])) #biases初始值最好也不要都为0,则biases值全部等于0.1 biases = tf.Variable( tf.zeros([1,out_size]) + 0.1) #相当于Y_predict Wx_plus_b = tf.matmul ( inputs,Weights ) +biases
#如果为线性则outputs不用改变,如果不为线性则用激活函数
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_functions ( Wx_plus_b)
reeturn outputs
2、建立神经网络
#定义数据 x_data = np.linspace ( -1,1,300) [:,np.newaxis]
noise = np.random.normal ( 0,0.05 , x_data.shape)
y_data = np.square( x_data) - 0.5 +noise #建立第一层layer
#一个输入层、一个隐藏层、一个输出层
#输入层:输入多少data就多少个神经元,这里的x只有一个特征属性,则输入层有1个神经元
#隐藏层:自己定义10个
#输出层:输出y只有1个输出 #None表示无论给多少个样本都可以
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None,1] ) #add_layer为上面自己建立的函数,这里建立隐藏层
l1 = add_layer( xs , 1 ,10 ,activation_function = tf.nn.relu)
#输出层
predition = add_layer( l1 ,10 , 1,activation_function = None) #算损失函数 , reduction_indices =[1] 按行求和
loss = tf.reduce_mean ( tf.square ( ys -prediction ),
reduction_indices =[1] ) #选择一个优化器,选择:梯度下降,需要给定一个学习率为0.1,通常要小于1
#优化器以0.1的学习效率要减少loss函数,使下一次结果更好
train_step = tf.train.GradientDecentOptimizer( 0.1).minimize (loss) #初始所有变量 init = tf.initialize_all_variables () sess = tf.Session() sess.run(init)
#重复学习1000次
for i in range(1000):
sess.run( train_step , feed_dict = {xs:x_data,ys:y_data})
#每50次打印loss
if i % 50 == 0:
print(sess.run(loss,feed_dict={x:x_data,ys:y_data})
TensorFlow+实战Google深度学习框架学习笔记(5)----神经网络训练步骤的更多相关文章
- [Tensorflow实战Google深度学习框架]笔记4
本系列为Tensorflow实战Google深度学习框架知识笔记,仅为博主看书过程中觉得较为重要的知识点,简单摘要下来,内容较为零散,请见谅. 2017-11-06 [第五章] MNIST数字识别问题 ...
- 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)
学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...
- 学习《TensorFlow实战Google深度学习框架 (第2版) 》中文PDF和代码
TensorFlow是谷歌2015年开源的主流深度学习框架,目前已得到广泛应用.<TensorFlow:实战Google深度学习框架(第2版)>为TensorFlow入门参考书,帮助快速. ...
- TensorFlow实战Google深度学习框架5-7章学习笔记
目录 第5章 MNIST数字识别问题 第6章 图像识别与卷积神经网络 第7章 图像数据处理 第5章 MNIST数字识别问题 MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会 ...
- TensorFlow实战Google深度学习框架-人工智能教程-自学人工智能的第二天-深度学习
自学人工智能的第一天 "TensorFlow 是谷歌 2015 年开源的主流深度学习框架,目前已得到广泛应用.本书为 TensorFlow 入门参考书,旨在帮助读者以快速.有效的方式上手 T ...
- 实现迁徙学习-《Tensorflow 实战Google深度学习框架》代码详解
为了实现迁徙学习,首先是数据集的下载 #利用curl下载数据集 curl -o flower_photos.tgz http://download.tensorflow.org/example_ima ...
- 2 (自我拓展)部署花的识别模型(学习tensorflow实战google深度学习框架)
kaggle竞赛的inception模型已经能够提取图像很好的特征,后续训练出一个针对当前图片数据的全连接层,进行花的识别和分类.这里见书即可,不再赘述. 书中使用google参加Kaggle竞赛的i ...
- TensorFlow实战第三课(可视化、加速神经网络训练)
matplotlib可视化 构件图形 用散点图描述真实数据之间的关系(plt.ion()用于连续显示) # plot the real data fig = plt.figure() ax = fig ...
- TensorFlow实战Google深度学习框架10-12章学习笔记
目录 第10章 TensorFlow高层封装 第11章 TensorBoard可视化 第12章 TensorFlow计算加速 第10章 TensorFlow高层封装 目前比较流行的TensorFlow ...
随机推荐
- 前端开发—jQuery
jquery简介 jQuery是一个轻量级的.兼容多浏览器的JavaScript库. jQuery使用户能够更方便地处理HTML Document.Events.实现动画效果.方便地进行Ajax交互, ...
- AOJ 2224 Save your cats( 最小生成树 )
链接:传送门 题意:有个女巫把猫全部抓走放在一个由 n 个木桩(xi,yi),m 个篱笆(起点终点木桩的编号)围成的法术领域内,我们必须用圣水才能将篱笆打开,然而圣水非常贵,所以我们尽量想降低花费来解 ...
- N1-1 - 树 - Minimum Depth of Binary Tree
题目描述: Given a binary tree, find its minimum depth.The minimum depth is the number of nodes along the ...
- 4.2、Ansible常用模块
1.command:命令模块,默认模块,用于在远程执行命令,不支持变量.ansible 192.168.139.128 -a 'date' 2.cron:计划任务模块:ansible 192.168. ...
- docker数据卷的使用 -v --volumes--from
总结一下docker数据管理的三种方法: 1.普通的挂在数据: -v docker run -v /father/path:/child/path-v 参数会把当前系统的文件目录/father/pa ...
- Linux 密码的暴力破解
Linux 的密码的介绍 两个文件 1 . /etc/passwd 2 . /etc/shadow ## 关于/etc/shadow 文件的介绍 1 . 第一个字段是用户名 2 . 第二字字段是加密的 ...
- maven引入jsp相关依赖
<!--引入Servlet开始--> <dependency> <groupId>javax.servlet</groupId> <artifac ...
- 转载:手游安全破“黑”行动:向黑产业链说NO
目前的手游市场已被称为红海.从业界认为的2013年的“手游元年”至今,手游发展可谓是既经历了市场的野蛮生长,也有百家争鸣的战国时代.如今,手游市场竞争已趋白热化,增长放缓.但移动互联网的发展大势之下, ...
- Android4.0设置界面改动总结(二)
今年1月份的时候.有和大家分享给予Android4.0+系统设置的改动:Android4.0设置界面改动总结 时隔半年.回头看看那个时候的改动.事实上是有非常多问题的,比方说: ①.圆角Item会影响 ...
- poj - 1159 - Palindrome(滚动数组dp)
题意:一个长为N的字符串( 3 <= N <= 5000).问最少插入多少个字符使其变成回文串. 题目链接:http://poj.org/problem?id=1159 -->> ...