python,tensorflow,CNN实现mnist数据集的训练与验证正确率
1.工程目录
2.导入data和input_data.py
链接:https://pan.baidu.com/s/1EBNyNurBXWeJVyhNeVnmnA
提取码:4nnl
3.CNN.py
import tensorflow as tf
import matplotlib.pyplot as plt
import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('MNIST ready') n_input = 784
n_output = 10 weights = {
'wc1': tf.Variable(tf.truncated_normal([3, 3, 1, 64], stddev=0.1)),
'wc2': tf.Variable(tf.truncated_normal([3, 3, 64, 128], stddev=0.1)),
'wd1': tf.Variable(tf.truncated_normal([7*7*128, 1024], stddev=0.1)),
'wd2': tf.Variable(tf.truncated_normal([1024, n_outpot], stddev=0.1)),
}
biases = {
'bc1': tf.Variable(tf.random_normal([64], stddev=0.1)),
'bc2': tf.Variable(tf.random_normal([128], stddev=0.1)),
'bd1': tf.Variable(tf.random_normal([1024], stddev=0.1)),
'bd2': tf.Variable(tf.random_normal([n_outpot], stddev=0.1)),
} def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].get_shape().as_list()[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out print('CNN READY') x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])
keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, weights, biases, keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y))
optm = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
_corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))
init = tf.global_variables_initializer() print('GRAPH READY') sess = tf.Session()
sess.run(init)
training_epochs = 15
batch_size = 16
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys, keepratio: 0.7})
avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.0})/total_batch if epoch % display_step == 0:
print('Epoch: %03d/%03d cost: %.9f' % (epoch, training_epochs, avg_cost))
train_acc = sess.run(accr, feed_dict={x: batch_xs, y: batch_ys, keepratio: 1.})
print('Training accuracy: %.3f' % (train_acc)) res_dict = {'weight': sess.run(weights), 'biases': sess.run(biases)} import pickle
with open('res_dict.pkl', 'wb') as f:
pickle.dump(res_dict, f, pickle.HIGHEST_PROTOCOL)
4.test.py
import pickle
import numpy as np def load_file(path, name):
with open(path+''+name+'.pkl', 'rb') as f:
return pickle.load(f) res_dict = load_file('', 'res_dict')
print(res_dict['weight']['wc1']) index = 0 import input_data
mnist = input_data.read_data_sets('data/', one_hot=True) test_image = mnist.test.images
test_label = mnist.test.labels import tensorflow as tf def conv_basic(_input, _w, _b, _keepratio):
_input_r = tf.reshape(_input, shape=[-1, 28, 28, 1])
_conv1 = tf.nn.conv2d(_input_r, _w['wc1'], strides=[1, 1, 1, 1], padding='SAME')
_conv1 = tf.nn.relu(tf.nn.bias_add(_conv1, _b['bc1']))
_pool1 = tf.nn.max_pool(_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr1 = tf.nn.dropout(_pool1, _keepratio)
_conv2 = tf.nn.conv2d(_pool_dr1, _w['wc2'], strides=[1, 1, 1, 1], padding='SAME')
_conv2 = tf.nn.relu(tf.nn.bias_add(_conv2, _b['bc2']))
_pool2 = tf.nn.max_pool(_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
_pool_dr2 = tf.nn.dropout(_pool2, _keepratio)
_densel = tf.reshape(_pool_dr2, [-1, _w['wd1'].shape[0]])
_fc1 = tf.nn.relu(tf.add(tf.matmul(_densel, _w['wd1']), _b['bd1']))
_fc_dr1 = tf.nn.dropout(_fc1, _keepratio)
_out = tf.add(tf.matmul(_fc_dr1, _w['wd2']), _b['bd2'])
out = {
'input_r': _input_r, 'conv1': _conv1, 'pool1': _pool1, 'pool_dr1': _pool_dr1,
'conv2': _conv2, 'pool2': _pool2, 'pool_dr2': _pool_dr2, 'densel': _densel,
'fc1': _fc1, 'fc_dr1': _fc_dr1, 'out': _out
}
return out n_input = 784
n_output = 10 x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output]) keepratio = tf.placeholder(tf.float32) _pred = conv_basic(x, res_dict['weight'], res_dict['biases'], keepratio)['out']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred, y)) _corr = tf.equal(tf.argmax(_pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32)) init = tf.global_variables_initializer() sess = tf.Session()
sess.run(init)
training_epochs = 1
batch_size = 1
display_step = 1 for epoch in range(training_epochs):
avg_cost = 0.
total_batch = 10
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) if epoch % display_step == 0:
print('_pre:', np.argmax(sess.run(_pred, feed_dict={x: batch_xs, keepratio: 1. })))
print('answer:', np.argmax(batch_ys))
python,tensorflow,CNN实现mnist数据集的训练与验证正确率的更多相关文章
- tensorflow中使用mnist数据集训练全连接神经网络-学习笔记
tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...
- tensorflow读取本地MNIST数据集
tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...
- Ubuntu14.04+caffe+cuda7.5 环境搭建以及MNIST数据集的训练与测试
Ubuntu14.04+caffe+cuda 环境搭建以及MNIST数据集的训练与测试 一.ubuntu14.04的安装: ubuntu的安装是一件十分简单的事情,这里给出一个参考教程: http:/ ...
- 十折交叉验证10-fold cross validation, 数据集划分 训练集 验证集 测试集
机器学习 数据挖掘 数据集划分 训练集 验证集 测试集 Q:如何将数据集划分为测试数据集和训练数据集? A:three ways: 1.像sklearn一样,提供一个将数据集切分成训练集和测试集的函数 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 学习TensorFlow,邂逅MNIST数据集
如果说"Hello Word!"是程序员的第一个程序,那么MNIST数据集,毫无疑问是机器学习者第一个训练的数据集,本文将使用Google公布的TensorFLow来学习训练MNI ...
- mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同
有一个关于mnist的一个事例可以参考,我觉得写的很好:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html #!/usr/bin/ ...
- Python Tensorflow CNN 识别验证码
Python+Tensorflow的CNN技术快速识别验证码 文章来源于: https://www.jianshu.com/p/26ff7b9075a1 验证码处理的流程是:验证码分析和处理—— te ...
- TensorFlow CNN 测试CIFAR-10数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...
随机推荐
- JSP入门之自定义标签
第二部分简单讲解:主要讲解el表达式,核心标签库.本章主要讲解:自定义标签库:404页面,505页面,错误页面配置方法 全部代码下载:链接 1.JSP自定义标签: 自定义标签是用户定义的JSP语言元素 ...
- 如何使用Node爬虫利器Puppteer进行自动化测试
文:华为云DevCloud 乐少 1.背景 1.1 前端自动化测试较少 前端浏览器众多导致页面兼容性问题比较多,另外界面变化比较快,一个月内可能页面改版两三次,这样导致对前端自动化测试较少,大家也不是 ...
- Java操作数据库实现"增删改查"
本文主要讲解JDBC操作数据库 主要实现对MySql数据库的"增删改查" 综合概述: JDBC的常用类和接口 一 DriverManager类 DriverManage类 ...
- 启动多个appium服务(同时运行多台设备)
准备: 一台真机一台模拟器(使用的是“夜神模拟器”) 先查看是否检测到设备 adb devices 由上图可看出没有检测到模拟器(夜神模拟器已开启) 可通过以下配置完成: 第一步:找到adb的 ...
- 南昌网络赛 Distance on the tree 主席树+树剖 (给一颗树,m次查询ui->vi这条链中边权小于等于ki的边数。)
https://nanti.jisuanke.com/t/38229 题目: 给一颗树,m次查询ui->vi这条链中边权小于等于ki的边数. #include <bits/stdc++.h ...
- notepad++中配置python运行命令
1.安装notepad++ 2.打开notepad++ 3.键盘上按下“F5”,在弹出的命令菜单中输入“cmd /k C:\Python30\python.exe "$(FULL_CURRE ...
- Spring Boot遇到的某些问题
Spring Boot遇到的某些问题 1.关于templates的html包格式问题: <properties> <project.build.sourceEncoding>U ...
- 使用Java Servlet进行简单登录
效果图 登录页面代码:login.html <%@ page language="java" contentType="text/html; charset=UTF ...
- unity字库精简
有2种办法,具体看情况使用 1.unity自带功能 选择放入的字体,修改Character项为"Custom set",接着出现Custom Chars中输入你想使用的字符串,字符 ...
- C#删除只读文件
File.SetAttributes(fileRealPath, FileAttributes.Normal);//先将文件设置成普通属性 //...你的删除文件的代码