本文主要介绍了 TVM 框架中的 IR 架构设计,包括 Relay 层、IRModule 层,编译期和执行期的表示,预测部署的使用方式等。
1|1 TVM 中的 IR 是什么,架构设计上分几层?
解答:TVM 的整体结构图如下:
概念上,分为两层:上层为面向前端组网的 Relay IR, 下层为面向 LLVM 的底层 IR。
但从设计实现上,底层通过 Object 元类实现统一的 AST Node 表示,借助一个 IRModule 贯穿上下层。个人理解,TVM 的 IR 实现上其实只有一层,只是封装后在直观概念上分为上下层。
- IRModule 里持有的是
BaseFunction列表 - 上层
relay::Funtion继承自BaseFunction官方解释:
relay::Function对应于一个 end2end 的模型。可以简单理解为一个支持控制流、递归、以及复杂数据结构的计算图。
- 下层
tir::PrimFunc也继承自BaseFunction官方解释:
tir::PrimFunc包含了一些底层 threading、vector/tensor 的 “指令”。通常为模型 layers 中的一个 Op 执行单元
- 在编译阶段,一个
relay::Function可能会被lower成多个tir::PrimFunc。
1|2 TVM 架构上主要包含了哪些核心模块和概念?
从编译流程上来看,涉及的核心数据结构有两个:
IRModule:包含relay::Function和tir::PrimFunc- 此部分也是 Pass 策略的输入输出单元,即
IRModule→pass→IRModule - 传送门:TVM 的 Relay IR 设计
- 此部分也是 Pass 策略的输入输出单元,即
runtime::Module:经过lowering之后,可执行期的基本单元,包含很多runtime::PackedFunc(可以理解为 KernelFunc
编译时的 Pass 策略主要在IRModule数据结构层面进行,分为两方面:
ruled-base:包括relay/transform和tir/transform- 前者多为上层 “图” 结构上 Pass 优化,比如常量折叠,fusion
- 后者多为下层偏向编译器方面的 Pass 优化,比如 prefetch 注入,unrollLoop
search-based:包括auto-schedule和auto-tvm
在前后端交互上,TVM 将所有的核心数据结构都暴露到了 Python 前端,易用性和灵活性极强:
- 所有的核心对象都可以通过 Python API 直接构造和操作,比如
IRModule - 支持在前端自定义组合 pass 和 transformation
- 通过 TVM 的 API 直接操作 IR,支持 Python 端写 pass
1|3 IRMoule 是什么样的?
- IRModule 通过 IRModuleNode 管理元信息
- 核心成员:
- Functions
- 表示计算的函数单元,如 Conv、log
- Function 内部有通过 params、body 关联 Var
- 概念上,对应与 AST 的 Module
- Global_var
import tvmfrom tvm import relayimport numpy as np\# step 1: modelingm,n = 4, 2x = relay.var("x", shape=(m,n), dtype='float32')out = relay.nn.softmax(x)net = relay.Function(\[x\], out)\# step 2: build and loweringmodule = tvm.IRModule.from\_expr(net) lib = relay.build(module, "llvm") \# step 3: input tensor datactx = tvm.cpu(0)x\_t = tvm.nd.array(np.random.uniform(size=\[m,n\]).astype('float32'), ctx) runtime = tvm.contrib.graph\_runtime.GraphModule(lib\["default"\](ctx))runtime.set\_input("x", x\_t)runtime.run()print(runtime.get\_output(0)) \# print(net.body)'''fn (%x: Tensor\[(4, 2), float32\]) {nn.softmax(%x)}'''\# print(module)'''def @main(%x: Tensor\[(4, 2), float32\]) {nn.softmax(%x)}'''
- Functions
1|4 Relay 的 pass 是如何实现和管理的?
概念上讲,TVM 可以看做是分两层的:Relay 层和 tir 层,通过 IRModule 来贯穿。在 Pass 优化上,TVM 也进行了两层的设计:
- 上层基于 “图” 的优化
这部分很类似 Paddle 的 pass,主要通过对 AST 的分析,应用一些上层的 pass 策略,主要包括:- 常量折叠、DSE、Layout 转换、scaling 因子折叠
- 最后会应用 fuse pass。比如将一个 MobileNet 表示成很多 conv2d-relu 的 “段”
- pass 的定义见
relay/transform
- 下层基于 “target” 的优化
这部分 pass 主要涉及 lowering 到 target 时采取的优化策略,如如何生成高效执行conv2d-relu的代码。主要包括:- Prefetch 语句注入、VectorizeLoop、UnrollLoop、RemoveNoOp
- SkipAssert、ThreadSync、HoistIfThenElse 等
- 此部分 pass 有的可以直接复用底层编译器的 pass,如 LLVM、CUDA C 等编译器。因此 TVM 主要关注和 ML 相关、且底层编译器未考虑到的场景
TVM 的 pass 是通过遍历 AST修改Node来实现(类似 paddle 的动转静),通过TVM_REGISTER_GLOBAL注册和暴露支持的 pass。
对于开发者来讲,TVM 是如何新增一个 Pass 呢?
TVM 官方给出了一个常量折叠 Pass 的文档。由于 TVM 的 IR 比较像 AST,因此 pass 的新增主要包括如下几个步骤:
- 需要一个
AST Traversers用于确定哪些 node 是需要修改。常量折叠 pass实现了
ConstantChecker,通过 map 结构的memo_记录哪些 node 是常量 node。这里只涉及两个 node 的函数重载:ConstantNode 和 TupleNode
需要一个
Expression Mutators用于修改和替换满足条件的 node。在常量折叠 pass 中,只有三种 node 涉及折叠:LetNode、TupleItemGetNode 和 CallNode,因此也需要重载这三个函数即可
TVM 的 pass 设计思想和架构,可以更多的参考Pass Infrastructure文档介绍。整体上借鉴了很多 LLVM 的 pass 设计思想。目标很明确,旨在实现如下效果:可以灵活地排布 Optimization 单元,支持用户随意地进行 pass piplines 定制
- 提供友好地 pass budug 体验
- 避免用户去手动处理 pass 之间的依赖
- 简化开发者新增 pass 的流程,支持在 python 端写 pass
TVM Pass 实现上,可以分为三大类:
- Module-Level Pass
- 利用全局信息进行优化,可以删减
Function,如 DSE pass - 核心 pass 函数是
PackedFunc类型,因此支持 python、C++ 去写 pass
- 利用全局信息进行优化,可以删减
- Funtion-Level Pass
- 对 Module 中的每个
Function进行优化,只有局部信息 - 不允许删减
Function - 如公共子表达式替换、vectorization
- 对 Module 中的每个
- Sequential-Level Pass
- 顺序执行一系列的 pass
FusionPass 的基本原理:
- 会先将 IRModule 转为 Graph
1|5 TVM 中的 auto-tvm 的角色是什么?
上面我们介绍的 TVM 的 pass 都是 rule-based 的,意味着开发者在新增 pass 时,其实是只要匹配什么样的模式,然后替换成什么样的模式。
这导致两个问题:
- pass 的数量会很受限
- pass 都需要预定义后才能支持
auto-tvm 会先定义一些粒度比较小的优化策略,TVM 会启发式组合应用、评估这些策略带来的提升,最后使用最佳的组合策略,以实现 auto。
1|6 Relay 结构是执行期的结构么?
解答:Relay 的解释器(Interpreter)可以执行 relay 的表达式,但不适合生产环境部署时使用。原因是:
- 解释器是通过遍历 AST 来执行程序,遍历过程是很低效的。
- 无法友好支持动态代码。比如动态 schduling、动态 Tensor shape、还有控制流。解释器提供了简单的实现方案,但无法高效地编译和优化
静态的代码优点:graphs 是固定的,方便大刀阔斧地进行优化,比如内存静态分配,最佳的内存复用等。
TVM 也使用了 graph runtime 技术——提供了一种快速执行机制,但仅支持部分 Relay 的 programs
因此,Relay 引入了 Virtual Machine,旨在取得部署、执行 Relay programs 时,性能与灵活性之间的平衡。
从用户的角度,可以通过relay.crete_executor(kind, ctx, target)接口来创建不同的执行器:
kind取值为:graph、vm、debug- 统一实现了
evalutae(expr, *args)接口
前置知识:VM
- 传统的 VM 主要操作部分 scalar 和大量低阶 instructions
- 对于 ML,主要是 Tensor,以及部分的高阶 instructions
- 耗时集中在计算密集型 Op 的调用,如 GEMM 和 Conv
- 设计的核心点是:指令集的选择、指令表示
- op-code 和 data payload
TVM 中的 VM 的指令集的设计:
- 偏向 high-level 的设计,尽量与 Relay 层的 operation 相呼应
- AllocTenor、If、Goto
- 核心的三种 object 对象:
- NDArray、ADT 和 Closure,分别用于表示 Tensor、tuple/list、closure data。
- 栈(Stack)和状态(State)
- 栈帧用于标记当前的函数调用
- 每个函数的寄存器都是在连续空间上申请的
- dispatch loop
- VM 实现了 switch 和 goto
TVM 的 VM compiler 设计:
- 作用:将 Relay 的 IR 编译成字节码序列,即
tvm::relay::Module→tvm::relay::vm::Executable→tvm::relay::vm::Function→tvm::relay::vm::VirtualMachine
TVM 的 VM 对序列化和反序列化的支持:
- Graph Runtime 方案中序列化的结果是:
- 权重参数保存为
.weight文件 - graph 保存为
.json文件 - 计算 kernel 保存为
.so库
- 权重参数保存为
- VM 方案中序列化的结果为:
- Relay 的 object 文件
.o文件 - 计算 kernel 保存为
.so库
- Relay 的 object 文件
1|7 TVM 的 Runtime 模块是什么样的?
解答:先看一个用户侧使用的接口样例:
import tvm\# Example runtime execution program in python, with type annotatedmod: tvm.runtime.Module = tvm.runtime.load\_module("compiled\_artifact.so")arr: tvm.runtime.NDArray = tvm.nd.array(\[1, 2, 3\], ctx=tvm.gpu(0))fun: tvm.runtime.PackedFunc = mod\["addone"\]fun(a)print(a.asnumpy())
Runtime 时期的三大核心概念:
runtime.Module:封装编译 DSO 的核心单元,包含了很多PackedFunc,可以根据name获取函数runtime.PackedFunc:后端生成的函数,对应于 DL 中的 KernelFuncruntime.NDArray:封装了执行期的 Tensor 概念
1|8 TVM 的 target 过程做了什么事情?
TVM在lower到target时,会将 IRModule emit 到后端编译器去in-memory地生成可执行代码。
个人理解,target 的过程涉及到编译,这对框架要求很高,在大多数场景下,这个过程应该是超级轻量级的,速度应该越快越好。
通过本地编译安装和试用 TVM,发现 target 的过程超级快,几乎瞬发返回可执行函数。
1|9 TVM 中编译执行和预测部署是什么样的?
解答:首先需要进行网络的定义:
import tvmimport numpy as npn = 12A = te.placeholder((n,), name="A") \# TensorB = te.compute(A.shape, lambda \*i: A(\*i) + 1.0, name="B") \# TensorC = te.compute(A.shape, lambda \*i: A(\*i) - 1.0, name="C") \# Tensors = te.create\_scheduleC\[B.op, C.op\]) \# scheduleadd\_func = tvm.build(s, \[A, B, C\], "llvm", name="add") \# compile\# prepare datactx = tvm.cpu(0)a\_t = tvm.nd.array(np.random.uniform(size=nn).astype(A.type), ctx)b\_t = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)c\_t = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx)add\_func(a\_t, b\_t, c\_t)
对于预测部署,可以将计算逻辑编译为 DSO:
from tvm.contrib import cc\# serializationadd\_func.save('./add\_kernel.o')cc.create\_shared('./for\_infer.so', \['./add\_kernel.o'\])\# load for inferencem = tvm.runtime.load\_module('./for\_infer.so')add\_func = m\['add'\] # load add kernel funcadd\_func(a\_t, b\_t, c\_t) # infer
对于 model 的序列化和加载的例子:
\# Resnet18 workloadresnet18\_mod, resnet18\_params = relay.testing.resnet.get\_workload(num\_layers=18)\# buildwith relay.build\_config(opt\_level=3):\_, resnet18\_lib, \_ = relay.build\_module.build(resnet18\_mod, "cuda", params=resnet18\_params)\# export libraryfile\_name = "./deploy.so"resnet18\_lib.export\_library(file\_name)\# load it backloaded\_lib = tvm.runtime.load\_module(file\_name)#inferdata = np.random.uniform(-1, 1, size=input\_shape(mod)).astype("float32")ctx = tvm.gpu()gmod = graph\_runtime.GraphModule(loaded\_lib\["default"\](ctx))gmod.set\_input("data", data)gmod.run()out = gmod.get\_output(0).asnumpy()
1|10 TVM 中对训练是如何支持的?
TVM 支持训练包括如下几个核心模块
- 自动微分 auto-diff
TVM 中提供了grads = te.gradient(out, inputs)接口,实现反向梯度的自动求导。但目前仍然是只是一个实现性功能
1|11 TVM 的动态 shape 是如何实现的?
解答:理解 TVM 的动态 shape 实现机制,首先我们先看下:从用户的角度,动态 shape 怎么使用。
import tvmfrom tvm import teimport numpy as np# 组网n, m = te.size_var("n"), te.size_var("m")A = te.placeholder((n,m), name="A")k = te.reduce_axis((0, m), "k")B = te.compute((n,),lambda i:te.sum(A[i,k], axis=k), name="B")# 编译s = te.create_schedule(B.op)net = tvm.build(s, [A, B, n, m])# 执行def run(n, m):ctx = tvm.cpu(0)a = tvm.nd.array(np.random.uniform(size=[n,m]).astype(A.dtype), ctx)b = tvm.nd.array(np.zeros((n,)).astype(A.dtype), ctx)return net(a, b, n, m)run(4, 6)run(10, 16)
TVM 提供了便捷的 debug 机制,可以直接打印查看中间编译的函数代码:
print(str(tvm.lower(s, [A, B])))print(m.get_source())
