mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-21 09:01:14 +00:00
feat(01-04): move GPR impl to components/gpr/faiss_gpr.py, shim core/gpr.py
- Create components/gpr/faiss_gpr.py with 269 LOC (verbatim copy + module docstring) - Inline numpy fallback kept as specified (Phase 4 VPR-03 owns the split) - Update components/gpr/__init__.py: barrel-export GlobalPlaceRecognition (impl), IGlobalPlaceRecognition (protocol), _faiss, _FAISS_AVAILABLE - Replace core/gpr.py with re-export shim preserving all public names
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
"""GPR component barrel exports.
|
||||
|
||||
``GlobalPlaceRecognition`` resolves to the Faiss-backed implementation
|
||||
(faiss_gpr.py). The structural Protocol lives in protocol.py and is
|
||||
re-exported as ``IGlobalPlaceRecognition``.
|
||||
"""
|
||||
|
||||
from gps_denied.components.gpr.faiss_gpr import ( # noqa: F401
|
||||
GlobalPlaceRecognition,
|
||||
_faiss,
|
||||
_FAISS_AVAILABLE,
|
||||
)
|
||||
from gps_denied.components.gpr.protocol import ( # noqa: F401
|
||||
IGlobalPlaceRecognition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GlobalPlaceRecognition",
|
||||
"IGlobalPlaceRecognition",
|
||||
"_faiss",
|
||||
"_FAISS_AVAILABLE",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
"""Faiss-backed GlobalPlaceRecognition with inline numpy fallback.
|
||||
|
||||
Phase 4 (VPR-03) may split numpy fallback into a sibling module.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.models import IModelManager
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.gpr import DatabaseMatch, TileCandidate
|
||||
from gps_denied.schemas.satellite import TileBounds
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Attempt to import Faiss (optional — only available on Jetson or with faiss-cpu installed)
|
||||
try:
|
||||
import faiss as _faiss # type: ignore
|
||||
_FAISS_AVAILABLE = True
|
||||
logger.info("Faiss available — real index search enabled")
|
||||
except ImportError:
|
||||
_faiss = None # type: ignore
|
||||
_FAISS_AVAILABLE = False
|
||||
logger.info("Faiss not available — using numpy L2 fallback for GPR")
|
||||
|
||||
|
||||
class IGlobalPlaceRecognition(ABC):
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class GlobalPlaceRecognition(IGlobalPlaceRecognition):
|
||||
"""AnyLoc-VLAD-DINOv2 coarse localisation component — sprint 1 GPR baseline.
|
||||
|
||||
GPR-01: load_index() tries to open a real Faiss .index file; falls back to
|
||||
a NumPy L2 mock when the file is missing or Faiss is not installed.
|
||||
GPR-02: Descriptor computed via DINOv2 engine (TRT FP16 on Jetson, Mock on
|
||||
dev/CI). INT8 quantization is disabled — broken for ViT on Jetson
|
||||
(NVIDIA/TRT#4348, facebookresearch/dinov2#489).
|
||||
GPR-03: Candidates ranked by descriptor similarity (L2 → converted to [0,1]).
|
||||
|
||||
Selected over NetVLAD (deprecated, −2.4% R@1 on MSLS 2024) and SuperPoint+
|
||||
LightGlue (unvalidated for cross-view UAV↔satellite gap at sprint 1).
|
||||
Stage 2 evaluation: SP+LG+FAISS per _docs/03_backlog/stage2_ideas/.
|
||||
Long-term target: EigenPlaces (ICCV 2023) — cleaner ONNX export.
|
||||
|
||||
Ref: docs/superpowers/specs/2026-04-18-oss-stack-tech-audit-design.md §2.3
|
||||
"""
|
||||
|
||||
_DIM = 4096 # DINOv2 VLAD descriptor dimension
|
||||
|
||||
def __init__(self, model_manager: IModelManager):
|
||||
self.model_manager = model_manager
|
||||
|
||||
# Index storage — one of: Faiss index OR numpy matrix
|
||||
self._faiss_index = None # faiss.IndexFlatIP or similar
|
||||
self._np_descriptors: np.ndarray | None = None # (N, DIM) fallback
|
||||
self._metadata: Dict[int, dict] = {}
|
||||
self._is_loaded = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GPR-02: Descriptor extraction via DINOv2
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
"""Run DINOv2 inference and return an L2-normalised descriptor."""
|
||||
engine = self.model_manager.get_inference_engine("DINOv2")
|
||||
desc = engine.infer(image)
|
||||
norm = np.linalg.norm(desc)
|
||||
return desc / max(norm, 1e-12)
|
||||
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||||
"""Mean-aggregate per-frame DINOv2 descriptors for a chunk."""
|
||||
if not chunk_images:
|
||||
return np.zeros(self._DIM, dtype=np.float32)
|
||||
descs = [self.compute_location_descriptor(img) for img in chunk_images]
|
||||
agg = np.mean(descs, axis=0)
|
||||
return agg / max(np.linalg.norm(agg), 1e-12)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GPR-01: Index loading
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||||
"""Load a Faiss descriptor index from disk (GPR-01).
|
||||
|
||||
Falls back to a NumPy random-vector mock when:
|
||||
- `index_path` does not exist, OR
|
||||
- Faiss is not installed (dev/CI without faiss-cpu).
|
||||
"""
|
||||
logger.info("Loading GPR index for flight=%s path=%s", flight_id, index_path)
|
||||
|
||||
# Try real Faiss load ------------------------------------------------
|
||||
if _FAISS_AVAILABLE and os.path.isfile(index_path):
|
||||
try:
|
||||
self._faiss_index = _faiss.read_index(index_path)
|
||||
# Load companion metadata JSON if present
|
||||
meta_path = os.path.splitext(index_path)[0] + "_meta.json"
|
||||
if os.path.isfile(meta_path):
|
||||
with open(meta_path) as f:
|
||||
raw = json.load(f)
|
||||
self._metadata = {int(k): v for k, v in raw.items()}
|
||||
# Deserialise GPSPoint / TileBounds from dicts
|
||||
for idx, m in self._metadata.items():
|
||||
if isinstance(m.get("gps_center"), dict):
|
||||
m["gps_center"] = GPSPoint(**m["gps_center"])
|
||||
if isinstance(m.get("bounds"), dict):
|
||||
bounds_d = m["bounds"]
|
||||
for corner in ("nw", "ne", "sw", "se", "center"):
|
||||
if isinstance(bounds_d.get(corner), dict):
|
||||
bounds_d[corner] = GPSPoint(**bounds_d[corner])
|
||||
m["bounds"] = TileBounds(**bounds_d)
|
||||
else:
|
||||
self._metadata = self._generate_stub_metadata(self._faiss_index.ntotal)
|
||||
self._is_loaded = True
|
||||
logger.info("Faiss index loaded: %d vectors", self._faiss_index.ntotal)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Faiss load failed (%s) — falling back to numpy mock", exc)
|
||||
|
||||
# NumPy mock fallback ------------------------------------------------
|
||||
logger.info("GPR: using numpy mock index (dev/CI mode)")
|
||||
db_size = 1000
|
||||
vecs = np.random.rand(db_size, self._DIM).astype(np.float32)
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
self._np_descriptors = vecs / np.maximum(norms, 1e-12)
|
||||
self._metadata = self._generate_stub_metadata(db_size)
|
||||
self._is_loaded = True
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _generate_stub_metadata(n: int) -> Dict[int, dict]:
|
||||
"""Generate placeholder tile metadata for dev/CI mock index."""
|
||||
meta: Dict[int, dict] = {}
|
||||
for i in range(n):
|
||||
meta[i] = {
|
||||
"tile_id": f"tile_{i:06d}",
|
||||
"gps_center": GPSPoint(lat=49.0 + np.random.rand(), lon=32.0 + np.random.rand()),
|
||||
"bounds": TileBounds(
|
||||
nw=GPSPoint(lat=49.1, lon=32.0),
|
||||
ne=GPSPoint(lat=49.1, lon=32.1),
|
||||
sw=GPSPoint(lat=49.0, lon=32.0),
|
||||
se=GPSPoint(lat=49.0, lon=32.1),
|
||||
center=GPSPoint(lat=49.05, lon=32.05),
|
||||
gsd=0.6,
|
||||
),
|
||||
}
|
||||
return meta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GPR-03: Similarity search ranked by descriptor distance
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||||
"""Search the index for the top-k most similar tiles.
|
||||
|
||||
Uses Faiss when loaded, numpy L2 otherwise.
|
||||
Results are sorted by ascending L2 distance (= descending similarity).
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
logger.error("GPR index not loaded — call load_index() first.")
|
||||
return []
|
||||
|
||||
q = descriptor.astype(np.float32).reshape(1, -1)
|
||||
|
||||
# Faiss path
|
||||
if self._faiss_index is not None:
|
||||
try:
|
||||
distances, indices = self._faiss_index.search(q, top_k)
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
sim = 1.0 / (1.0 + float(dist))
|
||||
meta = self._metadata.get(int(idx), {"tile_id": f"tile_{idx}"})
|
||||
results.append(DatabaseMatch(
|
||||
index=int(idx),
|
||||
tile_id=meta.get("tile_id", str(idx)),
|
||||
distance=float(dist),
|
||||
similarity_score=sim,
|
||||
))
|
||||
return results
|
||||
except Exception as exc:
|
||||
logger.warning("Faiss search failed: %s", exc)
|
||||
|
||||
# NumPy path
|
||||
if self._np_descriptors is None:
|
||||
return []
|
||||
diff = self._np_descriptors - q # (N, DIM)
|
||||
distances = np.sum(diff ** 2, axis=1)
|
||||
top_indices = np.argsort(distances)[:top_k]
|
||||
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
dist = float(distances[idx])
|
||||
sim = 1.0 / (1.0 + dist)
|
||||
meta = self._metadata.get(int(idx), {"tile_id": f"tile_{idx}"})
|
||||
results.append(DatabaseMatch(
|
||||
index=int(idx),
|
||||
tile_id=meta.get("tile_id", str(idx)),
|
||||
distance=dist,
|
||||
similarity_score=sim,
|
||||
))
|
||||
return results
|
||||
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||||
"""Sort candidates by descriptor similarity (descending) — GPR-03."""
|
||||
return sorted(candidates, key=lambda c: c.similarity_score, reverse=True)
|
||||
|
||||
def _matches_to_candidates(self, matches: List[DatabaseMatch]) -> List[TileCandidate]:
|
||||
candidates = []
|
||||
for rank, match in enumerate(matches, 1):
|
||||
meta = self._metadata.get(match.index, {})
|
||||
gps = meta.get("gps_center", GPSPoint(lat=49.0, lon=32.0))
|
||||
bounds = meta.get("bounds", TileBounds(
|
||||
nw=GPSPoint(lat=49.1, lon=32.0), ne=GPSPoint(lat=49.1, lon=32.1),
|
||||
sw=GPSPoint(lat=49.0, lon=32.0), se=GPSPoint(lat=49.0, lon=32.1),
|
||||
center=GPSPoint(lat=49.05, lon=32.05), gsd=0.6,
|
||||
))
|
||||
candidates.append(TileCandidate(
|
||||
tile_id=match.tile_id,
|
||||
gps_center=gps,
|
||||
bounds=bounds,
|
||||
similarity_score=match.similarity_score,
|
||||
rank=rank,
|
||||
))
|
||||
return self.rank_candidates(candidates)
|
||||
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]:
|
||||
desc = self.compute_location_descriptor(image)
|
||||
matches = self.query_database(desc, top_k)
|
||||
return self._matches_to_candidates(matches)
|
||||
|
||||
def retrieve_candidate_tiles_for_chunk(
|
||||
self, chunk_images: List[np.ndarray], top_k: int = 5
|
||||
) -> List[TileCandidate]:
|
||||
desc = self.compute_chunk_descriptor(chunk_images)
|
||||
matches = self.query_database(desc, top_k)
|
||||
return self._matches_to_candidates(matches)
|
||||
+15
-270
@@ -1,271 +1,16 @@
|
||||
"""Global Place Recognition (Component F08).
|
||||
"""Legacy import path for GPR. Phase 1 shim — code lives in components/gpr/."""
|
||||
from gps_denied.components.gpr.protocol import (
|
||||
IGlobalPlaceRecognition, # noqa: F401
|
||||
)
|
||||
from gps_denied.components.gpr.faiss_gpr import (
|
||||
GlobalPlaceRecognition,
|
||||
_faiss,
|
||||
_FAISS_AVAILABLE,
|
||||
)
|
||||
|
||||
GPR-01: Loads a real Faiss index from disk when available; numpy-L2 fallback for dev/CI.
|
||||
GPR-02: DINOv2/AnyLoc TRT FP16 on Jetson; MockInferenceEngine on dev/CI (via ModelManager).
|
||||
GPR-03: Candidates ranked by DINOv2 descriptor similarity (dot-product / L2 distance).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.models import IModelManager
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.gpr import DatabaseMatch, TileCandidate
|
||||
from gps_denied.schemas.satellite import TileBounds
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Attempt to import Faiss (optional — only available on Jetson or with faiss-cpu installed)
|
||||
try:
|
||||
import faiss as _faiss # type: ignore
|
||||
_FAISS_AVAILABLE = True
|
||||
logger.info("Faiss available — real index search enabled")
|
||||
except ImportError:
|
||||
_faiss = None # type: ignore
|
||||
_FAISS_AVAILABLE = False
|
||||
logger.info("Faiss not available — using numpy L2 fallback for GPR")
|
||||
|
||||
|
||||
class IGlobalPlaceRecognition(ABC):
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class GlobalPlaceRecognition(IGlobalPlaceRecognition):
|
||||
"""AnyLoc-VLAD-DINOv2 coarse localisation component — sprint 1 GPR baseline.
|
||||
|
||||
GPR-01: load_index() tries to open a real Faiss .index file; falls back to
|
||||
a NumPy L2 mock when the file is missing or Faiss is not installed.
|
||||
GPR-02: Descriptor computed via DINOv2 engine (TRT FP16 on Jetson, Mock on
|
||||
dev/CI). INT8 quantization is disabled — broken for ViT on Jetson
|
||||
(NVIDIA/TRT#4348, facebookresearch/dinov2#489).
|
||||
GPR-03: Candidates ranked by descriptor similarity (L2 → converted to [0,1]).
|
||||
|
||||
Selected over NetVLAD (deprecated, −2.4% R@1 on MSLS 2024) and SuperPoint+
|
||||
LightGlue (unvalidated for cross-view UAV↔satellite gap at sprint 1).
|
||||
Stage 2 evaluation: SP+LG+FAISS per _docs/03_backlog/stage2_ideas/.
|
||||
Long-term target: EigenPlaces (ICCV 2023) — cleaner ONNX export.
|
||||
|
||||
Ref: docs/superpowers/specs/2026-04-18-oss-stack-tech-audit-design.md §2.3
|
||||
"""
|
||||
|
||||
_DIM = 4096 # DINOv2 VLAD descriptor dimension
|
||||
|
||||
def __init__(self, model_manager: IModelManager):
|
||||
self.model_manager = model_manager
|
||||
|
||||
# Index storage — one of: Faiss index OR numpy matrix
|
||||
self._faiss_index = None # faiss.IndexFlatIP or similar
|
||||
self._np_descriptors: np.ndarray | None = None # (N, DIM) fallback
|
||||
self._metadata: Dict[int, dict] = {}
|
||||
self._is_loaded = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GPR-02: Descriptor extraction via DINOv2
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
"""Run DINOv2 inference and return an L2-normalised descriptor."""
|
||||
engine = self.model_manager.get_inference_engine("DINOv2")
|
||||
desc = engine.infer(image)
|
||||
norm = np.linalg.norm(desc)
|
||||
return desc / max(norm, 1e-12)
|
||||
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||||
"""Mean-aggregate per-frame DINOv2 descriptors for a chunk."""
|
||||
if not chunk_images:
|
||||
return np.zeros(self._DIM, dtype=np.float32)
|
||||
descs = [self.compute_location_descriptor(img) for img in chunk_images]
|
||||
agg = np.mean(descs, axis=0)
|
||||
return agg / max(np.linalg.norm(agg), 1e-12)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GPR-01: Index loading
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||||
"""Load a Faiss descriptor index from disk (GPR-01).
|
||||
|
||||
Falls back to a NumPy random-vector mock when:
|
||||
- `index_path` does not exist, OR
|
||||
- Faiss is not installed (dev/CI without faiss-cpu).
|
||||
"""
|
||||
logger.info("Loading GPR index for flight=%s path=%s", flight_id, index_path)
|
||||
|
||||
# Try real Faiss load ------------------------------------------------
|
||||
if _FAISS_AVAILABLE and os.path.isfile(index_path):
|
||||
try:
|
||||
self._faiss_index = _faiss.read_index(index_path)
|
||||
# Load companion metadata JSON if present
|
||||
meta_path = os.path.splitext(index_path)[0] + "_meta.json"
|
||||
if os.path.isfile(meta_path):
|
||||
with open(meta_path) as f:
|
||||
raw = json.load(f)
|
||||
self._metadata = {int(k): v for k, v in raw.items()}
|
||||
# Deserialise GPSPoint / TileBounds from dicts
|
||||
for idx, m in self._metadata.items():
|
||||
if isinstance(m.get("gps_center"), dict):
|
||||
m["gps_center"] = GPSPoint(**m["gps_center"])
|
||||
if isinstance(m.get("bounds"), dict):
|
||||
bounds_d = m["bounds"]
|
||||
for corner in ("nw", "ne", "sw", "se", "center"):
|
||||
if isinstance(bounds_d.get(corner), dict):
|
||||
bounds_d[corner] = GPSPoint(**bounds_d[corner])
|
||||
m["bounds"] = TileBounds(**bounds_d)
|
||||
else:
|
||||
self._metadata = self._generate_stub_metadata(self._faiss_index.ntotal)
|
||||
self._is_loaded = True
|
||||
logger.info("Faiss index loaded: %d vectors", self._faiss_index.ntotal)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Faiss load failed (%s) — falling back to numpy mock", exc)
|
||||
|
||||
# NumPy mock fallback ------------------------------------------------
|
||||
logger.info("GPR: using numpy mock index (dev/CI mode)")
|
||||
db_size = 1000
|
||||
vecs = np.random.rand(db_size, self._DIM).astype(np.float32)
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
self._np_descriptors = vecs / np.maximum(norms, 1e-12)
|
||||
self._metadata = self._generate_stub_metadata(db_size)
|
||||
self._is_loaded = True
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _generate_stub_metadata(n: int) -> Dict[int, dict]:
|
||||
"""Generate placeholder tile metadata for dev/CI mock index."""
|
||||
meta: Dict[int, dict] = {}
|
||||
for i in range(n):
|
||||
meta[i] = {
|
||||
"tile_id": f"tile_{i:06d}",
|
||||
"gps_center": GPSPoint(lat=49.0 + np.random.rand(), lon=32.0 + np.random.rand()),
|
||||
"bounds": TileBounds(
|
||||
nw=GPSPoint(lat=49.1, lon=32.0),
|
||||
ne=GPSPoint(lat=49.1, lon=32.1),
|
||||
sw=GPSPoint(lat=49.0, lon=32.0),
|
||||
se=GPSPoint(lat=49.0, lon=32.1),
|
||||
center=GPSPoint(lat=49.05, lon=32.05),
|
||||
gsd=0.6,
|
||||
),
|
||||
}
|
||||
return meta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GPR-03: Similarity search ranked by descriptor distance
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||||
"""Search the index for the top-k most similar tiles.
|
||||
|
||||
Uses Faiss when loaded, numpy L2 otherwise.
|
||||
Results are sorted by ascending L2 distance (= descending similarity).
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
logger.error("GPR index not loaded — call load_index() first.")
|
||||
return []
|
||||
|
||||
q = descriptor.astype(np.float32).reshape(1, -1)
|
||||
|
||||
# Faiss path
|
||||
if self._faiss_index is not None:
|
||||
try:
|
||||
distances, indices = self._faiss_index.search(q, top_k)
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
sim = 1.0 / (1.0 + float(dist))
|
||||
meta = self._metadata.get(int(idx), {"tile_id": f"tile_{idx}"})
|
||||
results.append(DatabaseMatch(
|
||||
index=int(idx),
|
||||
tile_id=meta.get("tile_id", str(idx)),
|
||||
distance=float(dist),
|
||||
similarity_score=sim,
|
||||
))
|
||||
return results
|
||||
except Exception as exc:
|
||||
logger.warning("Faiss search failed: %s", exc)
|
||||
|
||||
# NumPy path
|
||||
if self._np_descriptors is None:
|
||||
return []
|
||||
diff = self._np_descriptors - q # (N, DIM)
|
||||
distances = np.sum(diff ** 2, axis=1)
|
||||
top_indices = np.argsort(distances)[:top_k]
|
||||
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
dist = float(distances[idx])
|
||||
sim = 1.0 / (1.0 + dist)
|
||||
meta = self._metadata.get(int(idx), {"tile_id": f"tile_{idx}"})
|
||||
results.append(DatabaseMatch(
|
||||
index=int(idx),
|
||||
tile_id=meta.get("tile_id", str(idx)),
|
||||
distance=dist,
|
||||
similarity_score=sim,
|
||||
))
|
||||
return results
|
||||
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||||
"""Sort candidates by descriptor similarity (descending) — GPR-03."""
|
||||
return sorted(candidates, key=lambda c: c.similarity_score, reverse=True)
|
||||
|
||||
def _matches_to_candidates(self, matches: List[DatabaseMatch]) -> List[TileCandidate]:
|
||||
candidates = []
|
||||
for rank, match in enumerate(matches, 1):
|
||||
meta = self._metadata.get(match.index, {})
|
||||
gps = meta.get("gps_center", GPSPoint(lat=49.0, lon=32.0))
|
||||
bounds = meta.get("bounds", TileBounds(
|
||||
nw=GPSPoint(lat=49.1, lon=32.0), ne=GPSPoint(lat=49.1, lon=32.1),
|
||||
sw=GPSPoint(lat=49.0, lon=32.0), se=GPSPoint(lat=49.0, lon=32.1),
|
||||
center=GPSPoint(lat=49.05, lon=32.05), gsd=0.6,
|
||||
))
|
||||
candidates.append(TileCandidate(
|
||||
tile_id=match.tile_id,
|
||||
gps_center=gps,
|
||||
bounds=bounds,
|
||||
similarity_score=match.similarity_score,
|
||||
rank=rank,
|
||||
))
|
||||
return self.rank_candidates(candidates)
|
||||
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]:
|
||||
desc = self.compute_location_descriptor(image)
|
||||
matches = self.query_database(desc, top_k)
|
||||
return self._matches_to_candidates(matches)
|
||||
|
||||
def retrieve_candidate_tiles_for_chunk(
|
||||
self, chunk_images: List[np.ndarray], top_k: int = 5
|
||||
) -> List[TileCandidate]:
|
||||
desc = self.compute_chunk_descriptor(chunk_images)
|
||||
matches = self.query_database(desc, top_k)
|
||||
return self._matches_to_candidates(matches)
|
||||
__all__ = [
|
||||
"GlobalPlaceRecognition",
|
||||
"IGlobalPlaceRecognition",
|
||||
"_faiss",
|
||||
"_FAISS_AVAILABLE",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user