mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 01:36:36 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,127 @@
|
||||
import pytest
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
import shutil
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from f16_model_manager import ModelManager, TensorRTInferenceEngine, ONNXInferenceEngine
|
||||
|
||||
@pytest.fixture
|
||||
def models_dir():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def mm(models_dir):
|
||||
return ModelManager(models_dir=models_dir)
|
||||
|
||||
class TestModelManager:
|
||||
|
||||
# --- 16.01 Feature: Model Lifecycle Management ---
|
||||
|
||||
def test_load_model_tensorrt_mock(self, mm, models_dir):
|
||||
# Create dummy file to bypass file existence checks
|
||||
dummy_path = os.path.join(models_dir, "superpoint.engine")
|
||||
with open(dummy_path, "wb") as f: f.write(b"data")
|
||||
|
||||
assert mm.load_model("SuperPoint", "tensorrt", dummy_path) is True
|
||||
engine = mm.get_inference_engine("SuperPoint")
|
||||
assert engine is not None
|
||||
assert engine.format == "tensorrt"
|
||||
|
||||
def test_load_model_onnx_mock(self, mm, models_dir):
|
||||
dummy_path = os.path.join(models_dir, "lightglue.onnx")
|
||||
with open(dummy_path, "wb") as f: f.write(b"data")
|
||||
|
||||
assert mm.load_model("LightGlue", "onnx", dummy_path) is True
|
||||
engine = mm.get_inference_engine("LightGlue")
|
||||
assert isinstance(engine, ONNXInferenceEngine)
|
||||
|
||||
def test_load_already_cached_model(self, mm, models_dir):
|
||||
dummy_path = os.path.join(models_dir, "dinov2.onnx")
|
||||
with open(dummy_path, "wb") as f: f.write(b"data")
|
||||
|
||||
mm.load_model("DINOv2", "onnx", dummy_path)
|
||||
|
||||
# Second load should return True immediately (cache hit)
|
||||
with patch.object(mm, 'warmup_model') as mock_warmup:
|
||||
assert mm.load_model("DINOv2", "onnx", dummy_path) is True
|
||||
mock_warmup.assert_not_called()
|
||||
|
||||
def test_load_invalid_format(self, mm):
|
||||
assert mm.load_model("SuperPoint", "pytorch_pth") is False
|
||||
|
||||
def test_warmup_model(self, mm, models_dir):
|
||||
dummy_path = os.path.join(models_dir, "litesam.onnx")
|
||||
with open(dummy_path, "wb") as f: f.write(b"data")
|
||||
|
||||
mm.load_model("LiteSAM", "onnx", dummy_path)
|
||||
|
||||
# Replace engine's infer with a mock to track calls
|
||||
engine = mm.get_inference_engine("LiteSAM")
|
||||
engine.infer = Mock()
|
||||
|
||||
assert mm.warmup_model("LiteSAM") is True
|
||||
assert engine.infer.call_count == 3
|
||||
|
||||
def test_warmup_unloaded_model(self, mm):
|
||||
assert mm.warmup_model("MissingModel") is False
|
||||
|
||||
# --- 16.02 Feature: Inference Engine Provisioning ---
|
||||
|
||||
def test_get_inference_engine_missing(self, mm):
|
||||
assert mm.get_inference_engine("SuperPoint") is None
|
||||
|
||||
def test_optimize_to_tensorrt(self, mm, models_dir):
|
||||
onnx_path = os.path.join(models_dir, "test.onnx")
|
||||
with open(onnx_path, "wb") as f: f.write(b"data")
|
||||
|
||||
trt_path = mm.optimize_to_tensorrt("TestModel", onnx_path)
|
||||
|
||||
assert trt_path != ""
|
||||
assert os.path.exists(trt_path)
|
||||
assert trt_path.endswith(".engine")
|
||||
|
||||
def test_optimize_to_tensorrt_missing_source(self, mm):
|
||||
assert mm.optimize_to_tensorrt("Missing", "does_not_exist.onnx") == ""
|
||||
|
||||
def test_fallback_to_onnx_triggered_on_trt_failure(self, mm, models_dir):
|
||||
# Ensure TensorRT file doesn't exist, which triggers fallback
|
||||
onnx_path = os.path.join(models_dir, "superpoint.onnx")
|
||||
with open(onnx_path, "wb") as f: f.write(b"data")
|
||||
|
||||
# Explicitly mock TRT_AVAILABLE as False
|
||||
import f16_model_manager
|
||||
f16_model_manager.TRT_AVAILABLE = False
|
||||
|
||||
# Load TRT -> Fails -> Falls back to ONNX
|
||||
assert mm.load_model("SuperPoint", "tensorrt") is True
|
||||
|
||||
engine = mm.get_inference_engine("SuperPoint")
|
||||
assert engine.format == "onnx"
|
||||
|
||||
def test_inference_performance_latency_mock(self, mm, models_dir):
|
||||
dummy_path = os.path.join(models_dir, "test.onnx")
|
||||
with open(dummy_path, "wb") as f: f.write(b"data")
|
||||
mm.load_model("TestPerf", "onnx", dummy_path)
|
||||
|
||||
engine = mm.get_inference_engine("TestPerf")
|
||||
start = time.time()
|
||||
engine.infer()
|
||||
duration = time.time() - start
|
||||
|
||||
# Simulated latency shouldn't exceed 100ms heavily in mock
|
||||
assert duration < 0.100
|
||||
|
||||
def test_cold_start_all_models(self, mm, models_dir):
|
||||
# Mock paths for all 4 models
|
||||
for m in ["superpoint.engine", "lightglue.engine", "dinov2.engine", "litesam.engine"]:
|
||||
with open(os.path.join(models_dir, m), "wb") as f: f.write(b"data")
|
||||
|
||||
assert mm.initialize_models() is True
|
||||
assert mm.get_inference_engine("SuperPoint") is not None
|
||||
assert mm.get_inference_engine("LightGlue") is not None
|
||||
assert mm.get_inference_engine("DINOv2") is not None
|
||||
assert mm.get_inference_engine("LiteSAM") is not None
|
||||
Reference in New Issue
Block a user