import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets("data/",one_hot = True)
#导入Tensorflwo和mnist数据集 #构建输入层
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y') #隐藏层神经元数量
H1_NN = 256 #第一层神经元数量
W1 = tf.Variable(tf.random_normal([784,H1_NN])) #权重
b1 = tf.Variable(tf.zeros([H1_NN])) #偏置项
Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #第一层输出
W2 = tf.Variable(tf.random_normal([H1_NN,10]))#权重
b2 = tf.Variable(tf.zeros(10))#偏置项 forward = tf.matmul(Y1,W2)+b2 #定义前向传播
pred = tf.nn.softmax(forward) #激活函数输出 #损失函数
#loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
# reduction_indices=1))
#(log(0))超出范围报错 loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=forward,labels=y)) #训练参数
train_epochs = 50 #训练次数
batch_size = 50 #每次训练多少个样本
total_batch = int(mnist.train.num_examples/batch_size) #随机抽取样本
display_step = 1 #训练情况输出
learning_rate = 0.01 #学习率 #优化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function) #准确率函数
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #记录开始训练时间
from time import time
startTime = time()
#初始化变量
sess =tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
#训练
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)#读取批次数据
sess.run(opimizer,feed_dict={x:xs,y:ys})#执行批次数据训练 #total_batch个批次训练完成后,使用验证数据计算误差与准确率
loss,acc=sess.run([loss_function,accuracy],
feed_dict={x:mnist.validation.images,
y:mnist.validation.labels})
#输出训练情况
if(epoch+1) % display_step == 0:
print("Train Epoch:",'%02d' % (epoch + 1),
"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#显示预测耗时 #由于pred预测结果是one_hot编码格式,所以需要转换0~9数字
prediction_resul = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) prediction_resul[0:10] #模型评估
accu_test = sess.run(accuracy,
feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Accuray:",accu_test) compare_lists = prediction_resul == np.argmax(mnist.test.labels,1)
print(compare_lists)
err_lists = [i for i in range(len(mnist.test.labels)) if compare_lists[i] == False]
print(err_lists,len(err_lists)) index_list = []
def print_predct_errs(labels,#标签列表
perdiction):#预测值列表
count = 0
compare_lists = (perdiction == np.argmax(labels,1))
err_lists = [i for i in range(len(labels)) if compare_lists[i] == False]
for x in err_lists:
index_list.append(x)
print("index="+str(x)+
"标签值=",np.argmax(labels[x]),
"预测值=",perdiction[x])
count = count+1
print("总计:",count)
return index_list print_predct_errs(mnist.test.labels,prediction_resul) def plot_images_labels_prediction(images,labels,prediction,index,num=25):
fig = plt.gcf() # 获取当前图片
fig.set_size_inches(10,12)
if num>=25:
num=25 #最多显示25张图片
for i in range(0,num):
ax = plt.subplot(5,5, i+1) #获取当前要处理的子图 ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')#显示第index个图像
title = 'label=' + str(np.argmax(labels[index]))#构建该图上要显示的title
if len(prediction)>0:
title += 'predict= '+str(prediction[index]) ax.set_title(title,fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
index += 1
plt.show() plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_resul,index=index_list[100])

单纯记录一下个人代码,很基础的一个MNIST手写识别使用Tensorflwo实现,算是入门的Hello world 了,有些奇怪的问题暂时没有解决 训练次数调成40 在训练到第35次左右发生了梯度爆炸,原因未知,损失函数要使用带softmax那个,不然也会发生梯度爆炸

使用tensorflow实现mnist手写识别(单层神经网络实现)的更多相关文章

  1. 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环.我也是! 这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟.在 ...

  2. 基于tensorflow实现mnist手写识别 (多层神经网络)

    标题党其实也不多,一个输入层,三个隐藏层,一个输出层 老样子先上代码 导入mnist的路径很长,现在还记不住 import tensorflow as tf import tensorflow.exa ...

  3. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  4. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  5. tensorflow笔记(四)之MNIST手写识别系列一

    tensorflow笔记(四)之MNIST手写识别系列一 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7436310.html ...

  6. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  9. Tensorflow编程基础之Mnist手写识别实验+关于cross_entropy的理解

    好久没有静下心来写点东西了,最近好像又回到了高中时候的状态,休息不好,无法全心学习,恶性循环,现在终于调整的好一点了,听着纯音乐突然非常伤感,那些曾经快乐的大学时光啊,突然又慢慢的一下子出现在了眼前, ...

随机推荐

  1. 更改win系统的鼠标样式

    一.找一个你心仪的鼠标样式(.cur文件),并放到 C:\Windows\Cursors 目录下 二.打开,控制面板 -> 硬件和声音 -> 鼠标 ,如下图: 三.浏览鼠标目录,找到你存放 ...

  2. 记录使用git submodule时踩的坑

    在使用git子模块的时候踩了一个坑 在使用git submodule updata --init --recursive命令,即递归更新子模块并初始化时碰到了一个问题: 经过一段不短时间的排查,发现问 ...

  3. Pandas Learning

    Panda Introduction Pandas 是基于 NumPy 的一个很方便的库,不论是对数据的读取.处理都非常方便.常用于对csv,json,xml等格式数据的读取和处理. Pandas定义 ...

  4. 028、HTML 标签3表单标签插入组件

    内容:表单标签插入组件(经常使用)############################################################## form表单标签和input组件 < ...

  5. html简单介绍(一)

    什么是html HTML 是用来描述网页的一种语言.HTML 指的是超文本标记语言 (Hyper Text Markup Language)HTML 不是一种编程语言,而是一种标记语言 (markup ...

  6. MySQL数据库常用操作和技巧

    MySQL数据库可以说是DBA们最常见和常用的数据库之一,MySQL的广泛应用,也使更多的人加入到学习它的行列之中.下面是老MySQL DBA总结的MySQL数据库最常见和最常使用的一些经验和技巧,分 ...

  7. Scala学习之路 (二)使用IDEA开发Scala

    目前Scala的开发工具主要有两种:Eclipse和IDEA,这两个开发工具都有相应的Scala插件,如果使用Eclipse,直接到Scala官网下载即可http://scala-ide.org/do ...

  8. Python第三方模块--requests简单使用

    1.requests简介 requests是什么?python语言编写的,基于urllib的第三方模块 与urllib有什么关系?urllib是python的内置模块,比urllib更加简洁和方便使用 ...

  9. 代码段:js表单提交检测

    市面上当然有很多成型的框架,比如jquery的validation插件各种吧.现在工作地,由于前端童鞋也没用这些个插件.根据业务的要求,自己就在代码里写了个简单的表单提交的检测代码(php的也写了一个 ...

  10. php判断一个数组是否为另一个数组子集的方法

    原文地址http://www.jbxue.com/article/14703.html // 快速的判断$a数组是否是$b数组的子集  $a = array(135,138);  $b = array ...