[Pytorch框架] 4.3 fastai
import fastai
from fastai import *
from fastai.vision import *
import torch
print(torch.__version__)
print(fastai.__version__)
1.0.0
1.0.45
4.3 fastai
4.3.1 fastai介绍
fastai库
fastai将训练一个准确的神经网络变得十分简单。fastai库是基于他的创始人Jeremy Howard 等人开发的 Deep Learning 课程深度学习的研究,为计算机视觉、文本、表格数据、时间序列、协同过滤等常见深度学习应用提供单一、一致界面的深度学习库,可以做到开箱即用。这意味着,如果你已经学会用fastai创建实用的计算机视觉(CV)模型,那你就可以用同样的方法创建自然语言处理(NLP)模型,或是其他模型。
fastai 是目前把易用性和功能都做到了极致的深度学习框架,正如Jeremy所说的:如果一个深度学习框架需要写个教程给你,那它的易用性还不够好。Jeremy 说这话,不是为了夸自己,因为他甚至做了个 MOOC 出来。他自己评价说目前 fastai 的易用性依然不算成功。但在我看来它的门槛极低,你可以很轻易用几句话写个图片分类模型出来,人人都能立即上手,你甚至不需要知道深度学习的理论。
fast.ai课程
上面说到了课程,这里对fast.ai的课程做一个简单的介绍:
课程是由kaggle赛事老司机,连续两年冠军Jeremy Howard 和 Rachel Tomas 联合创办,旨在让更多人能接受深度学习的课程,而且是完全免费!真的是业界良心,这两年深度学习火了起来,国内有培训机构推出收费课程了,教学水平参差不齐。而Jeremy和Rachel推出的课程,恰恰提现了他们的教育理念:Make deep learning uncool ! (让深度学习变得没那么高大上)
Fast.ai给人的印象一直很“接地气”:
- 研究如何快速、可靠地把最先进的深度学习应用于实际问题。
- 提供Fast.ai库,它不仅是让新手快速构建深度学习实现的工具包,也是提供最佳实践的一个强大而便捷的资源。
- 课程内容简洁易懂,以便尽可能多的人从研究成果和软件中收益。
Github
这个官方的Github包含了fastai的所有内容 https://github.com/fastai
4.3.2 fastai实践
MNIST
我们还是以最简单的MNIST来入手看看fastai都为我们做了什么
# 使用fastai内置的MNIST数据集,这里会从fastai的服务器下载
path = untar_data(URLs.MNIST_SAMPLE)
URLs.MNIST_SAMPLE 只提供了3和7 两个分类的数据,这个是用来做演示的,我们正好也做个演示
这里如果下载很慢的话,那么我们可以手动进行操作(建议这样,比程序下载快很多而且稳定)
#进入我们用户目录,创建以下的目录
mkdir -p ~/.fastai/data
cd ~/.fastai/data
# 下载解压
wget -c http://files.fast.ai/data/examples/mnist_sample.tgz
tar -zxvf mnist_sample.tgz
完成后重新执行上面的命令即可
#使用ImageDataBunch从刚才的目录中将读入数据
data = ImageDataBunch.from_folder(path)
# 可以看一下data里面有什么?
data
ImageDataBunch;
Train: LabelList (12396 items)
x: ImageItemList
Image (3, 28, 28),Image (3, 28, 28),Image (3, 28, 28),Image (3, 28, 28),Image (3, 28, 28)
y: CategoryList
7,7,7,7,7
Path: /Users/tant/.fastai/data/mnist_sample;
Valid: LabelList (2038 items)
x: ImageItemList
Image (3, 28, 28),Image (3, 28, 28),Image (3, 28, 28),Image (3, 28, 28),Image (3, 28, 28)
y: CategoryList
7,7,7,7,7
Path: /Users/tant/.fastai/data/mnist_sample;
Test: None
# 使用cnn_learner来创建一个learn,这里模型我们选择resnet18,使用的计量方法是accuracy准确率
learn =create_cnn(data, models.resnet18, metrics=accuracy)
#可以直接使用train_ds来访问数据集里面的数据
img,label = data.train_ds[0]
print(label)
img
7
#或者我们直接使用show_batch方法,连标签都给我们自动生成好了
data.show_batch(rows=3, figsize=(6,6))
这里也是直接下载PyTorch官方提供的resnet18与训练模型
wget -P /Users/tant/.torch/models/ https://download.pytorch.org/models/resnet18-5c106cde.pth
# 使用learn的fit方法就可以进行训练了,训练一遍
learn.fit(1)
Total time: 02:21
epoch train_loss valid_loss accuracy 1 0.130960 0.086702 0.969087
经过上面的训练,你一定会很纳闷:
- 没有告诉模型类别有几个
- 没有指定任务迁移之后接续的几个层次的数量、大小、激活函数
- 没有告诉网络损失函数是什么
我几乎没有提供任何的信息,网络就开始训练了?
对,不需要。
因为 fastai 根据你输入的上述“数据”、“模型结构”和“损失度量”信息,自动帮你把这些闲七杂八的事情默默搞定了。
下面再介绍一些训练的高级用法
#从新生成一个数据集
learn2 =create_cnn(data, models.resnet18, metrics=accuracy,callback_fns=ShowGraph)
这里我们使用fit_one_cycle方法。
fit_one_cycle使用的是一种周期性学习率,从较小的学习率开始学习,缓慢提高至较高的学习率,然后再慢慢下降,周而复始,每个周期的长度略微缩短,在训练的最后部分,允许学习率比之前的最小值降得更低。这不仅可以加速训练,还有助于防止模型落入损失平面的陡峭区域,使模型更倾向于寻找更平坦的极小值,从而缓解过拟合现象。
learn2.fit_one_cycle(1)
Total time: 02:21
epoch train_loss valid_loss accuracy 1 0.167809 0.118627 0.956330
我们使用内置ShowGraph的方法直接打印训练的状态,如果我们需要更详细的状态,可以直接调用一下的方法:
# 学习率的变更
learn2.recorder.plot_lr()
#损失
learn2.recorder.plot_losses()
# 我们也可以使用lr_find()找到损失仍在明显改善最高学习率
learn2.lr_find()
learn2.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Min numerical gradient: 6.31E-07
4.3.3 fastai文档翻译
由于fastai的中文资料很少而且目前官方只提供英文的文档,所以如果谁有兴趣一起翻译的话可以联系我,如果人数够了的话可以组个团队一起翻译。
[Pytorch框架] 4.3 fastai的更多相关文章
- PyTorch框架+Python 3面向对象编程学习笔记
一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...
- 手写数字识别 卷积神经网络 Pytorch框架实现
MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- 【chainer框架】【pytorch框架】
教程: https://bennix.github.io/ https://bennix.github.io/blog/2017/12/14/chain_basic/ https://bennix.g ...
- 搭建 pytorch框架
Pytorch 发布了1.0,对windows的支持效果更好,因此,今天试了一下安装Pytorch.安装速度确实很快,安装也很方便. 进入pytorch的官网,选择对应的版本 根据版本输入相应命令 注 ...
- NLP使用pytorch框架,pytorch安装
pytorch的安装方法及出现问题的解决方案: 安装pytorch,使用pip 安装,在运行代码的时候会报错,但是导包的时候不会报错,因此要采用conda的方式安装 1.找到miniconda的网 ...
- PyToune:一款类Keras的PyTorch框架
PyToune is a Keras-like framework for PyTorch and handles much of the boilerplating code needed to t ...
随机推荐
- 重写org.springframework.beans.BeanUtils的copyProperties方法,能在实体映射的时候把纯数字格式的日期转格式
就是在拷贝的时候加个正则的校验,如果是纯数字的日期 就转成yyyy-MM-dd HH:mm:ss的格式原本想直接用注解在实体转格式,但是那样实体会变成日期格式,所以放弃了,直接重写拷贝的方法比较简单 ...
- eclipse微服务续,Hystrix+Gateway+Config配置管理中心+Bus动态刷新配置
Hystrix延迟和容错库 Gateway微服务网关 Config配置管理中心 Bus动态刷新配置 四.Hystrix延迟和容错库 SpringCloud默认已为Feign整合了hystrix,所以添 ...
- Unity_飞机大战记录总结
记录步骤:win+R→PSR.exe 一.竖屏设置 分辨率设为9:16 二.主控脚本 添加一个空节点,命名"游戏主控" 新建游戏的主控脚本,命名为MyGame.cs,方便管理(即, ...
- 【BUUCTF】ACTF2020 新生赛Include1 write up
查看源代码+抓包都没有发现什么信息,只有这两个东东 <meta charset="utf8"> Can you find out the flag? <meta ...
- 华大单片机HC32L13X软件设计时候要注意的事项
1.系统启动时默认设置主频为内部4MHz; 2.调试超低功耗程序或者把SWD端口复用为GPIO功能都会把芯片的SWD功能关掉,仿真器将会与芯片失去连接,建议在main函数开始后加上1到2秒的延时,仿真 ...
- java 程序运行机制
java 程序运行同时拥有 编译型语言和解释型语言的特点 程序运行流程: 源程序 .java文件 --> Java 编译器--> 字节码 .class 文件 --> 类装饰器 --& ...
- SHELL-反弹shell
什么是shell? 在我们深入了解发送和接收 shell 的复杂性之前,了解 shell 实际上是什么很重要.用最简单的术语来说,shell 就是我们在与命令行环境 (CLI) 交互时使用的工具.换句 ...
- 认识内存和Cache
认识内存和Cache 操作系统学习笔记,如有错误,还望指出. 我们有什么问题 什么是内存? 什么是Cache? 为什么需要Cache? 程序的局部性原理 这是个前置芝士点. 定义: 程序的局部性原理是 ...
- 《爆肝整理》保姆级系列教程-玩转Charles抓包神器教程(11)-Charles如何模拟弱网环境
1.前言 张三:"我写的软件好奇怪啊,在网络好的时候一点问题也没有,但是信号差的时候明显卡顿,看来我只能一直蹲在卫生间.电梯或者地铁(信号差)调bug了". Charles:&qu ...
- 手机号码归属地的自动查询.py(亲测有效)
import requests url = "http://m.ip138.com/sj.asp?mobile=" kv = {'user-agent':'Mozilla/5.0' ...