1. import numpy as np
  2. import tensorflow as tf
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. def add_layer(inputs,in_size,out_size,activation_function=None):
  5. W=tf.Variable(tf.random_normal([in_size,out_size]))
  6. b=tf.Variable(tf.zeros([1,out_size])+0.01)
  7. Z=tf.matmul(inputs,W)+b
  8. if activation_function is None:
  9. out_puts=Z
  10. else:
  11. out_puts=activation_function(Z)
  12. return out_puts
  13. if __name__=="__main__":
  14. MINST=input_data.read_data_sets("./",one_hot=True)
  15. learning_rate=0.05
  16. batch_size=128
  17. n_epochs=10
  18. X=tf.placeholder(tf.float32,[batch_size,784])
  19. Y=tf.placeholder(tf.float32,[batch_size,10])
  20. L1=add_layer(X,784,1000,tf.nn.relu)
  21. prediction=add_layer(L1,1000,10)
  22. entropy=tf.nn.softmax_cross_entropy_with_logits(labels=Y,logits=prediction)
  23. loss=tf.reduce_mean(entropy)
  24. optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
  25. init=tf.global_variables_initializer()
  26. with tf.Session() as sess:
  27. sess.run(init)
  28. n_batches=int(MINST.train.num_examples/batch_size)
  29. for i in range(n_epochs):
  30. for j in range(n_batches):
  31. X_batch,Y_batch=MINST.train.next_batch(batch_size=batch_size)
  32. _,loss_=sess.run([optimizer,loss],feed_dict={
  33. X:X_batch,
  34. Y:Y_batch
  35. })
  36. if j == 0:
  37. print("Loss of epochs[{0}] batch[{1}]: {2}".format(i, j, loss_))
  38.  
  39. # test the model
  40. n_batches = int(MINST.test.num_examples / batch_size)
  41. total_correct_preds = 0
  42. for i in range(n_batches):
  43. X_batch, Y_batch = MINST.test.next_batch(batch_size)
  44. preds = sess.run(prediction, feed_dict={X: X_batch, Y: Y_batch})
  45. correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y_batch, 1))
  46. accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))
  47.  
  48. total_correct_preds += sess.run(accuracy)
  49.  
  50. print("Accuracy {0}".format(total_correct_preds / MINST.test.num_examples))

我们不做卷积。直接将x输入到网络中去。最后用softmax作为激活函数

大概结构,我这里没法上传,等我回去在传。

使用一层神经网络训练mnist数据集的更多相关文章

  1. TensorFlow初探之简单神经网络训练mnist数据集(TensorFlow2.0代码)

    from __future__ import print_function from tensorflow.examples.tutorials.mnist import input_data #加载 ...

  2. mxnet卷积神经网络训练MNIST数据集测试

    mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存 import numpy as np import mxnet as mx import logging logging. ...

  3. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  4. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  5. Python实现bp神经网络识别MNIST数据集

    title: "Python实现bp神经网络识别MNIST数据集" date: 2018-06-18T14:01:49+08:00 tags: [""] cat ...

  6. TensorFlow——CNN卷积神经网络处理Mnist数据集

    CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...

  7. deep_learning_LSTM长短期记忆神经网络处理Mnist数据集

    1.RNN(Recurrent Neural Network)循环神经网络模型 详见RNN循环神经网络:https://www.cnblogs.com/pinard/p/6509630.html 2. ...

  8. TensorFlow——LSTM长短期记忆神经网络处理Mnist数据集

    1.RNN(Recurrent Neural Network)循环神经网络模型 详见RNN循环神经网络:https://www.cnblogs.com/pinard/p/6509630.html 2. ...

  9. TensorFlow 训练MNIST数据集(2)—— 多层神经网络

    在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...

随机推荐

  1. Octave及其工具包的安装

    Octave的安装: http://blog.sina.com.cn/s/blog_1358485f70102wmpa.html http://www.gnu.org/software/octave/ ...

  2. inotify-tools命令使用讲解

    inotify-tools 是为linux下inotify文件监控工具提供的一套c的开发接口库函数,同时还提供了一系列的命令行工具,这些工具可以用来监控文件系统的事件. inotify-tools是用 ...

  3. 【SqlServer】SQL Server的常用函数

    字符串函数 SubString():用于截取指定字符串的方法.该方法有三个参数:参数1:用于指定要操作的字符串.参数2:用于指定要截取的字符串的起始位置,起始值为 1 .参数3:用于指定要截取的长度. ...

  4. 【java】switch case支持的6种数据类型

    switch表达式后面的数据类型只能是byte,short,char,int四种整形类型,枚举类型和java.lang.String类型(从java 7才允许),不能是boolean类型. 在网上看到 ...

  5. solr开发从查询结果集中获取对象数据

    solrJ从查询结果集中获取对象数据. 方案一:自定义转换方式 /** * * SolrDocument与实体类转换 [测试通过] * * @author pudongping * * @param ...

  6. JsonPath小结

    在查看DHC Assertions 模块说明的时候,无意间发现assert模块中JsonBody使用了 JSON Path ,兴趣使然,看了下,发现是类似解析xml用到的 XPath.通过路径来获取j ...

  7. Knockout: radio选项切换引发click事件的一点总结

    1.场景:如下图,当选择定期存款时,输入框右边出现红色的必输项星号,当选择活期存款时,不再出现该星号. 2.思路一:不使用knockout,直接用click事件,就可以实现这个需求,代码如下: < ...

  8. 常用代码之五:RequireJS, 一个Define需要且只能有一个返回值/对象,一个JS文件里只能放一个Define.

    RequireJS 介绍说一个JS文件里只能放一个Define,这个众所周知,不提. 关于Define,它需要有一个返回值/对象,且只能有一个返回值/对象,这一点却是好多帖子没有提到的,但又非常重要的 ...

  9. docker-compose教程(安装,使用, 快速入门)

    1.Compose介绍Docker Compose是一个用来定义和运行复杂应用的Docker工具.一个使用Docker容器的应用,通常由多个容器组成.使用Docker Compose不再需要使用she ...

  10. CentOS 安装Mosquitto及测试

    系统信息,阿里云服务器 安装工具 yum install gcc gcc-c++ yum install openssl-devel yum install c-ares-devel yum inst ...