mirror of
https://github.com/azaion/gps-denied-desktop.git
synced 2026-04-22 22:06:36 +00:00
initial structure implemented
docs -> _docs
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from .base import ModelManagerBase
|
||||
from .model_manager import ModelManager
|
||||
|
||||
__all__ = ["ModelManagerBase", "ModelManager"]
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Any
|
||||
import numpy as np
|
||||
|
||||
from models.config import ModelConfig
|
||||
|
||||
|
||||
class ModelManagerBase(ABC):
|
||||
@abstractmethod
|
||||
async def load_model(self, config: ModelConfig) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unload_model(self, model_name: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model(self, model_name: str) -> Optional[Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_inference(
|
||||
self, model_name: str, inputs: dict[str, np.ndarray]
|
||||
) -> dict[str, np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def warmup_model(
|
||||
self, model_name: str, iterations: int = 3
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_loaded_models(self) -> list[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_info(self, model_name: str) -> Optional[dict]:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Optional, Any
|
||||
import numpy as np
|
||||
|
||||
from .base import ModelManagerBase
|
||||
from models.config import ModelConfig
|
||||
|
||||
|
||||
class ModelManager(ModelManagerBase):
|
||||
async def load_model(self, config: ModelConfig) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def unload_model(self, model_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_model(self, model_name: str) -> Optional[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_inference(
|
||||
self, model_name: str, inputs: dict[str, np.ndarray]
|
||||
) -> dict[str, np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def warmup_model(
|
||||
self, model_name: str, iterations: int = 3
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_loaded_models(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_model_info(self, model_name: str) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user