mirror of
https://github.com/azaion/detections.git
synced 2026-04-23 05:06:33 +00:00
[AZ-180] Refactor inference and engine factory for improved model handling
- Updated the autopilot state to reflect the current task status as in progress. - Refactored the inference module to streamline model downloading and conversion processes, replacing the download_model method with a more efficient load_source method. - Introduced asynchronous model building in the inference module to enhance performance during model conversion. - Enhanced the engine factory to include a new method for building and caching models, improving error handling and logging during the upload process. - Added calibration cache handling in the Jetson TensorRT engine for better resource management. Made-with: Cursor
This commit is contained in:
@@ -4,8 +4,8 @@
|
|||||||
flow: existing-code
|
flow: existing-code
|
||||||
step: 8
|
step: 8
|
||||||
name: New Task
|
name: New Task
|
||||||
status: not_started
|
status: in_progress
|
||||||
sub_step: 0
|
sub_step: 1 — Gather Feature Description
|
||||||
retry_count: 0
|
retry_count: 0
|
||||||
|
|
||||||
## Cycle Notes
|
## Cycle Notes
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from loader_http_client cimport LoaderHttpClient, LoadResult
|
from loader_http_client cimport LoaderHttpClient, LoadResult
|
||||||
|
|
||||||
|
|
||||||
@@ -29,9 +27,28 @@ class EngineFactory:
|
|||||||
def get_source_filename(self):
|
def get_source_filename(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def load_source(self, LoaderHttpClient loader_client, str models_dir):
|
||||||
|
cdef LoadResult res
|
||||||
|
filename = self.get_source_filename()
|
||||||
|
if filename is None:
|
||||||
|
return None
|
||||||
|
res = loader_client.load_big_small_resource(filename, models_dir)
|
||||||
|
if res.err is not None:
|
||||||
|
raise Exception(res.err)
|
||||||
|
return res.data
|
||||||
|
|
||||||
def build_from_source(self, onnx_bytes, loader_client, models_dir):
|
def build_from_source(self, onnx_bytes, loader_client, models_dir):
|
||||||
raise NotImplementedError(f"{type(self).__name__} does not support building from source")
|
raise NotImplementedError(f"{type(self).__name__} does not support building from source")
|
||||||
|
|
||||||
|
def build_and_cache(self, bytes source_bytes, LoaderHttpClient loader_client, str models_dir):
|
||||||
|
cdef LoadResult res
|
||||||
|
import constants_inf
|
||||||
|
engine_bytes, engine_filename = self.build_from_source(source_bytes, loader_client, models_dir)
|
||||||
|
res = loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir)
|
||||||
|
if res.err is not None:
|
||||||
|
constants_inf.log(f"Failed to upload converted model: {res.err}")
|
||||||
|
return engine_bytes
|
||||||
|
|
||||||
|
|
||||||
class OnnxEngineFactory(EngineFactory):
|
class OnnxEngineFactory(EngineFactory):
|
||||||
def create(self, model_bytes: bytes):
|
def create(self, model_bytes: bytes):
|
||||||
@@ -83,34 +100,7 @@ class JetsonTensorRTEngineFactory(TensorRTEngineFactory):
|
|||||||
return TensorRTEngine.get_engine_filename("int8")
|
return TensorRTEngine.get_engine_filename("int8")
|
||||||
|
|
||||||
def build_from_source(self, onnx_bytes, LoaderHttpClient loader_client, str models_dir):
|
def build_from_source(self, onnx_bytes, LoaderHttpClient loader_client, str models_dir):
|
||||||
cdef str calib_cache_path
|
from engines.jetson_tensorrt_engine import JetsonTensorRTEngine
|
||||||
from engines.tensorrt_engine import TensorRTEngine
|
from engines.tensorrt_engine import TensorRTEngine
|
||||||
calib_cache_path = self._download_calib_cache(loader_client, models_dir)
|
engine_bytes = JetsonTensorRTEngine.convert_from_source(onnx_bytes, loader_client, models_dir)
|
||||||
try:
|
|
||||||
engine_bytes = TensorRTEngine.convert_from_source(onnx_bytes, calib_cache_path)
|
|
||||||
return engine_bytes, TensorRTEngine.get_engine_filename("int8")
|
return engine_bytes, TensorRTEngine.get_engine_filename("int8")
|
||||||
finally:
|
|
||||||
if calib_cache_path is not None:
|
|
||||||
try:
|
|
||||||
os.unlink(calib_cache_path)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _download_calib_cache(self, LoaderHttpClient loader_client, str models_dir):
|
|
||||||
cdef LoadResult res
|
|
||||||
import constants_inf
|
|
||||||
try:
|
|
||||||
res = loader_client.load_big_small_resource(
|
|
||||||
constants_inf.INT8_CALIB_CACHE_FILE, models_dir
|
|
||||||
)
|
|
||||||
if res.err is not None:
|
|
||||||
constants_inf.log(f"INT8 calibration cache not available: {res.err}")
|
|
||||||
return None
|
|
||||||
fd, path = tempfile.mkstemp(suffix=".cache")
|
|
||||||
with os.fdopen(fd, "wb") as f:
|
|
||||||
f.write(res.data)
|
|
||||||
constants_inf.log("INT8 calibration cache downloaded")
|
|
||||||
return path
|
|
||||||
except Exception as e:
|
|
||||||
constants_inf.log(f"INT8 calibration cache download failed: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -1,5 +1,39 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
from engines.tensorrt_engine cimport TensorRTEngine
|
from engines.tensorrt_engine cimport TensorRTEngine
|
||||||
|
from loader_http_client cimport LoaderHttpClient, LoadResult
|
||||||
|
|
||||||
|
|
||||||
cdef class JetsonTensorRTEngine(TensorRTEngine):
|
cdef class JetsonTensorRTEngine(TensorRTEngine):
|
||||||
|
@staticmethod
|
||||||
|
def convert_from_source(bytes onnx_model, LoaderHttpClient loader_client, str models_dir):
|
||||||
|
cdef str calib_cache_path
|
||||||
|
calib_cache_path = JetsonTensorRTEngine._download_calib_cache(loader_client, models_dir)
|
||||||
|
try:
|
||||||
|
return TensorRTEngine.convert_from_source(onnx_model, calib_cache_path)
|
||||||
|
finally:
|
||||||
|
if calib_cache_path is not None:
|
||||||
|
try:
|
||||||
|
os.unlink(calib_cache_path)
|
||||||
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _download_calib_cache(LoaderHttpClient loader_client, str models_dir):
|
||||||
|
cdef LoadResult res
|
||||||
|
import constants_inf
|
||||||
|
try:
|
||||||
|
res = loader_client.load_big_small_resource(
|
||||||
|
constants_inf.INT8_CALIB_CACHE_FILE, models_dir
|
||||||
|
)
|
||||||
|
if res.err is not None:
|
||||||
|
constants_inf.log(f"INT8 calibration cache not available: {res.err}")
|
||||||
|
return None
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".cache")
|
||||||
|
with os.fdopen(fd, "wb") as f:
|
||||||
|
f.write(res.data)
|
||||||
|
constants_inf.log("INT8 calibration cache downloaded")
|
||||||
|
return path
|
||||||
|
except Exception as e:
|
||||||
|
constants_inf.log(f"INT8 calibration cache download failed: {str(e)}")
|
||||||
|
return None
|
||||||
|
|||||||
+5
-19
@@ -66,28 +66,14 @@ cdef class Inference:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
cdef bytes download_model(self, str filename):
|
cdef _build_engine_async(self, bytes source_bytes, str models_dir):
|
||||||
models_dir = constants_inf.MODELS_FOLDER
|
|
||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
|
|
||||||
res = self.loader_client.load_big_small_resource(filename, models_dir)
|
|
||||||
if res.err is not None:
|
|
||||||
raise Exception(res.err)
|
|
||||||
return <bytes>res.data
|
|
||||||
|
|
||||||
cdef convert_and_upload_model(self, bytes source_bytes, str models_dir):
|
|
||||||
try:
|
try:
|
||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING)
|
self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING)
|
||||||
engine_bytes, engine_filename = engine_factory.build_from_source(source_bytes, self.loader_client, models_dir)
|
engine_bytes = engine_factory.build_and_cache(source_bytes, self.loader_client, models_dir)
|
||||||
|
|
||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING)
|
|
||||||
res = self.loader_client.upload_big_small_resource(engine_bytes, engine_filename, models_dir)
|
|
||||||
if res.err is not None:
|
|
||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>f"Failed to upload converted model: {res.err}")
|
|
||||||
|
|
||||||
self._converted_model_bytes = engine_bytes
|
self._converted_model_bytes = engine_bytes
|
||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str> str(e))
|
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>str(e))
|
||||||
self._converted_model_bytes = <bytes>None
|
self._converted_model_bytes = <bytes>None
|
||||||
finally:
|
finally:
|
||||||
self.is_building_engine = <bint>False
|
self.is_building_engine = <bint>False
|
||||||
@@ -123,12 +109,12 @@ cdef class Inference:
|
|||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"No engine available and no source to build from")
|
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"No engine available and no source to build from")
|
||||||
return
|
return
|
||||||
|
|
||||||
source_bytes = self.download_model(source_filename)
|
source_bytes = engine_factory.load_source(self.loader_client, models_dir)
|
||||||
|
|
||||||
if engine_factory.has_build_step:
|
if engine_factory.has_build_step:
|
||||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>"Cached engine not found, converting from source")
|
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>"Cached engine not found, converting from source")
|
||||||
self.is_building_engine = <bint>True
|
self.is_building_engine = <bint>True
|
||||||
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, models_dir))
|
thread = Thread(target=self._build_engine_async, args=(source_bytes, models_dir))
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user