本文地址:https://www.cnblogs.com/tujia/p/13862357.html

系列文章:

【0】TensorFlow光速入门-序

【1】TensorFlow光速入门-tensorflow开发基本流程

【2】TensorFlow光速入门-数据预处理(得到数据集)

【3】TensorFlow光速入门-训练及评估

【4】TensorFlow光速入门-保存模型及加载模型并使用

【5】TensorFlow光速入门-图片分类完整代码

【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

【7】TensorFlow光速入门-总结

一、导入需要的包

import tensorflow as tf
from tensorflow import keras
import numpy as np

二、初始化模型并配置神经网络层

model = keras.Sequential([
# 展平数据,输入类型要和数据集保持一致,我这里是100*100的灰图
keras.layers.Flatten(input_shape=(100, 100, 1)),
# 第二层是神经元
keras.layers.Dense(128, activation='relu'),
# 第三层的参数很重要,2表示分两类,如果要分5类就传5,10类就传10
keras.layers.Dense(2, activation='softmax')
])

注:如果是图片分类,这三层网络是固定搭配,需要注意的是,input_shape要和数据集数据保持一致,第三层分几类就传几;其他模型的层选择和配置,我们后面再慢慢了解

三、模型编译

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

注:同样,图片分类的优化器、损失函数及指标也是固定搭配,其他类型的模型我们后面再慢慢了解

四、训练

model.fit(ds, epochs=100, steps_per_epoch=10)

注1:ds 是上一节准备好的数据集;epochs 代表要训练多少次,steps_per_epoch 代表每一次分几步训练;因为我准备的数据比较少,所以设置的训练100次。数据多的话,不用训练那么多次。

注2:使用 ZipDataset 类型的数据集时,steps_per_epoch 参数为必填,其他情况,根据自己的情况可以不传。

五、评估(评估训练效果)

test_loss, test_acc = model.evaluate(ds, verbose=2, steps=10)

注1:正常情况下,训练要用训练集,评估要用测试集。因为偷懒的原故,这里我都用的同一个数据集。

注2:使用 ZipDataset 类型的数据集时,steps 参数为必填,其他情况,根据自己的情况可以不传。

六、预测

预测即使用的意思,评估通过的模型,可以直接使用了

predictions = model.predict(ds, steps=10)
label = np.argmax(predictions[0])
print(label_names[label])

注:这里批量预测,对整个数据集都进行预测,正式使用的时候,一般只预测一张图片就可以了,下一节会说。

重点 Api :

keras.Sequential        https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential

model.compile            https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential#compile

model.fit                       https://tensorflow.google.cn/api_docs/python/tf/keras/Model#fit

model.evaluate            https://tensorflow.google.cn/api_docs/python/tf/keras/Model#evaluate

model.predict               https://tensorflow.google.cn/api_docs/python/tf/keras/Model#predict

至此,我们的图片分类模型已经训练好了。可以使用了模型来做图片分类预测了。

下一节,让我们来说一下,怎么保存这个训练好的模型。以及如何加载已保存的模型并使用:

【4】TensorFlow光速入门-保存模型及加载模型并使用

本文链接:https://www.cnblogs.com/tujia/p/13862357.html


完。

【3】TensorFlow光速入门-训练及评估的更多相关文章

  1. 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  2. 【0】TensorFlow光速入门-序

    本文地址:https://www.cnblogs.com/tujia/p/13863181.html 序言: 对于我这么一个技术渣渣来说,想学习TensorFlow机器学习,实在是太难了: 百度&qu ...

  3. 【1】TensorFlow光速入门-tensorflow开发基本流程

    本文地址:https://www.cnblogs.com/tujia/p/13862339.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  4. 【2】TensorFlow光速入门-数据预处理(得到数据集)

    本文地址:https://www.cnblogs.com/tujia/p/13862351.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  5. 【4】TensorFlow光速入门-保存模型及加载模型并使用

    本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  6. 【5】TensorFlow光速入门-图片分类完整代码

    本文地址:https://www.cnblogs.com/tujia/p/13862364.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...

  7. Tensorflow高速入门2--实现手写数字识别

    Tensorflow高速入门2–实现手写数字识别 环境: 虚拟机ubuntun16.0.4 Tensorflow 版本号:0.12.0(仅使用cpu下) Tensorflow安装见: http://b ...

  8. TensorFlow学习——入门篇

    本文主要通过一个简单的 Demo 介绍 TensorFlow 初级 API 的使用方法,因为自己也是初学者,因此本文的目的主要是引导刚接触 TensorFlow 或者 机器学习的同学,能够从第一步开始 ...

  9. 音频标签化3:igor-8m项目的训练、评估与测试

    上一节介绍了youtube-8m项目,这个项目以youtube-8m dataset(简称8m-dataset)样本集为基础,进行训练.评估与测试.youtube-8m设计用于视频特征样本,但实际也适 ...

随机推荐

  1. Spring Boot学习(一)初识Spring Boot

    Spring Boot 概述 Spring Boot 是所有基于 Spring 开发的项目的起点.Spring Boot 的设计是为了让你尽可能快的跑起来 Spring 应用程序并且尽可能减少你的配置 ...

  2. 一台电脑配置多个GigHub账号

    换了新的公司,原来的公司用SVN(比较老了),自己平时用码云(Gitee),新公司使用GitHub.前天通知我注册GitHub账号,但是并未通知用户名的事情(要求用自己的名字),原来的GitHub账号 ...

  3. Oracle复习(复习精简版v1.0)

    自己没记不住的,超基础Oracle知识,新手可以看一下. 大多数例子是用scott用户中的emp表完成 排序:order by 列名    desc是降序,默认是升序: update 表名 set 列 ...

  4. Vue编写的页面部署到springboot网站项目中出现页面加载不全问题

    问题描述: 在用Vue脚手架 编写出一个页面之后, 部署到后台项目中, 因为做的是一个页面 按理来说 怎么都能够在服务器上运行 , 我也在自己的node环境测试 , 在同学的springboot上运行 ...

  5. Spring Cloud系列(二):Eureka应用详解

    一.注册中心 1.注册中心演变过程 2.注册中心必备功能 ① 服务的上线 ② 服务的下线 ③ 服务的剔除 ④ 服务的查询 ⑤ 注册中心HA ⑥ 注册中心节点数据同步 ⑦ 服务信息的存储,比如mysql ...

  6. Mybatis 插件原理解析

    SqlSessionFactory 是 MyBatis 核心类之一,其重要功能是创建 MyBatis 的核心接口 SqlSession.MyBatis 通过 SqlSessionFactoryBuil ...

  7. 【奇淫巧技】XSS绕过技巧

    XSS记录 1.首先是弹窗函数: alert(1) prompt(1) confirm(1)eval() 2.然后是字符的编码和浏览器的解析机制: 要讲编码绕过,首先我们要理解浏览器的解析过程,浏览器 ...

  8. c++中 预编译头文件PCH

    转载:https://blog.csdn.net/lovemysea/article/details/74858430 一.预编译头文件使用经验: 如果预编译头文件被正确使用时,它确实大大提高我们编程 ...

  9. Codeforces Global Round 11 A~D题解

    A.Avoiding Zero 题目链接:https://codeforces.ml/contest/1427 题目大意:给定一个数组a1,a2...,an,要求找出一个a重排后的数组b1,b2,.. ...

  10. 【Jenkins】远程调用jenkins进行构建方式!

    前提:jenkins支持远程调用(具体设置自行百度)1.在我的个人中心--configure--API TOKEN--如果没有,则添加一个token,并生成,再复制并记录下来2.在你的job上面加上你 ...