TF随笔-8
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 10 09:35:04 2017
@author: myhaspl@myhaspl.com,http://blog.csdn.net/myhaspl
"""
#逻辑或
import tensorflow as tf
batch_size=10
w1=tf.Variable(tf.random_normal([2,6],stddev=1,seed=1))
w2=tf.Variable(tf.random_normal([6,1],stddev=1,seed=1))
b=tf.Variable(tf.zeros([6]),tf.float32)
x=tf.placeholder(tf.float32,shape=(None,2),name="x")
y=tf.placeholder(tf.float32,shape=(None,1),name="y")
h=tf.matmul(x,w1)+b
yo=tf.matmul(h,w2)
#损失函数计算差异平均值
cross_entropy=tf.reduce_mean(tf.abs(y-yo))
#反向传播
train_step=tf.train.AdamOptimizer(0.05).minimize(cross_entropy)
#生成样本
x_=[[0.,0.],[0.,1.],[1.,0.],[1.,1.]]
y_=[[0.],[1.],[1.],[1.]]
b_=tf.zeros([6])
with tf.Session() as sess:
#初始化变量
init_op=tf.global_variables_initializer()
sess.run(init_op)
print sess.run(w1)
print sess.run(w2)
#设定训练轮数
TRAINCOUNT=500
for i in range(TRAINCOUNT):
#开始训练
sess.run(train_step,feed_dict={x:x_,y:y_})
if i%10==0:
total_cross_entropy=sess.run(cross_entropy,feed_dict={x:x_,y:y_})
print("%d 次训练之后,损失:%g"%(i+1,total_cross_entropy))
print(sess.run(w1))
print(sess.run(w2))
#生成测试样本,仅进行前向传播验证:
testyo=sess.run(yo,feed_dict={x:[[0.,1.],[1.,1.]]})
myout=[int(testout>0.5) for testout in testyo]
print myout
(Initialize initial 1st moment vector)
v_0 <- 0 (Initialize initial 2nd moment vector)
t <- 0 (Initialize timestep)
The update rule for variable with gradient g uses an optimization described at the end of section2 of the paper:
t <- t + 1
lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g
variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
The default value of 1e-8 for epsilon might not be a good default in general. For example, when training an Inception network on ImageNet a current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the formulation just before Section 2.1 of the Kingma and Ba paper rather than the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon hat" in the paper.
The sparse implementation of this algorithm (used when the gradient is an IndexedSlices object, typically because of tf.gather or an embedding lookup in the forward pass) does apply momentum to variable slices even if they were not used in the forward pass (meaning they have a gradient equal to zero). Momentum decay (beta1) is also applied to the entire momentum accumulator. This means that the sparse behavior is equivalent to the dense behavior (in contrast to some momentum implementations which ignore momentum unless a variable slice was actually used).
TF随笔-8的更多相关文章
- TF随笔-13
import tensorflow as tf a=tf.constant(5) b=tf.constant(3) res1=tf.divide(a,b) res2=tf.div(a,b) with ...
- TF随笔-11
#!/usr/bin/env python2 # -*- coding: utf-8 -*- import tensorflow as tf my_var=tf.Variable(0.) step=t ...
- TF随笔-10
#!/usr/bin/env python# -*- coding: utf-8 -*-import tensorflow as tf x = tf.constant(2)y = tf.constan ...
- TF随笔-9
计算累加 #!/usr/bin/env python2 # -*- coding: utf-8 -*-"""Created on Mon Jul 24 08:25:41 ...
- TF随笔-7
求平均值的函数 reduce_mean axis为1表示求行 axis为0表示求列 >>> xxx=tf.constant([[1., 10.],[3.,30.]])>> ...
- tf随笔-6
import tensorflow as tfx=tf.constant([-0.2,0.5,43.98,-23.1,26.58])y=tf.clip_by_value(x,1e-10,1.0)ses ...
- tf随笔-5
# -*- coding: utf-8 -*-import tensorflow as tfw1=tf.Variable(tf.random_normal([2,6],stddev=1))w2=tf. ...
- TF随笔-4
>>> import tensorflow as tf>>> a=tf.constant([[1,2],[3,4]])>>> b=tf.const ...
- TF随笔-3
>>> import tensorflow as tf>>> node1 = tf.constant(3.0, dtype=tf.float32)>>& ...
随机推荐
- 20145310 《Java程序设计》第8周学习总结
20145310 <Java程序设计>第8周学习总结 教材学习内容总结 本周主要进行第十四章和第十五章的学习. 第十四章 NIO使用频道(channel)来衔接数据节点,对数据区的标记提供 ...
- 20145327 《Java程序设计》第三周学习总结
20145327 <Java程序设计>第三周学习总结 教材学习内容总结 对象:存在的具体实体,具有明确的状态和行为. 类:具有相同属性和行为的一组对象的集合,用于组合各个对象所共有操作和属 ...
- 1970年1月1日(00:00:00 GMT)Unix 时间戳(Unix Timestamp)
转载自(http://jm.ncxyol.com/post-88.html) 今天在看Python API时,看到time模块: The epoch is the point where the ...
- linux按照进程名杀掉进程
1.按照进程名杀掉进程 ps -ef | grep sftp | grep mysql |grep -v grep | awk '{print("kill -9 ", ...
- 在centos 6.9下Protocol Buffers数据传输及存储协议的使用(python)
我们知道Protocol Buffers是Google定义的一种跨语言.跨平台.可扩展的数据传输及存储的协议,因为将字段协议分别放在传输两端,传输数据中只包含数据本身,不需要包含字段说明,所以传输数据 ...
- "ImportError: cannot import name OVSLegacyKernelSwitch"
My VirtualMachine OS is Ubuntu14.04, Linux. Recently I want to install the mininet in my OS, and I u ...
- python ConfigParse模块(转)
最近写程序要用到配置文件,那么配置文件的解析就很重要了,下文转自chinaunix 一.ConfigParser简介 ConfigParser 是用来读取配置文件的包.配置文件的格式如下:中括号“[ ...
- javaScript tips —— z-index 对事件机制的影响
demo // DOM结构 class App extends React.Component { componentDidMount() { const div1 = document.getEle ...
- 在网页链接中打开qq或者微信
打开微信: 先说第一种,大家知道,在自己的微信资料里有个二维码,别人扫描后可以查看你的资料添加你,把二维码扫描后,得到的地址是:http://weixin.qq.com/r/ykzexmzEPzFAr ...
- JavaScript Browser 对象 实例
使用JavaScript来访问和控制浏览器对象实例. Window 对象 弹出一个警告框 弹出一个带折行的警告框 弹出一个确认框,并提醒访客点击的内容 弹出一个提示框 点击一个按钮时,打开一个新窗口 ...