1、准备数据

cifar2数据集为cifar10数据集的子集,只包括前两种类别airplane和automobile。

训练集有airplane和automobile图片各5000张,测试集有airplane和automobile图片各1000张。

cifar2任务的目标是训练一个模型来对飞机airplane和机动车automobile两种图片进行分类。

我们准备的Cifar2数据集的文件结构如下所示。

在tensorflow中准备图片数据的常用方案有两种,第一种是使用tf.keras中的ImageDataGenerator工具构建图片数据生成器。

第二种是使用tf.data.Dataset搭配tf.image中的一些图片处理方法构建数据管道。

第一种方法更为简单,其使用范例可以参考以下文章。

https://zhuanlan.zhihu.com/p/67466552

第二种方法是TensorFlow的原生方法,更加灵活,使用得当的话也可以获得更好的性能。

我们此处介绍第二种方法。

import tensorflow as tf
from tensorflow.keras import datasets,layers,models BATCH_SIZE = 100 def load_image(img_path,size = (32,32)):
label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,".*/automobile/.*") \
else tf.constant(0,tf.int8)
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img) #注意此处为jpeg格式
img = tf.image.resize(img,size)/255.0
return(img,label) # 使用并行化预处理num_parallel_calls 和预存数据prefetch来提升性能
ds_train = tf.data.Dataset.list_files("./data/cifar2/train/*/*.jpg") \
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE) ds_test = tf.data.Dataset.list_files("./data/cifar2/test/*/*.jpg") \
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE)

for x,y in ds_train.take(1):
print(x.shape,y.shape)

(100, 32, 32, 3) (100,)

2、定义模型

使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。

此处选择使用函数式API构建模型。

tf.keras.backend.clear_session() #清空会话

inputs = layers.Input(shape=(32,32,3))
x = layers.Conv2D(32,kernel_size=(3,3))(inputs)
x = layers.MaxPool2D()(x)
x = layers.Conv2D(64,kernel_size=(5,5))(x)
x = layers.MaxPool2D()(x)
x = layers.Dropout(rate=0.1)(x)
x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)
outputs = layers.Dense(1,activation = 'sigmoid')(x) model = models.Model(inputs = inputs,outputs = outputs) model.summary()

3、训练模型

训练模型通常有3种方法,内置fit方法,内置train_on_batch方法,以及自定义训练循环。此处我们选择最常用也最简单的内置fit方法。

import datetime

logdir = "./data/keras_model/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.binary_crossentropy,
metrics=["accuracy"]
) history = model.fit(ds_train,epochs= 10,validation_data=ds_test,
callbacks = [tensorboard_callback],workers = 4)
Epoch 1/10
100/100 [==============================] - 2205s 22s/step - loss: 0.4632 - accuracy: 0.7786 - val_loss: 0.3375 - val_accuracy: 0.8620
Epoch 2/10
100/100 [==============================] - 11s 110ms/step - loss: 0.3346 - accuracy: 0.8565 - val_loss: 0.2617 - val_accuracy: 0.8965
Epoch 3/10
100/100 [==============================] - 11s 111ms/step - loss: 0.2687 - accuracy: 0.8883 - val_loss: 0.2183 - val_accuracy: 0.9165
Epoch 4/10
100/100 [==============================] - 11s 110ms/step - loss: 0.2171 - accuracy: 0.9128 - val_loss: 0.1811 - val_accuracy: 0.9280
Epoch 5/10
100/100 [==============================] - 11s 114ms/step - loss: 0.1860 - accuracy: 0.9268 - val_loss: 0.1798 - val_accuracy: 0.9265
Epoch 6/10
100/100 [==============================] - 11s 112ms/step - loss: 0.1646 - accuracy: 0.9358 - val_loss: 0.1818 - val_accuracy: 0.9260
Epoch 7/10
100/100 [==============================] - 11s 113ms/step - loss: 0.1443 - accuracy: 0.9426 - val_loss: 0.1740 - val_accuracy: 0.9290
Epoch 8/10
100/100 [==============================] - 11s 113ms/step - loss: 0.1301 - accuracy: 0.9469 - val_loss: 0.1635 - val_accuracy: 0.9325
Epoch 9/10
100/100 [==============================] - 11s 112ms/step - loss: 0.1096 - accuracy: 0.9585 - val_loss: 0.1758 - val_accuracy: 0.9315
Epoch 10/10
100/100 [==============================] - 11s 113ms/step - loss: 0.0961 - accuracy: 0.9628 - val_loss: 0.1595 - val_accuracy: 0.9415

4、评估模型

# %load_ext tensorboard
# %tensorboard --logdir ./data/keras_model
from tensorboard import notebook
notebook.list()
# 在tensorboard中查看模型
notebook.start("--logdir ./data/keras_model")

或者我们自己绘图:首先我们构造数据

import pandas as pd
dfhistory = pd.DataFrame(history.history)
dfhistory.index = range(1,len(dfhistory) + 1)
dfhistory.index.name = 'epoch'
dfhistory

然后绘制:

%matplotlib inline
%config InlineBackend.figure_format = 'svg' import matplotlib.pyplot as plt def plot_metric(history, metric):
train_metrics = history.history[metric]
val_metrics = history.history['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
plot_metric(history,"loss")
plot_metric(history,"accuracy")

评估模型:

# 可以使用evaluate对数据进行评估
val_loss,val_accuracy = model.evaluate(ds_test,workers=4)
print(val_loss,val_accuracy)

20/20 [==============================] - 2s 80ms/step - loss: 0.1595 - accuracy: 0.9415

0.15954092144966125 0.9415000081062317

5、使用模型

可以使用model.predict(ds_test)进行预测。

也可以使用model.predict_on_batch(x_test)对一个批量进行预测。

model.predict(ds_test)
array([[1.1052408e-01],
[3.4282297e-02],
[2.7046111e-04],
...,
[2.7544077e-03],
[3.4654222e-04],
[9.9993896e-01]], dtype=float32)
for x,y in ds_test.take(1):
print(model.predict_on_batch(x[0:20]))
[[9.8728174e-01]
[2.0267103e-02]
[9.0806475e-03]
[9.9996555e-01]
[4.5376007e-02]
[1.2818890e-03]
[1.8698535e-03]
[2.2900696e-03]
[8.6169255e-01]
[6.2768459e-06]
[1.2383183e-02]
[4.3949869e-02]
[7.9778886e-01]
[9.9822074e-01]
[9.9993134e-01]
[8.6685091e-02]
[3.7480664e-02]
[9.9652690e-01]
[9.2210865e-01]
[1.6160560e-03]]

6、保存模型

推荐使用TensorFlow原生方式保存模型。

# 保存权重,该方式仅仅保存权重张量
model.save_weights('./data/tf_model_weights.ckpt',save_format = "tf")
# 保存模型结构与模型参数到文件,该方式保存的模型具有跨平台性便于部署 model.save('./data/tf_model_savedmodel', save_format="tf")
print('export saved model.') model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel')
model_loaded.evaluate(ds_test)

参考:

开源电子书地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/

GitHub 项目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days

【tensorflow2.0】处理图片数据-cifar2分类的更多相关文章

  1. 【tensorflow2.0】数据管道dataset

    如果需要训练的数据大小不大,例如不到1G,那么可以直接全部读入内存中进行训练,这样一般效率最高. 但如果需要训练的数据很大,例如超过10G,无法一次载入内存,那么通常需要在训练的过程中分批逐渐读入. ...

  2. colab上基于tensorflow2.0的BERT中文多分类

    bert模型在tensorflow1.x版本时,也是先发布的命令行版本,随后又发布了bert-tensorflow包,本质上就是把相关bert实现封装起来了. tensorflow2.0刚刚在2019 ...

  3. 【tensorflow2.0】处理结构化数据-titanic生存预测

    1.准备数据 import numpy as np import pandas as pd import matplotlib.pyplot as plt import tensorflow as t ...

  4. 【tensorflow2.0】处理时间序列数据

    国内的新冠肺炎疫情从发现至今已经持续3个多月了,这场起源于吃野味的灾难给大家的生活造成了诸多方面的影响. 有的同学是收入上的,有的同学是感情上的,有的同学是心理上的,还有的同学是体重上的. 那么国内的 ...

  5. [TensorFlow2.0]-手写神经网络实现鸢尾花分类

    本人人工智能初学者,现在在学习TensorFlow2.0,对一些学习内容做一下笔记.笔记中,有些内容理解可能较为肤浅.有偏差等,各位在阅读时如有发现问题,请评论或者邮箱(右侧边栏有邮箱地址)提醒. 若 ...

  6. Google工程师亲授 Tensorflow2.0-入门到进阶

    第1章 Tensorfow简介与环境搭建 本门课程的入门章节,简要介绍了tensorflow是什么,详细介绍了Tensorflow历史版本变迁以及tensorflow的架构和强大特性.并在Tensor ...

  7. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  8. 一文上手Tensorflow2.0之tf.keras(三)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  9. tensorflow2.0学习笔记

    今天我们开始学习tensorflow2.0,用一种简单和循循渐进的方式,带领大家亲身体验深度学习.学习的目录如下图所示: 1.简单的神经网络学习过程 1.1张量生成 1.2常用函数 1.3鸢尾花数据读 ...

随机推荐

  1. 【线上测试之后的应用】基于MySQL+MHA+Haproxy构建高可用负载均衡数据库集群(详解)

    这里我们先介绍一下MHA是什么,其次就是它的应用与测试,同时为了大家呈现了数据备份案例,最后总结了使用情况以及注意事项和解决办法 一.MHA 概述 MHA(Master High Availabili ...

  2. sf-git机制

    为什么要专门写一篇关于sf科技公司的GIT管理机制呢?因为本周经历了两天的学习和考试,刚开始没在意,因为之前公司也用的GIT,所以没怎么看视频,就看了文档,练习考试时候才发现并非以前的那种git流程, ...

  3. 分布式图数据库 Nebula Graph 的 Index 实践

    导读 索引是数据库系统中不可或缺的一个功能,数据库索引好比是书的目录,能加快数据库的查询速度,其实质是数据库管理系统中一个排序的数据结构.不同的数据库系统有不同的排序结构,目前常见的索引实现类型如 B ...

  4. 智慧港口——基于二三维一体化GIS的港口可视化监管平台

    “智慧港口”是以现代化基础设施设备为基础,以云计算.大数据.物联网.移动互联网.智能控制等新一代信息技术与港口运输业务深度融合为核心,以港口运输组织服务创新为动力,以完善的体制机制.法律法规.标准规范 ...

  5. TP5使用Redis处理电商秒杀

    本篇文章介绍了ThinkPHP使用Redis实现电商秒杀的处理方法,具有一定的参考价值,希望对学习ThinkPHP的朋友有帮助! TP5使用Redis处理电商秒杀 1.首先在TP5中创建抢购活动所需要 ...

  6. DIY 作品 及 维修 不定时更新

    手机电池DIY充电宝 2块,优质手机电池加一个升压ic ,焊上一个 usb 母头.比买的强多了. 还能调压,最高调到24V 可以带白光焊台. 更换手机 SIM/SD 卡二合一 贴上高温胶带,吹下来. ...

  7. [Python] iupdatable包:日志模块使用介绍

    一.说明 日志模块是对 logging 模块的单例封装 特点: 可同时向控制台和文件输出日志,并可选择关闭其中一种方式的输出: 集成colorlog,实现根据日志等级不同,控制台输出日志颜色不同: 灵 ...

  8. webpack的require.context()实现路由“去中心化”管理

    最近在开发一个大型vue单页面应用的时候,项目最初是将所有的路由写在一个router.js的文件里. const router = new Router({ mode: "history&q ...

  9. vs2017 dlib19.3 opencv3.41 C++ 环境配置 人脸特征点识别

    身为一个.net程序员经过两天的采坑终于把人脸特征检测的项目跑通了,然后本文将以dlib项目中人脸特征检测工程为例,讲解dlib与opencv 在vs2017 C++ 项目中的编译与运行路径配置. 1 ...

  10. SOFARPC模式下的Consul注册中心

    Consul大家不陌生,就是和Zookeeper.Nacos一伙的,能够作为微服务基础架构的注册中心,算是比较成熟的组件,和Springcloud集成顺滑, 考虑到Eureka已经停止更新,所以有必要 ...