测试本地mnist数据集

图片只用500张,450张做train与50张test,

代码如下:

  

  1. # conding:utf-8
  2. import os
  3. os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
  4. import numpy as np
  5. from sklearn.metrics import accuracy_score
  6. import matplotlib.pyplot as plt
  7. from autokeras.image_supervised import ImageClassifier
  8. from keras.models import load_model
  9. from keras.utils import plot_model
  10. import time
  11.  
  12. # 数据准备
  13. x_train =np.zeros((4500,28,28,1))
  14. x_test =np.zeros((500,28,28,1))
  15. y_train=[]
  16. y_test=[]
  17. start = time.time()
  18. for i in range(0,10):
  19. for j in range(1,501):
  20. if j < 451: #将数据保存到训练数据中
  21. x_train[(j-1)+(i*450),:,:,0]=plt.imread('./data/%d/%d_%d.bmp'%(i,i,j)) #reshape 可以降维也就是矩阵变化
  22. y_train.append(i) #append 是读进来的数据进行存储的意思
  23. else: #保存到预测数据中
  24. x_test[(i*50)+(j-452),:,:,0]=plt.imread('./data/%d/%d_%d.bmp'%(i,i,j))
  25. y_test.append(i)
  26. y_t = np.array(y_test).reshape(-1,1)
  27. print(x_train.shape)
  28. # x_train = np.array(x_train).reshape(450,28,28,1)
  29. y_train = np.array(y_train)
  30.  
  31. # x_test = np.array(x_test).reshape(50,28,28,1)
  32. y_test = np.array(y_test)
  33. print(y_test.shape)
  34.  
  35. if __name__ == '__main__':
  36. model_dir = r'./models/autoTf_model.h5'
  37. model_img = r'./model_img/atuoTf_model.png'
  38.  
  39. # 使用图片识别器
  40. clf = ImageClassifier(verbose=True)
  41. # 给其训练数据和标签,训练的最长时间可以设定,假设为1分钟,autokers会不断找寻最优的网络模型
  42. clf.fit(x_train, y_train, time_limit=5 * 60)
  43. # 找到最优模型后,再最后进行一次训练和验证
  44. clf.final_fit(x_train,y_train,x_test, y_test, retrain=True)
  45. # 给出评估结果
  46. y = clf.evaluate(x_test, y_test)
  47. # 输出我么测试数据
  48. y_pred = clf.predict(x_test)
  49. # accuracy精确度
  50. accuracy = accuracy_score(y_test,y_pred)
  51.  
  52. print("evaluate:", y,'\n','accuracy:',accuracy)
  53. #保存可视化模型
  54. clf.load_searcher().load_best_model().produce_keras_model().save(model_dir)
  55. # 加载可视化模型
  56. autoModel = load_model(model_dir)
  57. # y_predict = autoModel.predict(x_test)
  58. # print(y_predict.shape)
  59. # 可视化模型画图
  60. plot_model(autoModel,to_file=model_img)
  61.  
  62. #计时
  63. end = time.time()
  64. print(start-end)

测试结果:accuracy值为0.93而已,然而我用时5分钟(min),自己建立的网络可以达到 0.948,不过自动搜索网络感觉还行吧,有待提高。毕竟是自动的。能到达这种精度是非常厉害的了。如果给他 更多时间的话估计能上0.98了。下面我们来看一下可视化网络吧。

autoKeras入门的更多相关文章

  1. 深度学习应用系列(三)| autokeras使用入门

    我们在构建自己的神经网络模型时,往往会基于预编译模型上进行迁移学习.但不同的训练数据.不同的场景下,各个模型表现不一,需要投入大量的精力进行调参,耗费相当多的时间才能得到自己满意的模型. 而谷歌近期推 ...

  2. autoKeras Windows 的入门测试

    在测试中分析一下ide的效果,在pycharm中测试的时候老师提示内存溢出,而且跑autoKeras的cnn时确实消耗很大空间.但是同样的电脑,换了vscode进行测试的时候没有问题.我也不知道什么回 ...

  3. Angular2入门系列教程7-HTTP(一)-使用Angular2自带的http进行网络请求

    上一篇:Angular2入门系列教程6-路由(二)-使用多层级路由并在在路由中传递复杂参数 感觉这篇不是很好写,因为涉及到网络请求,如果采用真实的网络请求,这个例子大家拿到手估计还要自己写一个web ...

  4. ABP入门系列(1)——学习Abp框架之实操演练

    作为.Net工地搬砖长工一名,一直致力于挖坑(Bug)填坑(Debug),但技术却不见长进.也曾热情于新技术的学习,憧憬过成为技术大拿.从前端到后端,从bootstrap到javascript,从py ...

  5. Oracle分析函数入门

    一.Oracle分析函数入门 分析函数是什么?分析函数是Oracle专门用于解决复杂报表统计需求的功能强大的函数,它可以在数据中进行分组然后计算基于组的某种统计值,并且每一组的每一行都可以返回一个统计 ...

  6. Angular2入门系列教程6-路由(二)-使用多层级路由并在在路由中传递复杂参数

    上一篇:Angular2入门系列教程5-路由(一)-使用简单的路由并在在路由中传递参数 之前介绍了简单的路由以及传参,这篇文章我们将要学习复杂一些的路由以及传递其他附加参数.一个好的路由系统可以使我们 ...

  7. Angular2入门系列教程5-路由(一)-使用简单的路由并在在路由中传递参数

    上一篇:Angular2入门系列教程-服务 上一篇文章我们将Angular2的数据服务分离出来,学习了Angular2的依赖注入,这篇文章我们将要学习Angualr2的路由 为了编写样式方便,我们这篇 ...

  8. Angular2入门系列教程4-服务

    上一篇文章 Angular2入门系列教程-多个组件,主从关系 在编程中,我们通常会将数据提供单独分离出来,以免在编写程序的过程中反复复制粘贴数据请求的代码 Angular2中提供了依赖注入的概念,使得 ...

  9. wepack+sass+vue 入门教程(三)

    十一.安装sass文件转换为css需要的相关依赖包 npm install --save-dev sass-loader style-loader css-loader loader的作用是辅助web ...

随机推荐

  1. 【jQuery】 选择器

    [jQuery] 选择器 资料: w3school  http://www.w3school.com.cn/jquery/jquery_ref_selectors.asp 1. 标签选择器 : $(& ...

  2. 预装win8的笔记本如何重装win7

    测试电脑联想T440. 开机按F1,然后Enter,进入Bios设置. 先关闭Secure Boot,然后设置为Legacy Boot. 之后才能设置U盘为第一启动盘. 进入老毛桃的PE系统,使用Di ...

  3. Python 常见的字符串操作

    1.strip.lstrip和rstrip 描述: 用于移除字符串左右两边.左边.右边指定的字符(默认为空白符,例如:/n, /r, /t, ' ')或字符序列. 语法: str.strip([cha ...

  4. ByteArrayInputStream/ByteArrayOutputStream 学习

    ByteArrayInputStream: byte[] buff = new byte[1024]; ByteArrayInputStream bAIM = new ByteArrayInputSt ...

  5. 权限管理UML设计草图

    PS:  最近闲来无事,打算整一个权限管理模块.然而UML我只会看不会设计,现在的草图都是边学边做的,现在发出来,希望前辈们指点一二!先拜谢了! 搞开发也有2年多快三年了,我感觉自己基本上还是一个菜鸟 ...

  6. poj1789 Truck History最小生成树

    Truck History Time Limit: 2000MS   Memory Limit: 65536K Total Submissions: 20768   Accepted: 8045 De ...

  7. poj3026(bfs+prim)最小生成树

    The Borg is an immensely powerful race of enhanced humanoids from the delta quadrant of the galaxy. ...

  8. Mybatis学习系列(三)动态SQL

    在mapper配置文件中,有时需要根据查询条件选择不同的SQL语句,或者将一些使用频率高的SQL语句单独配置,在需要使用的地方引用.Mybatis的一个特性:动态SQL,来解决这个问题. mybati ...

  9. 解决Mysql错误Too many connections的方法

    MySQL数据库 Too many connections出现这种错误明显就是 mysql_connect 之后忘记 mysql_close:当大量的connect之后,就会出现Too many co ...

  10. vue-cli配置jquery 以及jquery第三方插件

    只使用jquery: 1.  cnpm install jquery --save 2.   cnpm install @types/jquery --save-dev (不使用ts的不需要安装此声明 ...