1,数据集简介

  SVHN(Street View House Number)Dateset 来源于谷歌街景门牌号码,原生的数据集1也就是官网的 Format 1 是一些原始的未经处理的彩色图片,如下图所示(不含有蓝色的边框),下载的数据集含有 PNG 的图像和 digitStruct.mat  的文件,其中包含了边框的位置信息,这个数据集每张图片上有好几个数字,适用于 OCR 相关方向。

  这里采用 Format2, Format2 将这些数字裁剪成32x32的大小,如图所示,并且数据是 .mat 文件。

    

2,数据处理

  数据集含有两个变量 X 代表图像, 训练集 X 的 shape 是  (32,32,3,73257) 也就是(width, height, channels, samples),  tensorflow 的张量需要 (samples, width, height, channels),所以需要转换一下,由于直接调用 cifar 10 的网络模型,数据只需要先做个归一化,所有像素除于255就 OK,另外原始数据 0 的标签是 10,这里要转化成 0,并提供 one_hot 编码。

  1. #!/usr/bin/env python2
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Jan 19 09:55:36 2017
  5.  
  6. @author: cheers
  7. """
  8.  
  9. import scipy.io as sio
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12.  
  13. image_size = 32
  14. num_labels = 10
  15.  
  16. def display_data():
  17. print 'loading Matlab data...'
  18. train = sio.loadmat('train_32x32.mat')
  19. data=train['X']
  20. label=train['y']
  21. for i in range(10):
  22. plt.subplot(2,5,i+1)
  23. plt.title(label[i][0])
  24. plt.imshow(data[...,i])
  25. plt.axis('off')
  26. plt.show()
  27.  
  28. def load_data(one_hot = False):
  29.  
  30. train = sio.loadmat('train_32x32.mat')
  31. test = sio.loadmat('test_32x32.mat')
  32.  
  33. train_data=train['X']
  34. train_label=train['y']
  35. test_data=test['X']
  36. test_label=test['y']
  37.  
  38. train_data = np.swapaxes(train_data, 0, 3)
  39. train_data = np.swapaxes(train_data, 2, 3)
  40. train_data = np.swapaxes(train_data, 1, 2)
  41. test_data = np.swapaxes(test_data, 0, 3)
  42. test_data = np.swapaxes(test_data, 2, 3)
  43. test_data = np.swapaxes(test_data, 1, 2)
  44.  
  45. test_data = test_data / 255.
  46. train_data =train_data / 255.
  47.  
  48. for i in range(train_label.shape[0]):
  49. if train_label[i][0] == 10:
  50. train_label[i][0] = 0
  51.  
  52. for i in range(test_label.shape[0]):
  53. if test_label[i][0] == 10:
  54. test_label[i][0] = 0
  55.  
  56. if one_hot:
  57. train_label = (np.arange(num_labels) == train_label[:,]).astype(np.float32)
  58. test_label = (np.arange(num_labels) == test_label[:,]).astype(np.float32)
  59.  
  60. return train_data,train_label, test_data,test_label
  61.  
  62. if __name__ == '__main__':
  63. load_data(one_hot = True)
  64. display_data()

3,TFearn 训练

注意 ImagePreprocessing 对数据做了 0 均值化。网络结构也比较简单,直接调用 TFlearn 的 cifar10 例子。

  1. from __future__ import division, print_function, absolute_import
  2.  
  3. import tflearn
  4. from tflearn.data_utils import shuffle, to_categorical
  5. from tflearn.layers.core import input_data, dropout, fully_connected
  6. from tflearn.layers.conv import conv_2d, max_pool_2d
  7. from tflearn.layers.estimator import regression
  8. from tflearn.data_preprocessing import ImagePreprocessing
  9. from tflearn.data_augmentation import ImageAugmentation
  10.  
  11. # Data loading and preprocessing
  12. import svhn_data as SVHN
  13. X, Y, X_test, Y_test = SVHN.load_data(one_hot = True)
  14. X, Y = shuffle(X, Y)
  15.  
  16. # Real-time data preprocessing
  17. img_prep = ImagePreprocessing()
  18. img_prep.add_featurewise_zero_center()
  19. img_prep.add_featurewise_stdnorm()
  20.  
  21. # Convolutional network building
  22. network = input_data(shape=[None, 32, 32, 3],
  23. data_preprocessing=img_prep)
  24. network = conv_2d(network, 32, 3, activation='relu')
  25. network = max_pool_2d(network, 2)
  26. network = conv_2d(network, 64, 3, activation='relu')
  27. network = conv_2d(network, 64, 3, activation='relu')
  28. network = max_pool_2d(network, 2)
  29. network = fully_connected(network, 512, activation='relu')
  30. network = dropout(network, 0.5)
  31. network = fully_connected(network, 10, activation='softmax')
  32. network = regression(network, optimizer='adam',
  33. loss='categorical_crossentropy',
  34. learning_rate=0.001)
  35.  
  36. # Train using classifier
  37. model = tflearn.DNN(network, tensorboard_verbose=0)
  38. model.fit(X, Y, n_epoch=15, shuffle=True, validation_set=(X_test, Y_test),
  39. show_metric=True, batch_size=96, run_id='svhn_cnn')

训练结果:

  1. Training Step: 11452 | total loss: 0.68217 | time: 7.973s
  2. | Adam | epoch: 015 | loss: 0.68217 - acc: 0.9329 -- iter: 72576/73257
  3. Training Step: 11453 | total loss: 0.62980 | time: 7.983s
  4. | Adam | epoch: 015 | loss: 0.62980 - acc: 0.9354 -- iter: 72672/73257
  5. Training Step: 11454 | total loss: 0.58649 | time: 7.994s
  6. | Adam | epoch: 015 | loss: 0.58649 - acc: 0.9356 -- iter: 72768/73257
  7. Training Step: 11455 | total loss: 0.53254 | time: 8.005s
  8. | Adam | epoch: 015 | loss: 0.53254 - acc: 0.9421 -- iter: 72864/73257
  9. Training Step: 11456 | total loss: 0.49179 | time: 8.016s
  10. | Adam | epoch: 015 | loss: 0.49179 - acc: 0.9416 -- iter: 72960/73257
  11. Training Step: 11457 | total loss: 0.45679 | time: 8.027s
  12. | Adam | epoch: 015 | loss: 0.45679 - acc: 0.9433 -- iter: 73056/73257
  13. Training Step: 11458 | total loss: 0.42026 | time: 8.038s
  14. | Adam | epoch: 015 | loss: 0.42026 - acc: 0.9469 -- iter: 73152/73257
  15. Training Step: 11459 | total loss: 0.38929 | time: 8.049s
  16. | Adam | epoch: 015 | loss: 0.38929 - acc: 0.9491 -- iter: 73248/73257
  17. Training Step: 11460 | total loss: 0.35542 | time: 9.928s
  18. | Adam | epoch: 015 | loss: 0.35542 - acc: 0.9542 | val_loss: 0.40315 - val_acc: 0.9085 -- iter: 73257/73257

TFlearn——(2)SVHN的更多相关文章

  1. TFlearn——(1)notMNIST

    1, 数据集简介    notMNIST, 看名字就知道,跟MNIST脱不了干系,其实就是升级版的MNIST,含有 A-J 10个类别的艺术印刷体字符,字符的形状各异,噪声更多,难度比 MNIST 要 ...

  2. (转)【重磅】无监督学习生成式对抗网络突破,OpenAI 5大项目落地

    [重磅]无监督学习生成式对抗网络突破,OpenAI 5大项目落地 [新智元导读]"生成对抗网络是切片面包发明以来最令人激动的事情!"LeCun前不久在Quroa答问时毫不加掩饰对生 ...

  3. 数十种TensorFlow实现案例汇集:代码+笔记(转)

    转:https://www.jiqizhixin.com/articles/30dc6dd9-39cd-406b-9f9e-041f5cbf1d14 这是使用 TensorFlow 实现流行的机器学习 ...

  4. torchvision库简介(翻译)

    部分跟新于:4.24日    torchvision 0.2.2.post3 torchvision是独立于pytorch的关于图像操作的一些方便工具库. torchvision的详细介绍在:http ...

  5. 神经网络中embedding层作用——本质就是word2vec,数据降维,同时可以很方便计算同义词(各个word之间的距离),底层实现是2-gram(词频)+神经网络

    Embedding tflearn.layers.embedding_ops.embedding (incoming, input_dim, output_dim, validate_indices= ...

  6. 两种开源聊天机器人的性能测试(二)——基于tensorflow的chatbot

    http://blog.csdn.net/hfutdog/article/details/78155676 开源项目链接:https://github.com/dennybritz/chatbot-r ...

  7. 深度学习方法(十):卷积神经网络结构变化——Maxout Networks,Network In Network,Global Average Pooling

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 最近接下来几篇博文会回到神经网络结构 ...

  8. TensorFlow学习笔记(六)循环神经网络

    一.循环神经网络简介 循环神经网络的主要用途是处理和预测序列数据.循环神经网络刻画了一个序列当前的输出与之前信息的关系.从网络结构上,循环神经网络会记忆之前的信息,并利用之前的信息影响后面节点的输出. ...

  9. 深度学习之 TensorFlow(一):基础库包的安装

    1.TensorFlow 简介:TensorFlow 是谷歌公司开发的深度学习框架,也是目前深度学习的主流框架之一. 2.TensorFlow 环境的准备: 本人使用 macOS,Python 版本直 ...

随机推荐

  1. docker 配置远程访问证书验证

    centos7 生成证书 工具:openssl #cd /etc/docker   (docker的证书一般放这) #openssl genrsa -aes256 -passout pass:密码   ...

  2. ubuntu中vim的设置

    问题:刚安装的VIM中,backspace不能删除字符,且上下左右箭头没反应. 解决方法: sudo vi  /etc/vim/vimrc.tiny 修改 set compatible为set noc ...

  3. 大型运输行业实战_day10_1_自定义事务管理类

    1.创建事务管理类  TransactionManager.java package com.day02.sation.transaction; import com.day02.sation.uti ...

  4. intellij idea 的常见配置

    1.视图配置 配置好后如下图:   2.修改字体大小 3.编码修改 4.行号显示 5.控制台字体大小调整 File->Settings->Editor->Colors & F ...

  5. select语法图

  6. crsf 跨站请求伪造

    [crsf 跨站请求伪造] CSRF(Cross-site request forgery),中文名称:跨站请求伪造.核心为利用浏览器帮助提交cookie.采用随机数方可防御.估计大部小站均无CSRF ...

  7. ccf认证模拟题之三---最大的矩形

    问题描述 在横轴上放了n个相邻的矩形,每个矩形的宽度是1,而第i(1 ≤ i ≤ n)个矩形的高度是hi.这n个矩形构成了一个直方图.例如,下图中六个矩形的高度就分别是3, 1, 6, 5, 2, 3 ...

  8. DOS 命令集锦——最常用命令

    一. 常用命令: cd 改变当前目录   sys 制作DOS系统盘 (电脑入门到精通网 www.58116.cn) copy 拷贝文件  del 删除文件 deltree 删除目录树    dir 列 ...

  9. 条款2:尽量以const, enum, inline替换#define

    原因: 1. 追踪困难,由于在编译期已经替换,在记号表中没有. 2. 由于编译期多处替换,可能导致目标代码体积稍大. 3. define没有作用域,如在类中定义一个常量不行. 做法: 可以用const ...

  10. Boost 库uuid 的使用

    UUID 简介 通用唯一识别码(英语:Universally Unique Identifier,简称UUID)是一种软件建构的标准,亦为开放软件基金会组织在分布式计算环境领域的一部分. uuid 版 ...