基础_模型迁移_CBIR_augmentation
在之前我们做过这样的研究:5图分类CBIR问题
import numpy as np
from keras.datasets import mnist
import gc
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.applications.vgg16 import VGG16
from keras.optimizers import SGD
from keras.utils.data_utils import get_file
import cv2
import h5py as h5py
import numpy as np
import os
import math
from matplotlib import pyplot as plt
#全局变量
RATIO = 0.2
train_dir = 'D:/dl4cv/datesets/littleCBIR/'
#根据分类总数确定one-hot总类
NUM_DENSE = 5
#训练总数
epochs = 10
def tran_y(y):
y_ohe = np.zeros(NUM_DENSE)
y_ohe[y] = 1
return y_ohe
#根据Ratio获得训练和测试数据集的图片地址和标签
##生成数据集,本例先验3**汽车、4**恐龙、5**大象、6**花、7**马
def get_files(file_dir, ratio):
'''
Args:
file_dir: file directory
Returns:
list of images and labels
'''
image_list = []
label_list = []
for file in os.listdir(file_dir):
if file[0:1]=='3':
image_list.append(file_dir + file)
label_list.append(0)
elif file[0:1]=='4':
image_list.append(file_dir + file)
label_list.append(1)
elif file[0:1]=='5':
image_list.append(file_dir + file)
label_list.append(2)
elif file[0:1]=='6':
image_list.append(file_dir + file)
label_list.append(3)
else:
image_list.append(file_dir + file)
label_list.append(4)
print('数据集导入完毕')
#图片list和标签list
#hstack 水平(按列顺序)把数组给堆叠起来
image_list = np.hstack(image_list)
label_list = np.hstack(label_list)
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp)
all_image_list = temp[:, 0]
all_label_list = temp[:, 1]
n_sample = len(all_label_list)
#根据比率,确定训练和测试数量
n_val = math.ceil(n_sample*ratio) # number of validation samples
n_train = n_sample - n_val # number of trainning samples
tra_images = []
val_images = []
#按照0-n_train为tra_images,后面位val_images的方式来排序
for index in range(n_train):
image = cv2.imread(all_image_list[index])
#灰度,然后缩放
image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
image = cv2.resize(image,(48,48))#到底在这个地方修改,还是在后面修改,需要做具体实验
tra_images.append(image)
tra_labels = all_label_list[:n_train]
tra_labels = [int(float(i)) for i in tra_labels]
for index in range(n_val):
image = cv2.imread(all_image_list[n_train+index])
#灰度,然后缩放
image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
image = cv2.resize(image,(32,32))
val_images.append(image)
val_labels = all_label_list[n_train:]
val_labels = [int(float(i)) for i in val_labels]
return np.array(tra_images),np.array(tra_labels),np.array(val_images),np.array(val_labels)
# colab+VGG要求至少48像素在现有数据集上,已经能够完成不错情况
ishape=48
#(X_train, y_train), (X_test, y_test) = mnist.load_data()
#获得数据集
#X_train, y_train, X_test, y_test = get_files(train_dir, RATIO)
#保持数据
##np.savez("D:\\dl4cv\\datesets\\littleCBIR.npz",X_train=X_train,y_train=y_train,X_test=X_test,y_test=y_test)
#读取数据
path='littleCBIR.npz'
#https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz
path = get_file(path,origin='https://github.com/jsxyhelu/GOCW/raw/master/littleCBIR.npz')
f = np.load(path)
X_train, y_train = f['X_train'], f['y_train']
X_test, y_test = f['X_test'], f['y_test']
X_train = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_train]
X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32')
X_train /= 255.0
X_test = [cv2.cvtColor(cv2.resize(i, (ishape, ishape)), cv2.COLOR_GRAY2BGR) for i in X_test]
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32')
X_test /= 255.0
y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))])
y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))])
y_train_ohe = y_train_ohe.astype('float32')
y_test_ohe = y_test_ohe.astype('float32')
model_vgg = VGG16(include_top = False, weights = 'imagenet', input_shape = (ishape, ishape, 3))
#for i, layer in enumerate(model_vgg.layers):
# if i<20:
for layer in model_vgg.layers:
layer.trainable = False
model = Flatten()(model_vgg.output)
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(NUM_DENSE, activation = 'softmax', name='prediction')(model)
model_vgg_pretrain = Model(model_vgg.input, model, name = 'vgg16_pretrain')
#model_vgg_pretrain.summary()
print("vgg准备完毕\n")
sgd = SGD(lr = 0.05, decay = 1e-5)
model_vgg_pretrain.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
print("vgg开始训练\n")
log = model_vgg_pretrain.fit(X_train, y_train_ohe, validation_data = (X_test, y_test_ohe), epochs = epochs, batch_size = 64)
score = model_vgg_pretrain.evaluate(X_test, y_test_ohe, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
plt.figure('acc')
plt.subplot(2, 1, 1)
plt.plot(log.history['acc'],'r--',label='Training Accuracy')
plt.plot(log.history['val_acc'],'r-',label='Validation Accuracy')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.axis([0, epochs, 0.5, 1])
plt.figure('loss')
plt.subplot(2, 1, 2)
plt.plot(log.history['loss'],'b--',label='Training Loss')
plt.plot(log.history['val_loss'],'b-',label='Validation Loss')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.axis([0, epochs, 0, 1])
plt.show()
os.system("pause")
log = model_vgg_pretrain.fit_generator(img_generator.flow(X_train,y_train_ohe, batch_size= 128), steps_per_epoch = 400, epochs=10,validation_data=(X_test, y_test_ohe),workers=4)
# Install the PyDrive wrapper & import libraries.
# This only needs to be done once in a notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once in a notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
# Create & upload a text file.
uploaded = drive.CreateFile()
uploaded.SetContentFile('5type4cbirMODEL.h5')
uploaded.Upload()
print('Uploaded file with ID {}'.format(uploaded.get('id')))

# Install the PyDrive wrapper & import libraries.
# This only needs to be done once per notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
#根据文件名进行下载
file_id = '1qjxAm_QiXdSqBmyIoPl3bfnyLNJxwKo9'
downloaded = drive.CreateFile({'id': file_id})
print('Downloaded content "{}"'.format(downloaded.GetContentString()))
附件列表
基础_模型迁移_CBIR_augmentation的更多相关文章
- 使用 Azure PowerShell 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager
以下步骤演示了如何使用 Azure PowerShell 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 也可根据需要通过 Az ...
- 老李分享: 并行计算基础&编程模型与工具 1
老李分享: 并行计算基础&编程模型与工具 在当前计算机应用中,对高速并行计算的需求是广泛的,归纳起来,主要有三种类型的应用需求: 计算密集(Computer-Intensive)型应用,如 ...
- 算法基础_递归_求杨辉三角第m行第n个数字
问题描述: 算法基础_递归_求杨辉三角第m行第n个数字(m,n都从0开始) 解题源代码(这里打印出的是杨辉三角某一层的所有数字,没用大数,所以有上限,这里只写基本逻辑,要符合题意的话,把循环去掉就好) ...
- 规划将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager
尽管 Azure 资源管理器提供了许多精彩功能,但请务必计划迁移,以确保一切顺利进行. 花时间进行规划可确保执行迁移活动时不会遇到问题. Note 以下指导的主要参与者为 Azure 客户顾问团队,以 ...
- 有关从经典部署模型迁移到 Azure Resource Manager 部署模型的常见问题
此迁移计划是否影响 Azure 虚拟机上运行的任何现有服务或应用程序? 不可以. VM(经典)是公开上市的完全受支持的服务. 你可以继续使用这些资源来拓展你在 Azure 上的足迹. 如果我近期不打算 ...
- 桥接模式_NAT模式_仅主机模式_模型图.ziw
2017年1月12日, 星期四 桥接模式_NAT模式_仅主机模式_模型图 null
- 使用 Azure CLI 将 IaaS 资源从经典部署模型迁移到 Azure Resource Manager 部署模型
以下步骤演示如何使用 Azure 命令行接口 (CLI) 命令将基础结构即服务 (IaaS) 资源从经典部署模型迁移到 Azure Resource Manager 部署模型. 本文中的操作需要 Az ...
- Flutter实战视频-移动电商-05.Dio基础_引入和简单的Get请求
05.Dio基础_引入和简单的Get请求 博客地址: https://jspang.com/post/FlutterShop.html#toc-4c7 第三方的http请求库叫做Dio https:/ ...
- Flutter实战视频-移动电商-08.Dio基础_伪造请求头获取数据
08.Dio基础_伪造请求头获取数据 上节课代码清楚 重新编写HomePage这个动态组件 开始写请求的方法 请求数据 .但是由于我们没加请求的头 所以没有返回数据 451就是表示请求错错误 创建请求 ...
随机推荐
- bowtie2 Linux安装
目前最新版本为2.3.2,网址为:https://sourceforge.net/projects/bowtie-bio/files/bowtie2/2.3.2 安装分为简单的下载可执行文件和源编译安 ...
- 排名前10的vue前端UI框架框架值得你掌握
参考:https://juejin.im/post/5b34faeef265da59645b188e muse-ui 框架: https://juejin.im/entry/582974eb8ac24 ...
- 4.无监督学习--K-means聚类
K-means方法及其应用 1.K-means聚类算法简介: k-means算法以k为参数,把n个对象分成k个簇,使簇内具有较高的相似度,而簇间的相似度较低.主要处理过程包括: 1.随机选择k个点作为 ...
- 06 str() bytes() 编码转换
x = str() #创建字符串#转换成字符串,字节,编码 m = bytes()#创建字节#转换成字节,字符串,要编程什么编码类型的字节 a = "李露" b1 = bytes( ...
- jQuery选择器--:selected和:checked
:selected 概述 匹配所有选中的option元素 <!DOCTYPE html> <html> <head> <meta charset=" ...
- 20155228 2016-2017-2 《Java程序设计》第1周学习总结
20155228 2016-2017-2 <Java程序设计>第1周学习总结 教材学习内容总结 这部分内容是以教材为基础,根据个人的理解来描述的,有的地方的理解和表述可能不规范甚至不正确, ...
- MySQL 查询表中某字段值重复的数据
MySQL中,查询表(dat_bill_2018_11)中字段(product_id)值重复的记录: ; 说明:先用GROUP BY 对 product_id 进行分组,同时使用COUNT(*)进行统 ...
- 运用kNN算法识别潜在续费商家
背景与目标 Youzan 是一家SAAS公司,服务于数百万商家,帮助互联网时代的生意人私有化顾客资产.拓展互联网客群.提高经营效率.现在,该公司希望能够从商家的交易数据中,挖掘出有强烈续费倾向的商家, ...
- 20165305 Linux安装及学习
一.虚拟机的安装 在根据老师所给的<基于VirtualBox虚拟机安装Ubuntu图文教程>的时候,我发现虚拟化处于被禁用状态,于是我在网上查找了一下解决办法,在我将bios中虚拟化设置为 ...
- tomcat9.0 配置账户
原文见: http://blog.csdn.net/guochunyang/article/details/51820066 tomcat9.0 管理页面如:http://192.168.2.10 ...