import tensorflow as tf
import numpy as np
import math
import time
import cifar10
import cifar10_input
"""
Created on Tue Nov 27 17:31:35 2018
@author: zhen
"""
max_steps = 1000
# 下载cifar10数据集的默认路径
batch_size = 128
data_dir = "C:/Users/zhen/.spyder-py3/cifar/cifar-10/cifar-10-batches/cifar-10-batches-bin" def variable_with_weight_losses(shape, stddev, wl):
# 定义初始化weights的函数
var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
if wl is not None:
weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
tf.add_to_collection("losses", weight_loss)
return var # 下载数据
cifar10.maybe_download_and_extract()
# 加载训练数据
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)
# 生成测试数据
images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size) image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size]) # 设置第一层卷积层
weight_1 = variable_with_weight_losses(shape=[5, 5, 3, 64], stddev=5e-2, wl=0.0)
kernel_1 = tf.nn.conv2d(image_holder, filter=weight_1, strides=[1, 1, 1, 1], padding='SAME')
bias_1 = tf.Variable(tf.constant(0.0, shape=[64]))
# 卷积
conv_1 = tf.nn.relu(tf.nn.bias_add(kernel_1, bias_1))
# 池化
pool_1 = tf.nn.max_pool(conv_1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
norm_1 = tf.nn.lrn(pool_1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) # 设置第二层卷积层
weight_2 = variable_with_weight_losses(shape=[5, 5, 64, 64], stddev=5e-2, wl=0.0)
kernel_2 = tf.nn.conv2d(norm_1, weight_2, [1, 1, 1, 1], padding='SAME')
bias_2 = tf.Variable(tf.constant(0.1, shape=[64])) conv_2 = tf.nn.relu(tf.nn.bias_add(kernel_2, bias_2))
norm_2 = tf.nn.lrn(conv_2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool_2 = tf.nn.max_pool(norm_2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') # 全连接层
reshape = tf.reshape(pool_2, [batch_size, -1])
dim = reshape.get_shape()[1].value weight_3 = variable_with_weight_losses(shape=[dim, 384], stddev=0.04, wl=0.004)
bias_3 = tf.Variable(tf.constant(0.1, shape=[384]))
local_3 = tf.nn.relu(tf.matmul(reshape, weight_3) + bias_3) # 第二层全连接层
weight_4 = variable_with_weight_losses(shape=[384, 192], stddev=0.04, wl=0.004)
bias_4 = tf.Variable(tf.constant(0.1, shape=[192]))
local_4 = tf.nn.relu(tf.matmul(local_3, weight_4) + bias_4) # 结果层
weight_5 = variable_with_weight_losses(shape=[192, 10], stddev=1/192.0, wl=0.0)
bias_5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local_4, weight_5), bias_5) def loss(logits, labels):
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits,
labels=labels,
name="cross_entropy_per_example"
)
cross_entropy_mean = tf.reduce_mean(cross_entropy, name="cross_entropy")
tf.add_to_collection("losses", cross_entropy_mean)
return tf.add_n(tf.get_collection("losses"), name="total_loss") loss = loss(logits=logits, labels=label_holder)
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.train.start_queue_runners() # 训练
for step in range(max_steps):
start_time = time.time()
image_batch, label_batch = sess.run([images_train, labels_train])
_, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
duration = time.time() - start_time if step % 10 == 0:
examples_per_sec = batch_size / duration
sec_per_batch = float(duration) format_str = "step %d, loss =%.2f (%.1f examples/sec; %.3f sec/batch"
print(format_str % (step, loss_value, examples_per_sec, sec_per_batch)) # 评估模型
num_examples = 10000
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
image_batch, label_batch = sess.run([images_test, labels_test])
predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_batch})
true_count += np.sum(predictions)
step += 1 precision = true_count / total_sample_count
print("precision @ 1 = %.3f" % precision)

过程:

Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
step 0, loss =4.68 (19.0 examples/sec; 6.734 sec/batch
step 10, loss =3.58 (62.1 examples/sec; 2.062 sec/batch
step 20, loss =3.09 (62.5 examples/sec; 2.047 sec/batch
step 30, loss =2.77 (62.5 examples/sec; 2.047 sec/batch
step 40, loss =2.48 (62.5 examples/sec; 2.047 sec/batch
step 50, loss =2.36 (62.5 examples/sec; 2.047 sec/batch
step 60, loss =2.13 (60.2 examples/sec; 2.125 sec/batch
step 70, loss =1.95 (63.0 examples/sec; 2.031 sec/batch
step 80, loss =2.01 (62.1 examples/sec; 2.062 sec/batch
step 90, loss =1.90 (63.5 examples/sec; 2.016 sec/batch
step 100, loss =1.93 (62.5 examples/sec; 2.047 sec/batch
step 110, loss =1.96 (62.1 examples/sec; 2.062 sec/batch
step 120, loss =1.92 (62.3 examples/sec; 2.055 sec/batch
step 130, loss =1.81 (63.5 examples/sec; 2.016 sec/batch
step 140, loss =1.86 (59.8 examples/sec; 2.141 sec/batch
step 150, loss =1.88 (64.0 examples/sec; 2.000 sec/batch
step 160, loss =1.87 (62.5 examples/sec; 2.047 sec/batch
step 170, loss =1.73 (49.6 examples/sec; 2.578 sec/batch
step 180, loss =1.86 (62.1 examples/sec; 2.062 sec/batch
step 190, loss =1.71 (62.5 examples/sec; 2.047 sec/batch
step 200, loss =1.63 (63.0 examples/sec; 2.031 sec/batch
step 210, loss =1.63 (63.5 examples/sec; 2.016 sec/batch
step 220, loss =1.67 (62.1 examples/sec; 2.063 sec/batch
step 230, loss =1.72 (62.5 examples/sec; 2.047 sec/batch
step 240, loss =1.76 (62.1 examples/sec; 2.062 sec/batch
step 250, loss =1.67 (61.6 examples/sec; 2.078 sec/batch
step 260, loss =1.67 (62.5 examples/sec; 2.047 sec/batch
step 270, loss =1.59 (63.0 examples/sec; 2.031 sec/batch
step 280, loss =1.55 (62.5 examples/sec; 2.047 sec/batch
step 290, loss =1.64 (62.5 examples/sec; 2.047 sec/batch
step 300, loss =1.63 (62.1 examples/sec; 2.062 sec/batch
step 310, loss =1.49 (62.1 examples/sec; 2.062 sec/batch
step 320, loss =1.49 (62.5 examples/sec; 2.047 sec/batch
step 330, loss =1.61 (62.1 examples/sec; 2.062 sec/batch
step 340, loss =1.55 (61.1 examples/sec; 2.094 sec/batch
step 350, loss =1.63 (62.5 examples/sec; 2.047 sec/batch
step 360, loss =1.75 (61.6 examples/sec; 2.078 sec/batch
step 370, loss =1.54 (61.1 examples/sec; 2.094 sec/batch
step 380, loss =1.66 (61.6 examples/sec; 2.078 sec/batch
step 390, loss =1.66 (62.1 examples/sec; 2.062 sec/batch
step 400, loss =1.74 (62.1 examples/sec; 2.062 sec/batch
step 410, loss =1.60 (61.6 examples/sec; 2.078 sec/batch
step 420, loss =1.64 (62.5 examples/sec; 2.047 sec/batch
step 430, loss =1.59 (61.1 examples/sec; 2.094 sec/batch
step 440, loss =1.64 (59.8 examples/sec; 2.141 sec/batch
step 450, loss =1.67 (62.5 examples/sec; 2.047 sec/batch
step 460, loss =1.35 (60.7 examples/sec; 2.109 sec/batch
step 470, loss =1.45 (63.5 examples/sec; 2.016 sec/batch
step 480, loss =1.47 (62.5 examples/sec; 2.047 sec/batch
step 490, loss =1.37 (61.6 examples/sec; 2.078 sec/batch
step 500, loss =1.64 (63.0 examples/sec; 2.031 sec/batch
step 510, loss =1.58 (64.0 examples/sec; 2.000 sec/batch
step 520, loss =1.36 (63.5 examples/sec; 2.016 sec/batch
step 530, loss =1.30 (61.6 examples/sec; 2.078 sec/batch
step 540, loss =1.49 (62.5 examples/sec; 2.047 sec/batch
step 550, loss =1.46 (62.5 examples/sec; 2.047 sec/batch
step 560, loss =1.58 (63.0 examples/sec; 2.031 sec/batch
step 570, loss =1.46 (63.5 examples/sec; 2.016 sec/batch
step 580, loss =1.49 (64.5 examples/sec; 1.984 sec/batch
step 590, loss =1.30 (64.0 examples/sec; 2.000 sec/batch
step 600, loss =1.39 (64.5 examples/sec; 1.984 sec/batch
step 610, loss =1.62 (63.0 examples/sec; 2.031 sec/batch
step 620, loss =1.41 (62.1 examples/sec; 2.062 sec/batch
step 630, loss =1.29 (62.5 examples/sec; 2.047 sec/batch
step 640, loss =1.42 (63.5 examples/sec; 2.016 sec/batch
step 650, loss =1.36 (63.0 examples/sec; 2.031 sec/batch
step 660, loss =1.46 (63.5 examples/sec; 2.016 sec/batch
step 670, loss =1.26 (63.0 examples/sec; 2.031 sec/batch
step 680, loss =1.64 (62.1 examples/sec; 2.062 sec/batch
step 690, loss =1.39 (63.0 examples/sec; 2.031 sec/batch
step 700, loss =1.32 (61.6 examples/sec; 2.078 sec/batch
step 710, loss =1.36 (61.6 examples/sec; 2.078 sec/batch
step 720, loss =1.51 (62.1 examples/sec; 2.062 sec/batch
step 730, loss =1.48 (63.5 examples/sec; 2.016 sec/batch
step 740, loss =1.34 (61.1 examples/sec; 2.094 sec/batch
step 750, loss =1.44 (61.1 examples/sec; 2.094 sec/batch
step 760, loss =1.34 (60.7 examples/sec; 2.109 sec/batch
step 770, loss =1.46 (61.1 examples/sec; 2.094 sec/batch
step 780, loss =1.46 (60.7 examples/sec; 2.109 sec/batch
step 790, loss =1.42 (61.1 examples/sec; 2.094 sec/batch
step 800, loss =1.40 (63.0 examples/sec; 2.031 sec/batch
step 810, loss =1.46 (61.6 examples/sec; 2.078 sec/batch
step 820, loss =1.32 (62.1 examples/sec; 2.062 sec/batch
step 830, loss =1.46 (62.5 examples/sec; 2.047 sec/batch
step 840, loss =1.27 (64.0 examples/sec; 2.000 sec/batch
step 850, loss =1.38 (62.5 examples/sec; 2.047 sec/batch
step 860, loss =1.30 (63.0 examples/sec; 2.031 sec/batch
step 870, loss =1.18 (63.0 examples/sec; 2.031 sec/batch
step 880, loss =1.39 (62.5 examples/sec; 2.047 sec/batch
step 890, loss =1.17 (63.5 examples/sec; 2.016 sec/batch
step 900, loss =1.27 (62.1 examples/sec; 2.062 sec/batch
step 910, loss =1.38 (60.7 examples/sec; 2.109 sec/batch
step 920, loss =1.64 (60.2 examples/sec; 2.125 sec/batch
step 930, loss =1.45 (60.7 examples/sec; 2.109 sec/batch
step 940, loss =1.39 (61.6 examples/sec; 2.078 sec/batch
step 950, loss =1.40 (63.5 examples/sec; 2.016 sec/batch
step 960, loss =1.32 (62.1 examples/sec; 2.063 sec/batch
step 970, loss =1.32 (63.0 examples/sec; 2.031 sec/batch
step 980, loss =1.28 (61.6 examples/sec; 2.078 sec/batch
step 990, loss =1.20 (63.5 examples/sec; 2.016 sec/batch

结果:

分析:

  cifar10数据集比mnist数据集更完整也更复杂,基于cifar数据集进行10分类比mnist有更高的难度,整体的准确率和召回率都普遍偏低,但适当的增加迭代次数和卷积核的大小有助于提升准确度,大概能到80%,要想获得更高的准确度可以增加训练集的数量!

基于cifar10实现卷积神经网络图像识别的更多相关文章

  1. 基于Python的卷积神经网络和特征提取

    基于Python的卷积神经网络和特征提取 用户1737318发表于人工智能头条订阅 224 在这篇文章中: Lasagne 和 nolearn 加载MNIST数据集 ConvNet体系结构与训练 预测 ...

  2. 基于 SoC 的卷积神经网络车牌识别系统设计(0)摘要

    ​NOTES:现如今,芯片行业无比火热啊,无论是前景还是钱景,国家芯片战略的发布,公司四五十万的年薪,着实令人非常的向往,为了支持芯片设计者,集成了工作.科研.竞赛于一体的<基于 SoC 的卷积 ...

  3. 基于 SoC 的卷积神经网络车牌识别系统设计(1)概述

    NOTES: 这是第三届全国大学生集成电路创新创业大赛 - Arm 杯 - 片上系统设计挑战赛(本人指导的一个比赛).主要划分为以下的 Top5 重点.难点.亮点.热点以及创新点:1.通过 Arm C ...

  4. 深度学习基础-基于Numpy的卷积神经网络(CNN)实现

    本文是深度学习入门: 基于Python的实现.神经网络与深度学习(NNDL)以及动手学深度学习的读书笔记.本文将介绍基于Numpy的卷积神经网络(Convolutional Networks,CNN) ...

  5. visualization of filters keras 基于Keras的卷积神经网络(CNN)可视化

    https://adeshpande3.github.io/adeshpande3.github.io/ https://blog.csdn.net/weiwei9363/article/detail ...

  6. 普适注意力:用于机器翻译的2D卷积神经网络,显著优于编码器-解码器架构

    现有的当前最佳机器翻译系统都是基于编码器-解码器架构的,二者都有注意力机制,但现有的注意力机制建模能力有限.本文提出了一种替代方法,这种方法依赖于跨越两个序列的单个 2D 卷积神经网络.该网络的每一层 ...

  7. tensorflow学习笔记——图像识别与卷积神经网络

    无论是之前学习的MNIST数据集还是Cifar数据集,相比真实环境下的图像识别问题,有两个最大的问题,一是现实生活中的图片分辨率要远高于32*32,而且图像的分辨率也不会是固定的.二是现实生活中的物体 ...

  8. 优化基于FPGA的深度卷积神经网络的加速器设计

    英文论文链接:http://cadlab.cs.ucla.edu/~cong/slides/fpga2015_chen.pdf 翻译:卜居 转载请注明出处:http://blog.csdn.net/k ...

  9. 基于MTCNN多任务级联卷积神经网络进行的人脸识别 世纪晟人脸检测

    神经网络和深度学习目前为处理图像识别的许多问题提供了最佳解决方案,而基于MTCNN(多任务级联卷积神经网络)的人脸检测算法也解决了传统算法对环境要求高.人脸要求高.检测耗时高的弊端. 基于MTCNN多 ...

随机推荐

  1. JavaScript和Ajax部分(4)

    31. 什么是jQuery选择器 1)jQuery选择器继承了CSS与Path语言的部分语法,允许通过标签名.属性名或内容对DOM元素进行快速.准确的选择,而不必担心浏览器的兼容性,通过jQuery选 ...

  2. MySQL高可用新玩法之MGR+Consul

    前面的文章有提到过利用consul+mha实现mysql的高可用,以及利用consul+sentinel实现redis的高可用,具体的请查看:http://www.cnblogs.com/gomysq ...

  3. crontab的使用笔记

    1 crond简介crond是linux下用来周期性的执行某种任务或等待处理某些事件的一个守护进程,与windows下的计划任务类似,当安装完成操作系统后,默认会安装此服务工具,并且会自动启动cron ...

  4. eclipse中打断点debug无效

    今天在测试的时候,发现在eclipse中打了断点,debug居然无效.结果是因为我无意开启了另一个功能,Run-->Skip All Breakpoints (忽略所有的断点) 快捷键:Ctrl ...

  5. Java基础系列——序列化(一)

    原创作品,可以转载,但是请标注出处地址:http://www.cnblogs.com/V1haoge/p/6797659.html 工作中发现,自己对Java的了解还很片面,没有深入的研究,有很多的J ...

  6. JavaSSM框架整合

    SSM整合 ssm框架 框架整合  在博客的前面介绍了mybatis,spring,springmvc的使用,那么这篇博客将介绍将mybatis和spring,springmvc的整合. 整合之前,我 ...

  7. 【leet-code】135. 加油站

    题目描述 在一条环路上有 N 个加油站,其中第 i 个加油站有汽油 gas[i] 升. 你有一辆油箱容量无限的的汽车,从第 i 个加油站开往第 i+1 个加油站需要消耗汽油 cost[i] 升.你从其 ...

  8. Effective Java——(一)创建和销毁对象

    第一条 考虑用静态工厂方法代替构造器 什么叫静态工厂方法,就是通过在类中通过静态方法对对象初始化. 比如说 public class StaticFactory { private String na ...

  9. OpenCV入门之获取验证码的单个字符(二)

      在文章 OpenCV入门之获取验证码的单个字符(字符切割)中,介绍了一类验证码的处理方法,该验证码如下: 该验证码的特点是字母之间的间隔较大,很容易就能提取出其中的单个字符.接下来,笔者将会介绍如 ...

  10. [转]MySQL忘记root密码解决方法

    本文转自:https://www.cnblogs.com/wxdblog/p/6864475.html 今天重新装了一遍MySQL,因为用的是免安装的,所以需要重新设置密码,然后我一通,结果搞得自己也 ...