mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 18:21:16 +00:00
55ef732b96
- Create components/gpr/faiss_gpr.py with 269 LOC (verbatim copy + module docstring) - Inline numpy fallback kept as specified (Phase 4 VPR-03 owns the split) - Update components/gpr/__init__.py: barrel-export GlobalPlaceRecognition (impl), IGlobalPlaceRecognition (protocol), _faiss, _FAISS_AVAILABLE - Replace core/gpr.py with re-export shim preserving all public names
270 lines
11 KiB
Python
270 lines
11 KiB
Python
"""Faiss-backed GlobalPlaceRecognition with inline numpy fallback.
|
||
|
||
Phase 4 (VPR-03) may split numpy fallback into a sibling module.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, List
|
||
|
||
import numpy as np
|
||
|
||
from gps_denied.core.models import IModelManager
|
||
from gps_denied.schemas import GPSPoint
|
||
from gps_denied.schemas.gpr import DatabaseMatch, TileCandidate
|
||
from gps_denied.schemas.satellite import TileBounds
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Attempt to import Faiss (optional — only available on Jetson or with faiss-cpu installed)
|
||
try:
|
||
import faiss as _faiss # type: ignore
|
||
_FAISS_AVAILABLE = True
|
||
logger.info("Faiss available — real index search enabled")
|
||
except ImportError:
|
||
_faiss = None # type: ignore
|
||
_FAISS_AVAILABLE = False
|
||
logger.info("Faiss not available — using numpy L2 fallback for GPR")
|
||
|
||
|
||
class IGlobalPlaceRecognition(ABC):
|
||
@abstractmethod
|
||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||
pass
|
||
|
||
|
||
class GlobalPlaceRecognition(IGlobalPlaceRecognition):
|
||
"""AnyLoc-VLAD-DINOv2 coarse localisation component — sprint 1 GPR baseline.
|
||
|
||
GPR-01: load_index() tries to open a real Faiss .index file; falls back to
|
||
a NumPy L2 mock when the file is missing or Faiss is not installed.
|
||
GPR-02: Descriptor computed via DINOv2 engine (TRT FP16 on Jetson, Mock on
|
||
dev/CI). INT8 quantization is disabled — broken for ViT on Jetson
|
||
(NVIDIA/TRT#4348, facebookresearch/dinov2#489).
|
||
GPR-03: Candidates ranked by descriptor similarity (L2 → converted to [0,1]).
|
||
|
||
Selected over NetVLAD (deprecated, −2.4% R@1 on MSLS 2024) and SuperPoint+
|
||
LightGlue (unvalidated for cross-view UAV↔satellite gap at sprint 1).
|
||
Stage 2 evaluation: SP+LG+FAISS per _docs/03_backlog/stage2_ideas/.
|
||
Long-term target: EigenPlaces (ICCV 2023) — cleaner ONNX export.
|
||
|
||
Ref: docs/superpowers/specs/2026-04-18-oss-stack-tech-audit-design.md §2.3
|
||
"""
|
||
|
||
_DIM = 4096 # DINOv2 VLAD descriptor dimension
|
||
|
||
def __init__(self, model_manager: IModelManager):
|
||
self.model_manager = model_manager
|
||
|
||
# Index storage — one of: Faiss index OR numpy matrix
|
||
self._faiss_index = None # faiss.IndexFlatIP or similar
|
||
self._np_descriptors: np.ndarray | None = None # (N, DIM) fallback
|
||
self._metadata: Dict[int, dict] = {}
|
||
self._is_loaded = False
|
||
|
||
# ------------------------------------------------------------------
|
||
# GPR-02: Descriptor extraction via DINOv2
|
||
# ------------------------------------------------------------------
|
||
|
||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||
"""Run DINOv2 inference and return an L2-normalised descriptor."""
|
||
engine = self.model_manager.get_inference_engine("DINOv2")
|
||
desc = engine.infer(image)
|
||
norm = np.linalg.norm(desc)
|
||
return desc / max(norm, 1e-12)
|
||
|
||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||
"""Mean-aggregate per-frame DINOv2 descriptors for a chunk."""
|
||
if not chunk_images:
|
||
return np.zeros(self._DIM, dtype=np.float32)
|
||
descs = [self.compute_location_descriptor(img) for img in chunk_images]
|
||
agg = np.mean(descs, axis=0)
|
||
return agg / max(np.linalg.norm(agg), 1e-12)
|
||
|
||
# ------------------------------------------------------------------
|
||
# GPR-01: Index loading
|
||
# ------------------------------------------------------------------
|
||
|
||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||
"""Load a Faiss descriptor index from disk (GPR-01).
|
||
|
||
Falls back to a NumPy random-vector mock when:
|
||
- `index_path` does not exist, OR
|
||
- Faiss is not installed (dev/CI without faiss-cpu).
|
||
"""
|
||
logger.info("Loading GPR index for flight=%s path=%s", flight_id, index_path)
|
||
|
||
# Try real Faiss load ------------------------------------------------
|
||
if _FAISS_AVAILABLE and os.path.isfile(index_path):
|
||
try:
|
||
self._faiss_index = _faiss.read_index(index_path)
|
||
# Load companion metadata JSON if present
|
||
meta_path = os.path.splitext(index_path)[0] + "_meta.json"
|
||
if os.path.isfile(meta_path):
|
||
with open(meta_path) as f:
|
||
raw = json.load(f)
|
||
self._metadata = {int(k): v for k, v in raw.items()}
|
||
# Deserialise GPSPoint / TileBounds from dicts
|
||
for idx, m in self._metadata.items():
|
||
if isinstance(m.get("gps_center"), dict):
|
||
m["gps_center"] = GPSPoint(**m["gps_center"])
|
||
if isinstance(m.get("bounds"), dict):
|
||
bounds_d = m["bounds"]
|
||
for corner in ("nw", "ne", "sw", "se", "center"):
|
||
if isinstance(bounds_d.get(corner), dict):
|
||
bounds_d[corner] = GPSPoint(**bounds_d[corner])
|
||
m["bounds"] = TileBounds(**bounds_d)
|
||
else:
|
||
self._metadata = self._generate_stub_metadata(self._faiss_index.ntotal)
|
||
self._is_loaded = True
|
||
logger.info("Faiss index loaded: %d vectors", self._faiss_index.ntotal)
|
||
return True
|
||
except Exception as exc:
|
||
logger.warning("Faiss load failed (%s) — falling back to numpy mock", exc)
|
||
|
||
# NumPy mock fallback ------------------------------------------------
|
||
logger.info("GPR: using numpy mock index (dev/CI mode)")
|
||
db_size = 1000
|
||
vecs = np.random.rand(db_size, self._DIM).astype(np.float32)
|
||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||
self._np_descriptors = vecs / np.maximum(norms, 1e-12)
|
||
self._metadata = self._generate_stub_metadata(db_size)
|
||
self._is_loaded = True
|
||
return True
|
||
|
||
@staticmethod
|
||
def _generate_stub_metadata(n: int) -> Dict[int, dict]:
|
||
"""Generate placeholder tile metadata for dev/CI mock index."""
|
||
meta: Dict[int, dict] = {}
|
||
for i in range(n):
|
||
meta[i] = {
|
||
"tile_id": f"tile_{i:06d}",
|
||
"gps_center": GPSPoint(lat=49.0 + np.random.rand(), lon=32.0 + np.random.rand()),
|
||
"bounds": TileBounds(
|
||
nw=GPSPoint(lat=49.1, lon=32.0),
|
||
ne=GPSPoint(lat=49.1, lon=32.1),
|
||
sw=GPSPoint(lat=49.0, lon=32.0),
|
||
se=GPSPoint(lat=49.0, lon=32.1),
|
||
center=GPSPoint(lat=49.05, lon=32.05),
|
||
gsd=0.6,
|
||
),
|
||
}
|
||
return meta
|
||
|
||
# ------------------------------------------------------------------
|
||
# GPR-03: Similarity search ranked by descriptor distance
|
||
# ------------------------------------------------------------------
|
||
|
||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||
"""Search the index for the top-k most similar tiles.
|
||
|
||
Uses Faiss when loaded, numpy L2 otherwise.
|
||
Results are sorted by ascending L2 distance (= descending similarity).
|
||
"""
|
||
if not self._is_loaded:
|
||
logger.error("GPR index not loaded — call load_index() first.")
|
||
return []
|
||
|
||
q = descriptor.astype(np.float32).reshape(1, -1)
|
||
|
||
# Faiss path
|
||
if self._faiss_index is not None:
|
||
try:
|
||
distances, indices = self._faiss_index.search(q, top_k)
|
||
results = []
|
||
for dist, idx in zip(distances[0], indices[0]):
|
||
if idx < 0:
|
||
continue
|
||
sim = 1.0 / (1.0 + float(dist))
|
||
meta = self._metadata.get(int(idx), {"tile_id": f"tile_{idx}"})
|
||
results.append(DatabaseMatch(
|
||
index=int(idx),
|
||
tile_id=meta.get("tile_id", str(idx)),
|
||
distance=float(dist),
|
||
similarity_score=sim,
|
||
))
|
||
return results
|
||
except Exception as exc:
|
||
logger.warning("Faiss search failed: %s", exc)
|
||
|
||
# NumPy path
|
||
if self._np_descriptors is None:
|
||
return []
|
||
diff = self._np_descriptors - q # (N, DIM)
|
||
distances = np.sum(diff ** 2, axis=1)
|
||
top_indices = np.argsort(distances)[:top_k]
|
||
|
||
results = []
|
||
for idx in top_indices:
|
||
dist = float(distances[idx])
|
||
sim = 1.0 / (1.0 + dist)
|
||
meta = self._metadata.get(int(idx), {"tile_id": f"tile_{idx}"})
|
||
results.append(DatabaseMatch(
|
||
index=int(idx),
|
||
tile_id=meta.get("tile_id", str(idx)),
|
||
distance=dist,
|
||
similarity_score=sim,
|
||
))
|
||
return results
|
||
|
||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||
"""Sort candidates by descriptor similarity (descending) — GPR-03."""
|
||
return sorted(candidates, key=lambda c: c.similarity_score, reverse=True)
|
||
|
||
def _matches_to_candidates(self, matches: List[DatabaseMatch]) -> List[TileCandidate]:
|
||
candidates = []
|
||
for rank, match in enumerate(matches, 1):
|
||
meta = self._metadata.get(match.index, {})
|
||
gps = meta.get("gps_center", GPSPoint(lat=49.0, lon=32.0))
|
||
bounds = meta.get("bounds", TileBounds(
|
||
nw=GPSPoint(lat=49.1, lon=32.0), ne=GPSPoint(lat=49.1, lon=32.1),
|
||
sw=GPSPoint(lat=49.0, lon=32.0), se=GPSPoint(lat=49.0, lon=32.1),
|
||
center=GPSPoint(lat=49.05, lon=32.05), gsd=0.6,
|
||
))
|
||
candidates.append(TileCandidate(
|
||
tile_id=match.tile_id,
|
||
gps_center=gps,
|
||
bounds=bounds,
|
||
similarity_score=match.similarity_score,
|
||
rank=rank,
|
||
))
|
||
return self.rank_candidates(candidates)
|
||
|
||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]:
|
||
desc = self.compute_location_descriptor(image)
|
||
matches = self.query_database(desc, top_k)
|
||
return self._matches_to_candidates(matches)
|
||
|
||
def retrieve_candidate_tiles_for_chunk(
|
||
self, chunk_images: List[np.ndarray], top_k: int = 5
|
||
) -> List[TileCandidate]:
|
||
desc = self.compute_chunk_descriptor(chunk_images)
|
||
matches = self.query_database(desc, top_k)
|
||
return self._matches_to_candidates(matches)
|