import torch import cv2 import numpy as np import logging from typing import Tuple, Optional, Dict, Any from lightglue import LightGlue, SuperPoint from lightglue.utils import rbd logger = logging.getLogger(__name__) class CrossViewGeolocator: """ Asynchronous Global Place Recognizer and Fine-Grained Matcher. Finds absolute metric GPS anchors for unscaled UAV keyframes. """ def __init__(self, faiss_manager, device: str = "cuda"): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.faiss_manager = faiss_manager logger.info("Initializing Global Place Recognition (DINOv2) & Fine Matcher (LightGlue)") # Global Descriptor Model for fast Faiss retrieval self.global_encoder = self._load_global_encoder() # Local feature matcher for metric alignment self.extractor = SuperPoint(max_num_keypoints=2048).eval().to(self.device) self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device) # Simulates the local geospatial SQLite cache of pre-downloaded satellite tiles self.satellite_cache = {} def _load_global_encoder(self): """Loads a Foundation Model (like DINOv2) for viewpoint-invariant descriptors.""" if USE_MOCK_MODELS: class MockEncoder: def __call__(self, x): return torch.randn(1, 384).to(x.device) return MockEncoder() else: return torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.device) def extract_global_descriptor(self, image: np.ndarray) -> np.ndarray: """Extracts a 1D vector signature resilient to seasonal/lighting changes.""" img_resized = cv2.resize(image, (224, 224)) tensor = torch.from_numpy(img_resized).float() / 255.0 # Adjust dimensions for PyTorch [B, C, H, W] if len(tensor.shape) == 3: tensor = tensor.permute(2, 0, 1).unsqueeze(0) else: tensor = tensor.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1) tensor = tensor.to(self.device) with torch.no_grad(): desc = self.global_encoder(tensor) return desc.cpu().numpy() def retrieve_and_match(self, uav_image: np.ndarray, index) -> Tuple[bool, Optional[np.ndarray], Optional[Dict[str, Any]]]: """Searches the Faiss Index and computes the precise 2D-to-2D geodetic alignment.""" # 1. Global Search (Coarse) global_desc = self.extract_global_descriptor(uav_image) distances, indices = self.faiss_manager.search(index, global_desc, k=3) best_transform, best_inliers, best_sat_info = None, 0, None # 2. Extract UAV features once (Fine) uav_gray = cv2.cvtColor(uav_image, cv2.COLOR_BGR2GRAY) if len(uav_image.shape) == 3 else uav_image uav_tensor = torch.from_numpy(uav_gray).float()[None, None, ...].to(self.device) / 255.0 with torch.no_grad(): uav_feats = self.extractor.extract(uav_tensor) # 3. Fine-grained matching against top-K satellite tiles for idx in indices[0]: if idx not in self.satellite_cache: continue sat_info = self.satellite_cache[idx] sat_feats = sat_info['features'] with torch.no_grad(): matches = self.matcher({'image0': uav_feats, 'image1': sat_feats}) feats0, feats1, matches01 = [rbd(x) for x in [uav_feats, sat_feats, matches]] kpts_uav = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy() kpts_sat = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy() if len(kpts_uav) > 15: H, mask = cv2.findHomography(kpts_uav, kpts_sat, cv2.RANSAC, 5.0) inliers = mask.sum() if mask is not None else 0 if inliers > best_inliers and inliers > 15: best_inliers, best_transform, best_sat_info = inliers, H, sat_info return (best_transform is not None), best_transform, best_sat_info