mirror of
https://github.com/azaion/ai-training.git
synced 2026-06-21 21:31:11 +00:00
Add core functionality for API client, CDN management, and data augmentation
- Introduced `ApiClient` for handling API interactions, including file uploads and downloads. - Implemented `CDNManager` for managing CDN operations with AWS S3. - Added `Augmentator` class for image augmentation, including bounding box corrections and transformations. - Created utility functions for annotation conversion and dataset visualization. - Established a new rules file for sound notifications during human input requests. These additions enhance the system's capabilities for data handling and user interaction, laying the groundwork for future features. Simplify autopilot state file to minimal current-step pointer; add execution safety rule to cursor-meta; remove Completed Steps/Key Decisions/Retry Log/Blockers from state template and all references.
This commit is contained in:
@@ -0,0 +1,47 @@
|
||||
import abc
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
import onnxruntime as onnx
|
||||
|
||||
|
||||
class InferenceEngine(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(self, model_path: str, batch_size: int = 1, **kwargs):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_input_shape(self) -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_batch_size(self) -> int:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, input_data: np.ndarray) -> List[np.ndarray]:
|
||||
pass
|
||||
|
||||
|
||||
class OnnxEngine(InferenceEngine):
|
||||
def __init__(self, model_bytes, batch_size: int = 1, **kwargs):
|
||||
self.batch_size = batch_size
|
||||
self.session = onnx.InferenceSession(model_bytes, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
||||
self.model_inputs = self.session.get_inputs()
|
||||
self.input_name = self.model_inputs[0].name
|
||||
self.input_shape = self.model_inputs[0].shape
|
||||
if self.input_shape[0] != -1:
|
||||
self.batch_size = self.input_shape[0]
|
||||
model_meta = self.session.get_modelmeta()
|
||||
print("Metadata:", model_meta.custom_metadata_map)
|
||||
self.class_names = eval(model_meta.custom_metadata_map["names"])
|
||||
pass
|
||||
|
||||
def get_input_shape(self) -> Tuple[int, int]:
|
||||
shape = self.input_shape
|
||||
return shape[2], shape[3]
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
return self.batch_size
|
||||
|
||||
def run(self, input_data: np.ndarray) -> List[np.ndarray]:
|
||||
return self.session.run(None, {self.input_name: input_data})
|
||||
Reference in New Issue
Block a user