TFRecord 使用
tfrecord生成
import os
import xmltodict
import tensorflow as tf
import numpy as np
dir_path = 'F:\数据存储\VOCdevkit\VOC2012\Annotations'
dirs = os.listdir(dir_path)
imgs_dir = "F:\数据存储\VOCdevkit\VOC2012\JPEGImages"
out_path = 'F:\数据存储\VOCdevkit\\voc2012.tfrecord'
classes = [
"background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
sess = tf.Session()
def get_and_resize_img(img_file):
'''
将图片设置为224*224的尺寸大小
返回图片,返回变化倍数,shape
'''
img = tf.read_file(imgs_dir + '/' + img_file)
img = tf.image.decode_jpeg(img)
shape_old = sess.run(img).shape
resized_img = tf.image.resize_images(img, [224, 224], method=0)
resized_img = sess.run(resized_img)
resized_img = np.asarray(resized_img, dtype='uint8')
resized_img_str = resized_img.tostring()
shape_new = resized_img.shape
# print(shape_new)
# print(shape_old)
# print('shape_old的长是width是维度1,height是维度0')
w_scale = shape_new[0] / shape_old[1]
h_scale = shape_new[1] / shape_old[0]
return resized_img_str, w_scale, h_scale, shape_new
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
for file in dirs:
i = i + 1
# if i > 1000:
# break
with open(dir_path + '/' + file) as xml_txt:
doc = xmltodict.parse(xml_txt.read())
img_file_name = file.split('.')[0]
resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg')
img_obtain_classes = []
y_mins = []
x_mins = []
y_maxes = []
x_maxes = []
if type(doc['annotation']["object"]).__name__ == 'OrderedDict':
if doc['annotation']["object"]['name'] in classes:
img_obtain_classes.append(classes.index(doc['annotation']["object"]['name']))
y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax'])))
else:
for one_object in doc['annotation']["object"]:
# ['annotation']["object"][0]["name"]
if one_object['name'] in classes:
img_obtain_classes.append(classes.index(one_object['name']))
y_mins.append(float(h_scale * int(one_object['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(one_object['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax'])))
# example = tf.train.Example(features=tf.train.Features(feature={
# 'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
# }
# ))
img_file_name = bytes(img_file_name, encoding='utf8')
example = tf.train.Example(features=tf.train.Features(feature={
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)),
'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
}))
writer.write(example.SerializeToString())
writer.close()
sess.close()
print('ok')
tfrecord读取
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
# import sys
#
# sys.path.append("..")
classes = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
# 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))),
# 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
# 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
# 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
# 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
# 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
def _parse_record(example_proto):
features = {
'filename': tf.FixedLenFeature([], tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'classes': tf.VarLenFeature(tf.int64),
'y_mins': tf.VarLenFeature(tf.float32),
'x_mins': tf.VarLenFeature(tf.float32),
'y_maxes': tf.VarLenFeature(tf.float32),
'x_maxes': tf.VarLenFeature(tf.float32),
'encoded': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
max_value = tf.placeholder(tf.int64, shape=[])
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(2):
features = sess.run(iterator.get_next())
name = features['filename']
name = name.decode()
shape = features['shape']
classes = features['classes']
y_mins = features['y_mins']
x_mins = features['x_mins']
y_maxes = features['y_maxes']
x_maxes = features['x_maxes']
# name = name.decode()
img_data = features['encoded']
print(len(img_data))
print('=======')
print("shape", shape)
print("name", name)
print("classes", classes.values)
print("y_mins", y_mins.values)
print("x_mins", x_mins.values)
print("y_maxes", y_maxes.values)
print("x_maxes", x_maxes.values)
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
print("img_data", image_data)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
# img_data = np.fromstring(img_data, dtype=np.uint8)
# image_data = np.reshape(img_data, shape)
#
# plt.figure()
# # 显示图片
plt.imshow(image_data)
plt.show()
read_test('F:\数据存储\VOCdevkit\\voc2012.tfrecord')
尺寸不固定矩阵的存储和读取
import json
import jieba
import tensorflow as tf
with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file:
dic = json.loads(file.read())
all_words_word2id = dic["all_words_word2id"]
stop_words = []
with open('./stop_words.txt', encoding='utf-8') as f:
line = f.readline()
while line:
stop_words.append(line[:-1])
line = f.readline()
stop_words = set(stop_words)
print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words)))
dir_path = 'F:\\数据存储\新闻语料\\news2016zh_train.json'
dir_path_test = 'F:\\数据存储\新闻语料\\news2016zh_valid.json'
out_path = 'F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord'
def getCutSequnce(line):
# 使用jieba 进行中文分词
raw_words = list(jieba.cut(line, cut_all=False))
# 存储一句话的分词结果
raw_word_list = []
# 去除停用词
for word in raw_words:
if word not in stop_words and word not in ['www', 'com', 'http']:
raw_word_list.append(word)
return raw_word_list
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
with open(dir_path, encoding='utf-8') as txt:
one_dic = txt.readline()
while one_dic:
i = i + 1
if i > 10000:
break
if (i % 1000) == 0:
print(i)
one_dic_json = json.loads(one_dic)
title = one_dic_json['title']
content = one_dic_json['content']
if len(content) > 3000:
one_dic = txt.readline()
continue
one_dic = txt.readline()
if len(title) == 0 or len(content) == 0:
continue
title_list = getCutSequnce(title)
content_list = getCutSequnce(content)
title_list_index = []
for one in title_list:
try:
title_list_index.append(all_words_word2id[one])
except:
pass
content_list_index = []
for one_word in content_list:
try:
content_list_index.append(all_words_word2id[one_word])
except:
pass
example = tf.train.Example(features=tf.train.Features(feature={
'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)),
'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index))
}))
writer.write(example.SerializeToString())
import tensorflow as tf
import numpy as np
def _parse_record(example_proto):
features = {
'title': tf.VarLenFeature(tf.int64),
'content': tf.VarLenFeature(dtype=tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(5):
features = sess.run(iterator.get_next())
name = features['title']
content = features['content']
print("xx", content)
print("xx", np.array(content).shape)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
read_test('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
统计数据条数
import tensorflow as tf
def total_sample(file_name):
sample_nums = 0
for record in tf.python_io.tf_record_iterator(file_name):
sample_nums += 1
return sample_nums
result = total_sample('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
print(result)
TFRecord 使用的更多相关文章
- Tensorflow 处理libsvm格式数据生成TFRecord (parse libsvm data to TFRecord)
#写libsvm格式 数据 write libsvm #!/usr/bin/env python #coding=gbk # ================================= ...
- 学习笔记TF016:CNN实现、数据集、TFRecord、加载图像、模型、训练、调试
AlexNet(Alex Krizhevsky,ILSVRC2012冠军)适合做图像分类.层自左向右.自上向下读取,关联层分为一组,高度.宽度减小,深度增加.深度增加减少网络计算量. 训练模型数据集 ...
- [TFRecord格式数据]利用TFRecords存储与读取带标签的图片
利用TFRecords存储与读取带标签的图片 原创文章,转载请注明出处~ 觉得有用的话,欢迎一起讨论相互学习~Follow Me TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是 ...
- 深度学习原理与框架-Tfrecord数据集的读取与训练(代码) 1.tf.train.batch(获取batch图片) 2.tf.image.resize_image_with_crop_or_pad(图片压缩) 3.tf.train.per_image_stand..(图片标准化) 4.tf.train.string_input_producer(字符串入队列) 5.tf.TFRecord(读
1.tf.train.batch(image, batch_size=batch_size, num_threads=1) # 获取一个batch的数据 参数说明:image表示输入图片,batch_ ...
- 深度学习原理与框架-Tfrecord数据集的制作 1.tf.train.Examples(数据转换为二进制) 3.tf.image.encode_jpeg(解码图片加码成jpeg) 4.tf.train.Coordinator(构建多线程通道) 5.threading.Thread(建立单线程) 6.tf.python_io.TFR(TFR读入器)
1. 配套使用: tf.train.Examples将数据转换为二进制,提升IO效率和方便管理 对于int类型 : tf.train.Examples(features=tf.train.Featur ...
- 3. Tensorflow生成TFRecord
1. Tensorflow高效流水线Pipeline 2. Tensorflow的数据处理中的Dataset和Iterator 3. Tensorflow生成TFRecord 4. Tensorflo ...
- TFRecord文件的读写
前言在跑通了官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示. TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而 ...
- 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...
- tfrecord
制作自己的TFRecord数据集,读取,显示及代码详解 http://blog.csdn.net/miaomiaoyuan/article/details/56865361
- 3 TFRecord样例程序实战
将图片数据写入Record文件 # 定义函数转化变量类型. def _int64_feature(value): return tf.train.Feature(int64_list=tf.train ...
随机推荐
- UOJ185 ZJOI2016 小星星 容斥、树形DP
传送门 先考虑一个暴力的DP:设\(f_{i,j,S}\)表示点\(i\)映射到了图中的点\(j\),且点\(i\)所在子树的所有点映射到了图中的集合\(S\)时的映射方案数,转移暴力地枚举子集即可, ...
- 记录:拷贝gitblit里的项目使用git命令clone、pull、push等,出现一直在加载,卡住没反应的问题
俺想克隆别人gitblit里的其中一个版本库(俺在别人gitblit有权限) 懂得git的道友们,都应该知道克隆一个公共项目,随便找到,打开git终端,输入git clone 项目地址就行了 到了俺这 ...
- windows10 iis浏览wcf报404.3错误
报错:HTTP错误404.3-Not Found 由于扩展配置问题而无法提供您请求的页面.如果该页面是脚本,请添加处理程序.如果应下载文件,请添加MIME映射. 解决步骤如下: 控制面板->打开 ...
- 未能加载文件或程序集system.web.extensions解决方法
发现未能加载文件或程序集的错误,这是由于我的机器上没有安装Ajax的原因.问题解决后,整理如下:表现:1."System.Web.Extensions, Version=1.0.61025.0, Cu ...
- python3-使用requests模拟登录网易云音乐
# -*- coding: utf-8 -*- from Crypto.Cipher import AES import base64 import random import codecs impo ...
- 四 python中关于OOP的常用术语
抽象/实现 抽象指对现实世界问题和实体的本质表现,行为和特征建模,建立一个相关的子集,可以用于 绘程序结构,从而实现这种模型.抽象不仅包括这种模型的数据属性,还定义了这些数据的接口. 对某种抽象的实现 ...
- vue-cli3 一直运行 /sockjs-node/info
首先 sockjs-node 是一个JavaScript库,提供跨浏览器JavaScript的API,创建了一个低延迟.全双工的浏览器和web服务器之间通信通道. 服务端:sockjs-node(ht ...
- Vivado debug异常现象
前言 bit文件和ltx文件的信号位宽不匹配问题.用了dont_touch等属性没用... WARNING: [Labtools 27-1972] Mismatch between the desig ...
- Jmeter学习笔记(十三)——xpath断言
1.什么是XPath断言 XPath即为XML路径语言,它是一种用来确定XML(标准通用标记语言的子集)文档中某部分位置的语言.XPath基于XML的树状结构,提供在数据结构树中找寻节点的能力. Ap ...
- xcode 手动管理内存 的相关知识点总结
一.XCode4.2以后支持自动释放内存ARC xcode自4.2以后就支持自动释放内存了,但有时我们还是想手动管理内存,这如何处理呢. 很简单,想要取消自动释放,只要在 Build Setting ...