"""AZ-278 — `LightGlueRuntime` AC suite (E-CC-HELPERS / R14 fix). Covers the 7 ACs from `_docs/02_tasks/todo/AZ-278_lightglue_runtime.md`. """ from __future__ import annotations import ast import threading from dataclasses import dataclass from pathlib import Path import numpy as np import pytest from gps_denied_onboard._types.matching import CorrespondenceSet, KeypointSet from gps_denied_onboard.helpers import ( LightGlueConcurrentAccessError, LightGlueRuntime, LightGlueRuntimeError, ) # --------------------------------------------------------------------------- # Test doubles — deterministic stub engines. @dataclass class _DeterministicStubEngine: """Deterministic stub: returns a correspondence per keypoint pair index.""" expected_dim: int = 256 block_event: threading.Event | None = None @property def descriptor_dim(self) -> int: return self.expected_dim def forward(self, features_a: KeypointSet, features_b: KeypointSet) -> CorrespondenceSet: # Optional barrier so test_ac4 can hold the first thread inside forward() # long enough for the second thread to race. if self.block_event is not None: self.block_event.wait() n = min(features_a.keypoints.shape[0], features_b.keypoints.shape[0]) corr = np.hstack( [ features_a.keypoints[:n].astype(np.float64), features_b.keypoints[:n].astype(np.float64), ] ) scores = np.linspace(0.5, 0.95, num=n, dtype=np.float64) return CorrespondenceSet(correspondences=corr, scores=scores) def _make_keypoints(n: int = 5, seed: int = 0, dim: int = 256) -> KeypointSet: rng = np.random.default_rng(seed) keypoints = rng.uniform(0, 1000, size=(n, 2)).astype(np.float32) descriptors = rng.standard_normal((n, dim)).astype(np.float32) return KeypointSet(keypoints=keypoints, descriptors=descriptors) # --------------------------------------------------------------------------- # AC-1: single-pair match returns non-empty correspondences. def test_ac1_single_pair_match() -> None: # Arrange runtime = LightGlueRuntime(_DeterministicStubEngine()) a = _make_keypoints(n=10, seed=1) b = _make_keypoints(n=10, seed=2) # Act result = runtime.match(a, b) # Assert assert isinstance(result, CorrespondenceSet) assert result.correspondences.shape == (10, 4) assert result.scores.shape == (10,) # --------------------------------------------------------------------------- # AC-2: batch of 3 pairs returns 3 ordered results. def test_ac2_batch_match_preserves_order() -> None: # Arrange runtime = LightGlueRuntime(_DeterministicStubEngine()) pairs_a = [_make_keypoints(n=5, seed=i) for i in range(3)] pairs_b = [_make_keypoints(n=5, seed=i + 100) for i in range(3)] # Act results = runtime.match_batch(pairs_a, pairs_b) # Assert assert len(results) == 3 for idx, (pair_a, pair_b, result) in enumerate(zip(pairs_a, pairs_b, results, strict=True)): # Each result's first 2 columns must echo features_a[:n].keypoints for that pair. ( np.testing.assert_array_equal( result.correspondences[:, :2], pair_a.keypoints.astype(np.float64) ), f"batch result {idx} lost input order", ) np.testing.assert_array_equal( result.correspondences[:, 2:], pair_b.keypoints.astype(np.float64) ) # --------------------------------------------------------------------------- # AC-3: descriptor-dim mismatch raises with both dims. def test_ac3_descriptor_dim_mismatch() -> None: # Arrange — engine expects 256, we feed 128. runtime = LightGlueRuntime(_DeterministicStubEngine(expected_dim=256)) a = _make_keypoints(n=5, dim=128) b = _make_keypoints(n=5, dim=128) # Act / Assert with pytest.raises(LightGlueRuntimeError, match=r"256.*128|128.*256"): runtime.match(a, b) # --------------------------------------------------------------------------- # AC-4: concurrent access raises LightGlueConcurrentAccessError in second thread. def test_ac4_concurrent_access_rejected() -> None: # Arrange — block the first call inside forward() so the second can race. barrier = threading.Event() engine = _DeterministicStubEngine(block_event=barrier) runtime = LightGlueRuntime(engine) a = _make_keypoints(n=3, seed=1) b = _make_keypoints(n=3, seed=2) results: list[CorrespondenceSet | Exception] = [] def worker_one() -> None: try: results.append(runtime.match(a, b)) except Exception as exc: results.append(exc) def worker_two() -> None: try: results.append(runtime.match(a, b)) except Exception as exc: results.append(exc) t1 = threading.Thread(target=worker_one) t1.start() # Give thread 1 time to enter forward() and hit the barrier. threading.Event().wait(0.05) t2 = threading.Thread(target=worker_two) t2.start() t2.join(timeout=2.0) # t2 should NOT block — guard raises immediately barrier.set() t1.join(timeout=2.0) # Assert — exactly one success and one LightGlueConcurrentAccessError. assert len(results) == 2 successes = [r for r in results if isinstance(r, CorrespondenceSet)] failures = [r for r in results if isinstance(r, LightGlueConcurrentAccessError)] assert len(successes) == 1, f"expected exactly one success, got results={results!r}" assert len(failures) == 1, f"expected exactly one concurrent-access error, got {results!r}" # --------------------------------------------------------------------------- # AC-5: construction-time guard. def test_ac5_construction_with_none_engine_raises() -> None: # Act / Assert with pytest.raises(LightGlueRuntimeError, match="engine_handle"): LightGlueRuntime(engine_handle=None) # type: ignore[arg-type] # --------------------------------------------------------------------------- # AC-6: no upward imports. def test_ac6_no_upward_imports() -> None: # Arrange module_path = ( Path(__file__).resolve().parents[2] / "src" / "gps_denied_onboard" / "helpers" / "lightglue_runtime.py" ) tree = ast.parse(module_path.read_text()) # Act forbidden: list[str] = [] for node in ast.walk(tree): if isinstance(node, ast.Import): forbidden.extend( alias.name for alias in node.names if "gps_denied_onboard.components" in alias.name ) elif isinstance(node, ast.ImportFrom): if node.module and "gps_denied_onboard.components" in node.module: forbidden.append(node.module) # Assert — R14 structural fix: no components.* imports. assert not forbidden, f"lightglue_runtime must not import components.*: {forbidden}" # --------------------------------------------------------------------------- # AC-7: determinism downstream of the engine. def test_ac7_determinism_byte_equal_outputs() -> None: # Arrange runtime = LightGlueRuntime(_DeterministicStubEngine()) a = _make_keypoints(n=8, seed=42) b = _make_keypoints(n=8, seed=43) # Act r1 = runtime.match(a, b) r2 = runtime.match(a, b) # Assert np.testing.assert_array_equal(r1.correspondences, r2.correspondences) np.testing.assert_array_equal(r1.scores, r2.scores) # --------------------------------------------------------------------------- # Additional guards. def test_construction_with_bad_descriptor_dim_raises() -> None: # Act / Assert with pytest.raises(LightGlueRuntimeError, match="descriptor_dim"): LightGlueRuntime(_DeterministicStubEngine(expected_dim=0)) def test_descriptor_dim_accessor() -> None: # Arrange / Act runtime = LightGlueRuntime(_DeterministicStubEngine(expected_dim=128)) # Assert assert runtime.descriptor_dim() == 128 def test_match_batch_length_mismatch_raises() -> None: # Arrange runtime = LightGlueRuntime(_DeterministicStubEngine()) a_list = [_make_keypoints(n=3, seed=1)] b_list = [_make_keypoints(n=3, seed=2), _make_keypoints(n=3, seed=3)] # Act / Assert with pytest.raises(LightGlueRuntimeError, match="equal length"): runtime.match_batch(a_list, b_list)