矩阵乘法分块
备注
单击 此处 下载完整的示例代码
本教程概述了如何用 TVM 在 VTA 设计上有效地映射矩阵乘法。推荐先学习 简单矩阵乘法 教程。
本教程演示 TVM 调度优化,将大型神经网络算子分解为更小的块,使得可以在有限的硬件加速器资源内实现计算。
RPC 设置
首先对 Pynq 的 FPGA 进行编程,并构建其 RPC runtime。
from __future__ import absolute_import, print_function
import os
import tvm
from tvm import te
import vta
import numpy as np
from tvm import rpc
from tvm.contrib import utils
from vta.testing import simulator
# 从 3rdparty/vta-hw/config/vta_config.json 文件加载 VTA 参数
env = vta.get_env()
# 从 OS 环境中读取 Pynq RPC 主机 IP 地址和端口号
host = os.environ.get("VTA_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_RPC_PORT", "9091"))
# 在 Pynq 上配置比特流和 runtime 系统,匹配 vta_config.json 文件指定的 VTA 配置。
if env.TARGET == "pynq":
# 确保 TVM 是用 RPC=1 编译的
assert tvm.runtime.enabled("rpc")
remote = rpc.connect(host, port)
# 重新配置 JIT runtime
vta.reconfig_runtime(remote)
# 用预编译的 VTA 比特流对 FPGA 进行编程。
# 可以通过传递比特流文件的路径而非 None,用自定义比特流对 FPGA 进行编程。
vta.program_fpga(remote, bitstream=None)
# 在模拟模式下,本地托管 RPC 服务器。
elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
计算声明
第一步,描述矩阵乘法计算。
将矩阵乘法定义为全连接层中可以找到的计算,由其 batch size、输入通道和输出通道定义。它们必须是 VTA 张量 shape 的整数倍:分别为 BATCH
、BLOCK_IN
和 BLOCK_OUT
。
在矩阵乘法中添加了额外的算子,这些算子对输出进行移位和裁剪,从而模拟定点卷积,然后进行校正线性激活。下面描述全连接层的 TVM 数据流图:
由于这种计算太大,无法一次全部放入 VTA 的芯片缓冲区。因此,在调度阶段,依靠计算分块策略将计算分解为可管理的块。
# 全连接层维度:1024 x 1024
batch_size = 1
in_channels = 1024
out_channels = 1024
assert batch_size % env.BATCH == 0
assert in_channels % env.BLOCK_IN == 0
assert out_channels % env.BLOCK_OUT == 0
# 推导平铺的输入张量 shape
data_shape = (batch_size // env.BATCH, in_channels // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
weight_shape = (
out_channels // env.BLOCK_OUT,
in_channels // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN,
)
output_shape = (batch_size // env.BATCH, out_channels // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)
num_ops = in_channels * out_channels * batch_size * 2
# Reduction 轴
ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name="ic")
ic_tns = te.reduce_axis((0, env.BLOCK_IN), name="ic_tns")
# 输入占位符张量
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)
# 复制缓冲区
data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")
# 声明矩阵乘法计算
res_gemm = te.compute(
output_shape,
lambda bo, co, bi, ci: te.sum(
data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype)
* weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype),
axis=[ic, ic_tns],
),
name="res_gem",
)
# 为定点归一化添加移位阶段
res_shr = te.compute(output_shape, lambda *i: res_gemm(*i) >> env.INP_WIDTH, name="res_shr")
# 在(0,输入最大值)之间应用裁剪
inp_max = (1 << (env.INP_WIDTH - 1)) - 1
res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), "res_max")
res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), "res_min")
# 返回结果前,对输入数据类型进行类型转换
res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")