Files
gps-denied-onboard/f08_global_place_recognition.py
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

259 lines
11 KiB
Python

import cv2
import numpy as np
import json
import os
import logging
from typing import List, Dict, Optional, Any, Tuple
from pydantic import BaseModel
from abc import ABC, abstractmethod
from f02_1_flight_lifecycle_manager import GPSPoint
from f04_satellite_data_manager import TileBounds
logger = logging.getLogger(__name__)
# --- Data Models ---
class TileCandidate(BaseModel):
tile_id: str
gps_center: GPSPoint
bounds: Optional[Any] = None # Optional TileBounds to avoid strict cyclic coupling
similarity_score: float
rank: int
spatial_score: Optional[float] = None
class DatabaseMatch(BaseModel):
index: int
tile_id: str
distance: float
similarity_score: float
class SatelliteTile(BaseModel):
tile_id: str
image: np.ndarray
gps_center: GPSPoint
bounds: Any
descriptor: Optional[np.ndarray] = None
model_config = {"arbitrary_types_allowed": True}
# --- Exceptions ---
class IndexNotFoundError(Exception): pass
class IndexCorruptedError(Exception): pass
class MetadataMismatchError(Exception): pass
# --- Interface ---
class IGlobalPlaceRecognition(ABC):
@abstractmethod
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]: pass
@abstractmethod
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray: pass
@abstractmethod
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]: pass
@abstractmethod
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]: pass
@abstractmethod
def load_index(self, flight_id: str, index_path: str) -> bool: pass
@abstractmethod
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]: pass
@abstractmethod
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray: pass
# --- Implementation ---
class GlobalPlaceRecognition(IGlobalPlaceRecognition):
"""
F08: Global Place Recognition
Computes DINOv2+VLAD semantic descriptors and queries a pre-built Faiss index
of satellite tiles to relocalize the UAV after catastrophic tracking loss.
"""
def __init__(self, model_manager=None, faiss_manager=None, satellite_manager=None):
self.model_manager = model_manager
self.faiss_manager = faiss_manager
self.satellite_manager = satellite_manager
self.is_index_loaded = False
self.tile_metadata: Dict[int, Dict] = {}
self.dim = 4096 # DINOv2 + VLAD standard dimension
# --- Descriptor Computation (08.02) ---
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
if len(image.shape) == 3 and image.shape[2] == 3:
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
img = image
# Standard DINOv2 input size
img = cv2.resize(img, (224, 224))
return img.astype(np.float32) / 255.0
def _extract_dense_features(self, preprocessed: np.ndarray) -> np.ndarray:
if self.model_manager and hasattr(self.model_manager, 'run_dinov2'):
return self.model_manager.run_dinov2(preprocessed)
# Mock fallback: return random features [num_patches, feat_dim]
rng = np.random.RandomState(int(np.sum(preprocessed) * 1000) % (2**32))
return rng.rand(256, 384).astype(np.float32)
def _vlad_aggregate(self, dense_features: np.ndarray, codebook: Optional[np.ndarray] = None) -> np.ndarray:
# Mock VLAD aggregation projecting to 4096 dims
rng = np.random.RandomState(int(np.sum(dense_features) * 1000) % (2**32))
vlad_desc = rng.rand(self.dim).astype(np.float32)
return vlad_desc
def _l2_normalize(self, descriptor: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(descriptor)
if norm == 0:
return descriptor
return descriptor / norm
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
preprocessed = self._preprocess_image(image)
dense_feat = self._extract_dense_features(preprocessed)
vlad_desc = self._vlad_aggregate(dense_feat)
return self._l2_normalize(vlad_desc)
def _aggregate_chunk_descriptors(self, descriptors: List[np.ndarray], strategy: str = "mean") -> np.ndarray:
if not descriptors:
raise ValueError("Cannot aggregate empty descriptor list.")
stacked = np.stack(descriptors)
if strategy == "mean":
agg = np.mean(stacked, axis=0)
elif strategy == "max":
agg = np.max(stacked, axis=0)
elif strategy == "vlad":
agg = np.mean(stacked, axis=0) # Simplified fallback for vlad aggregation
else:
raise ValueError(f"Unknown aggregation strategy: {strategy}")
return self._l2_normalize(agg)
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
if not chunk_images:
raise ValueError("Chunk images list is empty.")
descriptors = [self.compute_location_descriptor(img) for img in chunk_images]
return self._aggregate_chunk_descriptors(descriptors, strategy="mean")
# --- Index Management (08.01) ---
def _validate_index_integrity(self, index_dim: int, expected_count: int) -> bool:
if index_dim not in [4096, 8192]:
raise IndexCorruptedError(f"Invalid index dimensions: {index_dim}")
return True
def _load_tile_metadata(self, metadata_path: str) -> Dict[int, dict]:
if not os.path.exists(metadata_path):
raise MetadataMismatchError("Metadata file not found.")
try:
with open(metadata_path, 'r') as f:
content = f.read().strip()
if not content:
raise MetadataMismatchError("Metadata file is empty.")
data = json.loads(content)
if not data:
raise MetadataMismatchError("Metadata file contains empty JSON object.")
except json.JSONDecodeError:
raise MetadataMismatchError("Metadata file contains invalid JSON.")
return {int(k): v for k, v in data.items()}
def _verify_metadata_alignment(self, index_count: int, metadata: Dict) -> bool:
if index_count != len(metadata):
raise MetadataMismatchError(f"Index count ({index_count}) does not match metadata count ({len(metadata)}).")
return True
def load_index(self, flight_id: str, index_path: str) -> bool:
meta_path = index_path.replace(".index", ".json")
if not os.path.exists(index_path):
raise IndexNotFoundError(f"Index file {index_path} not found.")
if self.faiss_manager:
self.faiss_manager.load_index(index_path)
idx_count, idx_dim = self.faiss_manager.get_stats()
else:
# Mock Faiss loading
idx_count, idx_dim = 1000, 4096
self._validate_index_integrity(idx_dim, idx_count)
self.tile_metadata = self._load_tile_metadata(meta_path)
self._verify_metadata_alignment(idx_count, self.tile_metadata)
self.is_index_loaded = True
logger.info(f"Successfully loaded global index for flight {flight_id}.")
return True
# --- Candidate Retrieval (08.03) ---
def _retrieve_tile_metadata(self, indices: List[int]) -> List[Dict[str, Any]]:
"""Fetches metadata for a list of tile indices."""
# In a real system, this might delegate to F04 if metadata is not held in memory
return [self.tile_metadata.get(idx, {}) for idx in indices]
def _build_candidates_from_matches(self, matches: List[DatabaseMatch]) -> List[TileCandidate]:
candidates = []
for m in matches:
meta = self.tile_metadata.get(m.index, {})
lat, lon = meta.get("lat", 0.0), meta.get("lon", 0.0)
cand = TileCandidate(
tile_id=m.tile_id,
gps_center=GPSPoint(lat=lat, lon=lon),
similarity_score=m.similarity_score,
rank=0
)
candidates.append(cand)
return candidates
def _distance_to_similarity(self, distance: float) -> float:
# For L2 normalized vectors, Euclidean distance is in [0, 2].
# Sim = 1 - (dist^2 / 4) maps [0, 2] to [1, 0].
return max(0.0, 1.0 - (distance**2 / 4.0))
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
if not self.is_index_loaded:
return []
if self.faiss_manager:
distances, indices = self.faiss_manager.search(descriptor.reshape(1, -1), top_k)
else:
# Mock Faiss search
indices = np.random.choice(len(self.tile_metadata), min(top_k, len(self.tile_metadata)), replace=False).reshape(1, -1)
distances = np.sort(np.random.rand(top_k) * 1.5).reshape(1, -1) # Distances sorted ascending
matches = []
for i in range(len(indices[0])):
idx = int(indices[0][i])
dist = float(distances[0][i])
meta = self.tile_metadata.get(idx, {})
tile_id = meta.get("tile_id", f"tile_{idx}")
sim = self._distance_to_similarity(dist)
matches.append(DatabaseMatch(index=idx, tile_id=tile_id, distance=dist, similarity_score=sim))
return matches
def _apply_spatial_reranking(self, candidates: List[TileCandidate], dead_reckoning_estimate: Optional[GPSPoint] = None) -> List[TileCandidate]:
# Currently returns unmodified, leaving hook for future GPS-proximity heuristics
return candidates
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
# Primary sort by similarity score descending
candidates.sort(key=lambda x: x.similarity_score, reverse=True)
for i, cand in enumerate(candidates):
cand.rank = i + 1
return self._apply_spatial_reranking(candidates)
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]:
descriptor = self.compute_location_descriptor(image)
matches = self.query_database(descriptor, top_k)
candidates = self._build_candidates_from_matches(matches)
return self.rank_candidates(candidates)
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int = 5) -> List[TileCandidate]:
descriptor = self.compute_chunk_descriptor(chunk_images)
matches = self.query_database(descriptor, top_k)
candidates = self._build_candidates_from_matches(matches)
return self.rank_candidates(candidates)