Files
gps-denied-onboard/f09_metric_refinement.py
T
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

214 lines
8.5 KiB
Python

import cv2
import torch
import math
import numpy as np
import logging
from typing import List, Optional, Tuple
from pydantic import BaseModel
import os
USE_MOCK_MODELS = os.environ.get("USE_MOCK_MODELS", "0") == "1"
if USE_MOCK_MODELS:
class SuperPoint(torch.nn.Module):
def __init__(self, **kwargs): super().__init__()
def forward(self, x):
b, _, h, w = x.shape
kpts = torch.rand(b, 50, 2, device=x.device)
kpts[..., 0] *= w
kpts[..., 1] *= h
return {'keypoints': kpts, 'descriptors': torch.rand(b, 256, 50, device=x.device), 'scores': torch.rand(b, 50, device=x.device)}
class LightGlue(torch.nn.Module):
def __init__(self, **kwargs): super().__init__()
def forward(self, data):
b = data['image0']['keypoints'].shape[0]
matches = torch.stack([torch.arange(25), torch.arange(25)], dim=-1).unsqueeze(0).repeat(b, 1, 1).to(data['image0']['keypoints'].device)
return {'matches': matches, 'matching_scores': torch.rand(b, 25, device=data['image0']['keypoints'].device)}
def rbd(data):
return {k: v[0] for k, v in data.items()}
else:
# Requires: pip install lightglue
from lightglue import LightGlue, SuperPoint
from lightglue.utils import rbd
logger = logging.getLogger(__name__)
# --- Data Models ---
class GPSPoint(BaseModel):
lat: float
lon: float
class TileBounds(BaseModel):
nw: GPSPoint
ne: GPSPoint
sw: GPSPoint
se: GPSPoint
center: GPSPoint
gsd: float # Ground Sampling Distance (meters/pixel)
class Sim3Transform(BaseModel):
translation: np.ndarray
rotation: np.ndarray
scale: float
class Config: arbitrary_types_allowed = True
class AlignmentResult(BaseModel):
matched: bool
homography: np.ndarray
transform: np.ndarray # 4x4 matrix for pipeline compatibility
gps_center: GPSPoint
confidence: float
inlier_count: int
total_correspondences: int
reprojection_error: float
class Config: arbitrary_types_allowed = True
class ChunkAlignmentResult(BaseModel):
matched: bool
chunk_id: str
chunk_center_gps: GPSPoint
rotation_angle: float
confidence: float
inlier_count: int
transform: Sim3Transform
reprojection_error: float
class Config: arbitrary_types_allowed = True
# --- Implementation ---
class MetricRefinement:
"""
F09: Metric Refinement Module.
Performs dense cross-view geo-localization between UAV images and satellite tiles.
Computes homography mappings, Mean Reprojection Error (MRE), and exact GPS coordinates.
"""
def __init__(self, device: str = "cuda", max_keypoints: int = 2048):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing Metric Refinement (SuperPoint+LightGlue) on {self.device}")
# Using SuperPoint + LightGlue as the high-accuracy "Fine Matcher"
self.extractor = SuperPoint(max_num_keypoints=max_keypoints).eval().to(self.device)
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
def _preprocess_image(self, image: np.ndarray) -> torch.Tensor:
"""Converts an image to a normalized grayscale tensor for feature extraction."""
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
tensor = torch.from_numpy(image).float() / 255.0
return tensor[None, None, ...].to(self.device)
def compute_homography(self, uav_image: np.ndarray, satellite_tile: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], int, float]:
"""
Computes homography transformation from UAV to satellite.
Returns: (Homography Matrix, Inlier Mask, Total Correspondences, Reprojection Error)
"""
tensor_uav = self._preprocess_image(uav_image)
tensor_sat = self._preprocess_image(satellite_tile)
with torch.no_grad():
feats_uav = self.extractor.extract(tensor_uav)
feats_sat = self.extractor.extract(tensor_sat)
matches = self.matcher({'image0': feats_uav, 'image1': feats_sat})
feats0, feats1, matches01 = [rbd(x) for x in [feats_uav, feats_sat, matches]]
kpts_uav = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
kpts_sat = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
total_correspondences = len(kpts_uav)
if total_correspondences < 15:
return None, None, total_correspondences, 0.0
H, mask = cv2.findHomography(kpts_uav, kpts_sat, cv2.RANSAC, 5.0)
reprojection_error = 0.0
if H is not None and mask is not None and mask.sum() > 0:
# Calculate Mean Reprojection Error (MRE) for inliers (AC-10 requirement)
inliers_uav = kpts_uav[mask.ravel() == 1]
inliers_sat = kpts_sat[mask.ravel() == 1]
proj_uav = cv2.perspectiveTransform(inliers_uav.reshape(-1, 1, 2), H).reshape(-1, 2)
errors = np.linalg.norm(proj_uav - inliers_sat, axis=1)
reprojection_error = float(np.mean(errors))
return H, mask, total_correspondences, reprojection_error
def extract_gps_from_alignment(self, homography: np.ndarray, tile_bounds: TileBounds, image_center: Tuple[int, int]) -> GPSPoint:
"""
Extracts GPS coordinates by projecting the UAV center pixel onto the satellite tile
and interpolating via Ground Sampling Distance (GSD).
"""
cx, cy = image_center
pt = np.array([cx, cy, 1.0], dtype=np.float64)
sat_pt = homography @ pt
sat_x, sat_y = sat_pt[0] / sat_pt[2], sat_pt[1] / sat_pt[2]
# Linear interpolation based on Web Mercator projection approximations
meters_per_deg_lat = 111319.9
meters_per_deg_lon = meters_per_deg_lat * math.cos(math.radians(tile_bounds.nw.lat))
delta_lat = (sat_y * tile_bounds.gsd) / meters_per_deg_lat
delta_lon = (sat_x * tile_bounds.gsd) / meters_per_deg_lon
lat = tile_bounds.nw.lat - delta_lat
lon = tile_bounds.nw.lon + delta_lon
return GPSPoint(lat=lat, lon=lon)
def compute_match_confidence(self, inlier_count: int, total_correspondences: int, reprojection_error: float) -> float:
"""Evaluates match reliability based on inliers and geometric reprojection error."""
if total_correspondences == 0: return 0.0
inlier_ratio = inlier_count / total_correspondences
# High confidence requires low reprojection error (< 1.0px) for AC-10 compliance
if inlier_count > 50 and reprojection_error < 1.0:
return min(1.0, 0.8 + 0.2 * inlier_ratio)
elif inlier_count > 25:
return min(0.8, 0.5 + 0.3 * inlier_ratio)
return max(0.0, 0.4 * inlier_ratio)
def align_to_satellite(self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds = None) -> Optional[AlignmentResult]:
"""Aligns a single UAV image to a satellite tile."""
H, mask, total, mre = self.compute_homography(uav_image, satellite_tile)
if H is None or mask is None:
return None
inliers = int(mask.sum())
if inliers < 15:
return None
h, w = uav_image.shape[:2]
center = (w // 2, h // 2)
gps = self.extract_gps_from_alignment(H, tile_bounds, center) if tile_bounds else GPSPoint(lat=0.0, lon=0.0)
conf = self.compute_match_confidence(inliers, total, mre)
# Provide a mocked 4x4 matrix for downstream Sim3 compatability
transform = np.eye(4)
transform[:2, :2] = H[:2, :2]
transform[0, 3] = H[0, 2]
transform[1, 3] = H[1, 2]
return AlignmentResult(
matched=True,
homography=H,
transform=transform,
gps_center=gps,
confidence=conf,
inlier_count=inliers,
total_correspondences=total,
reprojection_error=mre
)
def match_chunk_homography(self, chunk_images: List[np.ndarray], satellite_tile: np.ndarray) -> Optional[np.ndarray]:
"""Computes homography for a chunk by evaluating the center representative frame."""
center_idx = len(chunk_images) // 2
H, _, _, _ = self.compute_homography(chunk_images[center_idx], satellite_tile)
return H