TensorFlow(十八):从零开始训练图片分类模型
(一):进入GitHub下载模型--》下载地址
因为我们需要slim模块,所以将包中的slim文件夹复制出来使用。
(1):在slim中新建images文件夹存放图片集
(2):新建model文件夹用来放模型
(3):在datasets文件夹中新建myimages.py文件
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset. The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
""" from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import os
import tensorflow as tf from datasets import dataset_utils slim = tf.contrib.slim _FILE_PATTERN = 'image_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 3500, 'test': 500} # 这里根据自己的训练集内容进行修改 _NUM_CLASSES = 5 _ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
} def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers. Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type. Returns:
A `Dataset` namedtuple. Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
} items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
} decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers) labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir) return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
myimages.py
(4):修改dataset_factory.py
from datasets import myimages datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'myimages':myimages, # 这一句为添加的内容
}
添加的内容
(二):对图片进行处理,生成tfrecord格式的文件。
import tensorflow as tf
import os
import random
import math
import sys #验证集数量
_NUM_TEST = 500
#随机种子
_RANDOM_SEED = 0
#数据块数目
_NUM_SHARDS = 5
#数据集路径
DATASET_DIR = "C:/Users/FELIX/Desktop/tensor_study/slim/images/"
#标签文件名字
LABELS_FILENAME = ''.join([DATASET_DIR,'labels.txt']) #定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename) #判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ['train', 'test']:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True #获取所有文件以及分类
def _get_filenames_and_classes(dataset_dir):
#数据目录
directories = []
#分类名称
class_names = []
for filename in os.listdir(dataset_dir):
#合并文件路径
path = os.path.join(dataset_dir, filename)
#判断该路径是否为目录
if os.path.isdir(path):
#加入数据目录
directories.append(path)
#加入类别名称
class_names.append(filename) photo_filenames = []
#循环每个分类的文件夹
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
#把图片加入图片列表
photo_filenames.append(path) return photo_filenames, class_names def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, class_id):
#Abstract base class for protocol messages.
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
})) def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name)) #把数据转为TFRecord格式
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
assert split_name in ['train', 'test']
#计算每个数据块有多少数据
num_per_shard = int(len(filenames) / _NUM_SHARDS)
with tf.Graph().as_default():
with tf.Session() as sess:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
#每一个数据块开始的位置
start_ndx = shard_id * num_per_shard
#每一个数据块最后的位置
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
try:
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))
sys.stdout.flush()
#读取图片
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # 这里一定要rb否则会出现编码错误
#获得图片的类别名称
class_name = os.path.basename(os.path.dirname(filenames[i]))
#找到类别名称对应的id
class_id = class_names_to_ids[class_name]
#生成tfrecord文件
example = image_to_tfexample(image_data, b'jpg', class_id)
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print("Could not read:",filenames[i])
print("Error:",e)
print("Skip it\n") sys.stdout.write('\n')
sys.stdout.flush() if __name__ == '__main__':
#判断tfrecord文件是否存在
if _dataset_exists(DATASET_DIR):
print('tfcecord文件已存在')
else:
#获得所有图片以及分类
photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
#把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}
class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #把数据切分为训练集和测试集
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_TEST:]
testing_filenames = photo_filenames[:_NUM_TEST] #数据转换
_convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)
_convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR) #输出labels文件
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
write_label_file(labels_to_class_names, DATASET_DIR)
生成tfrecord
(三):新建批处理文件,开始训练模型
python C:/Users/FELIX/Desktop/tensor_study/slim/train_image_classifier.py ^
--train_dir=C:/Users/FELIX/Desktop/tensor_study/slim/model ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--dataset_dir=C:/Users/FELIX/Desktop/tensor_study/slim/images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause 注释:
第一行表示运行训练文件,路径为全路径
第二行表示模型存放位置
第三行为创建的myimages文件名
第四行为使用的训练集
第五行为数据集所在的位置
第六行为批次大小,默认为32,看个人GPU,我用10
第七行为训练次数,默认无限次
第八行为使用模型名称
批处理文件
TensorFlow(十八):从零开始训练图片分类模型的更多相关文章
- TensorFlow(十七):训练自己的图片分类模型
(一)下载inception-v3--见TensorFlow(十四) (二)准备训练用的图片集,因为我没有图片集,所以写了个自动抓取百度图片的脚本-见抓取百度图片 (三)创建retrain.py文件, ...
- 使用tensorflow的retrain.py训练图片分类器
参考 https://hackernoon.com/creating-insanely-fast-image-classifiers-with-mobilenet-in-tensorflow-f030 ...
- 用C++调用tensorflow在python下训练好的模型(centos7)
本文主要参考博客https://blog.csdn.net/luoyexuge/article/details/80399265 [1] bazel安装参考:https://blog.csdn.net ...
- NLP(十八)利用ALBERT提升模型预测速度的一次尝试
前沿 在文章NLP(十七)利用tensorflow-serving部署kashgari模型中,笔者介绍了如何利用tensorflow-serving部署来部署深度模型模型,在那篇文章中,笔者利用k ...
- PyTorch ImageNet 基于预训练六大常用图片分类模型的实战
微调 Torchvision 模型 在本教程中,我们将深入探讨如何对 torchvision 模型进行微调和特征提取,所有这些模型都已经预先在1000类的Imagenet数据集上训练完成.本教程将深入 ...
- 用Pytorch训练MNIST分类模型
本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...
- Tensorflow 使用slim框架下的分类模型进行分类
Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object ...
- 【emWin】例程十八:jpeg图片显示
说明:1.将文件拷入SD卡内即可在指定位置绘制jpeg图片文件,不必加载到储存器. 由于jpeg格式文件显示时需要进行解压缩,耗用动态内存,iCore3所有模块受emwin缓存的限制,jpeg ...
- 源码分析——迁移学习Inception V3网络重训练实现图片分类
1. 前言 近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现.在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题.无线领域 ...
随机推荐
- Vue使用指南(三)
组件 '''1.根组件:new Vue()创建的组件,一般不明确自身的模板,模板就采用挂载点2.局部组件: local_component = {}2.全局组件: Vue.component({})' ...
- oracle 生成随机日期+时间
oracle 生成随机日期+时间 SELECT to_date(TRUNC(DBMS_RANDOM.VALUE(to_number(to_char(to_date('20110101','yyyymm ...
- <a>的javascript+jquery编程实例之删除(定位节点与事件绑定)
相关jquery方法 parent(), remove() //上传图片 article_create.js article_edit.js function uploadAttachment() { ...
- 测试winform程序到树莓派运行
啥也不说了,都在下图中了.winform可以在树莓派上跑了
- Python查看模块
1.查看Python所有内置模块 按以下链接打开,每个模块有介绍,可以选择不同的版本 https://docs.python.org/3.6/library/index.html 2.查看Python ...
- Axios使用拦截器全局处理请求重试
Axios拦截器 Axios提供了拦截器的接口,让我们能够全局处理请求和响应.Axios拦截器会在Promise的then和catch调用前拦截到. 请求拦截示例 axios.interceptors ...
- 通透理解viewport
摘自:https://blog.csdn.net/u014787301/article/details/44466697 在移动设备上进行网页的重构或开发,首先得搞明白的就是移动设备上的viewpor ...
- Java 之 缓冲流
一.缓冲流概述 缓冲流,也叫高效流,是对四个 FileXXX 流的增强,所有也有四个流,按照类型分类: 字节缓冲流:BufferedInputStream,BufferedOutputStream 字 ...
- iOS 如何判断一个点在某个指定区域中
在iOS 开发中会遇到 判断位置的情况 iOS 自己都有函数实现的这些功能. 判断一个点是否在这个rect区域中 bool CGRectContainsPoint(CGRect rect,CGPoin ...
- centos 7.0 读写ntfs分区
wget -O /etc/yum.repos.d/epel.repo http://mirrors.aliyun.com/repo/epel-7.repo yum install ntfs-3g 查看 ...