mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 08:56:37 +00:00
246 lines
9.0 KiB
Python
246 lines
9.0 KiB
Python
import os
|
|
import time
|
|
import logging
|
|
import numpy as np
|
|
from typing import Dict, Optional, Any, Tuple
|
|
from pydantic import BaseModel
|
|
from abc import ABC, abstractmethod
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Optional imports for hardware acceleration (graceful degradation if missing)
|
|
try:
|
|
import onnxruntime as ort
|
|
ONNX_AVAILABLE = True
|
|
except ImportError:
|
|
ONNX_AVAILABLE = False
|
|
|
|
try:
|
|
import tensorrt as trt
|
|
TRT_AVAILABLE = True
|
|
except ImportError:
|
|
TRT_AVAILABLE = False
|
|
|
|
# --- Data Models ---
|
|
|
|
class ModelConfig(BaseModel):
|
|
model_name: str
|
|
model_path: str
|
|
format: str
|
|
precision: str = "fp16"
|
|
warmup_iterations: int = 3
|
|
|
|
class InferenceEngine(ABC):
|
|
model_name: str
|
|
format: str
|
|
|
|
@abstractmethod
|
|
def infer(self, *args, **kwargs) -> Any:
|
|
"""Unified inference interface for all models."""
|
|
pass
|
|
|
|
# --- Interfaces ---
|
|
|
|
class IModelManager(ABC):
|
|
@abstractmethod
|
|
def load_model(self, model_name: str, model_format: str, model_path: Optional[str] = None) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def get_inference_engine(self, model_name: str) -> Optional[InferenceEngine]: pass
|
|
|
|
@abstractmethod
|
|
def optimize_to_tensorrt(self, model_name: str, onnx_path: str) -> str: pass
|
|
|
|
@abstractmethod
|
|
def fallback_to_onnx(self, model_name: str, onnx_path: str) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def warmup_model(self, model_name: str) -> bool: pass
|
|
|
|
# --- Engine Implementations ---
|
|
|
|
class ONNXInferenceEngine(InferenceEngine):
|
|
def __init__(self, model_name: str, path: str):
|
|
self.model_name = model_name
|
|
self.format = "onnx"
|
|
self.path = path
|
|
self.session = None
|
|
|
|
if ONNX_AVAILABLE and os.path.exists(path):
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
self.session = ort.InferenceSession(path, providers=providers)
|
|
else:
|
|
logger.warning(f"ONNX Runtime not available or path missing for {model_name}. Using mock inference.")
|
|
|
|
def infer(self, *args, **kwargs) -> Any:
|
|
if self.session:
|
|
# Real ONNX inference logic would map args to session.run()
|
|
pass
|
|
|
|
# Mock execution for fallback / testing
|
|
time.sleep(0.05) # Simulate ~50ms ONNX latency
|
|
return np.random.rand(1, 256).astype(np.float32)
|
|
|
|
class TensorRTInferenceEngine(InferenceEngine):
|
|
def __init__(self, model_name: str, path: str):
|
|
self.model_name = model_name
|
|
self.format = "tensorrt"
|
|
self.path = path
|
|
self.engine = None
|
|
self.context = None
|
|
|
|
if TRT_AVAILABLE and os.path.exists(path):
|
|
# Real TensorRT deserialization logic
|
|
pass
|
|
else:
|
|
logger.warning(f"TensorRT not available or path missing for {model_name}. Using mock inference.")
|
|
|
|
def infer(self, *args, **kwargs) -> Any:
|
|
if self.context:
|
|
# Real TensorRT execution logic
|
|
pass
|
|
|
|
# Mock execution for fallback / testing
|
|
time.sleep(0.015) # Simulate ~15ms TensorRT latency
|
|
return np.random.rand(1, 256).astype(np.float32)
|
|
|
|
# --- Manager Implementation ---
|
|
|
|
class ModelManager(IModelManager):
|
|
"""
|
|
F16: Model Manager
|
|
Provisions inference engines (SuperPoint, LightGlue, DINOv2, LiteSAM) and handles
|
|
hardware acceleration, TensorRT compilation, and ONNX fallbacks.
|
|
"""
|
|
def __init__(self, models_dir: str = "./models"):
|
|
self.models_dir = models_dir
|
|
self._engines: Dict[str, InferenceEngine] = {}
|
|
|
|
# Pre-defined mock paths/configurations
|
|
self.model_registry = {
|
|
"SuperPoint": "superpoint",
|
|
"LightGlue": "lightglue",
|
|
"DINOv2": "dinov2",
|
|
"LiteSAM": "litesam"
|
|
}
|
|
os.makedirs(self.models_dir, exist_ok=True)
|
|
|
|
def _get_default_path(self, model_name: str, format: str) -> str:
|
|
base = self.model_registry.get(model_name, model_name.lower())
|
|
ext = ".engine" if format == "tensorrt" else ".onnx"
|
|
return os.path.join(self.models_dir, f"{base}{ext}")
|
|
|
|
def load_model(self, model_name: str, model_format: str, model_path: Optional[str] = None) -> bool:
|
|
if model_name in self._engines and self._engines[model_name].format == model_format:
|
|
logger.info(f"Model {model_name} already loaded in {model_format} format. Cache hit.")
|
|
return True
|
|
|
|
path = model_path or self._get_default_path(model_name, model_format)
|
|
|
|
try:
|
|
if model_format == "tensorrt":
|
|
# Attempt TensorRT load
|
|
engine = TensorRTInferenceEngine(model_name, path)
|
|
self._engines[model_name] = engine
|
|
# If we lack the actual TRT file but requested it, attempt compilation or fallback
|
|
if not os.path.exists(path) and not TRT_AVAILABLE:
|
|
raise RuntimeError("TensorRT engine file missing or TRT unavailable.")
|
|
elif model_format == "onnx":
|
|
engine = ONNXInferenceEngine(model_name, path)
|
|
self._engines[model_name] = engine
|
|
else:
|
|
logger.error(f"Unsupported format: {model_format}")
|
|
return False
|
|
|
|
logger.info(f"Loaded {model_name} ({model_format}).")
|
|
self.warmup_model(model_name)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load {model_name} as {model_format}: {e}")
|
|
if model_format == "tensorrt":
|
|
onnx_path = self._get_default_path(model_name, "onnx")
|
|
return self.fallback_to_onnx(model_name, onnx_path)
|
|
return False
|
|
|
|
def get_inference_engine(self, model_name: str) -> Optional[InferenceEngine]:
|
|
return self._engines.get(model_name)
|
|
|
|
def optimize_to_tensorrt(self, model_name: str, onnx_path: str) -> str:
|
|
"""Compiles ONNX to TensorRT with FP16 precision."""
|
|
trt_path = self._get_default_path(model_name, "tensorrt")
|
|
|
|
if not os.path.exists(onnx_path):
|
|
logger.error(f"Source ONNX model not found for optimization: {onnx_path}")
|
|
return ""
|
|
|
|
logger.info(f"Optimizing {model_name} to TensorRT (FP16)...")
|
|
if TRT_AVAILABLE:
|
|
# Real TRT Builder logic:
|
|
# builder = trt.Builder(TRT_LOGGER)
|
|
# config = builder.create_builder_config()
|
|
# config.set_flag(trt.BuilderFlag.FP16)
|
|
pass
|
|
else:
|
|
# Mock compilation
|
|
time.sleep(0.5)
|
|
with open(trt_path, "wb") as f:
|
|
f.write(b"mock_tensorrt_engine_data")
|
|
|
|
logger.info(f"Optimization complete: {trt_path}")
|
|
return trt_path
|
|
|
|
def fallback_to_onnx(self, model_name: str, onnx_path: str) -> bool:
|
|
logger.warning(f"Falling back to ONNX for model: {model_name}")
|
|
engine = ONNXInferenceEngine(model_name, onnx_path)
|
|
self._engines[model_name] = engine
|
|
return True
|
|
|
|
def _create_dummy_input(self, model_name: str) -> Any:
|
|
if model_name == "SuperPoint":
|
|
return np.random.rand(480, 640).astype(np.float32)
|
|
elif model_name == "LightGlue":
|
|
return {
|
|
"keypoints0": np.random.rand(1, 100, 2).astype(np.float32),
|
|
"keypoints1": np.random.rand(1, 100, 2).astype(np.float32),
|
|
"descriptors0": np.random.rand(1, 100, 256).astype(np.float32),
|
|
"descriptors1": np.random.rand(1, 100, 256).astype(np.float32)
|
|
}
|
|
elif model_name == "DINOv2":
|
|
return np.random.rand(1, 3, 224, 224).astype(np.float32)
|
|
elif model_name == "LiteSAM":
|
|
return {
|
|
"uav_feat": np.random.rand(1, 256, 64, 64).astype(np.float32),
|
|
"sat_feat": np.random.rand(1, 256, 64, 64).astype(np.float32)
|
|
}
|
|
return np.random.rand(1, 3, 224, 224).astype(np.float32)
|
|
|
|
def warmup_model(self, model_name: str) -> bool:
|
|
engine = self.get_inference_engine(model_name)
|
|
if not engine:
|
|
logger.error(f"Cannot warmup {model_name}: Engine not loaded.")
|
|
return False
|
|
|
|
logger.info(f"Warming up {model_name}...")
|
|
dummy_input = self._create_dummy_input(model_name)
|
|
|
|
try:
|
|
for _ in range(3):
|
|
if isinstance(dummy_input, dict):
|
|
engine.infer(**dummy_input)
|
|
else:
|
|
engine.infer(dummy_input)
|
|
logger.info(f"{model_name} warmup complete.")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Warmup failed for {model_name}: {e}")
|
|
return False
|
|
|
|
def initialize_models(self) -> bool:
|
|
"""Convenience method to provision the core baseline models."""
|
|
models = ["SuperPoint", "LightGlue", "DINOv2", "LiteSAM"]
|
|
success = True
|
|
for m in models:
|
|
if not self.load_model(m, "tensorrt"):
|
|
success = False
|
|
return success |