[深度学习]Keras利用VGG进行迁移学习模板
# -*- coding: UTF-8 -*-
import keras
from keras import Model
from keras.applications import VGG16
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D
from keras.models import load_model
from keras.preprocessing import image
from PIL import ImageFile
import numpy as np
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
ImageFile.LOAD_TRUNCATED_IMAGES = True
EPOCHS = 30
BATCH_SIZE = 16
DATA_TRAIN_PATH = 'D:/data/train'
def Train():
#-------------准备数据--------------------------
#数据集目录应该是 train/LabelA/1.jpg train/LabelB/1.jpg这样
gen = ImageDataGenerator(rescale=1. / 255)
train_generator = gen.flow_from_directory(DATA_TRAIN_PATH, (224,224)), shuffle=False,
batch_size=BATCH_SIZE, class_mode='categorical')
#-------------加载VGG模型并且添加自己的层----------------------
#这里自己添加的层需要不断调整超参数来提升结果,输出类别更改softmax层即可
#参数说明:inlucde_top:是否包含最上方的Dense层,input_shape:输入的图像大小(width,height,channel)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x=Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1, activation='sigmoid')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions)
#-----------控制需要FineTune的层数,不FineTune的就直接冻结
for layer in base_model.layers:
layer.trainable = False
#----------编译,设置优化器,损失函数,性能指标
model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy'])
#----------设置tensorboard,用来观察acc和loss的曲线---------------
tbCallBack = TensorBoard(log_dir='./logs/' + TIMESTAMP, # log 目录
histogram_freq=0, # 按照何等频率(epoch)来计算直方图,0为不计算
batch_size=16, # 用多大量的数据计算直方图
write_graph=True, # 是否存储网络结构图
write_grads=True, # 是否可视化梯度直方图
write_images=True, # 是否可视化参数
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None)
#---------设置自动保存点,acc最好的时候就会自动保存一次,会覆盖之前的存档---------------
checkpoint = ModelCheckpoint(filepath='HatNewModel.h5', monitor='acc', mode='auto', save_best_only='True')
#----------开始训练---------------------------------------------
model.fit_generator(generator=train_generator,
epochs=EPOCHS,
callbacks=[tbCallBack,checkpoint],
verbose=2
)
#-------------预测单个图像--------------------------------------
def Predict(imgPath):
model = load_model(SAVE_MODEL_NAME)
img = image.load_img(imgPath, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
res = model.predict(x)
print(np.argmax(res, axis=1)[0])
以上运行环境:
Keras2.1.4
Tensorflow-gpu 1.5
CUDA9.0
cudnn7.0
python3.5
[深度学习]Keras利用VGG进行迁移学习模板的更多相关文章
- 深度学习原理与框架-Alexnet(迁移学习代码) 1.sys.argv[1:](控制台输入的参数获取第二个参数开始) 2.tf.split(对数据进行切分操作) 3.tf.concat(对数据进行合并操作) 4.tf.variable_scope(指定w的使用范围) 5.tf.get_variable(构造和获得参数) 6.np.load(加载.npy文件)
1. sys.argv[1:] # 在控制台进行参数的输入时,只使用第二个参数以后的数据 参数说明:控制台的输入:python test.py what, 使用sys.argv[1:],那么将获得w ...
- 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习
import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...
- 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)
基于深度学习和迁移学习的识花实践(转) 深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...
- 【深度学习系列】迁移学习Transfer Learning
在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...
- 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)
一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...
- TensorFlow迁移学习的识别花试验
最近学习了TensorFlow,发现一个模型叫vgg16,然后搭建环境跑了一下,觉得十分神奇,而且准确率十分的高.又上了一节选修课,关于人工智能,老师让做一个关于人工智能的试验,于是觉得vgg16很不 ...
- 迁移学习(Transformer),面试看这些就够了!(附代码)
1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...
- PyTorch专栏(五):迁移学习
专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60分钟入门 PyTorch入门 PyTorch自动微分 PyTorch神经网络 P ...
- 迁移学习( Transfer Learning )
在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型:然后利用这个学习到的模型来对测试文档进行分类与预测.然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关 ...
- 【迁移学习】2010-A Survey on Transfer Learning
资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...
随机推荐
- JS---HelloWorld
1.功能效果图 2.代码实现 <!DOCTYPE html> <html> <head> <meta charset="utf-8"> ...
- 如何用Virtualbox搭建一个虚拟机
序言 各位好啊,我是会编程的蜗牛,作为java开发者,我们肯定会接触Linux服务器,除了使用云服务搭建Linux服务器外,我们一般也可以在自己的电脑上安装虚拟机来搭建Linux服务器用于各种功能的验 ...
- centos7 安装RabbitMQ3.6.15 以及各种报错
成功图镇楼 各个版本之间的差异不大,安装前要确保rabbitmq 的版本和 elang的版本一致.预防各种错乱. 注意点:(重要!!重要!!) * 同时安装的时候最好确保rabbitmq和erlang ...
- 我要手撕mybatis源码
传统的JDBC编程中的一般操作: 1.注册数据库驱动类,指定数据库的URL地址.数据库用户名.密码等连接信息 2.通过DriverManager打开数据库连接 3.通过数据库连接创建Statement ...
- 8.websocket slef概念
self代表当前用户客户端与服务端的连接对象,比如两客户端发来了两个连接,我们可以把两个连接放在一起 # 定义全局变量 CONN_List = [] class LiveConsumer(Websoc ...
- CF Round #829 题解 (Div. 2)
F 没看所以摆了 . 看拜月教教主 LHQ 在群里代打恰钱 /bx 目录 A. Technical Support (*800) B. Kevin and Permutation (*800) C. ...
- 变量的复制&传递
变量的复制 变量的类型 可以分为基本数据类型(Null.Undefined.Number.String.Boolean)和引用类型(Funtion.Object.Array) 基本数据类型是按照值访问 ...
- Debian11管理员手册
1 用户与群组数据库 用户清单通常保存在 /etc/passwd 文件内,把哈希编码后的密码保存在 /etc/shadow 文件内.这两个文件都是纯文本档,以简单的格式保存,可以用文本编辑器读取与修改 ...
- java反序列化漏洞cc_link_one
CC-LINK-one 前言 这里也正式进入的java的反序列化漏洞了,简单介绍一下CC是什么借用一些官方的解释:Apache Commons是Apache软件基金会的项目,曾经隶属于Jakarta项 ...
- Seata 1.5.2 源码学习(Client端)
在上一篇中通过阅读Seata服务端的代码,我们了解到TC是如何处理来自客户端的请求的,今天这一篇一起来了解一下客户端是如何处理TC发过来的请求的.要想搞清楚这一点,还得从GlobalTransacti ...