数据的读取

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator class TransferModel(object): def __init__(self):
#标准化和数据增强
self.train_generator = ImageDataGenerator(rescale=1.0/255.0)
self.test_generator = ImageDataGenerator(rescale=1.0/255.0)
#指定训练集数据和测试集数据目录
self.train_dir = "./data/train"
self.test_dir = "./data/test"
self.image_size = (224,224)
self.batch_size = 32 def get_loacl_data(self):
'''
读取本地的图片数据以及类别
:return:
'''
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True) return train_gen,test_gen if __name__ == '__main__':
tm = TransferModel()
train_gen,test_gen = tm.get_loacl_data()
print(train_gen)

  迁移学习完整代码

import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
import numpy as np class TransferModel(object): def __init__(self): # 定义训练和测试图片的变化方法,标准化以及数据增强
self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0) # 指定训练数据和测试数据的目录
self.train_dir = "./data/train"
self.test_dir = "./data/test" # 定义图片训练相关网络参数
self.image_size = (224, 224)
self.batch_size = 32 # 定义迁移学习的基类模型
# 不包含VGG当中3个全连接层的模型加载并且加载了参数
# vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
self.base_model = VGG16(weights='imagenet', include_top=False) self.label_dict = {
'0': '汽车',
'1': '恐龙',
'2': '大象',
'3': '花',
'4': '马'
} def get_local_data(self):
"""
读取本地的图片数据以及类别
:return: 训练数据和测试数据迭代器
"""
# 使用flow_from_derectory
train_gen = self.train_generator.flow_from_directory(self.train_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
test_gen = self.test_generator.flow_from_directory(self.test_dir,
target_size=self.image_size,
batch_size=self.batch_size,
class_mode='binary',
shuffle=True)
return train_gen, test_gen def refine_base_model(self):
"""
微调VGG结构,5blocks后面+全局平均池化(减少迁移学习的参数数量)+两个全连接层
:return:
"""
# 1、获取原notop模型得出
# [?, ?, ?, 512]
x = self.base_model.outputs[0] # 2、在输出后面增加我们结构
# [?, ?, ?, 512]---->[?, 1 * 1 * 512]
x = keras.layers.GlobalAveragePooling2D()(x) # 3、定义新的迁移模型
x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
y_predict = keras.layers.Dense(5, activation=tf.nn.softmax)(x) # model定义新模型
# VGG 模型的输入, 输出:y_predict
transfer_model = keras.models.Model(inputs=self.base_model.inputs, outputs=y_predict) return transfer_model def freeze_model(self):
"""
冻结VGG模型(5blocks)
冻结VGG的多少,根据你的数据量
:return:
"""
# self.base_model.layers 获取所有层,返回层的列表
for layer in self.base_model.layers:
layer.trainable = False def compile(self, model):
"""
编译模型
:return:
"""
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return None def fit_generator(self, model, train_gen, test_gen):
"""
训练模型,model.fit_generator()不是选择model.fit()
:return:
"""
# 每一次迭代准确率记录的h5文件
modelckpt = keras.callbacks.ModelCheckpoint('./ckpt/transfer_{epoch:02d}-{val_acc:.2f}.h5',
monitor='val_acc',
save_weights_only=True,
save_best_only=True,
mode='auto',
period=1) model.fit_generator(train_gen, epochs=3, validation_data=test_gen, callbacks=[modelckpt]) return None def predict(self, model):
"""
预测类别
:return:
""" # 加载模型,transfer_model
model.load_weights("./ckpt/transfer_02-0.93.h5") # 读取图片,处理
image = load_img("./1.jpg", target_size=(224, 224))
image.show()
image = img_to_array(image)
# print(image.shape)
# 四维(224, 224, 3)---》(1, 224, 224, 3)
img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
# print(img)
# model.predict() # 预测结果进行处理
image = preprocess_input(img)
predictions = model.predict(image)
print(predictions)
res = np.argmax(predictions, axis=1)
print("所预测的类别是:",self.label_dict[str(res[0])]) if __name__ == '__main__':
tm = TransferModel()
# 训练
# train_gen, test_gen = tm.get_local_data()
# # print(train_gen)
# # for data in train_gen:
# # print(data[0].shape, data[1].shape)
# # print(tm.base_model.summary())
# model = tm.refine_base_model()
# # print(model)
# tm.freeze_model()
# tm.compile(model)
#
# tm.fit_generator(model, train_gen, test_gen) # 测试
model = tm.refine_base_model() tm.predict(model)

  

TensorFlow keras 迁移学习的更多相关文章

  1. 『TensorFlow』迁移学习

    完全版见github:TransforLearning 零.迁移学习 将一个领域的已经成熟的知识应用到其他的场景中称为迁移学习.用神经网络的角度来表述,就是一层层网络中每个节点的权重从一个训练好的网络 ...

  2. tensorflow实现迁移学习

    此例程出自<TensorFlow实战Google深度学习框架>6.5.2小节 卷积神经网络迁移学习. 数据集来自http://download.tensorflow.org/example ...

  3. 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  4. ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)

    ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...

  5. TensorFlow从1到2(九)迁移学习

    迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...

  6. 深度学习应用系列(二) | 如何使用keras进行迁移学习,以训练和识别自己的图片集

    本文的keras后台为tensorflow,介绍如何利用预编译的模型进行迁移学习,以训练和识别自己的图片集. 官网 https://keras.io/applications/ 已经介绍了各个基于Im ...

  7. 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)

    一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...

  8. 常用深度学习框——Caffe/ TensorFlow / Keras/ PyTorch/MXNet

    常用深度学习框--Caffe/ TensorFlow / Keras/ PyTorch/MXNet 一.概述 近几年来,深度学习的研究和应用的热潮持续高涨,各种开源深度学习框架层出不穷,包括Tenso ...

  9. 用tensorflow迁移学习猫狗分类

    笔者这几天在跟着莫烦学习TensorFlow,正好到迁移学习(至于什么是迁移学习,看这篇),莫烦老师做的是预测猫和老虎尺寸大小的学习.作为一个有为的学生,笔者当然不能再预测猫啊狗啊的大小啦,正好之前正 ...

随机推荐

  1. java面试基础篇-List

    一.ArrayList: 底层为数组实现,线程不安全,查询,修改快,增加删除慢, 数据结构:数组以0为下标依次连续进行存储 数组查询元素:根据下标查询就行 数组增加元素:如果需要给index为10的位 ...

  2. spring 事务源码赏析(一)

    在本系列中,我们会分析:1.spring是如何开启事务的.2.spring是如何在不影响业务代码的情况下织入事务逻辑的.3.spirng事务是如何找到相应的的业务代码的.4.spring事务的传播行为 ...

  3. WEB缓存系统之varnish状态引擎

    前文我们聊了下varnish的VCL配置以及语法特点,怎样去编译加载varnish的vcl配置,以及命令行管理工具varnishadm怎么去连接varnish管理接口进行管理varnish,回顾请参考 ...

  4. MySQL手工注入进阶篇——突破过滤危险字符问题

    当我们在进行手工注入时,有时候会发现咱们构造的危险字符被过滤了,接下来,我就教大家如何解决这个问题.下面是我的实战过程.这里使用的是墨者学院的在线靶场.咱们直接开始. 第一步,判断注入点. 通过测试发 ...

  5. MATLAB实现一个EKF-2D-SLAM(已开源)

    1. SLAM问题定义 同时定位与建图(SLAM)的本质是一个估计问题,它要求移动机器人利用传感器信息实时地对外界环境结构进行估计,并且估算出自己在这个环境中的位置,Smith 和Cheeseman在 ...

  6. STM32F103ZET6串口通信

    1.电平标准 根据通讯使用的电平标准不同,串口通讯可分为TTL标准和RS-232标准,如下表: 从图中可以看到,TTL电平标准使用5V表示高电平,使用0V表示低电平.在R232电平标准中,为了增加串口 ...

  7. vscode 保存自动 格式化eslint 代码

    在网上找了很多种方法,大多都没有成功  一下是一种成功的 配置方法: 1) First, you need to install the ESLint extension in the VS code ...

  8. Halo博客的搭建

    今日主题:搭建一个私人博客 好多朋友和我说,能不能弄一个简单的私人博客啊,我说行吧,今天给你们一份福利啦! 搭建一个私人博客,就可以在自己的电脑上写博客了 Halo Halo 是一款现代化的个人独立博 ...

  9. Java内存可见性volatile

    概述 JMM规范指出,每一个线程都有自己的工作内存(working memory),当变量的值发生变化时,先更新自己的工作内存,然后再拷贝到主存(main memory),这样其他线程就能读取到更新后 ...

  10. github的学习使用以及将自己开发的app传上去。

    主要参考的网址如下: https://www.cnblogs.com/sdcs/p/8270029.html https://www.cnblogs.com/sjhsszl/p/8708471.htm ...