fix tensor rt engine

This commit is contained in:
zxsanny
2025-03-28 14:50:43 +02:00
committed by Alex Bezdieniezhnykh
parent 5b89a21b36
commit 06a23525a6
16 changed files with 272 additions and 94 deletions
+18 -16
View File
@@ -1,46 +1,48 @@
import re
import struct
import subprocess
from pathlib import Path
from typing import List, Tuple
import json
import numpy as np
import tensorrt as trt
import pycuda.driver as cuda
from inference.onnx_engine import InferenceEngine
import pycuda.autoinit # required for automatically initialize CUDA, do not remove.
from onnx_engine import InferenceEngine
class TensorRTEngine(InferenceEngine):
def __init__(self, model_path: str, batch_size: int = 4, **kwargs):
self.model_path = model_path
def __init__(self, model_bytes: bytes, batch_size: int = 4, **kwargs):
self.batch_size = batch_size
try:
logger = trt.Logger(trt.Logger.WARNING)
with open(model_path, 'rb') as f:
metadata_len = int.from_bytes(f.read(4), byteorder='little', signed=True)
metadata_bytes = f.read(metadata_len)
try:
self.metadata = json.loads(metadata_bytes)
print(f"Model metadata: {json.dumps(self.metadata, indent=2)}")
except json.JSONDecodeError:
print(f"Failed to parse metadata: {metadata_bytes}")
self.metadata = {}
engine_data = f.read()
metadata_len = struct.unpack("<I", model_bytes[:4])[0]
try:
self.metadata = json.loads(model_bytes[4:4 + metadata_len])
self.class_names = self.metadata['names']
print(f"Model metadata: {json.dumps(self.metadata, indent=2)}")
except json.JSONDecodeError as err:
print(f"Failed to parse metadata")
return
engine_data = model_bytes[4 + metadata_len:]
runtime = trt.Runtime(logger)
self.engine = runtime.deserialize_cuda_engine(engine_data)
if self.engine is None:
raise RuntimeError(f"Failed to load TensorRT engine from {model_path}")
raise RuntimeError(f"Failed to load TensorRT engine!")
self.context = self.engine.create_execution_context()
# input
self.input_name = self.engine.get_tensor_name(0)
engine_input_shape = self.engine.get_tensor_shape(self.input_name)
if engine_input_shape[0] != -1:
self.batch_size = engine_input_shape[0]
self.input_shape = [
batch_size if engine_input_shape[0] == -1 else engine_input_shape[0],
self.batch_size,
engine_input_shape[1], # Channels (usually fixed at 3 for RGB)
1280 if engine_input_shape[2] == -1 else engine_input_shape[2], # Height
1280 if engine_input_shape[3] == -1 else engine_input_shape[3] # Width