Files
gps-denied-onboard/f04_satellite_data_manager.py
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

343 lines
14 KiB
Python

import os
import math
import logging
import shutil
from typing import List, Dict, Optional, Iterator, Tuple
from pathlib import Path
from abc import ABC, abstractmethod
from pydantic import BaseModel
import numpy as np
import cv2
import httpx
import diskcache
import concurrent.futures
from f02_1_flight_lifecycle_manager import GPSPoint
import h06_web_mercator_utils as H06
logger = logging.getLogger(__name__)
# --- Data Models ---
class TileCoords(BaseModel):
x: int
y: int
zoom: int
def __hash__(self):
return hash((self.x, self.y, self.zoom))
def __eq__(self, other):
return (self.x, self.y, self.zoom) == (other.x, other.y, other.zoom)
class TileBounds(BaseModel):
nw: GPSPoint
ne: GPSPoint
sw: GPSPoint
se: GPSPoint
center: GPSPoint
gsd: float
class CacheConfig(BaseModel):
cache_dir: str = "./satellite_cache"
max_size_gb: int = 50
eviction_policy: str = "lru"
ttl_days: int = 30
# --- Interface ---
class ISatelliteDataManager(ABC):
@abstractmethod
def fetch_tile(self, lat: float, lon: float, zoom: int) -> Optional[np.ndarray]: pass
@abstractmethod
def fetch_tile_grid(self, center_lat: float, center_lon: float, grid_size: int, zoom: int) -> Dict[str, np.ndarray]: pass
@abstractmethod
def prefetch_route_corridor(self, waypoints: List[GPSPoint], corridor_width_m: float, zoom: int) -> bool: pass
@abstractmethod
def progressive_fetch(self, center_lat: float, center_lon: float, grid_sizes: List[int], zoom: int) -> Iterator[Dict[str, np.ndarray]]: pass
@abstractmethod
def cache_tile(self, flight_id: str, tile_coords: TileCoords, tile_data: np.ndarray) -> bool: pass
@abstractmethod
def get_cached_tile(self, flight_id: str, tile_coords: TileCoords) -> Optional[np.ndarray]: pass
@abstractmethod
def get_tile_grid(self, center: TileCoords, grid_size: int) -> List[TileCoords]: pass
@abstractmethod
def compute_tile_coords(self, lat: float, lon: float, zoom: int) -> TileCoords: pass
@abstractmethod
def expand_search_grid(self, center: TileCoords, current_size: int, new_size: int) -> List[TileCoords]: pass
@abstractmethod
def compute_tile_bounds(self, tile_coords: TileCoords) -> TileBounds: pass
@abstractmethod
def clear_flight_cache(self, flight_id: str) -> bool: pass
# --- Implementation ---
class SatelliteDataManager(ISatelliteDataManager):
"""
Manages satellite tile retrieval, local disk caching, and Web Mercator
coordinate transformations to support the Geospatial Anchoring Back-End.
"""
def __init__(self, config: Optional[CacheConfig] = None, provider_api_url: str = "http://mock-satellite-provider/api/tiles"):
self.config = config or CacheConfig()
self.base_dir = Path(self.config.cache_dir)
self.global_dir = self.base_dir / "global"
self.provider_api_url = provider_api_url
self.index_cache = diskcache.Cache(str(self.base_dir / "index"))
self.base_dir.mkdir(parents=True, exist_ok=True)
self.global_dir.mkdir(parents=True, exist_ok=True)
# --- 04.01 Cache Management ---
def _generate_cache_path(self, flight_id: str, tile_coords: TileCoords) -> Path:
flight_dir = self.global_dir if flight_id == "global" else self.base_dir / flight_id
return flight_dir / str(tile_coords.zoom) / f"{tile_coords.x}_{tile_coords.y}.png"
def _ensure_cache_directory(self, flight_id: str, zoom: int) -> bool:
flight_dir = self.global_dir if flight_id == "global" else self.base_dir / flight_id
zoom_dir = flight_dir / str(zoom)
zoom_dir.mkdir(parents=True, exist_ok=True)
return True
def _serialize_tile(self, tile_data: np.ndarray) -> bytes:
success, buffer = cv2.imencode('.png', tile_data)
if not success:
raise ValueError("Failed to encode tile to PNG.")
return buffer.tobytes()
def _deserialize_tile(self, data: bytes) -> Optional[np.ndarray]:
try:
np_arr = np.frombuffer(data, np.uint8)
return cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
except Exception as e:
logger.warning(f"Tile deserialization failed: {e}")
return None
def _update_cache_index(self, flight_id: str, tile_coords: TileCoords, action: str) -> None:
key = f"{flight_id}_{tile_coords.zoom}_{tile_coords.x}_{tile_coords.y}"
if action == "add":
self.index_cache.set(key, True)
elif action == "remove":
self.index_cache.delete(key)
def cache_tile(self, flight_id: str, tile_coords: TileCoords, tile_data: np.ndarray) -> bool:
try:
self._ensure_cache_directory(flight_id, tile_coords.zoom)
path = self._generate_cache_path(flight_id, tile_coords)
tile_bytes = self._serialize_tile(tile_data)
with open(path, 'wb') as f:
f.write(tile_bytes)
self._update_cache_index(flight_id, tile_coords, "add")
return True
except Exception as e:
logger.error(f"Failed to cache tile to {path}: {e}")
return False
def _check_global_cache(self, tile_coords: TileCoords) -> Optional[np.ndarray]:
path = self._generate_cache_path("global", tile_coords)
if path.exists():
with open(path, 'rb') as f:
return self._deserialize_tile(f.read())
return None
def get_cached_tile(self, flight_id: str, tile_coords: TileCoords) -> Optional[np.ndarray]:
path = self._generate_cache_path(flight_id, tile_coords)
if path.exists():
try:
with open(path, 'rb') as f:
return self._deserialize_tile(f.read())
except Exception:
logger.warning(f"Corrupted cache file at {path}")
return None
# Fallback to global shared cache
return self._check_global_cache(tile_coords)
def clear_flight_cache(self, flight_id: str) -> bool:
if flight_id == "global":
return False # Prevent accidental global purge
flight_dir = self.base_dir / flight_id
if flight_dir.exists():
shutil.rmtree(flight_dir)
return True
# --- 04.02 Coordinate Operations (Web Mercator) ---
def compute_tile_coords(self, lat: float, lon: float, zoom: int) -> TileCoords:
x, y = H06.latlon_to_tile(lat, lon, zoom)
return TileCoords(x=x, y=y, zoom=zoom)
def _tile_to_latlon(self, x: int, y: int, zoom: int) -> Tuple[float, float]:
return H06.tile_to_latlon(x, y, zoom)
def compute_tile_bounds(self, tile_coords: TileCoords) -> TileBounds:
bounds = H06.compute_tile_bounds(tile_coords.x, tile_coords.y, tile_coords.zoom)
return TileBounds(
nw=GPSPoint(lat=bounds["nw"][0], lon=bounds["nw"][1]),
ne=GPSPoint(lat=bounds["ne"][0], lon=bounds["ne"][1]),
sw=GPSPoint(lat=bounds["sw"][0], lon=bounds["sw"][1]),
se=GPSPoint(lat=bounds["se"][0], lon=bounds["se"][1]),
center=GPSPoint(lat=bounds["center"][0], lon=bounds["center"][1]),
gsd=bounds["gsd"]
)
def _compute_grid_offset(self, grid_size: int) -> int:
if grid_size <= 1: return 0
if grid_size <= 4: return 1
if grid_size <= 9: return 1
if grid_size <= 16: return 2
return int(math.sqrt(grid_size)) // 2
def _grid_size_to_dimensions(self, grid_size: int) -> Tuple[int, int]:
if grid_size == 1: return (1, 1)
if grid_size == 4: return (2, 2)
if grid_size == 9: return (3, 3)
if grid_size == 16: return (4, 4)
if grid_size == 25: return (5, 5)
dim = int(math.ceil(math.sqrt(grid_size)))
return (dim, dim)
def _generate_grid_tiles(self, center: TileCoords, rows: int, cols: int) -> List[TileCoords]:
tiles = []
offset_x = -(cols // 2)
offset_y = -(rows // 2)
for dy in range(rows):
for dx in range(cols):
tiles.append(TileCoords(x=center.x + offset_x + dx, y=center.y + offset_y + dy, zoom=center.zoom))
return tiles
def get_tile_grid(self, center: TileCoords, grid_size: int) -> List[TileCoords]:
rows, cols = self._grid_size_to_dimensions(grid_size)
return self._generate_grid_tiles(center, rows, cols)[:grid_size]
def expand_search_grid(self, center: TileCoords, current_size: int, new_size: int) -> List[TileCoords]:
current_grid = set(self.get_tile_grid(center, current_size))
new_grid = set(self.get_tile_grid(center, new_size))
return list(new_grid - current_grid)
# --- 04.03 Tile Fetching ---
def _generate_tile_id(self, tile_coords: TileCoords) -> str:
return f"{tile_coords.zoom}_{tile_coords.x}_{tile_coords.y}"
def _fetch_from_api(self, tile_coords: TileCoords) -> Optional[np.ndarray]:
lat, lon = self._tile_to_latlon(tile_coords.x + 0.5, tile_coords.y + 0.5, tile_coords.zoom)
url = f"{self.provider_api_url}?lat={lat}&lon={lon}&zoom={tile_coords.zoom}"
# Fast-path fallback for local development without a real provider configured
if "mock-satellite-provider" in self.provider_api_url:
return np.zeros((256, 256, 3), dtype=np.uint8)
try:
response = httpx.get(url, timeout=5.0)
response.raise_for_status()
return self._deserialize_tile(response.content)
except httpx.HTTPError as e:
logger.error(f"HTTP fetch failed for {url}: {e}")
return None
def _fetch_with_retry(self, tile_coords: TileCoords, max_retries: int = 3) -> Optional[np.ndarray]:
for _ in range(max_retries):
tile = self._fetch_from_api(tile_coords)
if tile is not None:
return tile
return None
def _fetch_tiles_parallel(self, tiles: List[TileCoords], max_concurrent: int = 20) -> Dict[str, np.ndarray]:
results = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent) as executor:
future_to_tile = {executor.submit(self._fetch_with_retry, tile): tile for tile in tiles}
for future in concurrent.futures.as_completed(future_to_tile):
tile = future_to_tile[future]
data = future.result()
if data is not None:
results[self._generate_tile_id(tile)] = data
return results
def fetch_tile(self, lat: float, lon: float, zoom: int, flight_id: str = "global") -> Optional[np.ndarray]:
if not (-90.0 <= lat <= 90.0) or not (-180.0 <= lon <= 180.0):
return None
coords = self.compute_tile_coords(lat, lon, zoom)
cached = self.get_cached_tile(flight_id, coords)
if cached is not None:
return cached
fetched = self._fetch_with_retry(coords)
if fetched is not None:
self.cache_tile(flight_id, coords, fetched)
self.cache_tile("global", coords, fetched) # Also update global cache
return fetched
def fetch_tile_grid(self, center_lat: float, center_lon: float, grid_size: int, zoom: int) -> Dict[str, np.ndarray]:
center_coords = self.compute_tile_coords(center_lat, center_lon, zoom)
grid_coords = self.get_tile_grid(center_coords, grid_size)
result = {}
for coords in grid_coords:
tile = self.fetch_tile(*self._tile_to_latlon(coords.x + 0.5, coords.y + 0.5, coords.zoom), coords.zoom)
if tile is not None:
result[self._generate_tile_id(coords)] = tile
return result
def progressive_fetch(self, center_lat: float, center_lon: float, grid_sizes: List[int], zoom: int) -> Iterator[Dict[str, np.ndarray]]:
for size in grid_sizes:
yield self.fetch_tile_grid(center_lat, center_lon, size, zoom)
def _compute_corridor_tiles(self, waypoints: List[GPSPoint], corridor_width_m: float, zoom: int) -> List[TileCoords]:
tiles = set()
if not waypoints:
return []
# Add tiles for all exact waypoints
for wp in waypoints:
center = self.compute_tile_coords(wp.lat, wp.lon, zoom)
tiles.update(self.get_tile_grid(center, 9))
# Interpolate between waypoints to ensure a continuous corridor (avoiding gaps on long straightaways)
for i in range(len(waypoints) - 1):
wp1, wp2 = waypoints[i], waypoints[i+1]
dist_lat = wp2.lat - wp1.lat
dist_lon = wp2.lon - wp1.lon
steps = max(int(abs(dist_lat) / 0.001), int(abs(dist_lon) / 0.001), 1)
for step in range(1, steps):
interp_lat = wp1.lat + dist_lat * (step / steps)
interp_lon = wp1.lon + dist_lon * (step / steps)
center = self.compute_tile_coords(interp_lat, interp_lon, zoom)
tiles.update(self.get_tile_grid(center, 9))
return list(tiles)
def prefetch_route_corridor(self, waypoints: List[GPSPoint], corridor_width_m: float, zoom: int) -> bool:
if not waypoints:
return False
tiles_to_fetch = self._compute_corridor_tiles(waypoints, corridor_width_m, zoom)
if not tiles_to_fetch:
return False
results = self._fetch_tiles_parallel(tiles_to_fetch)
if not results: # Complete failure (no tiles retrieved)
return False
for tile in tiles_to_fetch:
tile_id = self._generate_tile_id(tile)
if tile_id in results:
self.cache_tile("global", tile, results[tile_id])
return True