使用mnist数据集进行神经网络的构建

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/', one_hot=True)

这个神经网络共有三层。输入层有n个1*784的矩阵,第一层有256个神经元,第二层有128个神经元,输出层是一个十分类的结果。对w1、b1、w2、b2以及输出层的参数进行随机初始化

# NETWORK TOPOLOGIES
n_input = 784
n_hidden_1 = 256
n_hidden_2 = 128
n_classes = 10 # INPUTS AND OUTPUTS
x = tf.placeholder("float", [None, n_input])
y = tf.placeholder("float", [None, n_classes]) # NETWORK PARAMETERS
stddev = 0.1
weights = {
'w1': tf.Variable(tf.random_normal([n_input, n_hidden_1], stddev=stddev)),
'w2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2], stddev=stddev)),
'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes], stddev=stddev))
}
biases = {
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
print ("NETWORK READY")

开始进行前向传播

def multilayer_perceptron(_X, _weights, _biases):
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X, _weights['w1']), _biases['b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, _weights['w2']), _biases['b2']))
return (tf.matmul(layer_2, _weights['out']) + _biases['out'])

用前向传播函数算出预测值;算出损失值(此处使用交叉熵);构造梯度下降最优构造器;算出精度;

# PREDICTION
pred = multilayer_perceptron(x, weights, biases) # LOSS AND OPTIMIZER
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(corr, "float")) # INITIALIZER
init = tf.global_variables_initializer()
print ("FUNCTIONS READY")

定义迭代次数;使用以上定义好的神经网络函数

training_epochs = 20
batch_size = 100
display_step = 4
# LAUNCH THE GRAPH
sess = tf.Session()
sess.run(init)
# OPTIMIZE
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
# ITERATION
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
avg_cost += sess.run(cost, feed_dict=feeds)
avg_cost = avg_cost / total_batch
# DISPLAY
if (epoch+1) % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print ("TRAIN ACCURACY: %.3f" % (train_acc))
feeds = {x: mnist.test.images, y: mnist.test.labels}
test_acc = sess.run(accr, feed_dict=feeds)
print ("TEST ACCURACY: %.3f" % (test_acc))
print ("OPTIMIZATION FINISHED")

tensorflow学习笔记六----------神经网络的更多相关文章

  1. TensorFlow学习笔记——深层神经网络的整理

    维基百科对深度学习的精确定义为“一类通过多层非线性变换对高复杂性数据建模算法的合集”.因为深层神经网络是实现“多层非线性变换”最常用的一种方法,所以在实际中可以认为深度学习就是深度神经网络的代名词.从 ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  3. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  4. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  5. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  6. tensorflow学习笔记(3)前置数学知识

    tensorflow学习笔记(3)前置数学知识 首先是神经元的模型 接下来是激励函数 神经网络的复杂度计算 层数:隐藏层+输出层 总参数=总的w+b 下图为2层 如下图 w为3*4+4个   b为4* ...

  7. tensorflow学习笔记(1)-基本语法和前向传播

    tensorflow学习笔记(1) (1)tf中的图 图中就是一个计算图,一个计算过程.                                       图中的constant是个常量 计 ...

  8. tensorflow学习笔记——自编码器及多层感知器

    1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...

  9. tensorflow学习笔记——VGGNet

    2014年,牛津大学计算机视觉组(Visual Geometry Group)和 Google DeepMind 公司的研究员一起研发了新的深度卷积神经网络:VGGNet ,并取得了ILSVRC201 ...

随机推荐

  1. linux-ssh加密与https安全-9

    非对称加密算法:RSA,DSA/DSS 对称加密算法:AES,RC4,3DES HASH算法:MD5,SHA1,SHA256 hash就是找到一种数据内容和数据存放地址之间的映射关系 (1) 文件校验 ...

  2. matlab中画三维图形

    这里主要讲述两个方法用matlab画三维图形: 1.mesh函数 先看一个简单的例子: x = ::; y = ::; [X, Y] = meshgrid(x, y); Z = zeros(,); Z ...

  3. UVa 11212 Editing a Book (IDA* && 状态空间搜索)

    题意:你有一篇n(2≤n≤9)个自然段组成的文章,希望将它们排列成1,2,…,n.可以用Ctrl+X(剪切)和Ctrl+V(粘贴)快捷键来完成任务.每次可以剪切一段连续的自然段,粘贴时按照顺序粘贴.注 ...

  4. sh_10_分隔线模块

    sh_10_分隔线模块 def print_line(char, times): """打印单行分隔线 :param char: 分隔字符 :param times: 重 ...

  5. (57)Linux驱动开发之三Linux字符设备驱动

    1.一般情况下,对每一种设备驱动都会定义一个软件模块,这个工程模块包含.h和.c文件,前者定义该设备驱动的数据结构并声明外部函数,后者进行设备驱动的具体实现. 2.典型的无操作系统下的逻辑开发程序是: ...

  6. java list去重方式,以及效率问题

    之前面试被问到关于java如何去重的问题,当时没怎么留意,今天刚好项目中用到了,所以记录一下. 实体类: /** * 用户类 */ class User{ private String usernam ...

  7. SQL Server服务起不了

    转载 MSSQLSERVER服务无法启动的解决方案   1.IP地址配置不正确: 打开 Microsoft SQL Server 2005配置工具下的SQL Server Configuration ...

  8. NBU5240备份系统还原数据库---Windows版

    NBU5240是一个基于系统文件和多种数据库备份的灾备系统,灵活性比较高.下面具体记录如何利用该系统的备份文件进行数据库还原.(基于业务场景) 公司某业务部门突然发现前台系统数据有异常,已经是几天前的 ...

  9. 微信小程序 视频 组件

    video 组件 视频组件 相关的api :wx.createVideoContext 支持的格式: 支持的编码格式 video 组件的属性: src:类型 字符串 必填 要播放视频的资源地址 (支持 ...

  10. 3,、maven修改jar包下载为国内镜像下载地址

    maven 默认的中央仓库是在国外的服务器,下载速度慢,有时候稍不注意就下载出错 通常我将maven的中央仓库修改为阿里云的地址,下载速度很快体验非常好 修改conf下的setting.xml文件 在 ...