【小白学PyTorch】15 TF2实现一个简单的服装分类任务
【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.
参考目录:
0 为什么学TF
之前的15节课的pytorch的学习,应该是让不少朋友对PyTorch有了一个全面而深刻的认识了吧 (如果你认真跑代码了并且认真看文章了的话) 。
大家都会比较Tensorflow2和pytorch之间孰优孰劣,但是我们也并不是非要二者选一,两者都是深度学习的工具,其实我们或多或少应该了解一些比较好。 就好比,PyTorch是冲锋枪,TensorFlow是步枪,在上战场前,我们可以选择带上冲锋枪还是步枪,但是在战场上,可能手中的枪支没有子弹了,你只能在地上随便捡了一把枪。 很多时候,用Pytorch还是Tensorflow的选择权不在自己。
此外,了解了TensorFlow,大家才能更好的理解PyTorch和TF究竟有什么区别。我见过有的大佬是TF和PyTorch一起用在一个项目中,数据读取用PyTorch然后模型用TF构建。
总之,大家有时间有精力的话,顺便学学TF也不亏,更何况TF2.0现在已经优化了很多。本系列预计用3节课来简单的入门一下Tensorflow2.
和PyTorch的第一课一样,我们直接做一个简单的小实战。MNIST手写数字分类,Fashion MNIST时尚服装分类。
1 Tensorflow的安装
安装TensorFlow的方法很简单,就是在控制台执行:
pip install tensorflow --user
这里的--user
是赋予这个命令执行权限的,一般我都会带上。
2 数据集构建
# keras是TF的高级API,用起来更加的方便,一般也是用keras。
import tensorflow as tf
from tensorflow import keras
import numpy as np
导入需要用到的库函数. 正如torchvision.datasets
中一样,keras.datasets
中也封装了一些常用的数据集。
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print('train_images shape:',train_images.shape)
print('train_labels shape:',train_labels.shape)
print('test_images shape:',test_images.shape)
print('test_labels shape:',test_labels.shape)
输出结果是:
训练数据集中有60000个样本,每一个样本和MNIST手写数字大小是一样的,是\(28\times 28\)大小的,然后每一个样本有一个标签,这个标签和MNIST也是一样的,是从0到9,是一个十分类任务。
来看一下这些类别有哪些:
标签 | 类别 | 标签 | 类别 |
---|---|---|---|
0 | T-shirt | 5 | Sandal |
1 | Trouser | 6 | Shirt |
2 | Pullover | 7 | Sneaker |
3 | Dress | 8 | Bag |
4 | Coat | 9 | Ankle boot |
这里学学单词吧:
- T-shirt就是T型的衬衫,就是短袖,我感觉前面没有扣子的那种也叫T-shirt;
- Shirt就是长袖的那种衬衫;
- Trouser是裤子;
- pullover是毛衣,套头毛衣,就是常说的卫衣吧感觉;
- dress连衣裙;
- coat是外套;
- sandal是凉鞋;
- sneaker是运动鞋;
- ankle boot是短靴,是到脚踝的那种靴子;
- 这里补充一个吧,sweater,是毛线衣,运动衫,这个和pullover有些类似,个人感觉主要的区分在于运动系列的可以叫做sweater,其他的毛衣卫衣是pullover。
运动短袖T-shirt+运动卫衣sweater是我秋天去健身房的穿搭。
2 预处理
这里不做图像增强之类的了,上面的数据中,图像像素值是从0到255的,我们要把这些标准化成0到1的范围。
train_images = train_images / 255.0
test_images = test_images / 255.0
3 构建模型
# 模型搭建
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
这就是一个用keras构建简单模型的例子:
keras.layers.Flatten
是把\(28\times 28\)的二维度拉平成一个维度,因为这里是直接用全连接层而不是卷积层进行处理的;- 后面跟上两个全连接层
keras.layers.Dense()
就行了。我们可以发现,这个全连接层的参数和PyTorch是有一些区别的:- PyTorch的全连接层需要一个输入神经元数量和输出数量
torch.nn.Linear(5,10)
,而keras中的Dense是不需要输入参数的keras.layers.Dense(10)
; - keras中的激活层直接封装在了Dense函数里面,所以不需要像PyTorch一样单独写一个
nn.ReLU()
了。
- PyTorch的全连接层需要一个输入神经元数量和输出数量
4 优化器
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
定义优化器和损失函数,在keras中叫做对模型进行编译compile(在C语言中,在运行代码之前都需要对代码进行编译嘛)。损失函数和优化器还有metric衡量指标的设置都在模型的编译函数中设置完成。
上面使用Adam作为优化器,然后损失函数用了交叉熵,然后衡量模型性能的使用了准确率Accuracy。
5 训练与预测
model.fit(train_images, train_labels, epochs=10)
这就是训练过程,相比PyTorch而言,更加的简单简洁,但是不像PyTorch那样灵活。
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy:', test_acc)
这个.evaluate
方法是对模型的验证集进行验证的,因为本次任务中并没有对训练数据再划分出验证集,所以这里直接使用测试数据了。
大家应该能理解训练集、验证集和测试集的用途和区别吧,我在第二课讲过这个内容,在此不多加赘述。
predictions = model.predict(test_images)
这个.predict
方法才是用在测试集上,进行未知标签样本的类别推理的。
本次内容到此为止,大家应该对keras和tensorflow有一个直观浅显的认识了。当然tensorflow也有一套类似于PyTorch中的dataset,dataloader的那样自定义的数据集加载器的方法,在后续内容中会深入浅出的学一下。
【小白学PyTorch】15 TF2实现一个简单的服装分类任务的更多相关文章
- 【小白学PyTorch】1 搭建一个超简单的网络
文章目录: 目录 1 任务 2 实现思路 3 实现过程 3.1 引入必要库 3.2 创建训练集 3.3 搭建网络 3.4 设置优化器 3.5 训练网络 3.6 测试 1 任务 首先说下我们要搭建的网络 ...
- 【小白学PyTorch】20 TF2的eager模式与求导
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 学习用node.js建立一个简单的web服务器
一.建立简单的Web服务器涉及到Node.js的一些基本知识点: 1.请求模块 在Node.js中,系统提供了许多有用的模块(当然你也可以用JavaScript编写自己的模块,以后的章节我们将详细讲解 ...
- 【小白学PyTorch】10 pytorch常见运算详解
参考目录: 目录 1 矩阵与标量 2 哈达玛积 3 矩阵乘法 4 幂与开方 5 对数运算 6 近似值运算 7 剪裁运算 这一课主要是讲解PyTorch中的一些运算,加减乘除这些,当然还有矩阵的乘法这些 ...
- 【小白学PyTorch】21 Keras的API详解(上)卷积、激活、初始化、正则
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...
- 【小白学PyTorch】18 TF2构建自定义模型
[机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...
- 【小白学PyTorch】16 TF2读取图片的方法
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx645016617. 参考 ...
- 【小白学PyTorch】19 TF2模型的存储与载入
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...
- 【小白学PyTorch】5 torchvision预训练模型与数据集全览
文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...
随机推荐
- springboot+themeleaf+bootstrap访问静态资源/无法访问静态资源/图片
在网页HTML上访问静态资源的正确写法例: 1.<img src="../../static/bootstarp/img/2.jpg" th:src="@{ ...
- MySQL添加外键报错 - referencing column 'xx' and referenced column 'xx' in foreign key constraint 'xx' are incompatible
MySQL给两个表添加外键时,报错 翻译意思是:外键约束“xx”中的引用列“xx”和引用列“xx”不兼容 说明两个表关联的列数据类型不一致,比如:varchar 与 int,或者 int无符号 与 i ...
- HTTP系列之跨域资源共享机制(CORS)介绍
前言 本文将继续解析详解HTTP系列1中的请求/ 响应报文的首部字段,今天带来的跨域资源共享(CORS)机制,具体内容包括CORS的原理.流程.实战,希望能给大家带来收获! CORS简介 跨域资源共享 ...
- Spring Cloud系列(一):微服务架构简介
一.微服务概述 1.微服务是什么 微服务架构的核心就是服务的拆分,把传统的单体式应用,根据一定的维度(比如业务)拆分为一个一个的服务,每一个服务都有自身特定的功能,又都能够独立的部署,甚至可以拥有自己 ...
- new Map()详细介绍与对比
说明: Map结构提供了“值—值”的对应,是一种更完善的Hash结构实现.如果你需要“键值对”的数据结构,Map比Object更合适.它类似于对象,也是键值对的集合,但是“键”的范围不限于字符串, ...
- MySQL常用指令,java,php程序员,数据库工程师必备。程序员小冰常用资料整理
MySQL常用指令,java,php程序员,数据库工程师必备.程序员小冰常用资料整理 MySQL常用指令(备查) 最常用的显示命令: 1.显示数据库列表. show databases; 2.显示库中 ...
- playable
探索TimelinePlayableAPI,让Timeline为所欲为 https://blog.csdn.net/qq826364410/article/details/80534892 Playa ...
- android Studio(3.1) 常用快捷键
说 明 快捷键 全部保存 Ctrl + S 最大话/最小化编辑器 Ctrl + Shift + F12 搜索内容(包括代码和菜单) 按两次Shift 查找 Ctrl + F 查找下一个 F3 查找上 ...
- 深入了解Redis【二】对象及数据结构综述
引言 Redis中每个键值对都是由对象组成: 键总是一个字符串对象(string) 值可以是字符串对象(string).列表对象(list).哈希对象(hash).集合对象(set).有序集合对象(z ...
- pyqt 设置QTabWidget标签页不可选
pyqt 设置QTabWidget标签页不可选 for i in range(1,7): self.tabWidget.setTabEnabled(i,False)i-对应标签页的位数