python,tensorflow,CNN实现mnist数据集的训练与验证正确率
1.工程目录
2.导入data和input_data.py
链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA
提取码:4nnl
3.CNN.py
import tensorflow as tf
import matplotlib.pyplot as plt
import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('MNIST ready') n_input = 784
n_output = 10 weights = {
'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], stddev=0.1)),
'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], stddev=0.1)),
'wd1': tf.Variable(tf.truncated_normal([7*7*128, 1024], stddev=0.1)),
'wd2': tf.Variable(tf.truncated_normal([1024, n_outpot], stddev=0.1)),
}
biases = {
'bc1': tf.Variable(tf.random_normal([64], stddev=0.1)),
'bc2': tf.Variable(tf.random_normal([128], stddev=0.1)),
'bd1': tf.Variable(tf.random_normal([1024], stddev=0.1)),
'bd2': tf.Variable(tf.random_normal([n_outpot], stddev=0.1)),
} def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].get_shape().as_list()[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out print('CNN READY') x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])
keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, weights, biases, keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))
optm = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
_corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))
init = tf.global_variables_initializer() print('GRAPH READY') sess = tf.Session()
sess.run(init)
training_epochs = 15
batch_size = 16
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepratio: 0.7})
avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.0})/total_batch if epoch % display_step == 0:
print('Epoch: %03d/%03d cost: %.9f' % (epoch, training_epochs, avg_cost))
train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.})
print('Training accuracy: %.3f' % (train_acc)) res_dict = {'weight': sess.run(weights), 'biases': sess.run(biases)} import pickle
with open('res_dict.pkl', 'wb') as f:
pickle.dump(res_dict, f, pickle.HIGHEST_PROTOCOL)
4.test.py
import pickle
import numpy as np def load_file(path, name):
with open(path+''+name+'.pkl', 'rb') as f:
return pickle.load(f) res_dict = load_file('', 'res_dict')
print(res_dict['weight']['wc1']) index = 0 import input_data
mnist = input_data.read_data_sets('data/', one_hot=True) test_image = mnist.test.images
test_label = mnist.test.labels import tensorflow as tf def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].shape[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out n_input = 784
n_output = 10 x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output]) keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, res_dict['weight'], res_dict['biases'], keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y)) _corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32)) init = tf.global_variables_initializer() sess = tf.Session()
sess.run(init)
training_epochs = 1
batch_size = 1
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) if epoch % display_step == 0:
print('_pre:', np.argmax(sess.run(_pred, feed_dict={x: batch_xs, keepratio: 1. })))
print('answer:', np.argmax(batch_ys))
python,tensorflow,CNN实现mnist数据集的训练与验证正确率的更多相关文章
- tensorflow中使用mnist数据集训练全连接神经网络-学习笔记
tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...
- tensorflow读取本地MNIST数据集
tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...
- Ubuntu14.04+caffe+cuda7.5 环境搭建以及MNIST数据集的训练与测试
Ubuntu14.04+caffe+cuda 环境搭建以及MNIST数据集的训练与测试 一.ubuntu14.04的安装: ubuntu的安装是一件十分简单的事情,这里给出一个参考教程: http:/ ...
- 十折交叉验证10-fold cross validation, 数据集划分 训练集 验证集 测试集
机器学习 数据挖掘 数据集划分 训练集 验证集 测试集 Q:如何将数据集划分为测试数据集和训练数据集? A:three ways: 1.像sklearn一样,提供一个将数据集切分成训练集和测试集的函数 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 学习TensorFlow,邂逅MNIST数据集
如果说"Hello Word!"是程序员的第一个程序,那么MNIST数据集,毫无疑问是机器学习者第一个训练的数据集,本文将使用Google公布的TensorFLow来学习训练MNI ...
- mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同
有一个关于mnist的一个事例可以参考,我觉得写的很好:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html #!/usr/bin/ ...
- Python Tensorflow CNN 识别验证码
Python+Tensorflow的CNN技术快速识别验证码 文章来源于: https://www.jianshu.com/p/26ff7b9075a1 验证码处理的流程是:验证码分析和处理—— te ...
- TensorFlow CNN 测试CIFAR-10数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...
随机推荐
- 总结day25 ---- udp 初识, 和tcp 进阶
前情提要 一: tcp 和udp 的区别 # tcp # # 面向连接的 可靠的 全双工的 流式传输 # # 面向连接 :同一时刻只能和一个客户端通信 # # 三次握手.四次挥手 # # 可靠的 :数 ...
- es6学习 1
块级作用域绑定 一 var 声明及变量提升(Hoisting)机制 在函数作用域或全局作用域中通过 var 声明的变量,无论实际上是在哪里声明的,都会被当成在当前作用域顶部声明的变量,这就是我们常说的 ...
- CentOS 7 分区方案
通常系统盘都会选择性能较好SSD,一般在500G左右,这里就以500G硬盘为例,以下为CentOS 自动分区方案: 分区应该按照实际服务器用途而定,自动分区方案将 /home 空间分配太多了,多数情况 ...
- trace跟踪代码运行
System.Diagnostics命名空间中. 1.Trace.Assert(a==12,"等于就不输出,不等于弹出对话框"); 名称 描述 Assert(Boolean ...
- Mac下运行git报错"xcrun: error: invalid active developer path .."
错误:xcrun: error: invalid active developer path (/Library/Developer/CommandLineTools), missing xcrun ...
- mono for android 第三课--页面布局(转)
对于C#程序员来说布局不是什么难事,可是对于我这个新手在mono for android 中布局还是有点小纠结的,不会没关系.慢慢学习.好吧我们开始简单的布局.在之前我们拖拽的控件都是自动的去布局,也 ...
- 不好意思啊,我上周到今天不到10天时间,用纯C语言写了一个小站!想拍砖的就赶紧拿出来拍啊
花10天时间用C语言做了个小站 http://tieba.yunxunmi.com/index.html 简称: 云贴吧 不好意思啊,我上周到今天不到10天时间,用纯C语言写了一个小站!想拍砖的就赶紧 ...
- (转)Memcached 之 .NET(C#)实例分析
一:Memcached的安装 step1. 下载memcache(http://jehiah.cz/projects/memcached-win32)的windows稳定版(这里我下载了memcach ...
- 模块和处理程序之通过HttpModule和HttpHandler拦截入站HTTP请求执行指定托管代码模块
1.简介 大多数情况下,作为一个asp.net web开发对整个web应用程序的控制是十分有限的,我们的控制往往只能做到对应用程序(高层面)的基本控制.但是,很多时候,我们需要能够低级层面进行交互,例 ...
- 测试Servlet生命周期例子程序
写一个类TestLifeCycleServlet,生成构造器TestLifeCycleServlet();重写HttpServlet的doGet();重写GenericServlet的destroy( ...