Tensorflow2 自定义数据集图片完成图片分类任务
对于自定义数据集的图片任务,通用流程一般分为以下几个步骤:
Load data
Train-Val-Test
Build model
Transfer Learning
其中大部分精力会花在数据的准备和预处理上,本文用一种较为通用的数据处理手段,并通过手动构建,简单模型, 层数较深的resnet网络,和基于VGG19的迁移学习。
你可以通过这个例子,快速搭建网络,并训练处一个较为满意的结果。
1. Load data
数据集来自Pokemon的5分类数据, 每一种的图片数量为200多张,是一个较小型的数据集。
官方项目链接:
https://www.pyimagesearch.com/2018/04/16/keras-and-convolutional-neural-networks-cnns/
1.1 数据集介绍
Pokemon文件夹中包含5个子文件,其中每个子文件夹名为对应的类别名。文件夹中包含有png, jpeg的图片文件。
1.2 解题思路
由于文件夹中没有划分,训练集和测试集,所以需要构建一个csv文件读取所有的文件,及其类别
shuffle数据集以后,划分Train_val_test
对数据进行预处理, 数据标准化,数据增强, 可视化处理
"""python
# 创建数字编码表
import os
import glob
import random
import csv
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import time
def load_csv(root, filename, name2label):
"""
将分散在各文件夹中的图片, 转换为图片和label对应的一个dataset文件, 格式为csv
:param root: 文件路径(每个子文件夹中的文件属于一类)
:param filename: 文件名
:param name2label: 类名编码表 {'类名1':0, '类名2':1..}
:return: images, labels
"""
# 判断是否csv文件已经生成
if not os.path.exists(os.path.join(root, filename)): # join-将路径与文件名何为一个路径并返回(没有会生成新路径)
images = [] # 存的是文件路径
for name in name2label.keys():
# pokemon\pikachu\00000001.png
# glob.glob() 利用通配符检索路径内的文件,类似于正则表达式
images += glob.glob(os.path.join(root, name, '*')) # png, jpg, jpeg
print(name2label)
print(len(images), images)
random.shuffle(images)
with open(os.path.join(root, filename), 'w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[1] # os.sep 表示分隔符 window-'\\' , linux-'/'
label = name2label[name] # 0, 1, 2..
# 'pokemon\\bulbasaur\\00000000.png', 0
writer.writerow([img, label]) # 如果不设定newline='', 2个数据会分为2行写
print('write into csv file:', filename)
# 读取现有文件
images, labels = [], []
with open(os.path.join(root, filename)) as f:
reader = csv.reader(f)
for row in reader:
# 'pokemon\\bulbasaur\\00000000.png', 0
img, label = row
label = int(label) # str-> int
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def load_pokemon(root, mode='train'):
"""
# 创建数字编码表
:param root: root path
:param mode: train, valid, test
:return: images, labels, name2label
"""
name2label = {} # {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
for name in sorted(os.listdir(os.path.join(root))):
# sorted() 是为了复现结果的一致性
# os.listdir - 返回路径下的所有文件(文件夹,文件)列表
if not os.path.isdir(os.path.join(root, name)): # 是否为文件夹且是否存在
continue
# 每个类别编码一个数字
name2label[name] = len(name2label)
# 读取label
images, labels = load_csv(root, 'images.csv', name2label)
# 划分数据集 [6:2:2]
if mode == 'train':
images = images[:int(0.6 * len(images))]
labels = labels[:int(0.6 * len(labels))] # len(images) == len(labels)
elif mode == 'valid':
images = images[int(0.6 * len(images)):int(0.8 * len(images))]
labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
else:
images = images[int(0.8 * len(images)):]
labels = labels[int(0.8 * len(labels)):]
return images, labels, name2label
# imagenet 数据集均值, 方差
img_mean = tf.constant([0.485, 0.456, 0.406]) # 3 channel
img_std = tf.constant([0.229, 0.224, 0.225])
def normalization(x, mean=img_mean, std=img_std):
# [224, 224, 3]
x = (x - mean) / std
return x
def denormalization(x, mean=img_mean, std=img_std):
x = x * std + mean
return x
def preprocess(x, y):
# x: path, y: label
x = tf.io.read_file(x) # 2进制
# x = tf.image.decode_image(x)
x = tf.image.decode_jpeg(x, channels=3) # RGBA
x = tf.image.resize(x, [244, 244])
# data augmentation
# x = tf.image.random_flip_up_down(x)
x = tf.image.random_flip_left_right(x)
x = tf.image.random_crop(x, [224, 224, 3]) # 模型缩减比例不宜过大,否则会增大训练难度
x = tf.cast(x, dtype=tf.float32) / 255. # unit8 -> float32
# U[0,1] -> N(0,1) # 提高训练准确度
x = normalization(x)
y = tf.convert_to_tensor(y)
return x, y
def main():
images, labels, name2label = load_pokemon('pokemon', 'train')
print('images:', len(images), images)
print('labels:', len(labels), labels)
# print(name2label)
# .map()函数要位于.batch()之前, 否则 x=tf.io.read_file()会一次读取一个batch的图片,从而报错
db = tf.data.Dataset.from_tensor_slices((images, labels)).map(preprocess).shuffle(1000).batch(32)
# tf.summary()
# 提供了各类方法(支持各种多种格式)用于保存训练过程中产生的数据(比如loss_value、accuracy、整个variable),
# 这些数据以日志文件的形式保存到指定的文件夹中。
# 数据可视化:而tensorboard可以将tf.summary()
# 记录下来的日志可视化,根据记录的数据格式,生成折线图、统计直方图、图片列表等多种图。
# tf.summary()
# 通过递增的方式更新日志,这让我们可以边训练边使用tensorboard读取日志进行可视化,从而实时监控训练过程。
writer = tf.summary.create_file_writer('logs')
for step, (x, y) in enumerate(db):
with writer.as_default():
x = denormalization(x)
tf.summary.image('img', x, step=step, max_outputs=9) # STEP:默认选项,指的是横轴显示的是训练迭代次数
time.sleep(5)
if __name__ == '__main__':
main()
"""
2. 构建模型进行训练
2.1 自定义小型网络
由于数据集数量较少,大型网络的训练中往往会出现过拟合情况,这里就定义了一个2层卷积的小型网络。
引入early_stopping回调函数后,3个epoch没有较大变化的情况下,模型训练的准确率为0.8547
"""
# 1. 自定义小型网络
model = keras.Sequential([
layers.Conv2D(16, 5, 3),
layers.MaxPool2D(3, 3),
layers.ReLU(),
layers.Conv2D(64, 5, 3),
layers.MaxPool2D(2, 2),
layers.ReLU(),
layers.Flatten(),
layers.Dense(64),
layers.ReLU(),
layers.Dense(5)
])
model.build(input_shape=(None, 224, 224, 3))
model.summary()
early_stopping = EarlyStopping(
monitor='val_loss',
patience=3,
min_delta=0.001
)
model.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
callbacks=[early_stopping])
model.evaluate(db_test)
"""
2.2 自定义的Resnet网络
resnet 网络对于层次较深的网络的可训练型提升很大,主要是通过一个identity layer保证了深层次网络的训练效果不会弱于浅层网络。
其他文章中有详细介绍resnet的搭建,这里就不做赘述, 这里构建了一个resnet18网络, 准确率0.7607。
"""
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
class ResnetBlock(keras.Model):
def __init__(self, channels, strides=1):
super(ResnetBlock, self).__init__()
self.channels = channels
self.strides = strides
self.conv1 = layers.Conv2D(channels, 3, strides=strides,
padding=[[0, 0], [1, 1], [1, 1], [0, 0]])
self.bn1 = keras.layers.BatchNormalization()
self.conv2 = layers.Conv2D(channels, 3, strides=1,
padding=[[0, 0], [1, 1], [1, 1], [0, 0]])
self.bn2 = keras.layers.BatchNormalization()
if strides != 1:
self.down_conv = layers.Conv2D(channels, 1, strides=strides, padding='valid')
self.down_bn = tf.keras.layers.BatchNormalization()
def call(self, inputs, training=None):
residual = inputs
x = self.conv1(inputs)
x = tf.nn.relu(x)
x = self.bn1(x, training=training)
x = self.conv2(x)
x = tf.nn.relu(x)
x = self.bn2(x, training=training)
# 残差连接
if self.strides != 1:
residual = self.down_conv(inputs)
residual = tf.nn.relu(residual)
residual = self.down_bn(residual, training=training)
x = x + residual
x = tf.nn.relu(x)
return x
class ResNet(keras.Model):
def __init__(self, num_classes, initial_filters=16, **kwargs):
super(ResNet, self).__init__(**kwargs)
self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid')
self.blocks = keras.models.Sequential([
ResnetBlock(initial_filters * 2, strides=3),
ResnetBlock(initial_filters * 2, strides=1),
# layers.Dropout(rate=0.5),
ResnetBlock(initial_filters * 4, strides=3),
ResnetBlock(initial_filters * 4, strides=1),
ResnetBlock(initial_filters * 8, strides=2),
ResnetBlock(initial_filters * 8, strides=1),
ResnetBlock(initial_filters * 16, strides=2),
ResnetBlock(initial_filters * 16, strides=1),
])
self.final_bn = layers.BatchNormalization()
self.avg_pool = layers.GlobalMaxPool2D()
self.fc = layers.Dense(num_classes)
def call(self, inputs, training=None):
# print('x:',inputs.shape)
out = self.stem(inputs, training = training)
out = tf.nn.relu(out)
# print('stem:',out.shape)
out = self.blocks(out, training=training)
# print('res:',out.shape)
out = self.final_bn(out, training=training)
# out = tf.nn.relu(out)
out = self.avg_pool(out)
# print('avg_pool:',out.shape)
out = self.fc(out)
# print('out:',out.shape)
return out
def main():
num_classes = 5
resnet18 = ResNet(5)
resnet18.build(input_shape=(None, 224, 224, 3))
resnet18.summary()
if __name__ == '__main__':
main()
"""
"""
# 2.resnet18训练, 图片数量较小,训练结果不是特别好
# resnet = ResNet(5) # 0.7607
# resnet.build(input_shape=(None, 224, 224, 3))
# resnet.summary()
"""
2.3 VGG19迁移学习
迁移学习利用了数据集之间的相似性,对于数据集数量较少的时候,训练效果会远优于其他。
在训练过程中,使用include_top=False, 去掉最后分类的基层Dense, 重新构建并训练就可以了。准确率0.9316
"""
# 3. VGG19迁移学习,迁移学习利用数据集之间的相似性, 结果远好于其他2种
# 为了方便,这里仍然使用resnet命名
net = tf.keras.applications.VGG19(weights='imagenet', include_top=False, pooling='max' )
net.trainable = False
resnet = keras.Sequential([
net,
layers.Dense(5)
])
resnet.build(input_shape=(None, 224, 224, 3)) # 0.9316
resnet.summary()
early_stopping = EarlyStopping(
monitor='val_loss',
patience=3,
min_delta=0.001
)
resnet.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
callbacks=[early_stopping])
resnet.evaluate(db_test)
"""
附录:
train_scratch.py 代码
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
# 设置GPU显存按需分配
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
# try:
# # Currently, memory growth needs to be the same across GPUs
# for gpu in gpus:
# tf.config.experimental.set_memory_growth(gpu, True)
# logical_gpus = tf.config.experimental.list_logical_devices('GPU')
# print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
# except RuntimeError as e:
# # Memory growth must be set before GPUs have been initialized
# print(e)
from pokemon import load_pokemon, normalization
from resnet import ResNet
def preprocess(x, y):
# x: 图片的路径,y:图片的数字编码
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3) # RGBA
# 图片缩放
# x = tf.image.resize(x, [244, 244])
# 图片旋转
# x = tf.image.rot90(x,2)
# 随机水平翻转
x = tf.image.random_flip_left_right(x)
# 随机竖直翻转
# x = tf.image.random_flip_up_down(x)
# 图片先缩放到稍大尺寸
x = tf.image.resize(x, [244, 244])
# 再随机裁剪到合适尺寸
x = tf.image.random_crop(x, [224, 224, 3])
# x: [0,255]=> -1~1
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalization(x)
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=5)
return x, y
batchsz = 32
# create train db
images1, labels1, table = load_pokemon('pokemon', 'train')
db_train = tf.data.Dataset.from_tensor_slices((images1, labels1))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
# create validation db
images2, labels2, table = load_pokemon('pokemon', 'valid')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# create test db
images3, labels3, table = load_pokemon('pokemon', mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
# 1. 自定义小型网络
# resnet = keras.Sequential([
# layers.Conv2D(16, 5, 3),
# layers.MaxPool2D(3, 3),
# layers.ReLU(),
# layers.Conv2D(64, 5, 3),
# layers.MaxPool2D(2, 2),
# layers.ReLU(),
# layers.Flatten(),
# layers.Dense(64),
# layers.ReLU(),
# layers.Dense(5)
# ]) # 0.8547
# 2.resnet18训练, 图片数量较小,训练结果不是特别好
# resnet = ResNet(5) # 0.7607
# resnet.build(input_shape=(None, 224, 224, 3))
# resnet.summary()
# 3. VGG19迁移学习,迁移学习利用数据集之间的相似性, 结果远好于其他2种
net = tf.keras.applications.VGG19(weights='imagenet', include_top=False, pooling='max' )
net.trainable = False
resnet = keras.Sequential([
net,
layers.Dense(5)
])
resnet.build(input_shape=(None, 224, 224, 3)) # 0.9316
resnet.summary()
early_stopping = EarlyStopping(
monitor='val_loss',
patience=3,
min_delta=0.001
)
resnet.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
callbacks=[early_stopping])
resnet.evaluate(db_test)
"""
Tensorflow2 自定义数据集图片完成图片分类任务的更多相关文章
- Android实现自定义带文字和图片的Button
Android实现自定义带文字和图片的Button 在Android开发中经常会需要用到带文字和图片的button,下面来讲解一下常用的实现办法. 一.用系统自带的Button实现 最简单的一种办法就 ...
- 重新想象 Windows 8.1 Store Apps (92) - 其他新特性: CoreDispatcher, 日历, 自定义锁屏系列图片
[源码下载] 重新想象 Windows 8.1 Store Apps (92) - 其他新特性: CoreDispatcher, 日历, 自定义锁屏系列图片 作者:webabcd 介绍重新想象 Win ...
- Inno Setup技巧[界面]自定义安装向导小图片宽度
原文 blog.sina.com.cn/s/blog_5e3cc2f30100cj7e.html 英文版中安装向导右上角小图片的大小为55×55,汉化版中为55×51.如果图片超过规定的宽度将会被压 ...
- 转:【译】Asp.net MVC 利用自定义RouteHandler来防止图片盗链
[译]Asp.net MVC 利用自定义RouteHandler来防止图片盗链 你曾经注意过在你服务器请求日志中多了很多对图片资源的请求吗?这可能是有人在他们的网站中盗链了你的图片所致,这会占用你 ...
- xpadder教程:自定义设置游戏手柄的图片
关于xpadder设置按键的教程,网上已经很多,我就不凑这个热闹了.这里介绍的是如何自定义设置手柄的图片,就是按钮的背景图,如下图所示: 步骤: 1)准备一张背景图 注意:格式必须是24位色的BMP位 ...
- 第十三节:HttpHander扩展及应用(自定义扩展名、图片防盗链)
一. 自定义扩展名 1. 前言 凡是实现了IHttpHandler接口的类均为Handler类,HttpHandler是一个HTTP请求的真正处理中心,在HttpHandler容器中,ASP.NET ...
- torch_13_自定义数据集实战
1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...
- Scaled-YOLOv4 快速开始,训练自定义数据集
代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...
- CSS3实现鼠标移动到图片上图片变大
CSS3实现鼠标移动到图片上图片变大(缓慢变大,有过渡效果,放大的过程是有动画过渡的,这个过渡的时间可以自定义 <!DOCTYPE html><html> <head&g ...
随机推荐
- MySQL死锁系列-常见加锁场景分析
在上一篇文章<锁的类型以及加锁原理>主要总结了 MySQL 锁的类型和模式以及基本的加锁原理,今天我们就从原理走向实战,分析常见 SQL 语句的加锁场景.了解了这几种场景,相信小伙伴们也能 ...
- 上传应用至Google Play 后被重新签名,怎么获取最新的签名信息
基本签名信息在Google Play 上都能查看到. 快速解决Google+登录和facebook登录的办法: 不用改包名重新创建应用,不用重新打包,不要删除自己的keystore文件,不要重新创建k ...
- 认证(Authentication)和授权(Authorization)总结
身份认证是验证你的身份,一旦通过验证,即启用授权.你所拥有的身份可以进行哪些操作都是由授权规定.例如,任何银行客户都可以创建一个账户(如用户名),并使用该账户登录该银行的网上服务,但银行的授权政策必须 ...
- Android Studio常见对话框(普通对话框、单选对话框、多选对话框、进度条对话框、消息对话框、自定义对话框)
Android Studio常见对话框(普通对话框.单选对话框.多选对话框.进度条对话框.消息对话框.自定义对话框) 1.普通对话框 2.单选对话框 3.多选对话框 4.进度条对话框 5.消息对话框 ...
- 分布式 ID 的 9 种生成方式
为什么要用分布式ID? 在说分布式ID的具体实现之前,我们来简单分析一下为什么用分布式ID?分布式ID应该满足哪些特征? 什么是分布式ID? 拿MySQL数据库举个栗子: 在我们业务数据量不大的时候, ...
- C#中值类型,引用类型,字符串类型的区别(内存图解)
如果用图片来解释值类型,引用类型和字符串类型(引用类型的一种)的区别的话 值类型: 引用类型: string类型:
- Qcom rampdump解析工具使用
使用如下命令获取qcom工具: ljj@ljj-ThinkCentre-M83:~/git/qcom_tools/ramdump$ git clone git://codeaurora.org/qui ...
- Java实现 LeetCode 581 最短无序连续子数组(从两遍搜索找两个指针)
581. 最短无序连续子数组 给定一个整数数组,你需要寻找一个连续的子数组,如果对这个子数组进行升序排序,那么整个数组都会变为升序排序. 你找到的子数组应是最短的,请输出它的长度. 示例 1: 输入: ...
- Java实现 蓝桥杯VIP 算法训练 入学考试
问题描述 辰辰是个天资聪颖的孩子,他的梦想是成为世界上最伟大的医师.为此,他想拜附近最有威望的医师为师.医师为了判断他的资质,给他出了一个难题.医师把他带到一个到处都是草药的山洞里对他说:" ...
- Java实现寻找最小的k个数
1 问题描述 有n个整数,请找出其中最小的k个数,要求时间复杂度尽可能低. 2 解决方案 2.1 全部排序法 先对这n个整数进行快速排序,在依次输出前k个数. package com.liuzhen. ...