mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 15:06:34 +00:00
Refactor constants management to use Pydantic BaseModel for configuration
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
# Module: start_inference
|
||||
|
||||
## Purpose
|
||||
Entry point for running inference on video files using a TensorRT engine. Downloads the encrypted model from the API/CDN, initializes the engine, and processes video.
|
||||
|
||||
## Public Interface
|
||||
|
||||
| Function | Signature | Returns | Description |
|
||||
|----------|-----------|---------|-------------|
|
||||
| `get_engine_filename` | `(device_id=0) -> str \| None` | Engine filename | Generates GPU-specific engine filename (duplicate of TensorRTEngine.get_engine_filename) |
|
||||
|
||||
`__main__` block: Creates ApiClient, downloads encrypted TensorRT model (split big/small), initializes TensorRTEngine, runs Inference on a test video.
|
||||
|
||||
## Internal Logic
|
||||
- **Model download flow**: ApiClient → `load_big_small_resource` → reassembles from local big part + API-downloaded small part → decrypts with model encryption key → raw engine bytes.
|
||||
- **Inference setup**: TensorRTEngine initialized from decrypted bytes, Inference configured with confidence_threshold=0.5, iou_threshold=0.3.
|
||||
- **Video source**: Hardcoded to `tests/ForAI_test.mp4`.
|
||||
- **get_engine_filename()**: Duplicates `TensorRTEngine.get_engine_filename()` — generates `azaion.cc_{major}.{minor}_sm_{sm_count}.engine` based on CUDA device compute capability and SM count.
|
||||
|
||||
## Dependencies
|
||||
- `constants` — config file paths
|
||||
- `api_client` — ApiClient, ApiCredentials for model download
|
||||
- `cdn_manager` — CDNManager, CDNCredentials (imported but CDN managed by api_client)
|
||||
- `inference/inference` — Inference pipeline
|
||||
- `inference/tensorrt_engine` — TensorRTEngine
|
||||
- `security` — model encryption key
|
||||
- `utils` — Dotdict
|
||||
- `pycuda.driver` (external) — CUDA device queries
|
||||
- `yaml` (external)
|
||||
|
||||
## Consumers
|
||||
None (entry point).
|
||||
|
||||
## Data Models
|
||||
None.
|
||||
|
||||
## Configuration
|
||||
- Confidence threshold: 0.5
|
||||
- IoU threshold: 0.3
|
||||
- Video path: `tests/ForAI_test.mp4` (hardcoded)
|
||||
|
||||
## External Integrations
|
||||
- Azaion API + CDN for model download
|
||||
- TensorRT GPU inference
|
||||
- OpenCV video capture and display
|
||||
|
||||
## Security
|
||||
- Model is downloaded encrypted (split big/small) and decrypted locally
|
||||
- Uses hardware-bound and model encryption keys
|
||||
|
||||
## Tests
|
||||
None.
|
||||
Reference in New Issue
Block a user