Tensorrt的安装、转化、以及推理

发布于:2025-03-28 ⋅ 阅读:(25) ⋅ 点赞:(0)

1、Tensorrt的安装:

        1)下载地址:一般下载GA版本到本地,EA为试用版,下载TAR包,这种安装最简单

TensorRT Download | NVIDIA Developerhttps://developer.nvidia.com/tensorrt/download        2)根据安装指南:Installation Guide :: NVIDIA Deep Learning TensorRT Documentation

根据12345678安装;

其中第4步为添加环境变量:

cat ~/.bashrc

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:*****、TensorRT-${version}/lib

source ~/.bashrc

        3)验证安装是否成功

import tensorrt as trt

print(trt.__version__)

print(trt.__file__)

二、onnx转tensorrt的engine操作

import onnx
import tensorrt as trt

# 加载 ONNX 模型
onnx_model = onnx.load('./model.onnx')
onnx.checker.check_model(onnx_model)

# 创建 TensorRT 日志记录器
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# 创建 TensorRT 构建器
builder = trt.Builder(TRT_LOGGER)

# 创建网络定义
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

# 创建 ONNX 解析器
parser = trt.OnnxParser(network, TRT_LOGGER)

# 解析 ONNX 模型
if not parser.parse(onnx_model.SerializeToString()):
    for error in range(parser.num_errors):
        print(parser.get_error(error))
    raise RuntimeError("Failed to parse ONNX model")

# 设置构建器配置
config = builder.create_builder_config()
# 使用 set_memory_pool_limit 方法设置工作空间大小
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB
config.set_flag(trt.BuilderFlag.FP16)  # 使用 FP16 精度

# 构建序列化的网络
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
    raise RuntimeError("Failed to build serialized TensorRT network")

# 创建 TensorRT 运行时
runtime = trt.Runtime(TRT_LOGGER)

# 反序列化引擎
engine = runtime.deserialize_cuda_engine(serialized_engine)

if engine is None:
    raise RuntimeError("Failed to deserialize TensorRT engine")

# 保存 TensorRT 引擎
with open('model.engine', 'wb') as f:
    f.write(engine.serialize())

三、使用engine的推理代码: