TensorFlow基础(二)实现神经网络
(1)前向传播算法
神经网络的前向传播算法主要构成部分:
1.神经网络的输入;
2.神经网络的连接结构;神经网络是由神经元(节点)构成的
3.每个神经元中的参数。
(2)TensorFlow随机数生成函数
函数名称 | 随机数分布 | 主要参数 |
tf.random_normal | 正态分布 | 平均值、标准差、取值类型 |
tf.truncated_normal | 正态分布,如果随机出来的值偏离平均值超过2个标准差,那么这个数将被重新随机 | 平均值、标准差、取值类型 |
tf.random_uniform | 均匀分布 | 最小、最大取值,取值类型 |
tf.random_gamma | Gamma分布 | 形状参数alpha、尺度参数beta、取值类型 |
(3)TensorFlow常数生成函数
函数名称 | 功能 | 样例 |
tf.zeros | 产生全0的数组 | tf.zeros([2,3],int32)-->[[0,0,0],[0,0,0]] |
tf.ones | 产生全1的数组 | tf.ones([2,3],int32)-->[[1,1,1],[1,1,1]] |
tf.fill | 产生一个全部为给定数字的数组 | tf.fill([2,3],9)-->[[9,9,9],[9,9,9]] |
tf.constant | 产生一个给定值的常量 | tf.constant([1,2,3])-->[1,2,3] |
(4)完整神经网络Python代码
# -*- coding: utf-8 -*- import tensorflow as tf
from numpy.random import RandomState #定义训练数据batch的大小
batch_size = 8
"""
#定义神经网络的参数
random_normal:随机生成函数,
随机数分布正太分布,2x3矩阵,标准差为1
随机数种子序号1,种子序号相同,产生的随机数相同
"""
w1=tf.Variable(tf.random_normal((2,3),stddev=1,seed=1))
w2=tf.Variable(tf.random_normal((3,1),stddev=1,seed=1)) """
定义placeholder作为存放输入数据的地方
参数:数据类型,数据维度
"""
x=tf.placeholder(tf.float32, shape=(None,2), name="x-input")
y_=tf.placeholder(tf.float32, shape=(None,1), name="y-input") """
matmul():矩阵乘法函数
"""
a=tf.matmul(x,w1)
y=tf.matmul(a,w2) """
使用sigmoid函数将y转换为0-1之间的数值,转换后的y代表预测是正样本的概率,1-y代表预测是负样本的概率。
"""
y=tf.sigmoid(y) #定义损失函数来刻画预测值与真实值的差距
cross_entropy=-tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-10,1.0))+(1-y)*tf.log(tf.clip_by_value(1-y,1e-10,1.0)))
#定义学习率
learning_rate=0.001
#定义反向传播算法优化神经网络的参数
train_step=tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy) #通过随机数生成一个模拟数据集
rdm=RandomState(1)
dataset_size = 128
X=rdm.rand(dataset_size,2)
print(X)
#0表示负样本,1表示正样本
Y=[[int(x1+x2<1)] for (x1,x2) in X]
print(Y) #创建一个会话来运行Tensorflow程序
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run(w1))
print(sess.run(w2)) STEPS=5000
for i in range(STEPS):
#每次选取batch_size个样本进行训练
#%取余
#min()函数取最小值
start = (i*batch_size) % dataset_size
end = min(start+batch_size,dataset_size) #通过选取的样本训练神经网络并更新参数
sess.run(train_step, feed_dict={x: X[start:end],y_ :Y[start:end]})
if i%1000 == 0:
total_cross_entropy=sess.run(cross_entropy,feed_dict={x: X, y_: Y} )
print("After %d training step(s), cross entropy on all data is %g" %(i, total_cross_entropy))
print(sess.run(w1))
print(sess.run(w2))
TensorFlow基础(二)实现神经网络的更多相关文章
- TensorFlow基础二(Shape)
首先说明tf中tensor有两种shape,分别为static (inferred) shape和dynamic (true) shape,其中static shape用于构建图,由创建这个tenso ...
- TensorFlow学习笔记——深层神经网络的整理
维基百科对深度学习的精确定义为“一类通过多层非线性变换对高复杂性数据建模算法的合集”.因为深层神经网络是实现“多层非线性变换”最常用的一种方法,所以在实际中可以认为深度学习就是深度神经网络的代名词.从 ...
- TensorFlow基础剖析
TensorFlow基础剖析 一.概述 TensorFlow 是一个使用数据流图 (Dataflow Graph) 表达数值计算的开源软件库.它使 用节点表示抽象的数学计算,并使用 OP 表达计算的逻 ...
- TensorFlow基础
TensorFlow基础 SkySeraph 2017 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站点:www.skyseraph.com Over ...
- 机器学习与Tensorflow(2)——神经网络及Tensorflow实现
神经网络算法以及Tensorflow的实现 一.多层向前神经网络(Multilayer Feed-Forward Neural Network) 多层向前神经网络由三部分组成:输入层(input la ...
- NO.2:自学tensorflow之路------BP神经网络编程
引言 在上一篇博客中,介绍了各种Python的第三方库的安装,本周将要使用Tensorflow完成第一个神经网络,BP神经网络的编写.由于之前已经介绍过了BP神经网络的内部结构,本文将直接介绍Tens ...
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...
- Python全栈开发【基础二】
Python全栈开发[基础二] 本节内容: Python 运算符(算术运算.比较运算.赋值运算.逻辑运算.成员运算) 基本数据类型(数字.布尔值.字符串.列表.元组.字典) 其他(编码,range,f ...
- Bootstrap <基础二十九>面板(Panels)
Bootstrap 面板(Panels).面板组件用于把 DOM 组件插入到一个盒子中.创建一个基本的面板,只需要向 <div> 元素添加 class .panel 和 class .pa ...
- Bootstrap <基础二十八>列表组
列表组.列表组件用于以列表形式呈现复杂的和自定义的内容.创建一个基本的列表组的步骤如下: 向元素 <ul> 添加 class .list-group. 向 <li> 添加 cl ...
随机推荐
- .net托管资源与非托管资源
在项目当中用到的资源分为托管资源和非托管资源,托管资源无非就是什么int.string.datatime之类,托管资源不需要人为去管理,.net framework中有专门针对托管资源的管理机制(GC ...
- MYSQL连接字符串参数解析(解释)
被迫转到MySQL数据库,发现读取数据库时,tinyint类型的值都被转化为boolean了,这样大于1的值都丢失,变成true了.查阅资料MySQL中无Boolean类型,都是存储为tinyint了 ...
- group by 语句
user E_book go 这样的程序会出错,因为play没有使用sum,所以要分组. group by play 有函数的和没有函数的表一起使用要用 GROUP BY .AVG 求平均值,只能与数 ...
- 阿里云、青云、腾讯云服务器,Mysql数据库,Redis等产品性能对比
阿里云.青云.腾讯云服务器,Mysql数据库,Redis等产品都使用过,对比维度很多就不一一放出.直接放结论吧:买的腾讯(金融专区)服务器,Mysql(TDSql)把所有项目转到腾讯云,但是没有用腾讯 ...
- BestCoder Round #92
这里是逢比赛必挂的智障选手ysf…… 不知道是因为自己菜还是心态不好……也许是后者吧,毕竟每次打比赛的时候都会很着急.lrd说我打比赛的功利性太强,想想确实是这样. 昨天打完之后自觉身败名裂没敢写出来 ...
- Git Bash Here常用命令以及使用步骤
1.首先,要clone项目代码: git clone 链接地址 2.更新代码: git pull 3.添加修改过的文件.文件夹: git add 修改过的文件,文件夹 4.提交并注释: git com ...
- Java Struts2 (四)
一.contextMap中的数据操作 root根:List 元素1 元素2 元素3 元素4 元素5 contextMap:Map key value application Map key value ...
- ASPF(Application Specific Packet Filter)
ASPF ASPF(Application Specific Packet Filter)是针对应用层的包过滤,其原理是检测通过设备的报文的应用层协议信息,记录临时协商的数据连接,使得某些在安全策略中 ...
- QQ 聊天机器人小薇 2.0.0 发布!
本次发布主要加入了支持讨论组聊天,并增强了稳定性.另外,官方小薇 QQ 机器人已经下线,大家要体验的话请 自建私服~ 简介 XiaoV(小薇)是一个用 Java 写的 QQ 聊天机器人 Web 服务, ...
- RocketMQ读书笔记2——生产者
[生产者的不同写入策略] 生产者向消息队列里写入数据,不同的业务需要生产者采用不同的写入策略: 同步发送.异步发送.延迟发送.发送事务消息等. [DefaultMQProduce示例] public ...