mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 08:56:32 +00:00
[AZ-180] Refactor inference and engine factory for improved model handling
- Updated the autopilot state to reflect the current task status as in progress. - Refactored the inference module to streamline model downloading and conversion processes, replacing the download_model method with a more efficient load_source method. - Introduced asynchronous model building in the inference module to enhance performance during model conversion. - Enhanced the engine factory to include a new method for building and caching models, improving error handling and logging during the upload process. - Added calibration cache handling in the Jetson TensorRT engine for better resource management. Made-with: Cursor
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
import tempfile
|
||||
from loader_http_client cimport LoaderHttpClient, LoadResult
|
||||
|
||||
|
||||
@@ -29,9 +27,28 @@ class EngineFactory:
|
||||
def get_source_filename(self):
|
||||
return None
|
||||
|
||||
def load_source(self, LoaderHttpClient loader_client, str models_dir):
|
||||
cdef LoadResult res
|
||||
filename = self.get_source_filename()
|
||||
if filename is None:
|
||||
return None
|
||||
res = loader_client.load_big_small_resource(filename, models_dir)
|
||||
if res.err is not None:
|
||||
raise Exception(res.err)
|
||||
return res.data
|
||||
|
||||
def build_from_source(self, onnx_bytes, loader_client, models_dir):
|
||||
raise NotImplementedError(f"{type(self).__name__} does not support building from source")
|
||||
|
||||
def build_and_cache(self, bytes source_bytes, LoaderHttpClient loader_client, str models_dir):
|
||||
cdef LoadResult res
|
||||
import constants_inf
|
||||
engine_bytes, engine_filename = self.build_from_source(source_bytes, loader_client, models_dir)
|
||||
res = loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir)
|
||||
if res.err is not None:
|
||||
constants_inf.log(f"Failed to upload converted model: {res.err}")
|
||||
return engine_bytes
|
||||
|
||||
|
||||
class OnnxEngineFactory(EngineFactory):
|
||||
def create(self, model_bytes: bytes):
|
||||
@@ -83,34 +100,7 @@ class JetsonTensorRTEngineFactory(TensorRTEngineFactory):
|
||||
return TensorRTEngine.get_engine_filename("int8")
|
||||
|
||||
def build_from_source(self, onnx_bytes, LoaderHttpClient loader_client, str models_dir):
|
||||
cdef str calib_cache_path
|
||||
from engines.jetson_tensorrt_engine import JetsonTensorRTEngine
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
calib_cache_path = self._download_calib_cache(loader_client, models_dir)
|
||||
try:
|
||||
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, calib_cache_path)
|
||||
return engine_bytes, TensorRTEngine.get_engine_filename("int8")
|
||||
finally:
|
||||
if calib_cache_path is not None:
|
||||
try:
|
||||
os.unlink(calib_cache_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _download_calib_cache(self, LoaderHttpClient loader_client, str models_dir):
|
||||
cdef LoadResult res
|
||||
import constants_inf
|
||||
try:
|
||||
res = loader_client.load_big_small_resource(
|
||||
constants_inf.INT8_CALIB_CACHE_FILE, models_dir
|
||||
)
|
||||
if res.err is not None:
|
||||
constants_inf.log(f"INT8 calibration cache not available: {res.err}")
|
||||
return None
|
||||
fd, path = tempfile.mkstemp(suffix=".cache")
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(res.data)
|
||||
constants_inf.log("INT8 calibration cache downloaded")
|
||||
return path
|
||||
except Exception as e:
|
||||
constants_inf.log(f"INT8 calibration cache download failed: {str(e)}")
|
||||
return None
|
||||
engine_bytes = JetsonTensorRTEngine.convert_from_source(onnx_bytes, loader_client, models_dir)
|
||||
return engine_bytes, TensorRTEngine.get_engine_filename("int8")
|
||||
|
||||
Reference in New Issue
Block a user