[深度学习]Keras利用VGG进行迁移学习模板
# -*- coding: UTF-8 -*-
import keras
from keras import Model
from keras.applications import VGG16
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D
from keras.models import load_model
from keras.preprocessing import image
from PIL import ImageFile
import numpy as np
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
ImageFile.LOAD_TRUNCATED_IMAGES = True
EPOCHS = 30
BATCH_SIZE = 16
DATA_TRAIN_PATH = 'D:/data/train'
def Train():
#-------------准备数据--------------------------
#数据集目录应该是 train/LabelA/1.jpg train/LabelB/1.jpg这样
gen = ImageDataGenerator(rescale=1. / 255)
train_generator = gen.flow_from_directory(DATA_TRAIN_PATH, (224,224)), shuffle=False,
batch_size=BATCH_SIZE, class_mode='categorical')
#-------------加载VGG模型并且添加自己的层----------------------
#这里自己添加的层需要不断调整超参数来提升结果,输出类别更改softmax层即可
#参数说明:inlucde_top:是否包含最上方的Dense层,input_shape:输入的图像大小(width,height,channel)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x=Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1, activation='sigmoid')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions)
#-----------控制需要FineTune的层数,不FineTune的就直接冻结
for layer in base_model.layers:
layer.trainable = False
#----------编译,设置优化器,损失函数,性能指标
model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy'])
#----------设置tensorboard,用来观察acc和loss的曲线---------------
tbCallBack = TensorBoard(log_dir='./logs/' + TIMESTAMP, # log 目录
histogram_freq=0, # 按照何等频率(epoch)来计算直方图,0为不计算
batch_size=16, # 用多大量的数据计算直方图
write_graph=True, # 是否存储网络结构图
write_grads=True, # 是否可视化梯度直方图
write_images=True, # 是否可视化参数
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None)
#---------设置自动保存点,acc最好的时候就会自动保存一次,会覆盖之前的存档---------------
checkpoint = ModelCheckpoint(filepath='HatNewModel.h5', monitor='acc', mode='auto', save_best_only='True')
#----------开始训练---------------------------------------------
model.fit_generator(generator=train_generator,
epochs=EPOCHS,
callbacks=[tbCallBack,checkpoint],
verbose=2
)
#-------------预测单个图像--------------------------------------
def Predict(imgPath):
model = load_model(SAVE_MODEL_NAME)
img = image.load_img(imgPath, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
res = model.predict(x)
print(np.argmax(res, axis=1)[0])
以上运行环境:
Keras2.1.4
Tensorflow-gpu 1.5
CUDA9.0
cudnn7.0
python3.5
[深度学习]Keras利用VGG进行迁移学习模板的更多相关文章
- 深度学习原理与框架-Alexnet(迁移学习代码) 1.sys.argv[1:](控制台输入的参数获取第二个参数开始) 2.tf.split(对数据进行切分操作) 3.tf.concat(对数据进行合并操作) 4.tf.variable_scope(指定w的使用范围) 5.tf.get_variable(构造和获得参数) 6.np.load(加载.npy文件)
1. sys.argv[1:] # 在控制台进行参数的输入时,只使用第二个参数以后的数据 参数说明:控制台的输入:python test.py what, 使用sys.argv[1:],那么将获得w ...
- 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习
import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...
- 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)
基于深度学习和迁移学习的识花实践(转) 深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...
- 【深度学习系列】迁移学习Transfer Learning
在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...
- 深度学习趣谈:什么是迁移学习?(附带Tensorflow代码实现)
一.迁移学习的概念 什么是迁移学习呢?迁移学习可以由下面的这张图来表示: 这张图最左边表示了迁移学习也就是把已经训练好的模型和权重直接纳入到新的数据集当中进行训练,但是我们只改变之前模型的分类器(全连 ...
- TensorFlow迁移学习的识别花试验
最近学习了TensorFlow,发现一个模型叫vgg16,然后搭建环境跑了一下,觉得十分神奇,而且准确率十分的高.又上了一节选修课,关于人工智能,老师让做一个关于人工智能的试验,于是觉得vgg16很不 ...
- 迁移学习(Transformer),面试看这些就够了!(附代码)
1. 什么是迁移学习 迁移学习(Transformer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相 ...
- PyTorch专栏(五):迁移学习
专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60分钟入门 PyTorch入门 PyTorch自动微分 PyTorch神经网络 P ...
- 迁移学习( Transfer Learning )
在传统的机器学习的框架下,学习的任务就是在给定充分训练数据的基础上来学习一个分类模型:然后利用这个学习到的模型来对测试文档进行分类与预测.然而,我们看到机器学习算法在当前的Web挖掘研究中存在着一个关 ...
- 【迁移学习】2010-A Survey on Transfer Learning
资源:http://www.cse.ust.hk/TL/ 简介: 一个例子: 关于照片的情感分析. 源:比如你之前已经搜集了大量N种类型物品的图片进行了大量的人工标记(label),耗费了巨大的人力物 ...
随机推荐
- sql面试50题------(1-10)
文章目录 1.查询课程编号'01'比课程编号'02'成绩高的所有学生学号 2.查询平均成绩大于60分得学生的学号和平均成绩 3.查询所有学生的学号,姓名,选课数,总成绩 4.查询姓"猴&qu ...
- 齐博x1如何开启自定义标签模板功能
为安全起见,同时也为了避免用户随意添加风格导致默认模板不协调,系统默认关闭了类似V系列的自定义修改模板功能.如下图所示,默认是关闭的 你如果需要启用的话,把下面的代码,参考下图导进去后,就可以增加一个 ...
- 39.BasicAuthentication认证
BasicAuthentication认证介绍 BasicAuthentication使用HTTP基本的认证机制 通过用户名/密码的方式验证,通常用于测试工作,尽量不要线上使用 用户名和密码必须在HT ...
- 【原创】在RT1050 LittleVgl GUI中嵌入中文输入法框架
时隔一年多终于又冒泡了,哎,随着工作越来越忙,自己踏实坐下来写点东西真是越来越费劲,这篇文章也是准备了好久好久才打算发表出来(不瞒大家,东西做完好久了,文章憋了一年了,当真"高产" ...
- 空链接的作用以及<a href="#"></a>和<a href="javascript:;"></a>的区别
空链接的作用以及<a href="#"></a>和<a href="javascript:;"></a>的区别在 ...
- Nginx配置-1
1.绑定nginx到指定cpu [root@nginx conf.d]# vim /apps/nginx/conf/nginx.conf worker_processes 2; worker_cpu_ ...
- Azure DevOps Server 入门实践与安装部署
一,引言 最近一段时间,公司希望在自己的服务器上安装本地版的 Azure DevOps Service(Azure DevOps Server),用于项目内的测试,学习.本着学习的目的,我也就开始学习 ...
- Rock18框架之整体框架介绍
1. 总体框架图 2.框架能解决哪些问题? 问题1: 自动化设备包含龙门架.机械手.伺服.步进等电机.IO控制.定位及纠偏.界面展示等部分.其中硬件(伺服.IO等)是需要更换的,硬件的更换不影响整套系 ...
- jmeter时间戳
时间戳这东西,在jmeter中会经常用到,自己又总是记不住,做个记录. jmeter自带的时间戳函数: ① ${__time(yyyy-MM-dd,)} ,对应时间示例:2022-09-24 ② $ ...
- Codeforces Round #786 (Div. 3) 补题记录
小结: A,B,F 切,C 没写 1ll 对照样例才发现,E,G 对照样例过,D 对照样例+看了其他人代码(主要急于看后面的题,能调出来的但偷懒了. CF1674A Number Transforma ...