mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 03:46:37 +00:00
feat: stage8 — Global Place Recognition and Metric Refinement
This commit is contained in:
@@ -17,6 +17,8 @@
|
||||
| **Менеджер ротацій (F06)** | Оберти 360° блоками по 30° для підбору орієнтації; трекінг історії курсу з виявленням різких поворотів (>45°). |
|
||||
| **Model Manager (F16)** | Архітектура завантаження ML моделей (Mock/Fallback). |
|
||||
| **Візуальна Одометрія (F07)** | Суперпоінт / LightGlue імітація. OpenCV (`findEssentialMat` + RANSAC + `recoverPose`) для розрахунку відносного руху між кадрами без відомого масштабу. |
|
||||
| **Global Place Recognition (F08)** | Розпізнавання місцевості (DINOv2/AnyLoc мок), використання імпровізованого Faiss-індексу для ранжирування кандидатів. |
|
||||
| **Metric Refinement (F09)** | Вимірювання абсолютної GPS-координати (LiteSAM мок) через гомографію з супутниковим знімком та bounds scaling. |
|
||||
| **Граф поз (VO/GPR)** | GTSAM (Python) - очікується в наступних етапах |
|
||||
|
||||
## Швидкий старт
|
||||
|
||||
@@ -91,9 +91,9 @@
|
||||
### Етап 7 — Model manager та послідовний VO ✅
|
||||
- Завантаження локальних вагів (SuperPoint+LightGlue - Mock), побудова ланцюжка відносних оцінок (`SequentialVisualOdometry`).
|
||||
|
||||
### Етап 8 — Глобальне місце та метричне уточнення
|
||||
- Кросс-вью вирівнювання до тайла Google Maps.
|
||||
|
||||
### Етап 8 — Глобальне місце та метричне уточнення ✅
|
||||
- Кросс-вью вирівнювання до тайла Google Maps (F09 LiteSAM Mock).
|
||||
- Отримання абсолютної GPS координати через Global Place Recognition (F08 AnyLoc/Faiss Mock).
|
||||
### Етап 9 — Фактор-граф і чанки (GTSAM)
|
||||
- Побудова чинників (відносні VO + абсолютні якорі). Оптимізація траєкторії через GTSAM.
|
||||
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Global Place Recognition (Component F08)."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.models import IModelManager
|
||||
from gps_denied.schemas.flight import GPSPoint
|
||||
from gps_denied.schemas.gpr import DatabaseMatch, TileCandidate
|
||||
from gps_denied.schemas.satellite import TileBounds
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IGlobalPlaceRecognition(ABC):
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int) -> List[TileCandidate]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class GlobalPlaceRecognition(IGlobalPlaceRecognition):
|
||||
"""AnyLoc (DINOv2) coarse localization component."""
|
||||
|
||||
def __init__(self, model_manager: IModelManager):
|
||||
self.model_manager = model_manager
|
||||
|
||||
# Mock Faiss Index - stores descriptors and metadata
|
||||
self._mock_db_descriptors: np.ndarray | None = None
|
||||
self._mock_db_metadata: Dict[int, dict] = {}
|
||||
self._is_loaded = False
|
||||
|
||||
def compute_location_descriptor(self, image: np.ndarray) -> np.ndarray:
|
||||
engine = self.model_manager.get_inference_engine("DINOv2")
|
||||
descriptor = engine.infer(image)
|
||||
return descriptor
|
||||
|
||||
def compute_chunk_descriptor(self, chunk_images: List[np.ndarray]) -> np.ndarray:
|
||||
if not chunk_images:
|
||||
return np.zeros(4096)
|
||||
|
||||
descriptors = [self.compute_location_descriptor(img) for img in chunk_images]
|
||||
# Mean aggregation
|
||||
agg = np.mean(descriptors, axis=0)
|
||||
# L2-normalize
|
||||
return agg / max(1e-12, np.linalg.norm(agg))
|
||||
|
||||
def load_index(self, flight_id: str, index_path: str) -> bool:
|
||||
"""
|
||||
Mock loading Faiss index.
|
||||
In reality, it reads index_path. Here we just create synthetic data.
|
||||
"""
|
||||
logger.info(f"Loading semantic index from {index_path} for flight {flight_id}")
|
||||
|
||||
# Create 1000 random tiles in DB
|
||||
db_size = 1000
|
||||
dim = 4096
|
||||
|
||||
# Generate random normalized descriptors
|
||||
vecs = np.random.rand(db_size, dim)
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
self._mock_db_descriptors = vecs / norms
|
||||
|
||||
# Generate dummy metadata
|
||||
for i in range(db_size):
|
||||
self._mock_db_metadata[i] = {
|
||||
"tile_id": f"tile_sync_{i}",
|
||||
"gps_center": GPSPoint(lat=49.0 + np.random.rand(), lon=32.0 + np.random.rand()),
|
||||
"bounds": TileBounds(
|
||||
nw=GPSPoint(lat=49.1, lon=32.0),
|
||||
ne=GPSPoint(lat=49.1, lon=32.1),
|
||||
sw=GPSPoint(lat=49.0, lon=32.0),
|
||||
se=GPSPoint(lat=49.0, lon=32.1),
|
||||
center=GPSPoint(lat=49.05, lon=32.05),
|
||||
gsd=0.3
|
||||
)
|
||||
}
|
||||
|
||||
self._is_loaded = True
|
||||
return True
|
||||
|
||||
def query_database(self, descriptor: np.ndarray, top_k: int) -> List[DatabaseMatch]:
|
||||
if not self._is_loaded or self._mock_db_descriptors is None:
|
||||
logger.error("Faiss index is not loaded.")
|
||||
return []
|
||||
|
||||
# Mock Faiss L2 distance calculation
|
||||
# L2 distance: ||A-B||^2
|
||||
diff = self._mock_db_descriptors - descriptor
|
||||
distances = np.sum(diff**2, axis=1)
|
||||
|
||||
# Top-K smallest distances
|
||||
top_indices = np.argsort(distances)[:top_k]
|
||||
|
||||
matches = []
|
||||
for idx in top_indices:
|
||||
dist = float(distances[idx])
|
||||
sim = 1.0 / (1.0 + dist) # convert distance to [0,1] similarity
|
||||
|
||||
meta = self._mock_db_metadata[idx]
|
||||
|
||||
matches.append(DatabaseMatch(
|
||||
index=int(idx),
|
||||
tile_id=meta["tile_id"],
|
||||
distance=dist,
|
||||
similarity_score=sim
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def rank_candidates(self, candidates: List[TileCandidate]) -> List[TileCandidate]:
|
||||
"""Rank by spatial score and similarity."""
|
||||
# Right now we just return them sorted by similarity (already ranked by Faiss largely)
|
||||
return sorted(candidates, key=lambda c: c.similarity_score, reverse=True)
|
||||
|
||||
def _matches_to_candidates(self, matches: List[DatabaseMatch]) -> List[TileCandidate]:
|
||||
candidates = []
|
||||
for rank, match in enumerate(matches, 1):
|
||||
meta = self._mock_db_metadata[match.index]
|
||||
|
||||
candidates.append(TileCandidate(
|
||||
tile_id=match.tile_id,
|
||||
gps_center=meta["gps_center"],
|
||||
bounds=meta["bounds"],
|
||||
similarity_score=match.similarity_score,
|
||||
rank=rank
|
||||
))
|
||||
return self.rank_candidates(candidates)
|
||||
|
||||
def retrieve_candidate_tiles(self, image: np.ndarray, top_k: int = 5) -> List[TileCandidate]:
|
||||
desc = self.compute_location_descriptor(image)
|
||||
matches = self.query_database(desc, top_k)
|
||||
return self._matches_to_candidates(matches)
|
||||
|
||||
def retrieve_candidate_tiles_for_chunk(self, chunk_images: List[np.ndarray], top_k: int = 5) -> List[TileCandidate]:
|
||||
desc = self.compute_chunk_descriptor(chunk_images)
|
||||
matches = self.query_database(desc, top_k)
|
||||
return self._matches_to_candidates(matches)
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Metric Refinement (Component F09)."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.models import IModelManager
|
||||
from gps_denied.schemas.flight import GPSPoint
|
||||
from gps_denied.schemas.metric import AlignmentResult, ChunkAlignmentResult, Sim3Transform
|
||||
from gps_denied.schemas.satellite import TileBounds
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMetricRefinement(ABC):
|
||||
@abstractmethod
|
||||
def align_to_satellite(self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds) -> Optional[AlignmentResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_homography(self, uav_image: np.ndarray, satellite_tile: np.ndarray) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_gps_from_alignment(self, homography: np.ndarray, tile_bounds: TileBounds, image_center: Tuple[int, int]) -> GPSPoint:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_match_confidence(self, alignment: AlignmentResult) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def align_chunk_to_satellite(self, chunk_images: List[np.ndarray], satellite_tile: np.ndarray, tile_bounds: TileBounds) -> Optional[ChunkAlignmentResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_chunk_homography(self, chunk_images: List[np.ndarray], satellite_tile: np.ndarray) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
|
||||
class MetricRefinement(IMetricRefinement):
|
||||
"""LiteSAM-based alignment logic."""
|
||||
|
||||
def __init__(self, model_manager: IModelManager):
|
||||
self.model_manager = model_manager
|
||||
|
||||
def compute_homography(self, uav_image: np.ndarray, satellite_tile: np.ndarray) -> Optional[np.ndarray]:
|
||||
engine = self.model_manager.get_inference_engine("LiteSAM")
|
||||
# In reality we pass both images, for mock we just invoke to get generated format
|
||||
res = engine.infer({"img1": uav_image, "img2": satellite_tile})
|
||||
|
||||
if res["inlier_count"] < 15:
|
||||
return None
|
||||
|
||||
return res["homography"]
|
||||
|
||||
def extract_gps_from_alignment(self, homography: np.ndarray, tile_bounds: TileBounds, image_center: Tuple[int, int]) -> GPSPoint:
|
||||
# UAV image center
|
||||
cx, cy = image_center
|
||||
# Apply homography
|
||||
pt = np.array([cx, cy, 1.0])
|
||||
# transformed = H * pt
|
||||
transformed = homography @ pt
|
||||
transformed = transformed / transformed[2]
|
||||
|
||||
tx, ty = transformed[0], transformed[1]
|
||||
|
||||
# Approximate GPS mapping using bounds
|
||||
# ty maps to latitude (ty=0 is North, ty=Height is South)
|
||||
# tx maps to longitude (tx=0 is West, tx=Width is East)
|
||||
# We assume standard 256x256 tiles for this mock calculation
|
||||
tile_size = 256.0
|
||||
|
||||
lat_span = tile_bounds.nw.lat - tile_bounds.sw.lat
|
||||
lon_span = tile_bounds.ne.lon - tile_bounds.nw.lon
|
||||
|
||||
# Calculate offsets
|
||||
# If ty is down, lat decreases
|
||||
lat_rel = (tile_size - ty) / tile_size
|
||||
lon_rel = tx / tile_size
|
||||
|
||||
target_lat = tile_bounds.sw.lat + (lat_span * lat_rel)
|
||||
target_lon = tile_bounds.nw.lon + (lon_span * lon_rel)
|
||||
|
||||
return GPSPoint(lat=target_lat, lon=target_lon)
|
||||
|
||||
def align_to_satellite(self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds) -> Optional[AlignmentResult]:
|
||||
engine = self.model_manager.get_inference_engine("LiteSAM")
|
||||
|
||||
res = engine.infer({"img1": uav_image, "img2": satellite_tile})
|
||||
|
||||
if res["inlier_count"] < 15:
|
||||
return None
|
||||
|
||||
h, w = uav_image.shape[:2] if hasattr(uav_image, "shape") else (480, 640)
|
||||
gps = self.extract_gps_from_alignment(res["homography"], tile_bounds, (w // 2, h // 2))
|
||||
|
||||
align = AlignmentResult(
|
||||
matched=True,
|
||||
homography=res["homography"],
|
||||
gps_center=gps,
|
||||
confidence=res["confidence"],
|
||||
inlier_count=res["inlier_count"],
|
||||
total_correspondences=100, # Mock total
|
||||
reprojection_error=np.random.rand() * 2.0 # mock 0..2 px
|
||||
)
|
||||
|
||||
return align if self.compute_match_confidence(align) > 0.5 else None
|
||||
|
||||
def compute_match_confidence(self, alignment: AlignmentResult) -> float:
|
||||
# Complex heuristic combining inliers, reprojection error
|
||||
score = alignment.confidence
|
||||
# Penalty for high reproj error
|
||||
if alignment.reprojection_error > 2.0:
|
||||
score -= 0.2
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
def match_chunk_homography(self, chunk_images: List[np.ndarray], satellite_tile: np.ndarray) -> Optional[np.ndarray]:
|
||||
# Aggregate logic is complex, for mock we just use the first image's match
|
||||
if not chunk_images:
|
||||
return None
|
||||
return self.compute_homography(chunk_images[0], satellite_tile)
|
||||
|
||||
def align_chunk_to_satellite(self, chunk_images: List[np.ndarray], satellite_tile: np.ndarray, tile_bounds: TileBounds) -> Optional[ChunkAlignmentResult]:
|
||||
if not chunk_images:
|
||||
return None
|
||||
|
||||
engine = self.model_manager.get_inference_engine("LiteSAM")
|
||||
res = engine.infer({"img1": chunk_images[0], "img2": satellite_tile})
|
||||
|
||||
# Demands higher inliners for chunk
|
||||
if res["inlier_count"] < 30:
|
||||
return None
|
||||
|
||||
h, w = chunk_images[0].shape[:2] if hasattr(chunk_images[0], "shape") else (480, 640)
|
||||
gps = self.extract_gps_from_alignment(res["homography"], tile_bounds, (w // 2, h // 2))
|
||||
|
||||
# Fake sim3
|
||||
sim3 = Sim3Transform(
|
||||
translation=np.array([10., 0., 0.]),
|
||||
rotation=np.eye(3),
|
||||
scale=1.0
|
||||
)
|
||||
|
||||
chunk_align = ChunkAlignmentResult(
|
||||
matched=True,
|
||||
chunk_id="chunk1",
|
||||
chunk_center_gps=gps,
|
||||
rotation_angle=0.0,
|
||||
confidence=res["confidence"],
|
||||
inlier_count=res["inlier_count"],
|
||||
transform=sim3,
|
||||
reprojection_error=1.0
|
||||
)
|
||||
|
||||
return chunk_align
|
||||
@@ -75,9 +75,31 @@ class MockInferenceEngine(InferenceEngine):
|
||||
"keypoints2": kp2[indices2]
|
||||
}
|
||||
|
||||
elif self.model_name == "DINOv2":
|
||||
# Mock generating 4096-dim VLAD descriptor
|
||||
dim = 4096
|
||||
desc = np.random.rand(dim)
|
||||
# L2 normalize
|
||||
return desc / np.linalg.norm(desc)
|
||||
|
||||
elif self.model_name == "LiteSAM":
|
||||
# Just a placeholder for F09
|
||||
pass
|
||||
# Mock LiteSAM matching between UAV and satellite image
|
||||
# Returns a generated Homography and valid correspondences count
|
||||
|
||||
# Simulated 3x3 homography matrix (identity with minor translation)
|
||||
homography = np.eye(3, dtype=np.float64)
|
||||
homography[0, 2] = np.random.uniform(-50, 50)
|
||||
homography[1, 2] = np.random.uniform(-50, 50)
|
||||
|
||||
# Simple simulation: 80% chance to "match"
|
||||
matched = np.random.rand() > 0.2
|
||||
inliers = np.random.randint(20, 100) if matched else np.random.randint(0, 15)
|
||||
|
||||
return {
|
||||
"homography": homography,
|
||||
"inlier_count": inliers,
|
||||
"confidence": min(1.0, inliers / 100.0)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown mock model: {self.model_name}")
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Global Place Recognition schemas (Component F08)."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gps_denied.schemas.flight import GPSPoint
|
||||
from gps_denied.schemas.satellite import TileBounds
|
||||
|
||||
|
||||
class TileCandidate(BaseModel):
|
||||
"""A matched satellite tile candidate."""
|
||||
tile_id: str
|
||||
gps_center: GPSPoint
|
||||
bounds: TileBounds
|
||||
similarity_score: float
|
||||
rank: int
|
||||
spatial_score: Optional[float] = None
|
||||
|
||||
|
||||
class DatabaseMatch(BaseModel):
|
||||
"""Raw index match from Faiss queries."""
|
||||
index: int
|
||||
tile_id: str
|
||||
distance: float
|
||||
similarity_score: float
|
||||
|
||||
|
||||
class SatelliteTile(BaseModel):
|
||||
"""A stored satellite tile representation for indexing."""
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
tile_id: str
|
||||
image: np.ndarray
|
||||
gps_center: GPSPoint
|
||||
bounds: TileBounds
|
||||
descriptor: Optional[np.ndarray] = None
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Metric Refinement schemas (Component F09)."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gps_denied.schemas.flight import GPSPoint
|
||||
|
||||
|
||||
class AlignmentResult(BaseModel):
|
||||
"""Result of aligning a UAV image to a single satellite tile."""
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
matched: bool
|
||||
homography: np.ndarray # (3, 3)
|
||||
gps_center: GPSPoint
|
||||
confidence: float
|
||||
inlier_count: int
|
||||
total_correspondences: int
|
||||
reprojection_error: float # Mean error in pixels
|
||||
|
||||
|
||||
class Sim3Transform(BaseModel):
|
||||
"""Sim(3) transformation: scale, rotation, translation."""
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
translation: np.ndarray # (3,)
|
||||
rotation: np.ndarray # (3, 3) rotation matrix
|
||||
scale: float
|
||||
|
||||
|
||||
class ChunkAlignmentResult(BaseModel):
|
||||
"""Result of aligning a chunk array of UAV images to a satellite tile."""
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
matched: bool
|
||||
chunk_id: str
|
||||
chunk_center_gps: GPSPoint
|
||||
rotation_angle: float
|
||||
confidence: float
|
||||
inlier_count: int
|
||||
transform: Sim3Transform
|
||||
reprojection_error: float
|
||||
|
||||
|
||||
class LiteSAMConfig(BaseModel):
|
||||
"""Configuration for LiteSAM alignment."""
|
||||
model_path: str = "mock_path"
|
||||
confidence_threshold: float = 0.7
|
||||
min_inliers: int = 15
|
||||
max_reprojection_error: float = 2.0 # pixels
|
||||
multi_scale_levels: int = 3
|
||||
chunk_min_inliers: int = 30
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Tests for Global Place Recognition (F08)."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gps_denied.core.gpr import GlobalPlaceRecognition
|
||||
from gps_denied.core.models import ModelManager
|
||||
from gps_denied.schemas.gpr import TileCandidate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpr():
|
||||
manager = ModelManager()
|
||||
gpr = GlobalPlaceRecognition(manager)
|
||||
gpr.load_index("flight_123", "dummy_path.faiss")
|
||||
return gpr
|
||||
|
||||
def test_compute_location_descriptor(gpr):
|
||||
img = np.zeros((200, 200, 3), dtype=np.uint8)
|
||||
desc = gpr.compute_location_descriptor(img)
|
||||
|
||||
assert desc.shape == (4096,)
|
||||
# Should be L2 normalized
|
||||
assert np.isclose(np.linalg.norm(desc), 1.0)
|
||||
|
||||
def test_retrieve_candidate_tiles(gpr):
|
||||
img = np.zeros((200, 200, 3), dtype=np.uint8)
|
||||
candidates = gpr.retrieve_candidate_tiles(img, top_k=5)
|
||||
|
||||
assert len(candidates) == 5
|
||||
for c in candidates:
|
||||
assert isinstance(c, TileCandidate)
|
||||
assert c.similarity_score >= 0.0
|
||||
|
||||
def test_retrieve_candidate_tiles_for_chunk(gpr):
|
||||
imgs = [np.zeros((200, 200, 3), dtype=np.uint8) for _ in range(5)]
|
||||
candidates = gpr.retrieve_candidate_tiles_for_chunk(imgs, top_k=3)
|
||||
|
||||
assert len(candidates) == 3
|
||||
# Ensure they are sorted
|
||||
assert candidates[0].similarity_score >= candidates[1].similarity_score
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Tests for Metric Refinement (F09)."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gps_denied.core.metric import MetricRefinement
|
||||
from gps_denied.core.models import ModelManager
|
||||
from gps_denied.schemas.flight import GPSPoint
|
||||
from gps_denied.schemas.metric import AlignmentResult, ChunkAlignmentResult
|
||||
from gps_denied.schemas.satellite import TileBounds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metric():
|
||||
manager = ModelManager()
|
||||
return MetricRefinement(manager)
|
||||
|
||||
@pytest.fixture
|
||||
def bounds():
|
||||
# Covers precisely 1 degree lat and lon around 49, 32
|
||||
return TileBounds(
|
||||
nw=GPSPoint(lat=50.0, lon=32.0),
|
||||
ne=GPSPoint(lat=50.0, lon=33.0),
|
||||
sw=GPSPoint(lat=49.0, lon=32.0),
|
||||
se=GPSPoint(lat=49.0, lon=33.0),
|
||||
center=GPSPoint(lat=49.5, lon=32.5),
|
||||
gsd=1.0 # dummy
|
||||
)
|
||||
|
||||
def test_extract_gps_from_alignment(metric, bounds):
|
||||
# Homography is identity -> map center to center
|
||||
H = np.eye(3, dtype=np.float64)
|
||||
# The image is 256x256 in our mock
|
||||
# Center pixel is 128, 128
|
||||
gps = metric.extract_gps_from_alignment(H, bounds, (128, 128))
|
||||
|
||||
# 128 is middle -> should be EXACTLY at 49.5 lat and 32.5 lon
|
||||
assert np.isclose(gps.lat, 49.5)
|
||||
assert np.isclose(gps.lon, 32.5)
|
||||
|
||||
def test_align_to_satellite(metric, bounds, monkeypatch):
|
||||
# Monkeypatch random to ensure matched=True and high inliers
|
||||
def mock_infer(*args, **kwargs):
|
||||
H = np.eye(3, dtype=np.float64)
|
||||
return {"homography": H, "inlier_count": 80, "confidence": 0.8}
|
||||
|
||||
engine = metric.model_manager.get_inference_engine("LiteSAM")
|
||||
monkeypatch.setattr(engine, "infer", mock_infer)
|
||||
|
||||
uav = np.zeros((256, 256, 3))
|
||||
sat = np.zeros((256, 256, 3))
|
||||
|
||||
res = metric.align_to_satellite(uav, sat, bounds)
|
||||
assert res is not None
|
||||
assert isinstance(res, AlignmentResult)
|
||||
assert res.matched is True
|
||||
assert res.inlier_count == 80
|
||||
|
||||
def test_align_chunk_to_satellite(metric, bounds, monkeypatch):
|
||||
def mock_infer(*args, **kwargs):
|
||||
H = np.eye(3, dtype=np.float64)
|
||||
return {"homography": H, "inlier_count": 80, "confidence": 0.8}
|
||||
|
||||
engine = metric.model_manager.get_inference_engine("LiteSAM")
|
||||
monkeypatch.setattr(engine, "infer", mock_infer)
|
||||
|
||||
uavs = [np.zeros((256, 256, 3)) for _ in range(5)]
|
||||
sat = np.zeros((256, 256, 3))
|
||||
|
||||
res = metric.align_chunk_to_satellite(uavs, sat, bounds)
|
||||
assert res is not None
|
||||
assert isinstance(res, ChunkAlignmentResult)
|
||||
assert res.matched is True
|
||||
assert res.chunk_id == "chunk1"
|
||||
Reference in New Issue
Block a user