auto-keras 测试保存导入模型
- # coding:utf-8
- import time
- import matplotlib.pyplot as plt
- from autokeras import ImageClassifier
- # 保存和导入模型方法
- from autokeras.utils import pickle_to_file,pickle_from_file
- from keras.engine.saving import load_model
- from keras.utils import plot_model
- from scipy.misc import imresize
- import numpy as np
- import pandas as pd
- import random
- import os
- from sklearn.model_selection import train_test_split
- from sklearn.metrics import accuracy_score
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- # 导入图片的函数
- def read_img(path):
- nameList = os.listdir(path)
- n = len(nameList)
- # indexImg,columnImg = plt.imread(path+'/'+nameList[0]).shape
- x_train = np.zeros([n,28,28,1]);y_train=[]
- for i in range(n):
- x_train[i,:,:,0] = imresize(plt.imread(path+'/'+nameList[i]),[28,28])
- y_train.append(np.int(nameList[i].split('.')[1]))
- return x_train,y_train
- x_train,y_train = read_img('./dataset')
- y_train = pd.DataFrame(y_train)
- n = len(y_train[y_train.iloc[:,0]==2])
- x_train = np.array(x_train)
- x_wzp = np.random.choice(y_train[y_train.iloc[:,0]==1].index.tolist(),n,replace=False)
- x_train_w = x_train[x_wzp,:].copy()
- x_train_l = x_train[y_train[y_train.iloc[:,0]==2].index.tolist()].copy()
- x_train = np.concatenate([x_train_w,x_train_l],axis=0)
- print(x_train.shape)
- y_train = y_train.iloc[-208:,:].copy()
- # 对两组数据进行洗牌
- index = random.sample(range(len(y_train)),len(y_train))
- index = np.array(index)
- y_train = y_train.iloc[index,:]
- # y_train.plot()
- # plt.show()
- x_train = x_train[index,:,:,:]
- # x_train,x_test,y_train,y_test = train_test_split(x_train,y_train,test_size=0.2)
- # print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
- # y_test = y_test.values.reshape(-1)
- y_train = y_train.values.reshape(-1)
- # 数据测试
- '''
- print(y_train)
- for i in range(5):
- n = i*20
- img = x_train[n,:,:,:].reshape((28,28))
- print(y_train[n])
- plt.figure()
- plt.imshow(img,cmap='gray')
- plt.xticks([])
- plt.yticks([])
- plt.show()
- '''
- if __name__=='__main__':
- start = time.time()
- # 模型构建
- model = ImageClassifier(verbose=True)
- # 搜索网络模型
- model.fit(x_train,y_train,time_limit=1*60)
- # 验证最优模型
- model.final_fit(x_train,y_train,x_train,y_train,retrain=True)
- # 给出评估结果
- score = model.evaluate(x_train,y_train)
- # 识别结果
- y_predict = model.predict(x_train)
- # y_pred = np.argmax(y_predict,axis=1)
- # 精确度
- accuracy = accuracy_score(y_train,y_predict)
- # 打印出score与accuracy
- print('score:',score,' accuracy:',accuracy)
- print(y_predict,y_train)
- model_dir = r'./trainer/new_auto_learn_Model.h5'
- model_img = r'./trainer/imgModel_ST.png'
- # 保存可视化模型
- # model.load_searcher().load_best_model().produce_keras_model().save(model_dir)
- pickle_to_file(model,model_dir)
- # 加载模型
- # automodel = load_model(model_dir)
- # models = pickle_from_file(model_dir)
- # 输出模型 structure 图
- # plot_model(automodel, to_file=model_img)
- end = time.time()
- print('time:',end-start)
auto-keras 测试保存导入模型的更多相关文章
- Keras读取保存的模型时, 产生错误[ValueError: Unknown activation function:relu6]
Solution: from keras.utils.generic_utils import CustomObjectScope with CustomObjectScope({'relu6': k ...
- Python机器学习笔记:深入理解Keras中序贯模型和函数模型
先从sklearn说起吧,如果学习了sklearn的话,那么学习Keras相对来说比较容易.为什么这样说呢? 我们首先比较一下sklearn的机器学习大致使用流程和Keras的大致使用流程: skl ...
- 使用Keras基于RCNN类模型的卫星/遥感地图图像语义分割
遥感数据集 1. UC Merced Land-Use Data Set 图像像素大小为256*256,总包含21类场景图像,每一类有100张,共2100张. http://weegee.vision ...
- 从3dmax中导入模型到UDK Editor(供个人备忘)
笔记从3dmax中导入模型到UDK Editor 1) 在3dmax中导出 2) 选择FBX格式,保存 3) 在UDK中打开content browser,自己选个pac ...
- keras中保存自定义层和loss
在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...
- Torch 7 load saved model failed, 加载保存的模型失败
Torch 7 load saved model failed, 加载保存的模型失败: 可以尝试下面的解决方案:
- 3dmax导入模型,解决贴图不显示的问题
在3dmax中导入模型数据后,经常出现贴图不显示的情况,效果如下图: 解决方法: 1.怀疑是贴图文件的路径设置有误.快捷键 shift+T打开“资源追踪”界面,重新设置贴图的正确路径(这里如果快捷键无 ...
- TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model
TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...
- thinkphp3.2 控制器导入模型
方法一: public function index(){ $Member = new MemberModel(); $money = $Member->Money(); print_r($mo ...
随机推荐
- android gradle打包常见问题及解决方案
背景: 问题: Q1: UNEXPECTED TOP-LEVEL ERROR: java.lang.OutOfMemoryError: Java heap space at com.android.d ...
- python爬虫:爬取网站视频
python爬取百思不得姐网站视频:http://www.budejie.com/video/ 新建一个py文件,代码如下: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 1 ...
- [C/C++] 友元函数和友元类
A---友元函数: class Data{ public: ... friend int f(int &m);//友元函数 ... } 友元函数是可以直接访问类的私有成员的非成员函数.它是定义 ...
- 【数据库】Sql Server备份还原脚本
USE master RESTORE DATABASE 新建的没有任何数据的数据库名 FROM DISK = 'e:\数据库备份文件.bak' WITH MOVE '原来的逻辑名称' TO 'e:\新 ...
- bzoj3546[ONTAK2010]Life of the Party
题意是裸的二分图关键点(必然在二分图最大匹配中出现的点).比较经典的做法在cyb15年的论文里有: 前几天写jzoj5007的时候脑补了一种基于最小割可行边的做法:考虑用最大流求解二分图匹配.如果某个 ...
- POJ2724:Purifying Machine——题解
http://poj.org/problem?id=2724 描述迈克是奶酪工厂的老板.他有2^N个奶酪,每个奶酪都有一个00 ... 0到11 ... 1的二进制数.为了防止他的奶酪免受病毒侵袭,他 ...
- [学习笔记]FFT——快速傅里叶变换
大力推荐博客: 傅里叶变换(FFT)学习笔记 一.多项式乘法: 我们要明白的是: FFT利用分治,处理多项式乘法,达到O(nlogn)的复杂度.(虽然常数大) FFT=DFT+IDFT DFT: 本质 ...
- 【简单算法】17.字符串转整数(atoi)
题目: 实现 atoi,将字符串转为整数. 在找到第一个非空字符之前,需要移除掉字符串中的空格字符.如果第一个非空字符是正号或负号,选取该符号,并将其与后面尽可能多的连续的数字组合起来,这部分字符即为 ...
- POJ3164:Command Network(有向图的最小生成树)
Command Network Time Limit: 1000MS Memory Limit: 131072K Total Submissions: 20766 Accepted: 5920 ...
- Spring 容器AOP的实现原理——动态代理
参考:http://wiki.jikexueyuan.com/project/ssh-noob-learning/dynamic-proxy.html(from极客学院) 一.介绍 Spring的动态 ...