mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-23 00:16:35 +00:00
Refactor constants management to use Pydantic BaseModel for configuration
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
# Component: Inference Engine
|
||||
|
||||
## Overview
|
||||
Real-time object detection inference subsystem supporting ONNX Runtime and TensorRT backends. Processes video streams with batched inference, custom NMS, and live visualization.
|
||||
|
||||
**Pattern**: Strategy pattern (InferenceEngine ABC) + pipeline orchestrator
|
||||
**Upstream**: Core, Security, API & CDN (for model download)
|
||||
**Downstream**: None (end-user facing — processes video input)
|
||||
|
||||
## Modules
|
||||
- `inference/dto` — Detection, Annotation, AnnotationClass data classes
|
||||
- `inference/onnx_engine` — InferenceEngine ABC + OnnxEngine implementation
|
||||
- `inference/tensorrt_engine` — TensorRTEngine implementation with CUDA memory management + ONNX converter
|
||||
- `inference/inference` — Video processing pipeline (preprocess → infer → postprocess → draw)
|
||||
- `start_inference` — Entry point: downloads model, initializes engine, runs on video
|
||||
|
||||
## Internal Interfaces
|
||||
|
||||
### InferenceEngine (ABC)
|
||||
```python
|
||||
InferenceEngine.__init__(model_path: str, batch_size: int = 1, **kwargs)
|
||||
InferenceEngine.get_input_shape() -> Tuple[int, int]
|
||||
InferenceEngine.get_batch_size() -> int
|
||||
InferenceEngine.run(input_data: np.ndarray) -> List[np.ndarray]
|
||||
```
|
||||
|
||||
### OnnxEngine (extends InferenceEngine)
|
||||
Constructor takes `model_bytes` (not path). Uses CUDAExecutionProvider + CPUExecutionProvider.
|
||||
|
||||
### TensorRTEngine (extends InferenceEngine)
|
||||
Constructor takes `model_bytes: bytes`. Additional static methods:
|
||||
```python
|
||||
TensorRTEngine.get_gpu_memory_bytes(device_id=0) -> int
|
||||
TensorRTEngine.get_engine_filename(device_id=0) -> str | None
|
||||
TensorRTEngine.convert_from_onnx(onnx_model: bytes) -> bytes | None
|
||||
```
|
||||
|
||||
### Inference
|
||||
```python
|
||||
Inference(engine: InferenceEngine, confidence_threshold, iou_threshold)
|
||||
Inference.preprocess(frames: list) -> np.ndarray
|
||||
Inference.postprocess(batch_frames, batch_timestamps, output) -> list[Annotation]
|
||||
Inference.process(video: str) -> None
|
||||
Inference.draw(annotation: Annotation) -> None
|
||||
Inference.remove_overlapping_detections(detections) -> list[Detection]
|
||||
```
|
||||
|
||||
## Data Access Patterns
|
||||
- Model bytes loaded by caller (start_inference via ApiClient.load_big_small_resource)
|
||||
- Video input via cv2.VideoCapture (file path)
|
||||
- No disk writes during inference
|
||||
|
||||
## Implementation Details
|
||||
- **Video processing**: Every 4th frame processed (25% frame sampling), batched to engine batch size
|
||||
- **Preprocessing**: cv2.dnn.blobFromImage (1/255 scale, model input size, BGR→RGB)
|
||||
- **Postprocessing**: Raw detections filtered by confidence, coordinates normalized to [0,1], custom NMS applied
|
||||
- **Custom NMS**: Pairwise IoU comparison. Keeps higher confidence; ties broken by lower class ID.
|
||||
- **TensorRT**: Async CUDA execution (memcpy_htod_async → execute_async_v3 → synchronize → memcpy_dtoh)
|
||||
- **TensorRT shapes**: Default 1280×1280 input, 300 max detections, 6 values per detection (x1,y1,x2,y2,conf,cls)
|
||||
- **ONNX conversion**: TensorRT builder with 90% GPU memory workspace, FP16 if supported
|
||||
- **Engine filename**: GPU-architecture-specific: `azaion.cc_{major}.{minor}_sm_{sm_count}.engine`
|
||||
- **start_inference flow**: ApiClient → load encrypted TensorRT model (big/small split) → decrypt → TensorRTEngine → Inference.process()
|
||||
|
||||
## Caveats
|
||||
- `start_inference.get_engine_filename()` duplicates `TensorRTEngine.get_engine_filename()`
|
||||
- Video path hardcoded in `start_inference` (`tests/ForAI_test.mp4`)
|
||||
- `inference/dto` has its own AnnotationClass — duplicated from `dto/annotationClass`
|
||||
- cv2.imshow display requires a GUI environment — won't work headless
|
||||
- TensorRT `batch_size` attribute used before assignment if engine input shape has dynamic batch — potential NameError
|
||||
|
||||
## Dependency Graph
|
||||
```mermaid
|
||||
graph TD
|
||||
inference_dto[inference/dto] --> inference_inference[inference/inference]
|
||||
inference_onnx[inference/onnx_engine] --> inference_inference
|
||||
inference_onnx --> inference_trt[inference/tensorrt_engine]
|
||||
inference_trt --> start_inference
|
||||
inference_inference --> start_inference
|
||||
constants --> start_inference
|
||||
api_client --> start_inference
|
||||
security --> start_inference
|
||||
```
|
||||
|
||||
## Logging Strategy
|
||||
Print statements for metadata, download progress, timing. cv2.imshow for visual output.
|
||||
Reference in New Issue
Block a user