mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 21:56:36 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
4.1 KiB
4.1 KiB
Component: Inference Engine
Overview
Real-time object detection inference subsystem supporting ONNX Runtime and TensorRT backends. Processes video streams with batched inference, custom NMS, and live visualization.
Pattern: Strategy pattern (InferenceEngine ABC) + pipeline orchestrator Upstream: Core, Security, API & CDN (for model download) Downstream: None (end-user facing — processes video input)
Modules
inference/dto— Detection, Annotation, AnnotationClass data classesinference/onnx_engine— InferenceEngine ABC + OnnxEngine implementationinference/tensorrt_engine— TensorRTEngine implementation with CUDA memory management + ONNX converterinference/inference— Video processing pipeline (preprocess → infer → postprocess → draw)start_inference— Entry point: downloads model, initializes engine, runs on video
Internal Interfaces
InferenceEngine (ABC)
InferenceEngine.__init__(model_path: str, batch_size: int = 1, **kwargs)
InferenceEngine.get_input_shape() -> Tuple[int, int]
InferenceEngine.get_batch_size() -> int
InferenceEngine.run(input_data: np.ndarray) -> List[np.ndarray]
OnnxEngine (extends InferenceEngine)
Constructor takes model_bytes (not path). Uses CUDAExecutionProvider + CPUExecutionProvider.
TensorRTEngine (extends InferenceEngine)
Constructor takes model_bytes: bytes. Additional static methods:
TensorRTEngine.get_gpu_memory_bytes(device_id=0) -> int
TensorRTEngine.get_engine_filename(device_id=0) -> str | None
TensorRTEngine.convert_from_onnx(onnx_model: bytes) -> bytes | None
Inference
Inference(engine: InferenceEngine, confidence_threshold, iou_threshold)
Inference.preprocess(frames: list) -> np.ndarray
Inference.postprocess(batch_frames, batch_timestamps, output) -> list[Annotation]
Inference.process(video: str) -> None
Inference.draw(annotation: Annotation) -> None
Inference.remove_overlapping_detections(detections) -> list[Detection]
Data Access Patterns
- Model bytes loaded by caller (start_inference via ApiClient.load_big_small_resource)
- Video input via cv2.VideoCapture (file path)
- No disk writes during inference
Implementation Details
- Video processing: Every 4th frame processed (25% frame sampling), batched to engine batch size
- Preprocessing: cv2.dnn.blobFromImage (1/255 scale, model input size, BGR→RGB)
- Postprocessing: Raw detections filtered by confidence, coordinates normalized to [0,1], custom NMS applied
- Custom NMS: Pairwise IoU comparison. Keeps higher confidence; ties broken by lower class ID.
- TensorRT: Async CUDA execution (memcpy_htod_async → execute_async_v3 → synchronize → memcpy_dtoh)
- TensorRT shapes: Default 1280×1280 input, 300 max detections, 6 values per detection (x1,y1,x2,y2,conf,cls)
- ONNX conversion: TensorRT builder with 90% GPU memory workspace, FP16 if supported
- Engine filename: GPU-architecture-specific:
azaion.cc_{major}.{minor}_sm_{sm_count}.engine - start_inference flow: ApiClient → load encrypted TensorRT model (big/small split) → decrypt → TensorRTEngine → Inference.process()
Caveats
start_inference.get_engine_filename()duplicatesTensorRTEngine.get_engine_filename()- Video path hardcoded in
start_inference(tests/ForAI_test.mp4) inference/dtohas its own AnnotationClass — duplicated fromdto/annotationClass- cv2.imshow display requires a GUI environment — won't work headless
- TensorRT
batch_sizeattribute used before assignment if engine input shape has dynamic batch — potential NameError
Dependency Graph
graph TD
inference_dto[inference/dto] --> inference_inference[inference/inference]
inference_onnx[inference/onnx_engine] --> inference_inference
inference_onnx --> inference_trt[inference/tensorrt_engine]
inference_trt --> start_inference
inference_inference --> start_inference
constants --> start_inference
api_client --> start_inference
security --> start_inference
Logging Strategy
Print statements for metadata, download progress, timing. cv2.imshow for visual output.