import cv2 import torch import math import numpy as np import logging from typing import List, Optional, Tuple from pydantic import BaseModel import os USE_MOCK_MODELS = os.environ.get("USE_MOCK_MODELS", "0") == "1" if USE_MOCK_MODELS: class SuperPoint(torch.nn.Module): def __init__(self, **kwargs): super().__init__() def forward(self, x): b, _, h, w = x.shape kpts = torch.rand(b, 50, 2, device=x.device) kpts[..., 0] *= w kpts[..., 1] *= h return {'keypoints': kpts, 'descriptors': torch.rand(b, 256, 50, device=x.device), 'scores': torch.rand(b, 50, device=x.device)} class LightGlue(torch.nn.Module): def __init__(self, **kwargs): super().__init__() def forward(self, data): b = data['image0']['keypoints'].shape[0] matches = torch.stack([torch.arange(25), torch.arange(25)], dim=-1).unsqueeze(0).repeat(b, 1, 1).to(data['image0']['keypoints'].device) return {'matches': matches, 'matching_scores': torch.rand(b, 25, device=data['image0']['keypoints'].device)} def rbd(data): return {k: v[0] for k, v in data.items()} else: # Requires: pip install lightglue from lightglue import LightGlue, SuperPoint from lightglue.utils import rbd logger = logging.getLogger(__name__) # --- Data Models --- class GPSPoint(BaseModel): lat: float lon: float class TileBounds(BaseModel): nw: GPSPoint ne: GPSPoint sw: GPSPoint se: GPSPoint center: GPSPoint gsd: float # Ground Sampling Distance (meters/pixel) class Sim3Transform(BaseModel): translation: np.ndarray rotation: np.ndarray scale: float class Config: arbitrary_types_allowed = True class AlignmentResult(BaseModel): matched: bool homography: np.ndarray transform: np.ndarray # 4x4 matrix for pipeline compatibility gps_center: GPSPoint confidence: float inlier_count: int total_correspondences: int reprojection_error: float class Config: arbitrary_types_allowed = True class ChunkAlignmentResult(BaseModel): matched: bool chunk_id: str chunk_center_gps: GPSPoint rotation_angle: float confidence: float inlier_count: int transform: Sim3Transform reprojection_error: float class Config: arbitrary_types_allowed = True # --- Implementation --- class MetricRefinement: """ F09: Metric Refinement Module. Performs dense cross-view geo-localization between UAV images and satellite tiles. Computes homography mappings, Mean Reprojection Error (MRE), and exact GPS coordinates. """ def __init__(self, device: str = "cuda", max_keypoints: int = 2048): self.device = torch.device(device if torch.cuda.is_available() else "cpu") logger.info(f"Initializing Metric Refinement (SuperPoint+LightGlue) on {self.device}") # Using SuperPoint + LightGlue as the high-accuracy "Fine Matcher" self.extractor = SuperPoint(max_num_keypoints=max_keypoints).eval().to(self.device) self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device) def _preprocess_image(self, image: np.ndarray) -> torch.Tensor: """Converts an image to a normalized grayscale tensor for feature extraction.""" if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) tensor = torch.from_numpy(image).float() / 255.0 return tensor[None, None, ...].to(self.device) def compute_homography(self, uav_image: np.ndarray, satellite_tile: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], int, float]: """ Computes homography transformation from UAV to satellite. Returns: (Homography Matrix, Inlier Mask, Total Correspondences, Reprojection Error) """ tensor_uav = self._preprocess_image(uav_image) tensor_sat = self._preprocess_image(satellite_tile) with torch.no_grad(): feats_uav = self.extractor.extract(tensor_uav) feats_sat = self.extractor.extract(tensor_sat) matches = self.matcher({'image0': feats_uav, 'image1': feats_sat}) feats0, feats1, matches01 = [rbd(x) for x in [feats_uav, feats_sat, matches]] kpts_uav = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy() kpts_sat = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy() total_correspondences = len(kpts_uav) if total_correspondences < 15: return None, None, total_correspondences, 0.0 H, mask = cv2.findHomography(kpts_uav, kpts_sat, cv2.RANSAC, 5.0) reprojection_error = 0.0 if H is not None and mask is not None and mask.sum() > 0: # Calculate Mean Reprojection Error (MRE) for inliers (AC-10 requirement) inliers_uav = kpts_uav[mask.ravel() == 1] inliers_sat = kpts_sat[mask.ravel() == 1] proj_uav = cv2.perspectiveTransform(inliers_uav.reshape(-1, 1, 2), H).reshape(-1, 2) errors = np.linalg.norm(proj_uav - inliers_sat, axis=1) reprojection_error = float(np.mean(errors)) return H, mask, total_correspondences, reprojection_error def extract_gps_from_alignment(self, homography: np.ndarray, tile_bounds: TileBounds, image_center: Tuple[int, int]) -> GPSPoint: """ Extracts GPS coordinates by projecting the UAV center pixel onto the satellite tile and interpolating via Ground Sampling Distance (GSD). """ cx, cy = image_center pt = np.array([cx, cy, 1.0], dtype=np.float64) sat_pt = homography @ pt sat_x, sat_y = sat_pt[0] / sat_pt[2], sat_pt[1] / sat_pt[2] # Linear interpolation based on Web Mercator projection approximations meters_per_deg_lat = 111319.9 meters_per_deg_lon = meters_per_deg_lat * math.cos(math.radians(tile_bounds.nw.lat)) delta_lat = (sat_y * tile_bounds.gsd) / meters_per_deg_lat delta_lon = (sat_x * tile_bounds.gsd) / meters_per_deg_lon lat = tile_bounds.nw.lat - delta_lat lon = tile_bounds.nw.lon + delta_lon return GPSPoint(lat=lat, lon=lon) def compute_match_confidence(self, inlier_count: int, total_correspondences: int, reprojection_error: float) -> float: """Evaluates match reliability based on inliers and geometric reprojection error.""" if total_correspondences == 0: return 0.0 inlier_ratio = inlier_count / total_correspondences # High confidence requires low reprojection error (< 1.0px) for AC-10 compliance if inlier_count > 50 and reprojection_error < 1.0: return min(1.0, 0.8 + 0.2 * inlier_ratio) elif inlier_count > 25: return min(0.8, 0.5 + 0.3 * inlier_ratio) return max(0.0, 0.4 * inlier_ratio) def align_to_satellite(self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds = None) -> Optional[AlignmentResult]: """Aligns a single UAV image to a satellite tile.""" H, mask, total, mre = self.compute_homography(uav_image, satellite_tile) if H is None or mask is None: return None inliers = int(mask.sum()) if inliers < 15: return None h, w = uav_image.shape[:2] center = (w // 2, h // 2) gps = self.extract_gps_from_alignment(H, tile_bounds, center) if tile_bounds else GPSPoint(lat=0.0, lon=0.0) conf = self.compute_match_confidence(inliers, total, mre) # Provide a mocked 4x4 matrix for downstream Sim3 compatability transform = np.eye(4) transform[:2, :2] = H[:2, :2] transform[0, 3] = H[0, 2] transform[1, 3] = H[1, 2] return AlignmentResult( matched=True, homography=H, transform=transform, gps_center=gps, confidence=conf, inlier_count=inliers, total_correspondences=total, reprojection_error=mre ) def match_chunk_homography(self, chunk_images: List[np.ndarray], satellite_tile: np.ndarray) -> Optional[np.ndarray]: """Computes homography for a chunk by evaluating the center representative frame.""" center_idx = len(chunk_images) // 2 H, _, _, _ = self.compute_homography(chunk_images[center_idx], satellite_tile) return H