mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 22:46:36 +00:00
e4ba7bced3
GlobalPlaceRecognition already implements AnyLoc-VLAD-DINOv2 (existing code). This change makes the sprint 1 GPR technology selection explicit: - Expand class docstring with selection rationale vs NetVLAD / SP+LG - Document INT8 quantization as known-broken for ViT on Jetson - Reference design doc §2.3 and stage2 backlog - Add two marker tests asserting 4096-d descriptor + DINOv2 engine name No behavioral change — existing Mock/TRT path unchanged. Ref: docs/superpowers/specs/2026-04-18-oss-stack-tech-audit-design.md §2.3 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
146 lines
5.4 KiB
Python
146 lines
5.4 KiB
Python
"""Tests for Global Place Recognition (F08)."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from gps_denied.core.gpr import GlobalPlaceRecognition
|
|
from gps_denied.core.models import ModelManager
|
|
from gps_denied.schemas.gpr import TileCandidate
|
|
|
|
|
|
@pytest.fixture
|
|
def gpr():
|
|
manager = ModelManager()
|
|
gpr = GlobalPlaceRecognition(manager)
|
|
gpr.load_index("flight_123", "dummy_path.faiss")
|
|
return gpr
|
|
|
|
def test_compute_location_descriptor(gpr):
|
|
img = np.zeros((200, 200, 3), dtype=np.uint8)
|
|
desc = gpr.compute_location_descriptor(img)
|
|
|
|
assert desc.shape == (4096,)
|
|
# Should be L2 normalized
|
|
assert np.isclose(np.linalg.norm(desc), 1.0)
|
|
|
|
def test_retrieve_candidate_tiles(gpr):
|
|
img = np.zeros((200, 200, 3), dtype=np.uint8)
|
|
candidates = gpr.retrieve_candidate_tiles(img, top_k=5)
|
|
|
|
assert len(candidates) == 5
|
|
for c in candidates:
|
|
assert isinstance(c, TileCandidate)
|
|
assert c.similarity_score >= 0.0
|
|
|
|
def test_retrieve_candidate_tiles_for_chunk(gpr):
|
|
imgs = [np.zeros((200, 200, 3), dtype=np.uint8) for _ in range(5)]
|
|
candidates = gpr.retrieve_candidate_tiles_for_chunk(imgs, top_k=3)
|
|
|
|
assert len(candidates) == 3
|
|
# Ensure they are sorted descending (GPR-03)
|
|
assert candidates[0].similarity_score >= candidates[1].similarity_score
|
|
|
|
|
|
# ---------------------------------------------------------------
|
|
# GPR-01: Real Faiss index with file path
|
|
# ---------------------------------------------------------------
|
|
|
|
def test_load_index_missing_file_falls_back(tmp_path):
|
|
"""GPR-01: non-existent index path → numpy fallback, still usable."""
|
|
from gps_denied.core.gpr import GlobalPlaceRecognition
|
|
from gps_denied.core.models import ModelManager
|
|
|
|
g = GlobalPlaceRecognition(ModelManager())
|
|
ok = g.load_index("f1", str(tmp_path / "nonexistent.index"))
|
|
assert ok is True
|
|
assert g._is_loaded is True
|
|
# Should still answer queries
|
|
img = np.zeros((200, 200, 3), dtype=np.uint8)
|
|
cands = g.retrieve_candidate_tiles(img, top_k=3)
|
|
assert len(cands) == 3
|
|
|
|
|
|
def test_load_index_not_loaded_returns_empty():
|
|
"""query_database before load_index → empty list (no crash)."""
|
|
from gps_denied.core.gpr import GlobalPlaceRecognition
|
|
from gps_denied.core.models import ModelManager
|
|
|
|
g = GlobalPlaceRecognition(ModelManager())
|
|
desc = np.random.rand(4096).astype(np.float32)
|
|
matches = g.query_database(desc, top_k=5)
|
|
assert matches == []
|
|
|
|
|
|
# ---------------------------------------------------------------
|
|
# GPR-03: Ranking is deterministic (sorted by similarity)
|
|
# ---------------------------------------------------------------
|
|
|
|
def test_rank_candidates_sorted(gpr):
|
|
"""rank_candidates must return descending similarity order."""
|
|
from gps_denied.schemas import GPSPoint
|
|
from gps_denied.schemas.gpr import TileCandidate
|
|
from gps_denied.schemas.satellite import TileBounds
|
|
|
|
dummy_bounds = TileBounds(
|
|
nw=GPSPoint(lat=49.1, lon=32.0), ne=GPSPoint(lat=49.1, lon=32.1),
|
|
sw=GPSPoint(lat=49.0, lon=32.0), se=GPSPoint(lat=49.0, lon=32.1),
|
|
center=GPSPoint(lat=49.05, lon=32.05), gsd=0.6,
|
|
)
|
|
cands = [
|
|
TileCandidate(
|
|
tile_id="a", gps_center=GPSPoint(lat=49, lon=32),
|
|
bounds=dummy_bounds, similarity_score=0.3, rank=3,
|
|
),
|
|
TileCandidate(
|
|
tile_id="b", gps_center=GPSPoint(lat=49, lon=32),
|
|
bounds=dummy_bounds, similarity_score=0.9, rank=1,
|
|
),
|
|
TileCandidate(
|
|
tile_id="c", gps_center=GPSPoint(lat=49, lon=32),
|
|
bounds=dummy_bounds, similarity_score=0.6, rank=2,
|
|
),
|
|
]
|
|
ranked = gpr.rank_candidates(cands)
|
|
scores = [c.similarity_score for c in ranked]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
|
|
def test_descriptor_is_l2_normalised(gpr):
|
|
"""DINOv2 descriptor returned by compute_location_descriptor is unit-norm."""
|
|
img = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
|
|
desc = gpr.compute_location_descriptor(img)
|
|
assert np.isclose(np.linalg.norm(desc), 1.0, atol=1e-5)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AnyLoc baseline markers — document sprint 1 GPR technology selection.
|
|
# GlobalPlaceRecognition IS the AnyLoc-VLAD-DINOv2 baseline (see class docstring).
|
|
# Ref: docs/superpowers/specs/2026-04-18-oss-stack-tech-audit-design.md §2.3
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_anyloc_baseline_descriptor_dim_is_4096(gpr):
|
|
"""AnyLoc-VLAD-DINOv2 baseline produces 4096-d descriptors (ViT-base + VLAD)."""
|
|
img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
|
desc = gpr.compute_location_descriptor(img)
|
|
assert desc.shape == (4096,), (
|
|
f"AnyLoc-VLAD-DINOv2 expects 4096-d descriptor, got {desc.shape}. "
|
|
"If you changed this, update the baseline reference in "
|
|
"docs/superpowers/specs/2026-04-18-oss-stack-tech-audit-design.md §2.3"
|
|
)
|
|
|
|
|
|
def test_anyloc_baseline_uses_dinov2_engine():
|
|
"""Sprint 1 GPR baseline must call DINOv2 via ModelManager."""
|
|
from unittest.mock import MagicMock
|
|
|
|
from gps_denied.core.gpr import GlobalPlaceRecognition
|
|
|
|
mm = MagicMock()
|
|
mm.get_inference_engine.return_value.infer.return_value = np.ones(4096, dtype=np.float32)
|
|
gpr_local = GlobalPlaceRecognition(mm)
|
|
gpr_local.compute_location_descriptor(np.zeros((224, 224, 3), dtype=np.uint8))
|
|
|
|
# AnyLoc == DINOv2 + VLAD. Engine name must be "DINOv2".
|
|
mm.get_inference_engine.assert_called_with("DINOv2")
|