import cv2 import numpy as np import json import os import logging from typing import List, Dict, Optional, Any, Tuple from pydantic import BaseModel from abc import ABC, abstractmethod from f02_1_flight_lifecycle_manager import GPSPoint from f04_satellite_data_manager import TileBounds logger = logging.getLogger(__name__) # --- Data Models --- class TileCandidate(BaseModel): tile_id: str gps_center: GPSPoint bounds: Optional[Any] = None # Optional TileBounds to avoid strict cyclic coupling similarity_score: float rank: int spatial_score: Optional[float] = None class DatabaseMatch(BaseModel): index: int tile_id: str distance: float similarity_score: float class SatelliteTile(BaseModel): tile_id: str image: np.ndarray gps_center: GPSPoint bounds: Any descriptor: Optional[np.ndarray] = None model_config = {"arbitrary_types_allowed": True} # --- Exceptions --- class IndexNotFoundError(Exception): pass class IndexCorruptedError(Exception): pass class MetadataMismatchError(Exception): pass # --- Interface --- class IGlobalPlaceRecognition(ABC): @abstractmethod def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]: pass @abstractmethod def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray: pass @abstractmethod def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]: pass @abstractmethod def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]: pass @abstractmethod def load_index(self, flight_id: str, index_path: str) -> bool: pass @abstractmethod def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]: pass @abstractmethod def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray: pass # --- Implementation --- class GlobalPlaceRecognition(IGlobalPlaceRecognition): """ F08: Global Place Recognition Computes DINOv2+VLAD semantic descriptors and queries a pre-built Faiss index of satellite tiles to relocalize the UAV after catastrophic tracking loss. """ def __init__(self, model_manager=None, faiss_manager=None, satellite_manager=None): self.model_manager = model_manager self.faiss_manager = faiss_manager self.satellite_manager = satellite_manager self.is_index_loaded = False self.tile_metadata: Dict[int, Dict] = {} self.dim = 4096 # DINOv2 + VLAD standard dimension # --- Descriptor Computation (08.02) --- def _preprocess_image(self, image: np.ndarray) -> np.ndarray: if len(image.shape) == 3 and image.shape[2] == 3: img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) else: img = image # Standard DINOv2 input size img = cv2.resize(img, (224, 224)) return img.astype(np.float32) / 255.0 def _extract_dense_features(self, preprocessed: np.ndarray) -> np.ndarray: if self.model_manager and hasattr(self.model_manager, 'run_dinov2'): return self.model_manager.run_dinov2(preprocessed) # Mock fallback: return random features [num_patches, feat_dim] rng = np.random.RandomState(int(np.sum(preprocessed) * 1000) % (2**32)) return rng.rand(256, 384).astype(np.float32) def _vlad_aggregate(self, dense_features: np.ndarray, codebook: Optional[np.ndarray] = None) -> np.ndarray: # Mock VLAD aggregation projecting to 4096 dims rng = np.random.RandomState(int(np.sum(dense_features) * 1000) % (2**32)) vlad_desc = rng.rand(self.dim).astype(np.float32) return vlad_desc def _l2_normalize(self, descriptor: np.ndarray) -> np.ndarray: norm = np.linalg.norm(descriptor) if norm == 0: return descriptor return descriptor / norm def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray: preprocessed = self._preprocess_image(image) dense_feat = self._extract_dense_features(preprocessed) vlad_desc = self._vlad_aggregate(dense_feat) return self._l2_normalize(vlad_desc) def _aggregate_chunk_descriptors(self, descriptors: List[np.ndarray], strategy: str = "mean") -> np.ndarray: if not descriptors: raise ValueError("Cannot aggregate empty descriptor list.") stacked = np.stack(descriptors) if strategy == "mean": agg = np.mean(stacked, axis=0) elif strategy == "max": agg = np.max(stacked, axis=0) elif strategy == "vlad": agg = np.mean(stacked, axis=0) # Simplified fallback for vlad aggregation else: raise ValueError(f"Unknown aggregation strategy: {strategy}") return self._l2_normalize(agg) def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray: if not chunk_images: raise ValueError("Chunk images list is empty.") descriptors = [self.compute_location_descriptor(img) for img in chunk_images] return self._aggregate_chunk_descriptors(descriptors, strategy="mean") # --- Index Management (08.01) --- def _validate_index_integrity(self, index_dim: int, expected_count: int) -> bool: if index_dim not in [4096, 8192]: raise IndexCorruptedError(f"Invalid index dimensions: {index_dim}") return True def _load_tile_metadata(self, metadata_path: str) -> Dict[int, dict]: if not os.path.exists(metadata_path): raise MetadataMismatchError("Metadata file not found.") try: with open(metadata_path, 'r') as f: content = f.read().strip() if not content: raise MetadataMismatchError("Metadata file is empty.") data = json.loads(content) if not data: raise MetadataMismatchError("Metadata file contains empty JSON object.") except json.JSONDecodeError: raise MetadataMismatchError("Metadata file contains invalid JSON.") return {int(k): v for k, v in data.items()} def _verify_metadata_alignment(self, index_count: int, metadata: Dict) -> bool: if index_count != len(metadata): raise MetadataMismatchError(f"Index count ({index_count}) does not match metadata count ({len(metadata)}).") return True def load_index(self, flight_id: str, index_path: str) -> bool: meta_path = index_path.replace(".index", ".json") if not os.path.exists(index_path): raise IndexNotFoundError(f"Index file {index_path} not found.") if self.faiss_manager: self.faiss_manager.load_index(index_path) idx_count, idx_dim = self.faiss_manager.get_stats() else: # Mock Faiss loading idx_count, idx_dim = 1000, 4096 self._validate_index_integrity(idx_dim, idx_count) self.tile_metadata = self._load_tile_metadata(meta_path) self._verify_metadata_alignment(idx_count, self.tile_metadata) self.is_index_loaded = True logger.info(f"Successfully loaded global index for flight {flight_id}.") return True # --- Candidate Retrieval (08.03) --- def _retrieve_tile_metadata(self, indices: List[int]) -> List[Dict[str, Any]]: """Fetches metadata for a list of tile indices.""" # In a real system, this might delegate to F04 if metadata is not held in memory return [self.tile_metadata.get(idx, {}) for idx in indices] def _build_candidates_from_matches(self, matches: List[DatabaseMatch]) -> List[TileCandidate]: candidates = [] for m in matches: meta = self.tile_metadata.get(m.index, {}) lat, lon = meta.get("lat", 0.0), meta.get("lon", 0.0) cand = TileCandidate( tile_id=m.tile_id, gps_center=GPSPoint(lat=lat, lon=lon), similarity_score=m.similarity_score, rank=0 ) candidates.append(cand) return candidates def _distance_to_similarity(self, distance: float) -> float: # For L2 normalized vectors, Euclidean distance is in [0, 2]. # Sim = 1 - (dist^2 / 4) maps [0, 2] to [1, 0]. return max(0.0, 1.0 - (distance**2 / 4.0)) def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]: if not self.is_index_loaded: return [] if self.faiss_manager: distances, indices = self.faiss_manager.search(descriptor.reshape(1, -1), top_k) else: # Mock Faiss search indices = np.random.choice(len(self.tile_metadata), min(top_k, len(self.tile_metadata)), replace=False).reshape(1, -1) distances = np.sort(np.random.rand(top_k) * 1.5).reshape(1, -1) # Distances sorted ascending matches = [] for i in range(len(indices[0])): idx = int(indices[0][i]) dist = float(distances[0][i]) meta = self.tile_metadata.get(idx, {}) tile_id = meta.get("tile_id", f"tile_{idx}") sim = self._distance_to_similarity(dist) matches.append(DatabaseMatch(index=idx, tile_id=tile_id, distance=dist, similarity_score=sim)) return matches def _apply_spatial_reranking(self, candidates: List[TileCandidate], dead_reckoning_estimate: Optional[GPSPoint] = None) -> List[TileCandidate]: # Currently returns unmodified, leaving hook for future GPS-proximity heuristics return candidates def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]: # Primary sort by similarity score descending candidates.sort(key=lambda x: x.similarity_score, reverse=True) for i, cand in enumerate(candidates): cand.rank = i + 1 return self._apply_spatial_reranking(candidates) def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]: descriptor = self.compute_location_descriptor(image) matches = self.query_database(descriptor, top_k) candidates = self._build_candidates_from_matches(matches) return self.rank_candidates(candidates) def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int = 5) -> List[TileCandidate]: descriptor = self.compute_chunk_descriptor(chunk_images) matches = self.query_database(descriptor, top_k) candidates = self._build_candidates_from_matches(matches) return self.rank_candidates(candidates)