编译 TensorFlow 模型
备注
单击 此处 下载完整的示例代码
本文介绍了如何用 TVM 部署 TensorFlow 模型。
首先安装 TensorFlow Python 模块(可参考 https://www.tensorflow.org/install)。
# 导入 tvm 和 relay
import tvm
from tvm import te
from tvm import relay
# 导入 os 和 numpy
import numpy as np
import os.path
# 导入 TensorFlow
import tensorflow as tf
# 让 TensorFlow 将 GPU 内存限制为实际需要的内存,而非占用所有可用的内存。
# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
# 本教程这样做,对 sphinx-gallery 更友好。
gpus = tf.config.list_physical_devices("GPU")
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
print("tensorflow will use experimental.set_memory_growth(True)")
except RuntimeError as e:
print("experimental.set_memory_growth option is not available: {}".format(e))
try:
tf_compat_v1 = tf.compat.v1
except ImportError:
tf_compat_v1 = tf
# TensorFlow 实用函数
import tvm.relay.testing.tf as tf_testing
# 模型相关文件的基本位置
repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/"
# 测试图像
img_name = "elephant-299.jpg"
image_url = os.path.join(repo_base, img_name)
教程
参考 docs/frontend/tensorflow.md,获取 TensorFlow 中各种模型的更多信息。
model_name = "classify_image_graph_def-with_shapes.pb"
model_url = os.path.join(repo_base, model_name)
# 图像标签图
map_proto = "imagenet_2012_challenge_label_map_proto.pbtxt"
map_proto_url = os.path.join(repo_base, map_proto)
# 可读的标签文本
label_map = "imagenet_synset_to_human_label_map.txt"
label_map_url = os.path.join(repo_base, label_map)
# target 设置
# 用下面这些注释为 cuda 构建
# target = tvm.target.Target("cuda", host="llvm")
# layout = "NCHW"
# dev = tvm.cuda(0)
target = tvm.target.Target("llvm", host="llvm")
layout = None
dev = tvm.cpu(0)
下载所需文件
下载上述列出的文件:
from tvm.contrib.download import download_testdata
img_path = download_testdata(image_url, img_name, module="data")
model_path = download_testdata(model_url, model_name, module=["tf", "InceptionV1"])
map_proto_path = download_testdata(map_proto_url, map_proto, module="data")
label_path = download_testdata(label_map_url, label_map, module="data")
导入模型
从 protobuf 文件创建 TensorFlow 计算图定义。
with tf_compat_v1.gfile.GFile(model_path, "rb") as f:
graph_def = tf_compat_v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name="")
# 调用函数将计算图定义导入默认计算图。
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
# 给计算图添加 shape
with tf_compat_v1.Session() as sess:
graph_def = tf_testing.AddShapesToGraphDef(sess, "softmax")
解码图像
备注
TensorFlow 前端导入不支持 JpegDecode 等预处理操作。 JpegDecode 被绕过(只返回源节点),因此我们只向 TVM 提供解码后的帧。
from PIL import Image
image = Image.open(img_path).resize((299, 299))
x = np.array(image)