前言

如果你有一个训练好的Pytorch模型,需要将他部署到嵌入式设备上(这里仅限NVIDIA生产的AGX系列),并且想利用Tensorrt实现模型优化,利用python脚本执行优化后的模型进行推理。那么可以参考这篇文章的思路。
本文的硬件和软件环境如下:

  • 硬件:NVIDIA Jetson AGX Orin (32GB),环境配置参考链接
  • jetpack: 5.0.2
  • cuda: 11.4
  • cudnn: 8.1.4.50
  • opencv: 4.5.4
  • tensorrt: 8.4.1.5 (C版本库,用于对pytorch模型进行优化加速)
  • torch: 1.11.0
  • torchvision: 0.12.0
  • tensorrt: 8.4.1.5 (python库,用于对优化后的engine文件进行解析和反序列化)

准备工作

首先需要整理好一个尽量干净的模型测试脚本,主要包括五部分内容:

# 数据集加载和初始化
dataset = Dataset(path='...', type='test', agrs={...})
dataloader = DataLoader(dataset, batch_size=1)

# 模型初始化
model = Net(args)

# 模型参数载入
state_dict = torch.load('model.ckpt')
model.load_state_dict(state_dict['model'])

# 测试
Pred, Label = [], []
for i, batch_data in enumurate(dataloader):
	batch_x, batch_y = batch_data
	batch_pred = model(batch_x)
    Pred.append(batch_pred)
    Label.append(batch_y)
    
# 评估
performance = evaluation(Pred, Label)

准备好测试脚本和少部分测试数据后,先保证该脚本可以在当前环境内正常执行并输出正确的结果(和主机上的结果相比不能有太大误差)

ONNX模型生成

要对Pytorch的模型进行Tensorrt加速,首先要生成ONNX模型,该模型存储了静态图结构和所有的参数信息,可以编写 make_onnx.py 脚本实现torch到onnx的转换。

# 初始化模型
model = YourNet().cuda().eval()

# 载入模型参数
state_dict = torch.load('model.ckpt')
model.load_state_dict(state_dict['model'])

# 随机初始化输入(与dataset中数据的尺寸和数量相对应,尽量避免用字典构建模型输入):本例中的模型输入有多个,并且尺寸各不相同
imgs = torch.randn(1, 3, 3, 384, 768).cuda().type(torch.float32)
proj_matrices = torch.randn(1, 2, 4, 4).cuda().type(torch.float32)
depth_values = torch.randn(1, 2).cuda().type(torch.float32)
edge = torch.randn(1, 3, 384, 768).cuda().type(torch.float32)

# 测试下随机输入是否可以跑通model的forward
outputs = model(imgs, proj_matrices, depth_values, edge)

# 输出包括以下内容:本例中的输出也有多个
# outputs_stage1, outputs_stage2, outputs_stage3, depth_min, depth, photometric_confidence, pc_stage1, pc_stage2, pc_stage3

# 导出torch模型为onnx格式
torch.onnx.export(
    model,  # 模型
    (imgs, proj_matrices, depth_values, edge),  # 输入
    'model.onnx',  # onnx文件名称
    verbose=True,  # 打印过程
    input_names=['imgs', 'proj_matrices', 'depth_values', 'edge'], # 输入变量名称
    output_names=[
        'outputs_stage1',
        'outputs_stage2',
        'outputs_stage3',
        'depth_min',
        'depth',
        'photometric_confidence',
        'pc_stage1',
        'pc_stage2',
        'pc_stage3'
    ],  # 输出变量名称
    opset_version=11,  # 默认选 11
    dynamic_axes={
        'imgs': {0: 'batch_size'},
        'proj_matrices': {0: 'batch_size'},
        'edge': {0: 'batch_size'},
        'outputs_stage1': {0: 'batch_size'},
        'outputs_stage2': {0: 'batch_size'},
        'outputs_stage3': {0: 'batch_size'},
        'depth_min': {0: 'batch_size'},
        'depth': {0: 'batch_size'},
        'photometric_confidence': {0: 'batch_size'},
        'pc_stage1': {0: 'batch_size'},
        'pc_stage2': {0: 'batch_size'},
        'pc_stage3': {0: 'batch_size'}
    }  # 动态维度为第一维 batch, 但其实如果推理的时候都是单样本推理的话,这里可以不设置
)

print('done!')

是不是看起来很简单,只需要 torch.onnx.export 一个函数就可以完成转换了,完成后会生成一个后缀为.onnx的文件。
但是痛苦的是,你可能一段时间内都无法正确打印出最后的 done! 。因为对于一些比较成熟的网络结构,onnx原生支持了很多可用的算子,但是如果是你自己魔改的或者完全自己定义的网络结构,很有可能很多操作是没有onnx算子支持的(甚至有些onnx支持的,tensorrt却不支持),所以要对模型进行非常多的修改和调整。

一些基本的修改策略:

以下策略部分引用自腾讯课堂《tensorRT从零起步高性能部署》(侵删):

  1. 模型中尽量杜绝 if 条件判断和 for 循环等操作,目的是为了保证模型的结构完全静态,不因任何输入的变换而改变,模型是固定的(这十分重要!)
  2. 对于任何用到 shapesize返回值的参数时,例如:tensor.view(tensor.size(0), -1)这类操作,避免直接使用 tensor.size 的返回值,而是加上int转换: tensor.view(int(tensor.size(0)), -1),断开节点跟踪
  3. 对于 nn.Upsamplenn.functional.interpolate 函数,使用 scale_factor 指定倍率,而不是使用 size 参数指定大小
  4. 对于 reshapeview操作时,-1的指定请放到 batch 维度。其他维度可以计算出来即可。batch 维度禁止指定为大于-1的明确数字
  5. torch.onnx.export 指定 dynamic_axes 参数,并且只指定 batch 维度,禁止其他动态
  6. 避免使用 inplace 操作,例如 y[…, 0:2] = y[…, 0:2] * 2 - 0.5
  7. 尽量少的出现5个维度,例如ShuffleNet Module,可以考虑合并wh避免出现5维
  8. 尽量把让后处理部分在onnx模型中实现,降低后处理复杂度

除此之外,我在torch转onnx的过程中还遇到了以下这些问题,仅在这里提供一些解决方案

双线性网格采样:

由于onnx不支持 torch.nn.functional.grid_sample() 算子,所以用下面的函数做平替,两者的结果经过测试有一些差别,但是不大,个人觉得在可接受范围内:

def bilinear_grid_sample(im,
                       grid,
                       align_corners=False):
  """Given an input and a flow-field grid, computes the output using input
  values and pixel locations from grid. Supported only bilinear interpolation
  method to sample the input pixels.

  Args:
      im (torch.Tensor): Input feature map, shape (N, C, H, W)
      grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
      align_corners (bool): If set to True, the extrema (-1 and 1) are
          considered as referring to the center points of the input’s
          corner pixels. If set to False, they are instead considered as
          referring to the corner points of the input’s corner pixels,
          making the sampling more resolution agnostic.

  Returns:
      torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
  """
  n, c, h, w = int(im.shape[0]), int(im.shape[1]), int(im.shape[2]), int(im.shape[3]) 
  gn, gh, gw, _ = int(grid.shape[0]), int(grid.shape[1]), int(grid.shape[2]), int(grid.shape[3])
  # assert n == gn

  x = grid[:, :, :, 0]
  y = grid[:, :, :, 1]

  if align_corners:
      x = ((x + 1) / 2) * (w - 1)
      y = ((y + 1) / 2) * (h - 1)
  else:
      x = ((x + 1) * w - 1) / 2
      y = ((y + 1) * h - 1) / 2

  x = x.view(gn, gh*gw)
  y = y.view(gn, gh*gw)

  x0 = torch.floor(x).long().cuda()
  y0 = torch.floor(y).long().cuda()
  x1 = x0 + 1
  y1 = y0 + 1

  wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
  wb = ((x1 - x) * (y - y0)).unsqueeze(1)
  wc = ((x - x0) * (y1 - y)).unsqueeze(1)
  wd = ((x - x0) * (y - y0)).unsqueeze(1)

  # Apply default for grid_sample function zero padding
  # 这里由于Tensorrt不支持Pad操作,所以对原始的pad操作进行了修改,后面也会提到
  # im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
  pad_tensor_1 = torch.zeros((1, 8, 96, 1)).cuda()
  pad_tensor_2 = torch.zeros((1, 8, 1, 194)).cuda()
  im_padded = torch.concat([pad_tensor_1, im, pad_tensor_1], axis=3)
  im_padded = torch.concat([pad_tensor_2, im_padded, pad_tensor_2], axis=2)
  # 两个维度的padding要分两次concat才能实现,需要提前计算好维度

  padded_h = h + 2
  padded_w = w + 2
  # save points positions after padding
  x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

  # Clip coordinates to padded image size
  x0 = torch.where(x0 < 0, torch.tensor(0).cuda(), x0)
  x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1).cuda(), x0)
  x1 = torch.where(x1 < 0, torch.tensor(0).cuda(), x1)
  x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1).cuda(), x1)
  y0 = torch.where(y0 < 0, torch.tensor(0).cuda(), y0)
  y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1).cuda(), y0)
  y1 = torch.where(y1 < 0, torch.tensor(0).cuda(), y1)
  y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1).cuda(), y1)

  im_padded = im_padded.view(n, c, (h+2)*(w+2))


  x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
  x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
  x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
  x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

  Ia = torch.gather(im_padded, 2, x0_y0)
  Ib = torch.gather(im_padded, 2, x0_y1)
  Ic = torch.gather(im_padded, 2, x1_y0)
  Id = torch.gather(im_padded, 2, x1_y1)

  return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

然后在原始的模型中把grid_sample函数替换掉即可:

# out = torch.nn.functional.grid_sample(feature, grid.float().view(-1, num_depth * height, width, 2), mode='bilinear', padding_mode='zeros')
out = bilinear_grid_sample(src_fea, grid.float().view(-1, num_depth * height, width, 2))

Einstein求和

torch.einsum 也是onnx不支持的算子,自己写平替(注意:einsum输入的字符串不同,需要写不同的操作来代替):

# 代替 torch.einsum('bdhn,bdhm->bhnm', query, key)
def einsum_bdhn_bdhm(a, b):
    a = a.unsqueeze(4)
    b = b.unsqueeze(3)
    out = a * b
    out = torch.sum(out, dim=1)
    return out

# 代替 torch.einsum('bhnm,bdhm->bdhn', prob, value)
def einsum_bhnm_bdhm(a, b):
    a = a.unsqueeze(1)
    b = b.unsqueeze(3)
    out = a * b
    out = torch.sum(out, dim=-1)
    return out

求Attention时对图片的切块和重组

对一张图片需要做切块计算局部self-attention的时候,常常会用到 einops 库里面的 rearrange, repeat 函数,但是这些操作onnx没有给予直接支持(后来发现可能不是这里的错误,不过尽量还是用原生的torch操作构建模型),所以需要对切块和重组的操作进行深入理解,之后自己实现:

# 代替 from einops import rearrange, repeat

# B: Batch
# C: 通道数
# D: depth, 图片深度
# H: height, 图片高
# W: weight, 图片宽

# output = repeat(input, 'B C (D d) (H h) (W w) -> (B D H W) (d h w C) V', d=self.patch, h=self.patch, w=self.patch, V=1)

# 需要切块的维度进行切块
output = input.reshape(B, C, int(D//self.patch), self.patch, int(H//self.patch), self.patch, int(W//self.patch), self.patch)
# 重组维度顺序
output = output.permute(0, 2, 4, 6, 3, 5, 7, 1)
# 合并维度
output = output.reshape(B*int(D//self.patch)*int(H//self.patch)*int(W//self.patch), C*self.patch**3)
# 扩充维度
ref_cpy = ref_cpy.unsqueeze(-1)
# (B, C, D, H, W, V) -> (X, C, V)

# 下面类似:
# src_cpy = rearrange(src, 'B C (D d) (H h) (W w) V -> (B D H W) (d h w C) V', d=self.patch, h=self.patch, w=self.patch)
src_cpy = src.reshape(B, C, int(D//self.patch), self.patch, int(H//self.patch), self.patch, int(W//self.patch), self.patch, V)
src_cpy = src_cpy.permute(0, 2, 4, 6, 3, 5, 7, 1, 8)
src_cpy = src_cpy.reshape(B*int(D//self.patch)*int(H//self.patch)*int(W//self.patch), C*self.patch**3, V)

求逆

求逆算子onnx和tensorrt都不支持,没有解决。但是因为我们模型的求逆实际上是对输入的一部分做预操作,不在模型中间起作用,因此我们把它挪到了dataset里面做预处理,避免了在模型层面解决这个问题。

Padding操作

虽然onnx支持Padding操作,但是Tensorrt不支持(指单独的F.pad操作,并非conv层或者poolling层的padding操作),所以以开始可以不考虑对其做修改,但是当onnx转trt时,又会遇到这个问题,所以建议直接修改好。
Padding操作特指 torch.nn.functional.pad 操作,可以用预定义的0填充张量与待padding的张量进行拼接实现padding。

# 举例:
# 原始的 im 维度是 [1, 8, 94, 192]
# 需要在最后两个维度padding成 [1, 8, 96, 194]
# 通常可以用下面的 F.pad 实现
# im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)

# 预定义0张量
pad_tensor_1 = torch.zeros((1, 8, 96, 1)).cuda()
pad_tensor_2 = torch.zeros((1, 8, 1, 194)).cuda()

# 填充第三个维度
im_padded = torch.concat([pad_tensor_1, im, pad_tensor_1], axis=3)
# 填充第四个维度
im_padded = torch.concat([pad_tensor_2, im_padded, pad_tensor_2], axis=2)

ONNX模型测试

当你已经可以成功导出onnx模型之后,为了保险起见,可以写一个测试onnx模型的脚本,主要用来观察onnx模型输出的结果是否和pytorch一致:

import onnxruntime as ort

# 创建onnx session
sess = ort.InferenceSession('model.onnx', None)

# 加载数据集
args = parser.parse_args()
test_dataset = Data_Set(args)

# 获取一个数据样本
sample = test_dataset[0]

# 执行测试(一个样本缺少batch维度,用np.newaxis补充)
output = sess.run([], 
            {'imgs': sample['imgs'][np.newaxis, :],
             'proj_matrices': sample['proj_matrices'][np.newaxis, :],
             'depth_values': sample['depth_values'][np.newaxis, :],
             'edge': sample['edge'][np.newaxis, :]})

如果这里的结果和torch的结果能够对齐,则可以进行后续步骤。

ONNX模型进行Tensorrt优化

这里的原理想要深入了解的话可以看Tensorrt的官方文档,可以简单理解为对模型的结构进行了一定的优化,使其在实际推理的过程中速度有一定的提升。这里我们可以利用Tensorrt的C版本库自己写onnx转trt的脚本,但是对C++编程有一定的要求,还要熟悉相关的API。方便起见,可以直接使用trtexec可执行程序对onnx模型进行自动优化只需一条命令行即可:

trtexec --onnx=model.onnx --saveEngine=model.trt

其实trtexec有很多参数可以设置,可配置性很强,可以参考链接进行学习,这里不做详细介绍,但是提供几个Tips:

  1. 如果没有定义trtexec的快捷方式,可能无法在任意路径下执行该命令,因此需要对其定义快捷方式。首先 find / -name 'trtexec*' 找到系统中 trtexec 所在的位置,通常是 /usr/.../tensorrt/bin/trtexec 。找到之后可以使用vim打开 ~/.bashrc 文件,在最后面添加一行 alias trtexec=/usr/.../tensorrt/bin/trtexec 。然后在命令行执行 source ~/.bashrc 。这样,该命令就可以在任意路径下执行了。
  2. 输出文件的后缀可以为 .trt 也可以为 .engine ,两者使用起来通常没有区别。
  3. --maxBatch--explicitBatch 参数通常都默认被指定为 --maxBatch 的值,其大小默认为1,即两者均不指定的话,模型一次推理只接受一个样本的输入。
  4. 模型优化默认是启用tf32的精度,--noTF32 可以禁用此精度,--fp16 表示启用fp16精度(模型复杂的话误差会很大,基本不可用), --int8 表示启用int8精度(基本不可用)。不设置这些参数的话默认是以tf32的精度进行优化,基本输出不会有太大误差。

看到这里,貌似onnx转trt也很容易,只需要用trtexec工具就可以了,但是依然有一些模型结构上的问题,会在这一步反映出来,导致 Failed!,我这里列举一些我碰到的问题:

Reshape算子产生 segmentation fault

在onnx转trt这一步,通常报错都会提示出在图的某一个节点,这大大方便了我们对错误进行定位,但是依然要依赖于我们对模型的熟悉程度,要能够判断出这个节点对应的哪一行原始代码,才能够做一些尝试和修改。如何对onnx模型进行可视化呢?可以借助这个开源网页工具netron,可以直接打开本地的onnx文件,方便的浏览模型结构和节点,也可以用菜单栏里的 find 操作寻找报错的节点,如下图:
搜索节点Reshape_326,点击后即可自动定位

根据报错节点的位置,我们要推测他出现在Python代码的哪个部位,从而判断错误原因。

Reshape 的维度变化不能一次操作太多,最好逐个维度进行拆分和组合,如下样例所示:

# 获取 src 的初始维度
B, C, D, H, W, V = int(src.shape[0]), int(src.shape[1]), int(src.shape[2]), int(src.shape[3]), int(src.shape[4]), int(src.shape[5])

# 目标 (B, C, D, H, W, V) -> (X, C, V)

# 初始的操作流程
src_cpy = src.reshape(B, C, int(D//self.patch), self.patch, int(H//self.patch), self.patch, int(W//self.patch), self.patch, V)
src_cpy = src_cpy.permute(0, 3, 5, 7, 4, 6, 8, 1, 2)
src_cpy = src_cpy.reshape(B*int(D//self.patch)*int(H//self.patch)*int(W//self.patch), C*self.patch**3, V)

# 改进后的操作流程

# V 这个维度先放到前面
src_cpy = src.permute(0, 1, 5, 2, 3, 4)
# 拆分 D 这个维度
src_cpy = src_cpy.reshape(B, C, V, int(D//self.patch), self.patch, H, W)
# 拆分后的 D 维度放在 B 的后面
src_cpy = src_cpy.permute(0, 3, 4, 5, 6, 1, 2)
# 合并拆分后的 D 和 B, 把 V 放到最后
src_cpy = src_cpy.reshape(B*int(D//self.patch), self.patch, H, W, C, V)

# 拆分 H 维度
src_cpy = src_cpy.reshape(B*int(D//self.patch), self.patch, int(H//self.patch), self.patch, W, C, V)
# 拆分后的 H 维度放在 B*int(D//self.patch) 的后面
src_cpy = src_cpy.permute(0, 2, 1, 3, 4, 5, 6)
# 合并 B 、D 、H 
src_cpy = src_cpy.reshape(B*int(D//self.patch)*int(H//self.patch), self.patch, self.patch, W, C, V)

# 拆分 W 维度
src_cpy = src_cpy.reshape(B*int(D//self.patch)*int(H//self.patch), self.patch, self.patch, int(W//self.patch), self.patch, C, V)
# 拆分后的 W 维度放在 B*int(D//self.patch)*int(H//self.patch) 的后面
src_cpy = src_cpy.permute(0, 3, 1, 2, 4, 5, 6)
# 合并 B 、D 、H 、W, 合并拆分的 3 个 patch 维度
src_cpy = src_cpy.reshape(B*int(D//self.patch)*int(H//self.patch)*int(W//self.patch), C*self.patch**3, V)

上面这个例子其实是前一节提到的 rearrangerepeat 函数的修改版做法,虽然在导出onnx时成功了,但是在转trt时又出现了问题,所以又对其作了修改,所以说做torch转trt的模型优化遇到的问题往往是永无止境的。

无中生有的If节点

当我解决完所有onnx和trt报错之后,模型依然抛出一个问题,如图所示:
无中生有的  节点以及不固定的输出维度

根据报错节点位置,我找到了紧接着输出层的 If 节点,我翻找代码发现,模型在上面的 Add 节点之后就已经结束了,应该直接输出 depth ,不知道中间的一串结构从何而来。最后查阅资料发现,有人提供了一些解决方案。问题出在,最后一个输出层额外添加了一个 squeeze() 的操作:

depth = self.PropagationNet(depth, edge).squeeze(1)

所以我把这个操作用view替换,并且调整到了 PropagationNet 模块内部:

class PropagationNet(nn.Module):
    def __init__(self, base_channels):
        super(PropagationNet, self).__init__()

        self.conv = nn.Sequential(
            ConvBnReLU(3, base_channels),
            ConvBnReLU(base_channels, base_channels),
            ConvBnReLU(base_channels, base_channels),
            nn.Conv2d(base_channels, 1, 3, padding=1, bias=False)
        )

    def forward(self, depth, img):
        x = self.conv(img)
        out = depth + x
        h, w = int(out.shape[2]), int(out.shape[3])
        # 直接在这里对 Add 后的输出用view的方式添加一个维度,避免输出后再使用squeeze
        out = out.view(1, h, w)
        return out

这个问题目前不清楚核心原因是什么,因为网络其他地方的squeeze并没有报错,唯独这里出问题,暂时用这种方式解决了。

理论上到这里,torch模型完成Tensorrt优化已经完成了,并且已经生成了后缀为 .trt 或者 .engine 的trt模型,下面将介绍如何在python脚本中使用它。

使用python加载trt模型

这里提供一个解析trt文件或者engine文件的类,用于加载模型,并且定义了 __call__(inputs) 方法便于直接对输入数据进行测试。

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

class TrtModel(object):
    
    def __init__(self,engine_path,max_batch_size=1,dtype=np.float32):
        
        self.engine_path = engine_path
        self.dtype = dtype
        self.logger = trt.Logger(trt.Logger.ERROR)
        self.runtime = trt.Runtime(self.logger)
        self.engine = self.load_engine(self.runtime, self.engine_path)
        self.max_batch_size = max_batch_size
        self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
        self.context = self.engine.create_execution_context()

    @staticmethod
    def load_engine(trt_runtime, engine_path):
        trt.init_libnvinfer_plugins(None, "")             
        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        engine = trt_runtime.deserialize_cuda_engine(engine_data)
        return engine
    
    def allocate_buffers(self):
        
        inputs = []
        outputs = []
        bindings = []
        stream = cuda.Stream()
        
        for binding in self.engine:
            size = trt.volume(self.engine.get_binding_shape(binding)) * self.max_batch_size
            host_mem = cuda.pagelocked_empty(size, self.dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)
            
            bindings.append(int(device_mem))

            if self.engine.binding_is_input(binding):
                inputs.append(HostDeviceMem(host_mem, device_mem))
            else:
                outputs.append(HostDeviceMem(host_mem, device_mem))
        
        return inputs, outputs, bindings, stream
       
            
    def __call__(self,imgs:np.ndarray,proj_matrices:np.ndarray,depth_values:np.ndarray,edge:np.ndarray,batch_size=1):
        
        imgs = imgs.cpu().numpy()
        proj_matrices = proj_matrices.cpu().numpy()
        depth_values = depth_values.cpu().numpy()
        edge = edge.cpu().numpy()
        
        np.copyto(self.inputs[0].host, imgs.ravel())
        np.copyto(self.inputs[1].host, proj_matrices.ravel())
        np.copyto(self.inputs[2].host, depth_values.ravel())
        np.copyto(self.inputs[3].host, edge.ravel())

        for inp in self.inputs:
            cuda.memcpy_htod_async(inp.device, inp.host, self.stream)
        
        self.context.execute_async(batch_size=batch_size, bindings=self.bindings, stream_handle=self.stream.handle)
        for out in self.outputs:
            cuda.memcpy_dtoh_async(out.host, out.device, self.stream) 
            
        self.stream.synchronize()
        return [out.host.reshape(batch_size,-1) for out in self.outputs]

然后类似pytorch模型的测试脚本一样,写一个用于trt模型的测试脚本:

# trt模型的文件路径
file = '.../model.trt'

# 加载并初始化trt模型
model = TrtModel(file)

# 数据集加载和初始化
dataset = Dataset(path='...', type='test', agrs={...})
dataloader = DataLoader(dataset, batch_size=1)

# 测试
Pred, Label = [], []
for i, batch_data in enumurate(dataloader):
	batch_x, batch_y = batch_data
	batch_pred = model(batch_x)
    Pred.append(batch_pred)
    Label.append(batch_y)
    
# 评估
performance = evaluation(Pred, Label)

这里有几个关键点需要注意:

  1. 测试的时候,输入数据是通过Dataloader进行加载的,所以数据已经被放在了cuda上,但是不符合trt模型从文件读取数据到cpu,再放到cuda上的逻辑,所以在 __call__() 方法的最开始,对输入数据进行了一些类型变换,以及存储区域的转移
# 转移数据到cpu并转为np.float32
imgs = imgs.cpu().numpy()
proj_matrices = proj_matrices.cpu().numpy()
depth_values = depth_values.cpu().numpy()
edge = edge.cpu().numpy()

# 将数据copy到它所分配好的Host memory位置
np.copyto(self.inputs[0].host, imgs.ravel())
np.copyto(self.inputs[1].host, proj_matrices.ravel())
np.copyto(self.inputs[2].host, depth_values.ravel())
np.copyto(self.inputs[3].host, edge.ravel())
  1. __call__() 方法的输出是存在一个list里的(多个输出),但是没有对输出的维度进行处理(输出是以一维数组的形式存在连续的内存空间里的),所以事后还要对每一个输出的维度进行调整。这里需要特别注意的是:输出的顺序依赖于图结构的顺序而不是pytorch模型里写的变量顺序!
# 对于pytorch模型,输出是这样排列的:

outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"], sample_cuda["edge"])
# outputs_stage1, outputs_stage2, outputs_stage3, depth_min, depth, photometric_confidence, pc_stage1, pc_stage2, pc_stage3
outputs_dict = {
    'stage1': {'depth': outputs[0], 'photometric_confidence': outputs[6]},
    'stage2': {'depth': outputs[1], 'photometric_confidence': outputs[7]},
    'stage3': {'depth': outputs[2], 'photometric_confidence': outputs[8]},
    'depth': outputs[4],
    'depth_min': outputs[3],
    'photometric_confidence': outputs[5]
}

# 而对于trt模型,输出是这样排列的:
outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"], sample_cuda["edge"])

outputs_dict = {
    'stage1': {'depth': torch.from_numpy(outputs[1].reshape(1, 96, 192)).to(device), 'photometric_confidence': torch.from_numpy(outputs[2].reshape(1, 96, 192)).to(device)},
    'stage2': {'depth': torch.from_numpy(outputs[3].reshape(1, 96, 192)).to(device), 'photometric_confidence': torch.from_numpy(outputs[4].reshape(1, 96, 192)).to(device)},
    'stage3': {'depth': torch.from_numpy(outputs[5].reshape(1, 96, 192)).to(device), 'photometric_confidence': torch.from_numpy(outputs[7].reshape(1, 96, 192)).to(device)},
    'depth': torch.from_numpy(outputs[8].reshape(1, 384, 768)).to(device),
    'depth_min': torch.from_numpy(outputs[0].reshape(1, 1, 1)).to(device),
    'photometric_confidence': torch.from_numpy(outputs[6].reshape(1, 96, 192)).to(device)
}

到这里,一套完整的Tensorrt加速和python部署流程就结束了,你可以通过对比pytorch模型和trt模型两个测试脚本的结果判断模型是否精确,以及速度是否提升。最终实际部署的时候可以去掉后面 evaluation 的部分,直接将模型的输出进行存储或流转。


本站由 困困鱼 使用 Stellar 创建。