TVM 中的 Schedule 原语
备注
单击 此处 下载完整的示例代码
作者:Ziheng Jiang
TVM 是一种用于高效构建内核的领域特定语言。
本教程展示了如何通过 TVM 提供的各种原语来调度计算。
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
计算相同结果的方法众多,然而,不同的方法会导致局部性和性能各异,因此 TVM 要求用户借助 Schedule 执行计算。
Schedule 是一组计算转换,可用于转换程序中的循环计算。
# 声明变量,供之后使用
n = te.var("n")
m = te.var("m")
Schedule 可由算子列表创建,它默认以行优先的方式串行计算张量。
# 声明一个矩阵元素乘法
A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")
s = te.create_schedule([C.op])
# lower 会将计算从定义转换为实际可调用的函数。
# 使用参数 `simple_mode=True` 会返回一个可读的类 C 的语句,这里用它来打印 schedule 结果。
print(tvm.lower(s, [A, B, C], simple_mode=True))
输出结果:
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")}
buffer_map = {A_1: A, B_1: B, C_1: C}
preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_3: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_4: int32], type="auto"), C_1: C_3: Buffer(C_2, float32, [m, n], [stride_2, stride_5: int32], type="auto")} {
for (i: int32, 0, m) {
for (j: int32, 0, n) {
C[((i*stride_2) + (j*stride_5))] = (A[((i*stride) + (j*stride_3))]*B[((i*stride_1) + (j*stride_4))])
}
}
}
一个 Schedule 由多个 Stage 组成,一个 Stage 代表一个操作的 schedule。每个 stage 的调度都有多种方法:
split
split
可根据 factor
将指定 axis 拆分为两个 axis。
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] * 2, name="B")
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
print(tvm.lower(s, [A, B], simple_mode=True))
输出结果:
@main = primfn(A_1: handle, B_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
buffer_map = {A_1: A, B_1: B}
preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto")} {
for (i.outer: int32, 0, floordiv((m + 31), 32)) {
for (i.inner: int32, 0, 32) {
if @tir.likely((((i.outer*32) + i.inner) < m), dtype=bool) {
let cse_var_1: int32 = ((i.outer*32) + i.inner)
B[(cse_var_1*stride_1)] = (A[(cse_var_1*stride)]*2f32)
}
}
}
}
也可用 nparts
来拆分 axis,它拆分 axis 的方式与 factor
相反。
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i], name="B")
s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
print(tvm.lower(s, [A, B], simple_mode=True))
输出结果:
@main = primfn(A_1: handle, B_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
buffer_map = {A_1: A, B_1: B}
preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto")} {
for (i.outer: int32, 0, 32) {
for (i.inner: int32, 0, floordiv((m + 31), 32)) {
if @tir.likely(((i.inner + (i.outer*floordiv((m + 31), 32))) < m), dtype=bool) {
B[((i.inner + (i.outer*floordiv((m + 31), 32)))*stride_1)] = A[((i.inner + (i.outer*floordiv((m + 31), 32)))*stride)]
}
}
}
}