Files
gps-denied-onboard/test_f16_model_manager.py
T
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

127 lines
4.8 KiB
Python

import pytest
import os
import time
import tempfile
import shutil
from unittest.mock import Mock, patch
from f16_model_manager import ModelManager, TensorRTInferenceEngine, ONNXInferenceEngine
@pytest.fixture
def models_dir():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
@pytest.fixture
def mm(models_dir):
return ModelManager(models_dir=models_dir)
class TestModelManager:
# --- 16.01 Feature: Model Lifecycle Management ---
def test_load_model_tensorrt_mock(self, mm, models_dir):
# Create dummy file to bypass file existence checks
dummy_path = os.path.join(models_dir, "superpoint.engine")
with open(dummy_path, "wb") as f: f.write(b"data")
assert mm.load_model("SuperPoint", "tensorrt", dummy_path) is True
engine = mm.get_inference_engine("SuperPoint")
assert engine is not None
assert engine.format == "tensorrt"
def test_load_model_onnx_mock(self, mm, models_dir):
dummy_path = os.path.join(models_dir, "lightglue.onnx")
with open(dummy_path, "wb") as f: f.write(b"data")
assert mm.load_model("LightGlue", "onnx", dummy_path) is True
engine = mm.get_inference_engine("LightGlue")
assert isinstance(engine, ONNXInferenceEngine)
def test_load_already_cached_model(self, mm, models_dir):
dummy_path = os.path.join(models_dir, "dinov2.onnx")
with open(dummy_path, "wb") as f: f.write(b"data")
mm.load_model("DINOv2", "onnx", dummy_path)
# Second load should return True immediately (cache hit)
with patch.object(mm, 'warmup_model') as mock_warmup:
assert mm.load_model("DINOv2", "onnx", dummy_path) is True
mock_warmup.assert_not_called()
def test_load_invalid_format(self, mm):
assert mm.load_model("SuperPoint", "pytorch_pth") is False
def test_warmup_model(self, mm, models_dir):
dummy_path = os.path.join(models_dir, "litesam.onnx")
with open(dummy_path, "wb") as f: f.write(b"data")
mm.load_model("LiteSAM", "onnx", dummy_path)
# Replace engine's infer with a mock to track calls
engine = mm.get_inference_engine("LiteSAM")
engine.infer = Mock()
assert mm.warmup_model("LiteSAM") is True
assert engine.infer.call_count == 3
def test_warmup_unloaded_model(self, mm):
assert mm.warmup_model("MissingModel") is False
# --- 16.02 Feature: Inference Engine Provisioning ---
def test_get_inference_engine_missing(self, mm):
assert mm.get_inference_engine("SuperPoint") is None
def test_optimize_to_tensorrt(self, mm, models_dir):
onnx_path = os.path.join(models_dir, "test.onnx")
with open(onnx_path, "wb") as f: f.write(b"data")
trt_path = mm.optimize_to_tensorrt("TestModel", onnx_path)
assert trt_path != ""
assert os.path.exists(trt_path)
assert trt_path.endswith(".engine")
def test_optimize_to_tensorrt_missing_source(self, mm):
assert mm.optimize_to_tensorrt("Missing", "does_not_exist.onnx") == ""
def test_fallback_to_onnx_triggered_on_trt_failure(self, mm, models_dir):
# Ensure TensorRT file doesn't exist, which triggers fallback
onnx_path = os.path.join(models_dir, "superpoint.onnx")
with open(onnx_path, "wb") as f: f.write(b"data")
# Explicitly mock TRT_AVAILABLE as False
import f16_model_manager
f16_model_manager.TRT_AVAILABLE = False
# Load TRT -> Fails -> Falls back to ONNX
assert mm.load_model("SuperPoint", "tensorrt") is True
engine = mm.get_inference_engine("SuperPoint")
assert engine.format == "onnx"
def test_inference_performance_latency_mock(self, mm, models_dir):
dummy_path = os.path.join(models_dir, "test.onnx")
with open(dummy_path, "wb") as f: f.write(b"data")
mm.load_model("TestPerf", "onnx", dummy_path)
engine = mm.get_inference_engine("TestPerf")
start = time.time()
engine.infer()
duration = time.time() - start
# Simulated latency shouldn't exceed 100ms heavily in mock
assert duration < 0.100
def test_cold_start_all_models(self, mm, models_dir):
# Mock paths for all 4 models
for m in ["superpoint.engine", "lightglue.engine", "dinov2.engine", "litesam.engine"]:
with open(os.path.join(models_dir, m), "wb") as f: f.write(b"data")
assert mm.initialize_models() is True
assert mm.get_inference_engine("SuperPoint") is not None
assert mm.get_inference_engine("LightGlue") is not None
assert mm.get_inference_engine("DINOv2") is not None
assert mm.get_inference_engine("LiteSAM") is not None