MNIST 被喻为深度学习中的Hello World示例,由Yann LeCun等大神组织收集的一个手写数字的数据集,有60000个训练集和10000个验证集,是个非常适合初学者入门的训练集。这个网站也提供了业界对这个数据集的各种算法的尝试结果,也能看出机器学习的算法的演进史,从早期的线性逻辑回归到K-means,再到两层神经网络,到多层神经网络,再到最近的卷积神经网络,随着的算法模型的改善,错误率也不断下降,所以目前这个数据集的错误率已经可以控制在0.2%左右,基本和人类识别的能力相当了。

这篇文章的例子我们会用一个更加有趣点的数据集 notMNIST,和MNIST不同的是它是一个各种形态的字母的数据集合,总共有a~j 10个字母组成,字母a相对应的一些图片如下:

在这个例子中,我们会使用TensorFlow和sklearn等库,对数据集进行一系列处理,最终使用逻辑回归模型来进行机器学习并且预测。

1. 准备环境

安装Python2.7和pip

Python2.7的官方网站:https://www.python.org/getit/

pip是Python Package Index,通过pip可以非常方便的查找安装其他软件,安装pip的方法如下:https://pip.pypa.io/en/stable/installing/

安装TensorFlow

$ pip install tensorflow

2. 下载数据

 import需要的第三方库

# These are all the modules we'll be using later. Make sure you can import them
# before proceeding further.
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tarfile
from IPython.display import display, Image
from scipy import ndimage
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle

# Config the matplotlib backend as plotting inline in IPython
%matplotlib inline

首先,我们会下载数据集到本地电脑。所有的图片都是28*28像素的图片,标示为"A"到"J"(10个分类)。整个数据集合包含大概50000个训练数据和19000个测试数据,所以这样规模的数据集合可以在大多数电脑上较快的完成训练。训练数据文件名是notMNIST_large.tar.gz,测试数据文件名是notMNIST_small.tar.gz。

url = 'http://commondatastorage.googleapis.com/books1000/'
last_percent_reported = None
data_root = '.' # Change me to store data elsewhere

def download_progress_hook(count, blockSize, totalSize):
"""A hook to report the progress of a download. This is mostly intended for users with
slow internet connections. Reports every 5% change in download progress.
"""
global last_percent_reported
percent = int(count * blockSize * 100 / totalSize)

if last_percent_reported != percent:
if percent % 5 == 0:
sys.stdout.write("%s%%" % percent)
sys.stdout.flush()
else:
sys.stdout.write(".")
sys.stdout.flush() last_percent_reported = percent def maybe_download(filename, expected_bytes, force=False):
"""Download a file if not present, and make sure it's the right size."""
dest_filename = os.path.join(data_root, filename)
if force or not os.path.exists(dest_filename):
print('Attempting to download:', filename)
filename, _ = urlretrieve(url + filename, dest_filename, reporthook=download_progress_hook)
print('\nDownload Complete!')
statinfo = os.stat(dest_filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', dest_filename)
else:
raise Exception(
'Failed to verify ' + dest_filename + '. Can you get to it with a browser?')
return dest_filename

train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)
test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)

解压数据集合,会产生一系列标记从A到J的目录。

num_classes = 10
np.random.seed(133)

def maybe_extract(filename, force=False):
root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz
if os.path.isdir(root) and not force:
# You may override by setting force=True.
print('%s already present - Skipping extraction of %s.' % (root, filename))
else:
print('Extracting data for %s. This may take a while. Please wait.' % root)
tar = tarfile.open(filename)
sys.stdout.flush()
tar.extractall(data_root)
tar.close()
data_folders = [
os.path.join(root, d) for d in sorted(os.listdir(root))
if os.path.isdir(os.path.join(root, d))]
if len(data_folders) != num_classes:
raise Exception(
'Expected %d folders, one per class. Found %d instead.' % (
num_classes, len(data_folders)))
print(data_folders)
return data_folders train_folders = maybe_extract(train_filename)
test_folders = maybe_extract(test_filename)

输出如下:

./notMNIST_large already present - Skipping extraction of ./notMNIST_large.tar.gz.
['./notMNIST_large/A', './notMNIST_large/B', './notMNIST_large/C', './notMNIST_large/D', './notMNIST_large/E', './notMNIST_large/F', './notMNIST_large/G', './notMNIST_large/H', './notMNIST_large/I', './notMNIST_large/J']
./notMNIST_small already present - Skipping extraction of ./notMNIST_small.tar.gz.
['./notMNIST_small/A', './notMNIST_small/B', './notMNIST_small/C', './notMNIST_small/D', './notMNIST_small/E', './notMNIST_small/F', './notMNIST_small/G', './notMNIST_small/H', './notMNIST_small/I', './notMNIST_small/J']

3. 加载数据

验证数据集,查看一下A目录里面前20个数据图片

fn = os.listdir("notMNIST_small/A/")
for file in fn[:20]:
path = 'notMNIST_small/A/' + file
display(Image(path))

现在我们将要把图片数据转换成为像素,并且对数据进行Zero Mean是数据更加正则化,整个数据集会被加载到一个3D数组中(图片index,x,y)。如果有些图片不能读取,我们就直接忽略掉。

由于可能不能一次性将所有数据读取到内存中,我们会对每个目录图片分别处理,并将处理完的数据存储到对应的pickle文件中。

image_size = 28  # Pixel width and height.
pixel_depth = 255.0 # Number of levels per pixel. def load_letter(folder, min_num_images):
"""Load the data for a single letter label."""
image_files = os.listdir(folder)
dataset = np.ndarray(shape=(len(image_files), image_size, image_size),
dtype=np.float32)
print(folder)
num_images = 0
for image in image_files:
image_file = os.path.join(folder, image)
try:
image_data = (ndimage.imread(image_file).astype(float) -
pixel_depth / 2) / pixel_depth
if image_data.shape != (image_size, image_size):
raise Exception('Unexpected image shape: %s' % str(image_data.shape))
dataset[num_images, :, :] = image_data
num_images = num_images + 1
except IOError as e:
print('Could not read:', image_file, ':', e, '- it\'s ok, skipping.') dataset = dataset[0:num_images, :, :]
if num_images < min_num_images:
raise Exception('Many fewer images than expected: %d < %d' %
(num_images, min_num_images)) print('Full dataset tensor:', dataset.shape)
print('Mean:', np.mean(dataset))
print('Standard deviation:', np.std(dataset))
return dataset def maybe_pickle(data_folders, min_num_images_per_class, force=False):
dataset_names = []
for folder in data_folders:
set_filename = folder + '.pickle'
dataset_names.append(set_filename)
if os.path.exists(set_filename) and not force:
# You may override by setting force=True.
print('%s already present - Skipping pickling.' % set_filename)
else:
print('Pickling %s.' % set_filename)
dataset = load_letter(folder, min_num_images_per_class)
try:
with open(set_filename, 'wb') as f:
pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
except Exception as e:
print('Unable to save data to', set_filename, ':', e) return dataset_names train_datasets = maybe_pickle(train_folders, 45000)
test_datasets = maybe_pickle(test_folders, 1800)

输入如下:

notMNIST_large/A.pickle already present - Skipping pickling.
notMNIST_large/B.pickle already present - Skipping pickling.
notMNIST_large/C.pickle already present - Skipping pickling.
notMNIST_large/D.pickle already present - Skipping pickling.
notMNIST_large/E.pickle already present - Skipping pickling.
notMNIST_large/F.pickle already present - Skipping pickling.
notMNIST_large/G.pickle already present - Skipping pickling.
notMNIST_large/H.pickle already present - Skipping pickling.
notMNIST_large/I.pickle already present - Skipping pickling.
notMNIST_large/J.pickle already present - Skipping pickling.
notMNIST_small/A.pickle already present - Skipping pickling.
notMNIST_small/B.pickle already present - Skipping pickling.
notMNIST_small/C.pickle already present - Skipping pickling.
notMNIST_small/D.pickle already present - Skipping pickling.
notMNIST_small/E.pickle already present - Skipping pickling.
notMNIST_small/F.pickle already present - Skipping pickling.
notMNIST_small/G.pickle already present - Skipping pickling.
notMNIST_small/H.pickle already present - Skipping pickling.
notMNIST_small/I.pickle already present - Skipping pickling.
notMNIST_small/J.pickle already present - Skipping pickling. 验证数据,我们从A.pickle中随机挑选了一个数据
# index 0 should be all As, 1 = all Bs, etc.
pickle_file = train_datasets[0] # With would automatically close the file after the nested block of code
with open(pickle_file, 'rb') as f: # unpickle
letter_set = pickle.load(f) # pick a random image index
sample_idx = np.random.randint(len(letter_set)) # extract a 2D slice
sample_image = letter_set[sample_idx, :, :]
plt.figure() # display it
plt.imshow(sample_image)

4. 准备训练数据、验证数据和测试数据

我们将pickle文件读取出来进行合并,生成了对应的训练数据(Training),验证数据(Validation)和测试数据(Testing)。

def make_arrays(nb_rows, img_size):
if nb_rows:
dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)
labels = np.ndarray(nb_rows, dtype=np.int32)
else:
dataset, labels = None, None
return dataset, labels def merge_datasets(pickle_files, train_size, valid_size=0):
num_classes = len(pickle_files)
valid_dataset, valid_labels = make_arrays(valid_size, image_size)
train_dataset, train_labels = make_arrays(train_size, image_size)
vsize_per_class = valid_size // num_classes
tsize_per_class = train_size // num_classes start_v, start_t = 0, 0
end_v, end_t = vsize_per_class, tsize_per_class
end_l = vsize_per_class+tsize_per_class
for label, pickle_file in enumerate(pickle_files):
try:
with open(pickle_file, 'rb') as f:
letter_set = pickle.load(f)
# let's shuffle the letters to have random validation and training set
np.random.shuffle(letter_set)
if valid_dataset is not None:
valid_letter = letter_set[:vsize_per_class, :, :]
valid_dataset[start_v:end_v, :, :] = valid_letter
valid_labels[start_v:end_v] = label
start_v += vsize_per_class
end_v += vsize_per_class train_letter = letter_set[vsize_per_class:end_l, :, :]
train_dataset[start_t:end_t, :, :] = train_letter
train_labels[start_t:end_t] = label
start_t += tsize_per_class
end_t += tsize_per_class
except Exception as e:
print('Unable to process data from', pickle_file, ':', e)
raise return valid_dataset, valid_labels, train_dataset, train_labels train_size = 200000
valid_size = 10000
test_size = 10000 valid_dataset, valid_labels, train_dataset, train_labels = merge_datasets(
train_datasets, train_size, valid_size)
_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size) print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
print('Testing:', test_dataset.shape, test_labels.shape)
Training: (200000, 28, 28) (200000,)
Validation: (10000, 28, 28) (10000,)
Testing: (10000, 28, 28) (10000,)

随后将数据随机排列

def randomize(dataset, labels):
permutation = np.random.permutation(labels.shape[0])
shuffled_dataset = dataset[permutation,:,:]
shuffled_labels = labels[permutation]
return shuffled_dataset, shuffled_labels
train_dataset, train_labels = randomize(train_dataset, train_labels)
test_dataset, test_labels = randomize(test_dataset, test_labels)
valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)

将数据保存到notMNIST.pickle文件

pickle_file = 'notMNIST.pickle'

try:
f = open(pickle_file, 'wb')
save = {
'train_dataset': train_dataset,
'train_labels': train_labels,
'valid_dataset': valid_dataset,
'valid_labels': valid_labels,
'test_dataset': test_dataset,
'test_labels': test_labels,
}
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)
f.close()
except Exception as e:
print('Unable to save data to', pickle_file, ':', e)
raise

去除数据集中重复的数据

import time

def check_overlaps(images1, images2):
images1.flags.writeable=False
images2.flags.writeable=False
start = time.clock()
hash1 = set([hash(image1.data) for image1 in images1])
hash2 = set([hash(image2.data) for image2 in images2])
all_overlaps = set.intersection(hash1, hash2)
return all_overlaps, time.clock()-start r, execTime = check_overlaps(train_dataset, test_dataset)
print('Number of overlaps between training and test sets: {}. Execution time: {}.'.format(len(r), execTime)) r, execTime = check_overlaps(train_dataset, valid_dataset)
print('Number of overlaps between training and validation sets: {}. Execution time: {}.'.format(len(r), execTime)) r, execTime = check_overlaps(valid_dataset, test_dataset)
print('Number of overlaps between validation and test sets: {}. Execution time: {}.'.format(len(r), execTime))
Number of overlaps between training and test sets: 1153. Execution time: 0.951144.
Number of overlaps between training and validation sets: 952. Execution time: 1.014579.
Number of overlaps between validation and test sets: 55. Execution time: 0.088879.

5. 训练模型

我们使用逻辑回归模型来进行训练,来看看最终的准确度如何?

samples, width, height = train_dataset.shape
X_train = np.reshape(train_dataset,(samples,width*height))
y_train = train_labels # Prepare testing data
samples, width, height = test_dataset.shape
X_test = np.reshape(test_dataset,(samples,width*height))
y_test = test_labels # Import
from sklearn.linear_model import LogisticRegression # Instantiate
lg = LogisticRegression(multi_class='multinomial', solver='lbfgs', random_state=42, verbose=1, max_iter=1000, n_jobs=-1) # Fit
lg.fit(X_train, y_train) # Predict
y_pred = lg.predict(X_test) # Score
from sklearn import metrics
metrics.accuracy_score(y_test, y_pred)
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:  4.6min finished
0.90059999999999996

大概花费了5分钟的时间,训练出来的回归模型准确率达到了90%,不错的尝试了!

深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型的更多相关文章

  1. 深度学习实践系列(2)- 搭建notMNIST的深度神经网络

    如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...

  2. 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络

    前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...

  3. 【转】RHadoop实践系列之一:Hadoop环境搭建

    RHadoop实践系列之一:Hadoop环境搭建 RHadoop实践系列文章,包含了R语言与Hadoop结合进行海量数据分析.Hadoop主要用来存储海量数据,R语言完成MapReduce 算法,用来 ...

  4. PyTorch深度学习实践——反向传播

    反向传播 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 目录 反向传播 笔记 作业 笔记 在之前课程中介绍的线性 ...

  5. Nagios学习实践系列——基本安装篇

    开篇介绍 最近由于工作需要,学习研究了一下Nagios的安装.配置.使用,关于Nagios的介绍,可以参考我上篇随笔Nagios学习实践系列——产品介绍篇 实验环境 操作系统:Red Hat Ente ...

  6. Nagios学习实践系列——配置研究[监控当前服务器]

    其实上篇Nagios学习实践系列——基本安装篇只是安装了Nagios基本组件,虽然能够打开主页,但是如果不配置相关配置文件文件,那么左边菜单很多页面都打不开,相当于只是一个空壳子.接下来,我们来学习研 ...

  7. Nagios学习实践系列

    其实上篇Nagios学习实践系列--基本安装篇只是安装了Nagios基本组件,虽然能够打开主页,但是如果不配置相关配置文件文件,那么左边菜单很多页面都打不开,相当于只是一个空壳子.接下来,我们来学习研 ...

  8. 深度学习基础系列(九)| Dropout VS Batch Normalization? 是时候放弃Dropout了

    Dropout是过去几年非常流行的正则化技术,可有效防止过拟合的发生.但从深度学习的发展趋势看,Batch Normalizaton(简称BN)正在逐步取代Dropout技术,特别是在卷积层.本文将首 ...

  9. 深度学习基础系列(五)| 深入理解交叉熵函数及其在tensorflow和keras中的实现

    在统计学中,损失函数是一种衡量损失和错误(这种损失与“错误地”估计有关,如费用或者设备的损失)程度的函数.假设某样本的实际输出为a,而预计的输出为y,则y与a之间存在偏差,深度学习的目的即是通过不断地 ...

随机推荐

  1. HDU5873

    Football Games Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)To ...

  2. [JQuery]Jquery对象和dom对象

    jquery对象是jquery包装dom对象后产生的对象,它们都只能使用各自的方法. 1.定义变量时,通过$来区分: var $variable = jquery对象: var variable = ...

  3. [转载] A successful Git branching model/GIT分支管理是一门艺术

    转载自:http://www.cnblogs.com/baiyw/p/3303125.html 英文原文:http://www.nvie.com/posts/a-successful-git-bran ...

  4. 为什么Java可以跨平台,而其他语言不行

    你好 我是大福 你现在看的是大福笔记 今天复习了Java语言的概述 内容包括Java 语言的历史.语言特点及平台版本 JRE和JDK的区别 这篇文章的主题是总结下对Java语言特点中的跨平台原理. 在 ...

  5. experss框架—基础认识

    express简介: Express是一个简洁.灵活的node.js Web应用开发框架, 它提供一系列强大的功能,比如:模板解析.静态文件服务.中间件.路由控制等等,并且还可以使用插件或整合其他模块 ...

  6. 腾讯优图及知脸(ZKface)人脸比对接口测试(python)

    一.腾讯优图 1.开发者地址:http://open.youtu.qq.com/welcome/developer 2.接入流程:按照开发者页面的接入流程接入之后,创建应用即可获得所需的AppID.S ...

  7. Spring中LocalSessionFactoryBean与SessionFactory

    相信不少人多纠结LocalSessionFactoryBean与SessionFactory到底是什么关系,怎么去进行关联的,正如图所示: transactonManager有一个对sessionFa ...

  8. B. Pasha and Phone

    B. Pasha and Phone time limit per test 1 second memory limit per test 256 megabytes input standard i ...

  9. [Kafka] - Kafka 安装介绍

    Kafka是由LinkedIn公司开发的,之后贡献给Apache基金会,成为Apache的一个顶级项目,开发语言为Scala.提供了各种不同语言的API,具体参考Kafka的cwiki页面: Kafk ...

  10. PowerPoint超链接字体颜色修改、怎么去掉超链接下划线

    经常在做PPT幻灯片时会遇到这样一个问题,给文字加超链接后发现链接的颜色是蓝色的,而且还带有下划线,这种效果与主题的色彩搭配简直是太影响美观效果了.有没有什么办法可以去掉PPT中的超链接下划线?再将超 ...