TVM 源码阅读PASS — VectorizeLoop
本文地址:https://www.cnblogs.com/wanger-sjtu/p/17501119.html
VectorizeLoop这个PASS就是对标记为ForKind::kVectorized
的For
循环做向量化处理,并对For循环中的语句涉及到的变量,替换为Ramp
,以便于在Codegen的过程中生成相关的向量化运算的指令。
VectorizeLoop这个PASS的入口函数如下,只有在打开enable_vectorize=true
的情况下载才会被启用,否则VectorizeSkipper
会把ForKind::kVectorized
的For
循环替换为普通循环。
Pass VectorizeLoop(bool enable_vectorize) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
if (enable_vectorize) {
n->body = LoopVectorizer()(std::move(n->body));
} else {
n->body = VectorizeSkipper()(std::move(n->body));
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
}
下面就以UT中的几个例子,介绍一下源码实现。
vectorize_loop
dtype = "int64"
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as i:
with ib.for_range(0, 4, kind="vectorize") as j:
A[i*4+j] += tvm.tir.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
上面的这个代码完成的是,向量加法,长度为4n的向量A,对每个元素+1。
# before
for (i, 0, n) {
vectorized (j, 0, 4) {
A[((i*4) + j)] = (A[((i*4) + j)] + 1f)
}
}
# after
for (i, 0, n) {
A[ramp((i*4), 1, 4)] = (A[ramp((i*4), 1, 4)] + x4(1f))
}
可以看到在经过VectorizeLoop
的PASS以后,内层的循环消掉了,替换成为了一个Ramp的向量指令,这个在CPU中会被替换为SIMD指令(neon,AVX等)
PASS流程
在向量化的处理的PASS中是在LoopVectorizer中处理的,处理For循环部分。
class LoopVectorizer : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kVectorized) {
ICHECK(is_zero(op->min));
auto* extent_as_int = op->extent.as<IntImmNode>();
if (!extent_as_int || extent_as_int->value < 1) {
LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
}
return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
}
};
当遇到需要向量化的节点时,首先记录循环变量和范围,这个在后续替换相应的Load和Store操作为Ramp时用到。然后就到了Vectorizer部分,遍历For循环体,修改相应的stmt。
Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(0, 1, var_lanes);
}
在Vectorizer中对不同的PrimExpr
、Stmt
做了重载。这里不逐一介绍,就以上面的向量加计算,介绍一下用到的函数以及流程。
首先看一下这里的上面sch的For的循环内的计算逻辑:
A[((i*4) + j)] = (A[((i*4) + j)] + 1f)
因为TVM中,Stmt的表达可以视为一个DSL的语言,访问的时候也是按照深度优先的策略遍历的AST,这里把上面的计算过程简单表示为一个AST的语法树,然后再分析一下流程中调用的各个函数是如何处理的。
从上面的AST的示意图可以看出来,对于上面的sch,依次访问了BufferStoreNode
、Add
Mul
、BufferLoadNode
等。这里就以这几个Node的处理介绍一下向量化的过程。
所谓向量化的过程就是把这个标记为kVectorized
的标量循环操作映射到向量化的操作,对于上面的例子来说就是把所有关于j
的访问映射为RampNode,以便于后续处理可以正确生成相应的指令。
BufferStoreNode
BufferStoreNode
中有三部分:
- buffer——写入的buffer
- value——待写入的值或者表达式
- indices——写入buffer的坐标
这里的目的就是修改value
和indices
中的内容。
对于indices
,是在这里完成的。最终通过MapHelper
依次访问了indices
的表达式。
auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); };
Array<PrimExpr> indices = op->indices.Map(fmutate);
对于value
则是直接遍历。
PrimExpr value = this->VisitExpr(op->value);
AddNode
对于AddNode
和SubNode
都会走到AddSubVec
这个模板函数。
这个函数里面首先会遍历左右表达式,
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
if (a.dtype().lanes() == 1 && b_ramp) {
return Ramp(fcompute(a, b_ramp->base),
fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
}
if (b.dtype().lanes() == 1 && a_ramp) {
return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
如果遍历之后没有变化,就直接返回了。而对于这里的我们需要计算的是
((i*4) + j)
j
是需要向量化的坐标。i*4
是没有变化的。遍历以后a
没变化,b
变成了T.Ramp(0, 1, 4)
这时候lanes=4
,会走到第一个if
分支,返回的是新构造的RampNode
T.Ramp(i * 4, 1, 4)
其他的分支也类似。比如:
A[i * 4 + j] + T.float32(1)
// --- after ---
A[i * 4:i * 4 + 4] T.float32(1)
这里会把a、b broadcast为一个向量再做计算。
VarNode
对于这里的VarNode判断就比较简单了,如果匹配到的是需要向量化的变量,就返回构造函数中构造的RampNode
,否则就返回。其他的操作,暂时略过。
Var var = GetRef<Var>(op);
if (var.same_as(var_)) {
return ramp_;
}
// ...
else {
return std::move(var);
}
MulNode
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
}
if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
}
}
return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec<Mul>(op);
这里的处理逻辑与Add基本一致。只是在计算RampNode的时候有点区别。
TVM 源码阅读PASS — VectorizeLoop的更多相关文章
- 【原】FMDB源码阅读(二)
[原]FMDB源码阅读(二) 本文转载请注明出处 -- polobymulberry-博客园 1. 前言 上一篇只是简单地过了一下FMDB一个简单例子的基本流程,并没有涉及到FMDB的所有方方面面,比 ...
- Rpc框架dubbo-client(v2.6.3) 源码阅读(二)
接上一篇 dubbo-server 之后,再来看一下 dubbo-client 是如何工作的. dubbo提供者服务示例, 其结构是这样的!dubbo://192.168.11.6:20880/com ...
- caffe中batch norm源码阅读
1. batch norm 输入batch norm层的数据为[N, C, H, W], 该层计算得到均值为C个,方差为C个,输出数据为[N, C, H, W]. <1> 形象点说,均值的 ...
- mxnet源码阅读笔记之include
写在前面 mxnet代码的规范性比Caffe2要好,看起来核心代码量也小很多,但由于对dmlc其它库的依赖太强,代码的独立性并不好.依赖的第三方库包括: cub dlpack dmlc-core go ...
- go 中 select 源码阅读
深入了解下 go 中的 select 前言 1.栗子一 2.栗子二 3.栗子三 看下源码实现 1.不存在 case 2.select 中仅存在一个 case 3.select 中存在两个 case,其 ...
- 【原】FMDB源码阅读(三)
[原]FMDB源码阅读(三) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 FMDB比较优秀的地方就在于对多线程的处理.所以这一篇主要是研究FMDB的多线程处理的实现.而 ...
- 【原】FMDB源码阅读(一)
[原]FMDB源码阅读(一) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 说实话,之前的SDWebImage和AFNetworking这两个组件我还是使用过的,但是对于 ...
- 【原】AFNetworking源码阅读(六)
[原]AFNetworking源码阅读(六) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 这一篇的想讲的,一个就是分析一下AFSecurityPolicy文件,看看AF ...
- 【原】AFNetworking源码阅读(五)
[原]AFNetworking源码阅读(五) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 上一篇中提及到了Multipart Request的构建方法- [AFHTTP ...
- 【原】AFNetworking源码阅读(四)
[原]AFNetworking源码阅读(四) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 上一篇还遗留了很多问题,包括AFURLSessionManagerTaskDe ...
随机推荐
- 常用脚本学习手册——Bat脚本
常用脚本学习手册--Bat脚本 我们在日常工作中常常会遇到一些需要重复进行的工作,又或者我们的项目在转交客户时需要去简化配置过程 这时我们就需要使用到一些自动化部署操作,我们常常会采用脚本来完成这部分 ...
- kubernetes(k8s)部署 Metrics Server 资源
资源使用指标,例如容器 CPU 和内存使用率,可通过 Metrics API 在 Kubernetes 中获得.这些指标可以直接被用户访问,比如使用 kubectl top 命令行,或者被集群中的控制 ...
- python标准模块之subprocess
subprocess --- 子进程管理 源代码: Lib/subprocess.py 写在前面: 感觉也就这俩有用: subprocess.run() subprocess.Popen() w下 ...
- 逍遥自在学C语言 | 位运算符的基础用法
前言 一.人物简介 第一位闪亮登场,有请今后会一直教我们C语言的老师 -- 自在. 第二位上场的是和我们一起学习的小白程序猿 -- 逍遥. 二.构成和表达方式 位运算符是一组用于在二进制数之间进行操作 ...
- 前后端分离 nginx 的配置
前端 nginx # 添加头部信息 proxy_send_timeout 30; # 后端服务器连接超时时间 proxy_read_timeout 30; # 后端服务器数据回传时间 proxy_co ...
- C51笔记-郭天祥-第二章 从点灯大师开始
第2章 Keil软件的使用及流水灯设计 Keil的用法:用Keil建立工程: 工程配置: C51单片机程序软件仿真.单步.全速.断点设置和变量查看等: 用一个完整的C51程序操控LED亮灭: 调用库 ...
- 一文理解TS泛型
当我们在编写 TypeScript 代码时,经常会遇到需要通用(Generic)的情况,这时候,泛型就是我们的好帮手了.在本篇文章中,我们将深入介绍 TypeScript 泛型的概念以及如何使用. 什 ...
- YOLO3论文中文版
文章目录 YOLO3论文中文版 摘要 1.引言 2. 解决方案 2.1 边界框预测 2.2 类预测 2.3 多尺度预测 2.4 特征提取器 2.5 训练 3.我们的做法 4. 失败的尝试 5.这一切意 ...
- [Pytorch框架] PyTorch 中文手册
PyTorch 中文手册 书籍介绍 这是一本开源的书籍,目标是帮助那些希望和使用PyTorch进行深度学习开发和研究的朋友快速入门. 由于本人水平有限,在写此教程的时候参考了一些网上的资料,在这里对他 ...
- Python + 超级鹰 识别图形验证码
前言: 一.下载 1.进入官网:http://www.chaojiying.com/,注册完成后,进行登录 2.点击开发文档,点击Python语言示例 3.进行示例下载 4.解压后的文件 注:关注公众 ...