1. # coding:utf-8
  2. import time
  3. import matplotlib.pyplot as plt
  4. from autokeras import ImageClassifier
  5. # 保存和导入模型方法
  6. from autokeras.utils import pickle_to_file,pickle_from_file
  7.  
  8. from keras.engine.saving import load_model
  9. from keras.utils import plot_model
  10. from scipy.misc import imresize
  11. import numpy as np
  12. import pandas as pd
  13. import random
  14. import os
  15. from sklearn.model_selection import train_test_split
  16. from sklearn.metrics import accuracy_score
  17.  
  18. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  19. # 导入图片的函数
  20.  
  21. def read_img(path):
  22. nameList = os.listdir(path)
  23. n = len(nameList)
  24. # indexImg,columnImg = plt.imread(path+'/'+nameList[0]).shape
  25. x_train = np.zeros([n,28,28,1]);y_train=[]
  26. for i in range(n):
  27. x_train[i,:,:,0] = imresize(plt.imread(path+'/'+nameList[i]),[28,28])
  28. y_train.append(np.int(nameList[i].split('.')[1]))
  29. return x_train,y_train
  30.  
  31. x_train,y_train = read_img('./dataset')
  32. y_train = pd.DataFrame(y_train)
  33. n = len(y_train[y_train.iloc[:,0]==2])
  34.  
  35. x_train = np.array(x_train)
  36.  
  37. x_wzp = np.random.choice(y_train[y_train.iloc[:,0]==1].index.tolist(),n,replace=False)
  38.  
  39. x_train_w = x_train[x_wzp,:].copy()
  40. x_train_l = x_train[y_train[y_train.iloc[:,0]==2].index.tolist()].copy()
  41. x_train = np.concatenate([x_train_w,x_train_l],axis=0)
  42.  
  43. print(x_train.shape)
  44.  
  45. y_train = y_train.iloc[-208:,:].copy()
  46.  
  47. # 对两组数据进行洗牌
  48. index = random.sample(range(len(y_train)),len(y_train))
  49. index = np.array(index)
  50. y_train = y_train.iloc[index,:]
  51. # y_train.plot()
  52. # plt.show()
  53. x_train = x_train[index,:,:,:]
  54.  
  55. # x_train,x_test,y_train,y_test = train_test_split(x_train,y_train,test_size=0.2)
  56. # print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
  57. # y_test = y_test.values.reshape(-1)
  58. y_train = y_train.values.reshape(-1)
  59.  
  60. # 数据测试
  61.  
  62. '''
  63. print(y_train)
  64. for i in range(5):
  65. n = i*20
  66. img = x_train[n,:,:,:].reshape((28,28))
  67. print(y_train[n])
  68. plt.figure()
  69. plt.imshow(img,cmap='gray')
  70. plt.xticks([])
  71. plt.yticks([])
  72. plt.show()
  73. '''
  74.  
  75. if __name__=='__main__':
  76. start = time.time()
  77. # 模型构建
  78. model = ImageClassifier(verbose=True)
  79. # 搜索网络模型
  80. model.fit(x_train,y_train,time_limit=1*60)
  81. # 验证最优模型
  82. model.final_fit(x_train,y_train,x_train,y_train,retrain=True)
  83. # 给出评估结果
  84. score = model.evaluate(x_train,y_train)
  85. # 识别结果
  86. y_predict = model.predict(x_train)
  87. # y_pred = np.argmax(y_predict,axis=1)
  88. # 精确度
  89. accuracy = accuracy_score(y_train,y_predict)
  90. # 打印出score与accuracy
  91. print('score:',score,' accuracy:',accuracy)
  92. print(y_predict,y_train)
  93. model_dir = r'./trainer/new_auto_learn_Model.h5'
  94. model_img = r'./trainer/imgModel_ST.png'
  95.  
  96. # 保存可视化模型
  97. # model.load_searcher().load_best_model().produce_keras_model().save(model_dir)
  98. pickle_to_file(model,model_dir)
  99. # 加载模型
  100. # automodel = load_model(model_dir)
  101. # models = pickle_from_file(model_dir)
  102. # 输出模型 structure 图
  103. # plot_model(automodel, to_file=model_img)
  104.  
  105. end = time.time()
  106. print('time:',end-start)

  

auto-keras 测试保存导入模型的更多相关文章

  1. Keras读取保存的模型时, 产生错误[ValueError: Unknown activation function:relu6]

    Solution: from keras.utils.generic_utils import CustomObjectScope with CustomObjectScope({'relu6': k ...

  2. Python机器学习笔记:深入理解Keras中序贯模型和函数模型

     先从sklearn说起吧,如果学习了sklearn的话,那么学习Keras相对来说比较容易.为什么这样说呢? 我们首先比较一下sklearn的机器学习大致使用流程和Keras的大致使用流程: skl ...

  3. 使用Keras基于RCNN类模型的卫星/遥感地图图像语义分割

    遥感数据集 1. UC Merced Land-Use Data Set 图像像素大小为256*256,总包含21类场景图像,每一类有100张,共2100张. http://weegee.vision ...

  4. 从3dmax中导入模型到UDK Editor(供个人备忘)

    笔记从3dmax中导入模型到UDK Editor 1)      在3dmax中导出 2)      选择FBX格式,保存 3)      在UDK中打开content browser,自己选个pac ...

  5. keras中保存自定义层和loss

    在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...

  6. Torch 7 load saved model failed, 加载保存的模型失败

    Torch 7 load saved model failed, 加载保存的模型失败: 可以尝试下面的解决方案:  

  7. 3dmax导入模型,解决贴图不显示的问题

    在3dmax中导入模型数据后,经常出现贴图不显示的情况,效果如下图: 解决方法: 1.怀疑是贴图文件的路径设置有误.快捷键 shift+T打开“资源追踪”界面,重新设置贴图的正确路径(这里如果快捷键无 ...

  8. TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

      TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...

  9. thinkphp3.2 控制器导入模型

    方法一: public function index(){ $Member = new MemberModel(); $money = $Member->Money(); print_r($mo ...

随机推荐

  1. android gradle打包常见问题及解决方案

    背景: 问题: Q1: UNEXPECTED TOP-LEVEL ERROR: java.lang.OutOfMemoryError: Java heap space at com.android.d ...

  2. python爬虫:爬取网站视频

    python爬取百思不得姐网站视频:http://www.budejie.com/video/ 新建一个py文件,代码如下: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 1 ...

  3. [C/C++] 友元函数和友元类

    A---友元函数: class Data{ public: ... friend int f(int &m);//友元函数 ... } 友元函数是可以直接访问类的私有成员的非成员函数.它是定义 ...

  4. 【数据库】Sql Server备份还原脚本

    USE master RESTORE DATABASE 新建的没有任何数据的数据库名 FROM DISK = 'e:\数据库备份文件.bak' WITH MOVE '原来的逻辑名称' TO 'e:\新 ...

  5. bzoj3546[ONTAK2010]Life of the Party

    题意是裸的二分图关键点(必然在二分图最大匹配中出现的点).比较经典的做法在cyb15年的论文里有: 前几天写jzoj5007的时候脑补了一种基于最小割可行边的做法:考虑用最大流求解二分图匹配.如果某个 ...

  6. POJ2724:Purifying Machine——题解

    http://poj.org/problem?id=2724 描述迈克是奶酪工厂的老板.他有2^N个奶酪,每个奶酪都有一个00 ... 0到11 ... 1的二进制数.为了防止他的奶酪免受病毒侵袭,他 ...

  7. [学习笔记]FFT——快速傅里叶变换

    大力推荐博客: 傅里叶变换(FFT)学习笔记 一.多项式乘法: 我们要明白的是: FFT利用分治,处理多项式乘法,达到O(nlogn)的复杂度.(虽然常数大) FFT=DFT+IDFT DFT: 本质 ...

  8. 【简单算法】17.字符串转整数(atoi)

    题目: 实现 atoi,将字符串转为整数. 在找到第一个非空字符之前,需要移除掉字符串中的空格字符.如果第一个非空字符是正号或负号,选取该符号,并将其与后面尽可能多的连续的数字组合起来,这部分字符即为 ...

  9. POJ3164:Command Network(有向图的最小生成树)

    Command Network Time Limit: 1000MS   Memory Limit: 131072K Total Submissions: 20766   Accepted: 5920 ...

  10. Spring 容器AOP的实现原理——动态代理

    参考:http://wiki.jikexueyuan.com/project/ssh-noob-learning/dynamic-proxy.html(from极客学院) 一.介绍 Spring的动态 ...