一、使用urllib下载cifar-10数据集,并读取再存为图片(TensorFlow v1.14.0)

 # -*- coding:utf-8 -*-
__author__ = 'Leo.Z' import sys
import os # 给定url下载文件
def download_from_url(url, dir=''):
_file_name = url.split('/')[-1]
_file_path = os.path.join(dir, _file_name) # 打印下载进度
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(_file_name, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush() # 如果不存在dir,则创建文件夹
if not os.path.exists(dir):
print("Dir is not exsit,Create it..")
os.makedirs(dir) if not os.path.exists(_file_path):
print("Start downloading..")
# 开始下载文件
import urllib
urllib.request.urlretrieve(url, _file_path, _progress)
else:
print("File already exists..") return _file_path # 使用tarfile解压缩
def extract(filepath, dest_dir):
if os.path.exists(filepath) and not os.path.exists(dest_dir):
import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_dir) if __name__ == '__main__':
FILE_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
FILE_DIR = 'cifar10_dir/' loaded_file_path = download_from_url(FILE_URL, FILE_DIR)
extract(loaded_file_path)

 按BATCH_SIZE读取二进制文件中的图片数据,并存放为jpg:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z' # Tensorflow Version:1.14.0 import os import tensorflow as tf
from PIL import Image BATCH_SIZE = 128 def read_cifar10(filenames):
label_bytes = 1
height = 32
width = 32
depth = 3
image_bytes = height * width * depth record_bytes = label_bytes + image_bytes # lamda函数体
# def load_transform(x):
# # Convert these examples to dense labels and processed images.
# per_record = tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])
# return per_record # tf v1.14.0版本的FixedLengthRecordDataset(filename_list,bin_data_len)
datasets = tf.data.FixedLengthRecordDataset(filenames=filenames, record_bytes=record_bytes)
# 是否打乱数据
# datasets.shuffle()
# 重复几轮epoches
datasets = datasets.shuffle(buffer_size=BATCH_SIZE).repeat(2).batch(BATCH_SIZE) # 使用map,也可使用lamda(注意,后面使用迭代器的时候这里转换为uint8没用,后面还得转一次,否则会报错)
# datasets.map(load_transform)
# datasets.map(lamda x : tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])) # 创建一起迭代器tf v1.14.0
iter = tf.compat.v1.data.make_one_shot_iterator(datasets)
# 获取下一条数据(label+image的二进制数据1+32*32*3长度的bytes)
rec = iter.get_next()
# 这里转uint8才生效,在map中转貌似有问题?
rec = tf.decode_raw(rec, tf.uint8) label = tf.cast(tf.slice(rec, [0, 0], [BATCH_SIZE, label_bytes]), tf.int32) # 从第二个字节开始获取图片二进制数据大小为32*32*3
depth_major = tf.reshape(
tf.slice(rec, [0, label_bytes], [BATCH_SIZE, image_bytes]),
[BATCH_SIZE, depth, height, width])
# 将维度变换顺序,变为[H,W,C]
image = tf.transpose(depth_major, [0, 2, 3, 1]) # 返回获取到的label和image组成的元组
return (label, image) def get_data_from_files(data_dir):
# filenames一共5个,从data_batch_1.bin到data_batch_5.bin
# 读入的都是训练图像
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, 6)]
# 判断文件是否存在
for f in filenames:
if not tf.io.gfile.exists(f):
raise ValueError('Failed to find file: ' + f) # 获取一张图片数据的数据,格式为(label,image)
data_tuple = read_cifar10(filenames)
return data_tuple if __name__ == "__main__": # 获取label和type的对应关系
label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
name_list = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
label_map = dict(zip(label_list, name_list)) with tf.compat.v1.Session() as sess:
batch_data = get_data_from_files('cifar10_dir/cifar-10-batches-bin')
# 在之前的旧版本中,因为使用了filename_queue,所以要使用start_queue_runners进行数据填充
# 1.14.0由于没有使用filename_queue所以不需要
# threads = tf.train.start_queue_runners(sess=sess) sess.run(tf.compat.v1.global_variables_initializer())
# 创建一个文件夹用于存放图片
if not os.path.exists('cifar10_dir/raw'):
os.mkdir('cifar10_dir/raw') # 存放30张,以index-typename.jpg命名,例如1-frog.jpg
for i in range(30):
# 获取一个batch的数据,BATCH_SIZE
# batch_data中包含一个batch的image和label
batch_data_tuple = sess.run(batch_data)
# 打印(128, 1)
print(batch_data_tuple[0].shape)
# 打印(128, 32, 32, 3)
print(batch_data_tuple[1].shape) # 每个batch存放第一张图片作为实验
Image.fromarray(batch_data_tuple[1][0]).save("cifar10_dir/raw/{index}-{type}.jpg".format(
index=i, type=label_map[batch_data_tuple[0][0][0]]))

简要代码流程图:

[深度学习] 各种下载深度学习数据集方法(In python)的更多相关文章

  1. ReLeQ:一种自动强化学习的神经网络深度量化方法

    ReLeQ:一种自动强化学习的神经网络深度量化方法     ReLeQ:一种自动强化学习的神经网络深度量化方法ReLeQ: An Automatic Reinforcement Learning Ap ...

  2. 腾讯优图&港科大提出一种基于深度学习的非光流 HDR 成像方法

    目前最好的高动态范围(HDR)成像方法通常是先利用光流将输入图像对齐,随后再合成 HDR 图像.然而由于输入图像存在遮挡和较大运动,这种方法生成的图像仍然有很多缺陷.最近,腾讯优图和香港科技大学的研究 ...

  3. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  4. OpenGL学习脚印:深度測试(depth testing)

    写在前面 上一节我们使用AssImp载入了3d模型,效果已经令人激动了.可是绘制效率和场景真实感还存在不足,接下来我们还是要保持耐心,继续学习一些高级主题,等学完后面的高级主题,我们再次来改进我们载入 ...

  5. OpenCV 学习笔记 04 深度估计与分割——GrabCut算法与分水岭算法

    1 使用普通摄像头进行深度估计 1.1 深度估计原理 这里会用到几何学中的极几何(Epipolar Geometry),它属于立体视觉(stereo vision)几何学,立体视觉是计算机视觉的一个分 ...

  6. (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇

    This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...

  7. 深度强化学习day01初探强化学习

    深度强化学习 基本概念 强化学习 强化学习(Reinforcement Learning)是机器学习的一个重要的分支,主要用来解决连续决策的问题.强化学习可以在复杂的.不确定的环境中学习如何实现我们设 ...

  8. 小菜学习设计模式(三)—工厂方法(Factory Method)模式

    前言 设计模式目录: 小菜学习设计模式(一)—模板方法(Template)模式 小菜学习设计模式(二)—单例(Singleton)模式 小菜学习设计模式(三)—工厂方法(Factory Method) ...

  9. VC++/MFC(VC6)开发技术精品学习资料下载汇总

    工欲善其事,必先利其器,VC开发MFC Windows程序,Visual C++或Visual Studio是必须的,恩,这里都给你总结好了,拿去吧:VC/MFC开发必备Visual C++.Visu ...

随机推荐

  1. SolidWorks学习笔记2草图

    几何约束 显示和隐藏约束 单个直线的约束 绘制一个直线,点击左侧的中的水平或者竖直,, 如果要删除改约束,右键绿色的小矩形,相关被约束的对象变成分红,点击删除即可. 两个对象之间的约束 点击一个对象, ...

  2. 2019 徐州icpc网络赛 E. XKC's basketball team

    题库链接: https://nanti.jisuanke.com/t/41387 题目大意 给定n个数,与一个数m,求ai右边最后一个至少比ai大m的数与这个数之间有多少个数 思路 对于每一个数,利用 ...

  3. [转] Python中的装饰器(decorator)

    想理解Python的decorator首先要知道在Python中函数也是一个对象,所以你可以 将函数复制给变量 将函数当做参数 返回一个函数 函数在Python中和变量的用法一样也是一等公民,也就是高 ...

  4. coredump产生的几种可能情况

    coredump产生的几种可能情况 造成程序coredump的原因有很多,这里总结一些比较常用的经验吧: 1,内存访问越界 a) 由于使用错误的下标,导致数组访问越界. b) 搜索字符串时,依靠字符串 ...

  5. DOS sqlcmd

    C:\>sqlcmd -? Microsoft (R) SQL Server 命令行工具版本 12.0.2000.8 NT版权所有 (c) 2014 Microsoft.保留所有权利. 用法: ...

  6. Oracle网络相关概念与常用配置文件

    监听器(Listener) 监听器是Oracle基于服务端的一种网络服务,主要用于监听客户端向数据库服务器提出的链接请求. 本地服务名(Tnsname) Oracle客户端与服务器端的链接是通过客户端 ...

  7. Idea中通过Git将代码同步到GitHub

    一.Idea中配置Git 点击IntelliJ IDEA->Preferences...->Version Control->Git->Path to Git executab ...

  8. JAVA基础--JAVA API集合框架(其他集合类,集合原理)

    一.ArrayList介绍 1.ArrayList介绍 ArrayList它是List接口的真正的实现类.也是我们开发中真正需要使用集合容器对象. ArrayList类,它是List接口的实现.肯定拥 ...

  9. Hive 教程(三)-DDL基础

    DDL,Hive Data Definition Language,数据定义语言: 通俗理解就是数据库与库表相关的操作,本文总结一下基本方法 hive 数据仓库配置 hive 数据仓库默认位置在 hd ...

  10. Ruby学习中(哈希变量/python的字典, 简单的类型转换)

    一. 哈希变量(相当于Python中的字典) 详情参看:https://www.runoob.com/ruby/ruby-hash.html 1.值得注意的 (1). 创建Hash时需注意 # 创建一 ...