mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 12:56:35 +00:00
Refactor constants management to use Pydantic BaseModel for configuration
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
# Module: inference/onnx_engine
|
||||
|
||||
## Purpose
|
||||
Defines the abstract `InferenceEngine` base class and the `OnnxEngine` implementation for running ONNX model inference with GPU acceleration.
|
||||
|
||||
## Public Interface
|
||||
|
||||
### InferenceEngine (ABC)
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `__init__` | `(model_path: str, batch_size: int = 1, **kwargs)` | Abstract constructor |
|
||||
| `get_input_shape` | `() -> Tuple[int, int]` | Returns (height, width) of model input |
|
||||
| `get_batch_size` | `() -> int` | Returns the batch size |
|
||||
| `run` | `(input_data: np.ndarray) -> List[np.ndarray]` | Runs inference, returns output tensors |
|
||||
|
||||
### OnnxEngine (extends InferenceEngine)
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `__init__` | `(model_bytes, batch_size: int = 1, **kwargs)` | Loads ONNX model from bytes, creates InferenceSession with CUDA+CPU providers |
|
||||
| `get_input_shape` | `() -> Tuple[int, int]` | Returns (height, width) from model input shape |
|
||||
| `get_batch_size` | `() -> int` | Returns batch size (from model shape or constructor arg) |
|
||||
| `run` | `(input_data: np.ndarray) -> List[np.ndarray]` | Runs ONNX inference session |
|
||||
|
||||
## Internal Logic
|
||||
- Uses ONNX Runtime with `CUDAExecutionProvider` (primary) and `CPUExecutionProvider` (fallback).
|
||||
- Reads model metadata to extract class names from custom metadata map.
|
||||
- If model input shape has a fixed batch dimension (not -1), overrides the constructor batch_size.
|
||||
|
||||
## Dependencies
|
||||
- `onnxruntime` (external) — ONNX inference runtime
|
||||
- `numpy` (external)
|
||||
- `abc`, `typing` (stdlib)
|
||||
|
||||
## Consumers
|
||||
inference/inference, inference/tensorrt_engine (inherits InferenceEngine), train (imports OnnxEngine)
|
||||
|
||||
## Data Models
|
||||
None.
|
||||
|
||||
## Configuration
|
||||
None.
|
||||
|
||||
## External Integrations
|
||||
- ONNX Runtime GPU execution (CUDA)
|
||||
|
||||
## Security
|
||||
None.
|
||||
|
||||
## Tests
|
||||
None.
|
||||
Reference in New Issue
Block a user