mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 21:56:36 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
1.8 KiB
1.8 KiB
Module: inference/onnx_engine
Purpose
Defines the abstract InferenceEngine base class and the OnnxEngine implementation for running ONNX model inference with GPU acceleration.
Public Interface
InferenceEngine (ABC)
| Method | Signature | Description |
|---|---|---|
__init__ |
(model_path: str, batch_size: int = 1, **kwargs) |
Abstract constructor |
get_input_shape |
() -> Tuple[int, int] |
Returns (height, width) of model input |
get_batch_size |
() -> int |
Returns the batch size |
run |
(input_data: np.ndarray) -> List[np.ndarray] |
Runs inference, returns output tensors |
OnnxEngine (extends InferenceEngine)
| Method | Signature | Description |
|---|---|---|
__init__ |
(model_bytes, batch_size: int = 1, **kwargs) |
Loads ONNX model from bytes, creates InferenceSession with CUDA+CPU providers |
get_input_shape |
() -> Tuple[int, int] |
Returns (height, width) from model input shape |
get_batch_size |
() -> int |
Returns batch size (from model shape or constructor arg) |
run |
(input_data: np.ndarray) -> List[np.ndarray] |
Runs ONNX inference session |
Internal Logic
- Uses ONNX Runtime with
CUDAExecutionProvider(primary) andCPUExecutionProvider(fallback). - Reads model metadata to extract class names from custom metadata map.
- If model input shape has a fixed batch dimension (not -1), overrides the constructor batch_size.
Dependencies
onnxruntime(external) — ONNX inference runtimenumpy(external)abc,typing(stdlib)
Consumers
inference/inference, inference/tensorrt_engine (inherits InferenceEngine), train (imports OnnxEngine)
Data Models
None.
Configuration
None.
External Integrations
- ONNX Runtime GPU execution (CUDA)
Security
None.
Tests
None.