import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("/data/stu05/mnist_data",one_hot=True)
 
 
Extracting /data/stu05/mnist_data/train-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/train-labels-idx1-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-images-idx3-ubyte.gz
Extracting /data/stu05/mnist_data/t10k-labels-idx1-ubyte.gz
 
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
#定义两个placeholder,None=100,28*28=784,即100行,784列
x = tf.placeholder(tf.float32,[None,784])
#0-9个输出标签
y = tf.placeholder(tf.float32,[None,10])
#创建一个简单的神经网络,只有输入层和输出层
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([1,10]))
#softmax函数转化为概率值
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#tf.equal()比较函数大小是否相同,相同为True,不同为false;tf.argmax():求y=1在哪个位置,求概率最大在哪个位置
#argmax返回一维张量中最大的值所在的位置,结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率
#cast转化类型,将布尔型转化为32位浮点型,True=1.0,False=0.0;再求平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
    sess.run(init)
    #将所有图片训练21次
    for epoch in range(21):
        #训练一次所有的图片
        for batch in range(n_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            #feed_dict传入训练集的图片和标签
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        #传入测试集的图片和标签
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter"+str(epoch)+",Testing Accuracy:"+str(acc))
 
 
 
Iter0,Testing Accuracy:0.8303
Iter1,Testing Accuracy:0.8708
Iter2,Testing Accuracy:0.8821
Iter3,Testing Accuracy:0.8885
Iter4,Testing Accuracy:0.8941
Iter5,Testing Accuracy:0.8973
Iter6,Testing Accuracy:0.9001
Iter7,Testing Accuracy:0.9013
Iter8,Testing Accuracy:0.9038
Iter9,Testing Accuracy:0.9048
Iter10,Testing Accuracy:0.9068
Iter11,Testing Accuracy:0.9068
Iter12,Testing Accuracy:0.9084
Iter13,Testing Accuracy:0.9094
Iter14,Testing Accuracy:0.9097
Iter15,Testing Accuracy:0.9107
Iter16,Testing Accuracy:0.9118
Iter17,Testing Accuracy:0.9116
Iter18,Testing Accuracy:0.9127
Iter19,Testing Accuracy:0.9136
Iter20,Testing Accuracy:0.9146
 
 
 
 
 
 
 

MNIST数据集分类简单版本的更多相关文章

  1. 6.MNIST数据集分类简单版本

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = i ...

  2. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  3. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  4. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  5. 深度学习(一)之MNIST数据集分类

    任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...

  6. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  7. 神经网络MNIST数据集分类tensorboard

    今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...

  8. 卷积神经网络应用于MNIST数据集分类

    先贴代码 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  9. MNIST数据集

    一.MNIST数据集分类简单版本 import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data # ...

随机推荐

  1. 在Ubuntu16.04上使用rz上传文件,XXX was skipped

    原本想把hadoop-2.8.5.tar.gz上传到/usr/local/src文件夹下,报错,was skipped 如下图: 换个文件夹位置,更换到本用户文件夹下,可以上传,说明是对文件夹操作权限 ...

  2. 局部变量和static变量的区别

    static int a ; int b; scanf_s("%d %d",&a,&b); 01374212 lea eax,[b] 01374215 push e ...

  3. Linux安装tomcat服务器

    1.下载tomcat(区分windows和Linux,以tar.gz为后缀名的是Linux操作系统使用的). 官网下载地址:http://test.m.xiaoyuanhao.com/micro/ap ...

  4. CentOS7下源码包方式安装Erlang

    1.官网上下载源码包:OTP 19.1 Source File 2.把源码放在source目录中 , 解压 :tar -zxvf otp_src_19.1.tar.gz [或者 直接下载 rpm包 e ...

  5. Visual Studio OpenCV 开发环境配置

    因为VS配置OpenCV好多新手都很难一次配置成功,而且OpenCV库每新建一个项目都要配置很是麻烦,所以今天就给大家介绍一个“一劳永逸”的方法. 注:理论上只要VS和OpenCV是版本兼容的,该方法 ...

  6. Python基础入门-函数实战登录功能

    ''' 函数实战: .加法计算器 .过滤器 .登录功能实战 ''' def add(a,b): return a+b def login_order(): return 'asdfasdfdasfad ...

  7. .NET Framework各版本特性一览

    https://msdn.microsoft.com/en-us/library/bb822049.aspx .NET Framework version CL version Features In ...

  8. PostgreSQL 速查、备忘手册 | PostgreSQL Quick Find and Tutorial

    PostgreSQL 速查.备忘手册 作者:汪嘉霖 这是一个你可能需要的一个备忘手册,此手册方便你快速查询到你需要的常见功能.有时也有一些曾经被使用过的高级功能.如无特殊说明,此手册仅适用于 Linu ...

  9. vitamio MediaController总是显示在底部的问题

    前面一直用腾讯的x5 tas来播放视频,但是体验效果不好,不能设置播放页,无法获取用户对视频的学习情况,百度了下,发现好多人在使用vitamio,最新版本是5.0的,下载可能要花费点时间,官网上竟然没 ...

  10. Visual Assist X破解安装及设置

    本文提供的插件版本为Visual Assist X 10.9.2248,支持Visual Studio 2010~2017各版本,本人亲测均可正常使用. 一. 插件下载: 点击下载链接,找到对应软件下 ...