mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 22:46:36 +00:00
214 lines
8.5 KiB
Python
214 lines
8.5 KiB
Python
import cv2
|
|
import torch
|
|
import math
|
|
import numpy as np
|
|
import logging
|
|
from typing import List, Optional, Tuple
|
|
from pydantic import BaseModel
|
|
|
|
import os
|
|
|
|
USE_MOCK_MODELS = os.environ.get("USE_MOCK_MODELS", "0") == "1"
|
|
|
|
if USE_MOCK_MODELS:
|
|
class SuperPoint(torch.nn.Module):
|
|
def __init__(self, **kwargs): super().__init__()
|
|
def forward(self, x):
|
|
b, _, h, w = x.shape
|
|
kpts = torch.rand(b, 50, 2, device=x.device)
|
|
kpts[..., 0] *= w
|
|
kpts[..., 1] *= h
|
|
return {'keypoints': kpts, 'descriptors': torch.rand(b, 256, 50, device=x.device), 'scores': torch.rand(b, 50, device=x.device)}
|
|
class LightGlue(torch.nn.Module):
|
|
def __init__(self, **kwargs): super().__init__()
|
|
def forward(self, data):
|
|
b = data['image0']['keypoints'].shape[0]
|
|
matches = torch.stack([torch.arange(25), torch.arange(25)], dim=-1).unsqueeze(0).repeat(b, 1, 1).to(data['image0']['keypoints'].device)
|
|
return {'matches': matches, 'matching_scores': torch.rand(b, 25, device=data['image0']['keypoints'].device)}
|
|
def rbd(data):
|
|
return {k: v[0] for k, v in data.items()}
|
|
else:
|
|
# Requires: pip install lightglue
|
|
from lightglue import LightGlue, SuperPoint
|
|
from lightglue.utils import rbd
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Data Models ---
|
|
|
|
class GPSPoint(BaseModel):
|
|
lat: float
|
|
lon: float
|
|
|
|
class TileBounds(BaseModel):
|
|
nw: GPSPoint
|
|
ne: GPSPoint
|
|
sw: GPSPoint
|
|
se: GPSPoint
|
|
center: GPSPoint
|
|
gsd: float # Ground Sampling Distance (meters/pixel)
|
|
|
|
class Sim3Transform(BaseModel):
|
|
translation: np.ndarray
|
|
rotation: np.ndarray
|
|
scale: float
|
|
|
|
class Config: arbitrary_types_allowed = True
|
|
|
|
class AlignmentResult(BaseModel):
|
|
matched: bool
|
|
homography: np.ndarray
|
|
transform: np.ndarray # 4x4 matrix for pipeline compatibility
|
|
gps_center: GPSPoint
|
|
confidence: float
|
|
inlier_count: int
|
|
total_correspondences: int
|
|
reprojection_error: float
|
|
|
|
class Config: arbitrary_types_allowed = True
|
|
|
|
class ChunkAlignmentResult(BaseModel):
|
|
matched: bool
|
|
chunk_id: str
|
|
chunk_center_gps: GPSPoint
|
|
rotation_angle: float
|
|
confidence: float
|
|
inlier_count: int
|
|
transform: Sim3Transform
|
|
reprojection_error: float
|
|
|
|
class Config: arbitrary_types_allowed = True
|
|
|
|
# --- Implementation ---
|
|
|
|
class MetricRefinement:
|
|
"""
|
|
F09: Metric Refinement Module.
|
|
Performs dense cross-view geo-localization between UAV images and satellite tiles.
|
|
Computes homography mappings, Mean Reprojection Error (MRE), and exact GPS coordinates.
|
|
"""
|
|
def __init__(self, device: str = "cuda", max_keypoints: int = 2048):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
logger.info(f"Initializing Metric Refinement (SuperPoint+LightGlue) on {self.device}")
|
|
|
|
# Using SuperPoint + LightGlue as the high-accuracy "Fine Matcher"
|
|
self.extractor = SuperPoint(max_num_keypoints=max_keypoints).eval().to(self.device)
|
|
self.matcher = LightGlue(features='superpoint', depth_confidence=0.9).eval().to(self.device)
|
|
|
|
def _preprocess_image(self, image: np.ndarray) -> torch.Tensor:
|
|
"""Converts an image to a normalized grayscale tensor for feature extraction."""
|
|
if len(image.shape) == 3:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
tensor = torch.from_numpy(image).float() / 255.0
|
|
return tensor[None, None, ...].to(self.device)
|
|
|
|
def compute_homography(self, uav_image: np.ndarray, satellite_tile: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], int, float]:
|
|
"""
|
|
Computes homography transformation from UAV to satellite.
|
|
Returns: (Homography Matrix, Inlier Mask, Total Correspondences, Reprojection Error)
|
|
"""
|
|
tensor_uav = self._preprocess_image(uav_image)
|
|
tensor_sat = self._preprocess_image(satellite_tile)
|
|
|
|
with torch.no_grad():
|
|
feats_uav = self.extractor.extract(tensor_uav)
|
|
feats_sat = self.extractor.extract(tensor_sat)
|
|
matches = self.matcher({'image0': feats_uav, 'image1': feats_sat})
|
|
|
|
feats0, feats1, matches01 = [rbd(x) for x in [feats_uav, feats_sat, matches]]
|
|
kpts_uav = feats0['keypoints'][matches01['matches'][..., 0]].cpu().numpy()
|
|
kpts_sat = feats1['keypoints'][matches01['matches'][..., 1]].cpu().numpy()
|
|
|
|
total_correspondences = len(kpts_uav)
|
|
|
|
if total_correspondences < 15:
|
|
return None, None, total_correspondences, 0.0
|
|
|
|
H, mask = cv2.findHomography(kpts_uav, kpts_sat, cv2.RANSAC, 5.0)
|
|
|
|
reprojection_error = 0.0
|
|
if H is not None and mask is not None and mask.sum() > 0:
|
|
# Calculate Mean Reprojection Error (MRE) for inliers (AC-10 requirement)
|
|
inliers_uav = kpts_uav[mask.ravel() == 1]
|
|
inliers_sat = kpts_sat[mask.ravel() == 1]
|
|
|
|
proj_uav = cv2.perspectiveTransform(inliers_uav.reshape(-1, 1, 2), H).reshape(-1, 2)
|
|
errors = np.linalg.norm(proj_uav - inliers_sat, axis=1)
|
|
reprojection_error = float(np.mean(errors))
|
|
|
|
return H, mask, total_correspondences, reprojection_error
|
|
|
|
def extract_gps_from_alignment(self, homography: np.ndarray, tile_bounds: TileBounds, image_center: Tuple[int, int]) -> GPSPoint:
|
|
"""
|
|
Extracts GPS coordinates by projecting the UAV center pixel onto the satellite tile
|
|
and interpolating via Ground Sampling Distance (GSD).
|
|
"""
|
|
cx, cy = image_center
|
|
pt = np.array([cx, cy, 1.0], dtype=np.float64)
|
|
sat_pt = homography @ pt
|
|
sat_x, sat_y = sat_pt[0] / sat_pt[2], sat_pt[1] / sat_pt[2]
|
|
|
|
# Linear interpolation based on Web Mercator projection approximations
|
|
meters_per_deg_lat = 111319.9
|
|
meters_per_deg_lon = meters_per_deg_lat * math.cos(math.radians(tile_bounds.nw.lat))
|
|
|
|
delta_lat = (sat_y * tile_bounds.gsd) / meters_per_deg_lat
|
|
delta_lon = (sat_x * tile_bounds.gsd) / meters_per_deg_lon
|
|
|
|
lat = tile_bounds.nw.lat - delta_lat
|
|
lon = tile_bounds.nw.lon + delta_lon
|
|
|
|
return GPSPoint(lat=lat, lon=lon)
|
|
|
|
def compute_match_confidence(self, inlier_count: int, total_correspondences: int, reprojection_error: float) -> float:
|
|
"""Evaluates match reliability based on inliers and geometric reprojection error."""
|
|
if total_correspondences == 0: return 0.0
|
|
|
|
inlier_ratio = inlier_count / total_correspondences
|
|
|
|
# High confidence requires low reprojection error (< 1.0px) for AC-10 compliance
|
|
if inlier_count > 50 and reprojection_error < 1.0:
|
|
return min(1.0, 0.8 + 0.2 * inlier_ratio)
|
|
elif inlier_count > 25:
|
|
return min(0.8, 0.5 + 0.3 * inlier_ratio)
|
|
return max(0.0, 0.4 * inlier_ratio)
|
|
|
|
def align_to_satellite(self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds = None) -> Optional[AlignmentResult]:
|
|
"""Aligns a single UAV image to a satellite tile."""
|
|
H, mask, total, mre = self.compute_homography(uav_image, satellite_tile)
|
|
|
|
if H is None or mask is None:
|
|
return None
|
|
|
|
inliers = int(mask.sum())
|
|
if inliers < 15:
|
|
return None
|
|
|
|
h, w = uav_image.shape[:2]
|
|
center = (w // 2, h // 2)
|
|
|
|
gps = self.extract_gps_from_alignment(H, tile_bounds, center) if tile_bounds else GPSPoint(lat=0.0, lon=0.0)
|
|
conf = self.compute_match_confidence(inliers, total, mre)
|
|
|
|
# Provide a mocked 4x4 matrix for downstream Sim3 compatability
|
|
transform = np.eye(4)
|
|
transform[:2, :2] = H[:2, :2]
|
|
transform[0, 3] = H[0, 2]
|
|
transform[1, 3] = H[1, 2]
|
|
|
|
return AlignmentResult(
|
|
matched=True,
|
|
homography=H,
|
|
transform=transform,
|
|
gps_center=gps,
|
|
confidence=conf,
|
|
inlier_count=inliers,
|
|
total_correspondences=total,
|
|
reprojection_error=mre
|
|
)
|
|
|
|
def match_chunk_homography(self, chunk_images: List[np.ndarray], satellite_tile: np.ndarray) -> Optional[np.ndarray]:
|
|
"""Computes homography for a chunk by evaluating the center representative frame."""
|
|
center_idx = len(chunk_images) // 2
|
|
H, _, _, _ = self.compute_homography(chunk_images[center_idx], satellite_tile)
|
|
return H |