将TVM集成到PyTorch上

随着TVM不断展示出对深度学习执行效率的改进,很明显PyTorch将从直接利用编译器堆栈中受益。PyTorch的主要宗旨是提供无缝且强大的集成,而这不会妨碍用户。为此,PyTorch现在具有基于TVM的官方后端torch_tvm

用法很简单:

import
torch_tvm

torch_tvm.enable()

PyTorch将尝试在其JIT编译过程中,将所有可能的运算符转换为已知的Relay运算符。

背景

与许多其他ML框架不同,PyTorch公开了一个渴望执行的编程接口。这种编程风格避免了基于图的元编程,而专注于以Python方式直接控制n维数组(张量)。因此,该框架最初非常适合模型的试验和开发,但不适用于自动性能优化或部署。为了利用优化的编译器技术,PyTorch引入了一些较大的更改来解决此问题。

PyTorch 1.0引入了PyTorch IR,PyTorch专用的中间表示形式,用于类似于Relay的模型。可以通过模型跟踪将PyTorch程序转换为IR,该跟踪记录模型或Python的子集TorchScript的执行。新的TVM后端将PyTorch的IR降低到了Relay,并能够透明地提高PyTorch的性能,而无需用户参与。

整合与结果

为了支持Relay,PyTorch JIT添加了两个功能:自定义转换过程和自定义子图解释器。

当torch_tvm启用时,可以转换到中继PyTorch IR的子图Expr旨意被标记为继电器兼容。由于PyTorch IR并不总是包含形状信息,因此在调用之前,无法以有用的方式编译任何子图。

在用户调用期间,PyTorch JIT运行时将确定输入形状信息,并使用新的Relay C ++构建系统编译先前标记的子图。根据输入形状来缓存编译,以供后续运行。可以在README中找到更多详细信息。

torch_tvm建立了一个连续的基准测试系统,该系统正在监视ResNet18在CPU上的性能。对于各种ResNet型号,TVM的性能都是默认PyTorch
JIT后端的两倍以上。在AWS c5n.4xlarge实例上使用16个线程实现的每秒迭代次数(越大越好)。

这些结果令人鼓舞,该项目将继续致力于,在更多模型上提高CPU推理速度。

未来的工作

现在,PyTorch JIT进行了大量工作来查找其IR的纯功能子集,以馈送到Relay。这避免了将别名和控制流信息映射到中继的需要,但这不是必需的。将更多的PyTorch IR映射到Relay可能会取得性能上的胜利,这是该项目的目标。PyTorch IR在开发过程中正在迅速变化,因此必须谨慎进行。

将做更多的工作来确保PyTorch和TVM代码之间的切换是有效的。这包括统一线程模型,分配器以及减少与将输入复制到TVM相关的开销。

解析

如果已经编写了PyTorch模型,最简单的入门方法就是使用torch.jit.trace以下方法

import
torch_tvm

from
your_model import model, inputs

torch_tvm.enable(opt_level=3)

iters
= 100

warmup
= 10

#
Ensure your model is in eval mode and also turn off gradients.

with
torch.no_grad():

# Use
tuned parameters for better performance.

with
autotvm.apply_history_best("test/autotvm_tuning.log"):

# This is where all the compilation
happens.

trace_tvm = torch.jit.trace(model, inputs)

# Warmup

for _ in range(warmup):

_ =
trace_tvm(*inputs)

# Benchmark

start = time.time()

for _ in range(iters):

_ = trace_tvm(*inputs)

tvm_time = time.time() - start

print("Took {}s to run {}
iters".format(tvm_time, iters))

这段代码大部分来自Benchmarks.py。请注意,用于AVX2 LLVM编译的调整参数位于存储库test/文件夹中。

如果更直接使用Relay,可以通过(隐式)跟踪或TorchScript,直接从PyTorch函数中提取表达式:

def
add(a, b, c):

return a + b + c

#
via tracing

relay_graph
= torch_tvm.to_relay(add, inputs)

@torch.jit.script

def
mul(a, b, c):

return a * b * c

#
via script

relay_graph
= torch_tvm.to_relay(mul, inputs)

将TVM集成到PyTorch上的更多相关文章

  1. 将TVM集成到PyTorch

    将TVM集成到PyTorch 随着TVM不断展示出对深度学习执行效率的改进,很明显PyTorch将从直接利用编译器堆栈中受益.PyTorch的主要宗旨是提供无缝且强大的集成,而这不会妨碍用户.PyTo ...

  2. 如何在TVM上集成Codegen(上)

    如何在TVM上集成Codegen(上) 许多常用的深度学习内核,或者提供DNNL或TensorRT等框架和图形引擎,让用户以某种方式描述他们的模型,从而获得高性能.此外,新兴的深度学习加速器也有自己的 ...

  3. [转载]PyTorch上的contiguous

    [转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...

  4. 在Pytorch上使用稀疏矩阵

    在Pytorch上使用稀疏矩阵 最近在写一个NLP的小项目,用到了Pytorch做神经网络模型.但是众所周知NLP的一个特点就是特征矩阵是稀疏矩阵,当时处理稀疏矩阵用的是scipy.sparse,现在 ...

  5. TVM 优化 ARM GPU 上的移动深度学习

    TVM 优化 ARM GPU 上的移动深度学习 随着深度学习的巨大成功,将深度神经网络部署到移动设备的需求正在迅速增长.与桌面平台上所做的类似,在移动设备中使用 GPU 既有利于推理速度,也有利于能源 ...

  6. TVM在ARM GPU上优化移动深度学习

    TVM在ARM GPU上优化移动深度学习 随着深度学习的巨大成功,将深度神经网络部署到移动设备的需求正在迅速增长.与在台式机平台上所做的类似,在移动设备中使用GPU可以提高推理速度和能源效率.但是,大 ...

  7. 【微服务专题之】.Net6下集成消息队列上-RabbitMQ

    ​ 微信公众号:趣编程ACE关注可了解更多的.NET日常实战开发技巧,如需源码 请公众号后台留言 源码;[如果觉得本公众号对您有帮助,欢迎关注] .Net中RabbitMQ的使用 [微服务专题之].N ...

  8. Liferay7 BPM门户开发之45: 集成Activiti文件上传部署流程BPMN模型

    开发文件上传,部署流程模板. 首先,开发jsp页面,deploy.jsp <%@ include file="/init.jsp" %> <h3>${RET ...

  9. iOS支付宝,微信,银联支付集成封装(上)

    一.集成支付宝支付 支付宝集成官方教程https://docs.open.alipay.com/204/105295/ 支付宝集成官方demo https://docs.open.alipay.com ...

随机推荐

  1. 织梦DedeCMS自定义表单限制IP24小时只能提交多少次

    方法1.打开plus/diy.php,找到一下代码, if(!is_array($diyform)) { showmsg('自定义表单不存在', '-1'); exit(); } 然后再在以下代码后面 ...

  2. 缓冲区溢出分析第09课:MS06-040漏洞研究——深入挖掘

    前言 经过前两次的分析,我们已经对Netapi32.dll文件中所包含的漏洞成功地实现了利用.在系统未打补丁之前,这确实是一个非常严重的漏洞,那么打了补丁之后,这个动态链接库是不是就安全了呢?答案是否 ...

  3. hdu3987 最小割边数

    题意:      是让你求最小割之后问最小割的最少边数是多少,因为最小割不是唯一的,所以存在最小边数的问法.思路:      两个方法,一个是先一遍最大流,然后把割边全都改成流量1,其他的全都改成流量 ...

  4. POJ3160强连通+spfa最长路(不错)

    题意:       给你一个有向图,每个点上有一个权值,可正可负,然后给你一些链接关系,让你找到一个起点,从起点开始走,走过的边可以在走,但是拿过权值的点就不能再拿了,问最多能拿到多少权值? 思路: ...

  5. SSRF(服务端请求伪造)漏洞

    目录 SSRF SSRF漏洞的挖掘 SSRF漏洞利用 SSRF漏洞防御 SSRF SSRF(Server-Side Request Forgery,服务器端请求伪造)漏洞,是一种由攻击者构造请求,由服 ...

  6. pandas(10):数据增删改

    目录 一.对索引进行操作 1 操作索引值df.rename() 二.指定数据替换.修改df.replace() 三.特殊值--缺失值处理 四.新增行列 1 直接赋值添加新列 2 df.assign() ...

  7. Redis(附Win10版本 和可视化工具)

    启动服务端 通过win+r,cmd 运行命令行然后输入如下指令: G: cd software cd G:\software\redis-64.3.0.503 redis-server.exe 这样就 ...

  8. vue-router的几种用法

    1.全局路由守卫 router.beforeEach((to, from, next) => { // ... }) 当一个导航触发时,全局前置守卫按照创建顺序调用.守卫是异步解析执行,此时导航 ...

  9. phpstudy2018 开启目录浏览

    废话不多说直接开始 一.打开 vhosts-ini 配置文件 二.加入以下内容  注意填写自己的网站根目录 <Directory "你自己的网站根目录"> Option ...

  10. ES系列(五):获取单条数据get处理过程实现

    前面讲的都是些比较大的东西,即框架层面的东西.今天咱们来个轻松点的,只讲一个点:如题,get单条记录的es查询实现. 1. get语义说明 get是用于搜索单条es的数据,是根据主键id查询数据方式. ...