From 5a60c1ee2ca366755685659eb5a7e5e685bcf484 Mon Sep 17 00:00:00 2001 From: Yuzviak Date: Mon, 11 May 2026 08:59:07 +0300 Subject: [PATCH] refactor(01-07): factor_graph, pipeline pkg, testing/benchmark, Protocol ABCs - Create core/factor_graph.py: IFactorGraphOptimizer converted to Protocol - Shim core/graph.py to re-export from core/factor_graph - Create pipeline/ package: orchestrator, image_input, result_manager, sse_streamer - Shim core/{processor,pipeline,results,sse}.py to re-export from pipeline/ - Create testing/benchmark.py; shim core/benchmark.py - Convert IRouteChunkManager, IFailureRecoveryCoordinator, IModelManager, IImageMatcher to Protocol - Update pyproject.toml ruff per-file-ignores to new paths - All 216 tests pass (regression floor maintained) --- pyproject.toml | 4 +- src/gps_denied/core/benchmark.py | 379 +------------- src/gps_denied/core/chunk_manager.py | 24 +- src/gps_denied/core/factor_graph.py | 350 +++++++++++++ src/gps_denied/core/graph.py | 369 +------------ src/gps_denied/core/models.py | 21 +- src/gps_denied/core/pipeline.py | 231 +-------- src/gps_denied/core/processor.py | 603 +--------------------- src/gps_denied/core/recovery.py | 12 +- src/gps_denied/core/results.py | 77 +-- src/gps_denied/core/rotation.py | 8 +- src/gps_denied/core/sse.py | 166 +----- src/gps_denied/pipeline/__init__.py | 15 + src/gps_denied/pipeline/image_input.py | 227 ++++++++ src/gps_denied/pipeline/orchestrator.py | 599 +++++++++++++++++++++ src/gps_denied/pipeline/result_manager.py | 73 +++ src/gps_denied/pipeline/sse_streamer.py | 164 ++++++ src/gps_denied/testing/benchmark.py | 371 +++++++++++++ 18 files changed, 1857 insertions(+), 1836 deletions(-) create mode 100644 src/gps_denied/core/factor_graph.py create mode 100644 src/gps_denied/pipeline/__init__.py create mode 100644 src/gps_denied/pipeline/image_input.py create mode 100644 src/gps_denied/pipeline/orchestrator.py create mode 100644 src/gps_denied/pipeline/result_manager.py create mode 100644 src/gps_denied/pipeline/sse_streamer.py create mode 100644 src/gps_denied/testing/benchmark.py diff --git a/pyproject.toml b/pyproject.toml index 9af1da8..49b7e33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,8 @@ line-length = 120 [tool.ruff.lint.per-file-ignores] # Abstract interfaces have long method signatures — allow up to 170 -"src/gps_denied/core/graph.py" = ["E501"] -"src/gps_denied/core/metric.py" = ["E501"] +"src/gps_denied/core/factor_graph.py" = ["E501"] +"src/gps_denied/components/satellite_matcher/metric_refinement.py" = ["E501"] "src/gps_denied/core/chunk_manager.py" = ["E501"] [tool.ruff.lint] diff --git a/src/gps_denied/core/benchmark.py b/src/gps_denied/core/benchmark.py index 0645bc3..fc80662 100644 --- a/src/gps_denied/core/benchmark.py +++ b/src/gps_denied/core/benchmark.py @@ -1,371 +1,8 @@ -"""Accuracy Benchmark (Phase 7). - -Provides: - - SyntheticTrajectory — generates a realistic fixed-wing UAV flight path - with ground-truth GPS + noisy sensor data. - - AccuracyBenchmark — replays a trajectory through the ESKF pipeline - and computes position-error statistics. - -Acceptance criteria (from solution.md): - AC-PERF-1: 80 % of frames within 50 m of ground truth. - AC-PERF-2: 60 % of frames within 20 m of ground truth. - AC-PERF-3: End-to-end per-frame latency < 400 ms. - AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m. -""" - -from __future__ import annotations - -import math -import time -from dataclasses import dataclass, field -from typing import Callable, Optional - -import numpy as np - -from gps_denied.core.coordinates import CoordinateTransformer -from gps_denied.core.eskf import ESKF -from gps_denied.schemas import GPSPoint -from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement - -# --------------------------------------------------------------------------- -# Synthetic trajectory -# --------------------------------------------------------------------------- - -@dataclass -class TrajectoryFrame: - """One simulated camera frame with ground-truth and noisy sensor data.""" - frame_id: int - timestamp: float - true_position_enu: np.ndarray # (3,) East, North, Up in metres - true_gps: GPSPoint # WGS84 from true ENU - imu_measurements: list[IMUMeasurement] # High-rate IMU between frames - vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,) - vo_tracking_good: bool = True - - -@dataclass -class SyntheticTrajectoryConfig: - """Parameters for trajectory generation.""" - # Origin (mission start) - origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0)) - altitude_m: float = 600.0 # Constant AGL altitude (m) - # UAV speed and heading - speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing) - heading_deg: float = 45.0 # Initial heading (degrees CW from North) - camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz) - imu_hz: float = 200.0 # IMU sample rate - num_frames: int = 50 # Number of camera frames to simulate - # Noise parameters - vo_noise_m: float = 0.5 # VO translation noise (sigma, metres) - imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²) - imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s) - # Failure injection - vo_failure_frames: list[int] = field(default_factory=list) - # Waypoints for heading changes (ENU East, North metres from origin) - waypoints_enu: list[tuple[float, float]] = field(default_factory=list) - - -class SyntheticTrajectory: - """Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors.""" - - def __init__(self, config: SyntheticTrajectoryConfig | None = None): - self.config = config or SyntheticTrajectoryConfig() - self._coord = CoordinateTransformer() - self._flight_id = "__synthetic__" - self._coord.set_enu_origin(self._flight_id, self.config.origin) - - def generate(self) -> list[TrajectoryFrame]: - """Generate all trajectory frames.""" - cfg = self.config - dt_camera = 1.0 / cfg.camera_fps - dt_imu = 1.0 / cfg.imu_hz - imu_steps = int(dt_camera * cfg.imu_hz) - - frames: list[TrajectoryFrame] = [] - pos = np.array([0.0, 0.0, cfg.altitude_m]) - vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps) - prev_pos = pos.copy() - t = time.time() - - waypoints = list(cfg.waypoints_enu) # copy - - for fid in range(cfg.num_frames): - # --- Waypoint steering --- - if waypoints: - wp_e, wp_n = waypoints[0] - to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0]) - dist_wp = np.linalg.norm(to_wp[:2]) - if dist_wp < cfg.speed_mps * dt_camera: - waypoints.pop(0) - else: - heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y - vel = np.array([ - cfg.speed_mps * math.sin(heading_rad), - cfg.speed_mps * math.cos(heading_rad), - 0.0, - ]) - - # --- Simulate IMU between frames --- - imu_list: list[IMUMeasurement] = [] - for step in range(imu_steps): - ts = t + step * dt_imu - # Body-frame acceleration (mostly gravity correction, small forward accel) - accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation - gyro_true = np.zeros(3) - imu = IMUMeasurement( - accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise, - gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise, - timestamp=ts, - ) - imu_list.append(imu) - - # --- Propagate position --- - prev_pos = pos.copy() - pos = pos + vel * dt_camera - t += dt_camera - - # --- True GPS from ENU position --- - true_gps = self._coord.enu_to_gps( - self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2])) - ) - - # --- VO measurement (relative displacement + noise) --- - true_displacement = pos - prev_pos - vo_tracking_good = fid not in cfg.vo_failure_frames - if vo_tracking_good: - noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m - noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z - else: - noisy_displacement = None - - frames.append(TrajectoryFrame( - frame_id=fid, - timestamp=t, - true_position_enu=pos.copy(), - true_gps=true_gps, - imu_measurements=imu_list, - vo_translation=noisy_displacement, - vo_tracking_good=vo_tracking_good, - )) - - return frames - - @staticmethod - def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray: - """Convert heading (degrees CW from North) to ENU velocity vector.""" - rad = math.radians(heading_deg) - return np.array([ - speed_mps * math.sin(rad), # East - speed_mps * math.cos(rad), # North - 0.0, # Up - ]) - - -# --------------------------------------------------------------------------- -# Accuracy Benchmark -# --------------------------------------------------------------------------- - -@dataclass -class BenchmarkResult: - """Position error statistics over a trajectory replay.""" - errors_m: list[float] # Per-frame horizontal error in metres - latencies_ms: list[float] # Per-frame process time in ms - frames_total: int - frames_with_good_estimate: int - - @property - def p80_error_m(self) -> float: - """80th percentile position error (metres).""" - return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf") - - @property - def p60_error_m(self) -> float: - """60th percentile position error (metres).""" - return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf") - - @property - def median_error_m(self) -> float: - """Median position error (metres).""" - return float(np.median(self.errors_m)) if self.errors_m else float("inf") - - @property - def max_error_m(self) -> float: - return float(max(self.errors_m)) if self.errors_m else float("inf") - - @property - def p95_latency_ms(self) -> float: - """95th percentile frame latency (ms).""" - return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf") - - @property - def pct_within_50m(self) -> float: - """Fraction of frames within 50 m error.""" - if not self.errors_m: - return 0.0 - return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m) - - @property - def pct_within_20m(self) -> float: - """Fraction of frames within 20 m error.""" - if not self.errors_m: - return 0.0 - return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m) - - def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]: - """Check all solution.md acceptance criteria. - - Returns (overall_pass, per_criterion_dict). - """ - checks = { - "AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80, - "AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60, - "AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0, - } - overall = all(checks.values()) - return overall, checks - - def summary(self) -> str: - overall, checks = self.passes_acceptance_criteria() - lines = [ - f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}", - f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m " - f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m", - f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%", - f"Latency p95: {self.p95_latency_ms:.1f}ms", - "", - "Acceptance criteria:", - ] - for criterion, passed in checks.items(): - lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}") - lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}") - return "\n".join(lines) - - -class AccuracyBenchmark: - """Replays a SyntheticTrajectory through the ESKF and measures accuracy. - - The benchmark uses only the ESKF (no full FlightProcessor) for speed. - Satellite corrections are injected optionally via sat_correction_fn. - """ - - def __init__( - self, - eskf_config: ESKFConfig | None = None, - sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None, - ): - """ - Args: - eskf_config: ESKF tuning parameters. - sat_correction_fn: Optional callback(frame) → ENU position or None. - Called on keyframes to inject satellite corrections. - If None, no satellite corrections are applied. - """ - self.eskf_config = eskf_config or ESKFConfig() - self.sat_correction_fn = sat_correction_fn - - def run( - self, - trajectory: list[TrajectoryFrame], - origin: GPSPoint, - satellite_keyframe_interval: int = 7, - ) -> BenchmarkResult: - """Replay trajectory frames through ESKF, collect errors and latencies. - - Args: - trajectory: List of TrajectoryFrame (from SyntheticTrajectory). - origin: WGS84 reference origin for ENU. - satellite_keyframe_interval: Apply satellite correction every N frames. - """ - coord = CoordinateTransformer() - flight_id = "__benchmark__" - coord.set_enu_origin(flight_id, origin) - - eskf = ESKF(self.eskf_config) - # Init at origin with HIGH uncertainty - eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]), - trajectory[0].timestamp) - - errors_m: list[float] = [] - latencies_ms: list[float] = [] - frames_with_estimate = 0 - - for frame in trajectory: - t_frame_start = time.perf_counter() - - # --- IMU prediction --- - for imu in frame.imu_measurements: - eskf.predict(imu) - - # --- VO update --- - if frame.vo_tracking_good and frame.vo_translation is not None: - dt_vo = 1.0 / 0.7 # camera interval - eskf.update_vo(frame.vo_translation, dt_vo) - - # --- Satellite update (keyframes) --- - if frame.frame_id % satellite_keyframe_interval == 0: - sat_pos_enu: Optional[np.ndarray] = None - if self.sat_correction_fn is not None: - sat_pos_enu = self.sat_correction_fn(frame) - else: - # Default: inject ground-truth position + realistic noise - noise_m = 10.0 - sat_pos_enu = ( - frame.true_position_enu[:3] - + np.random.randn(3) * noise_m - ) - sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude - - if sat_pos_enu is not None: - # Tell ESKF the measurement noise matches what we inject - eskf.update_satellite(sat_pos_enu, noise_meters=noise_m) - - latency_ms = (time.perf_counter() - t_frame_start) * 1000.0 - latencies_ms.append(latency_ms) - - # --- Compute horizontal error vs ground truth --- - if eskf.initialized and eskf._nominal_state is not None: - est_pos = eskf._nominal_state["position"] - true_pos = frame.true_position_enu - horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2])) - errors_m.append(horiz_error) - frames_with_estimate += 1 - else: - errors_m.append(float("inf")) - - return BenchmarkResult( - errors_m=errors_m, - latencies_ms=latencies_ms, - frames_total=len(trajectory), - frames_with_good_estimate=frames_with_estimate, - ) - - def run_vo_drift_test( - self, - trajectory_length_m: float = 1000.0, - speed_mps: float = 20.0, - ) -> float: - """Measure VO drift over a straight segment with NO satellite correction. - - Returns final horizontal position error in metres. - Per solution.md, this should be < 100m over 1km. - """ - fps = 0.7 - num_frames = max(10, int(trajectory_length_m / speed_mps * fps)) - cfg = SyntheticTrajectoryConfig( - speed_mps=speed_mps, - heading_deg=0.0, # straight North - camera_fps=fps, - num_frames=num_frames, - vo_noise_m=0.3, # cuVSLAM-grade VO noise - ) - traj_gen = SyntheticTrajectory(cfg) - frames = traj_gen.generate() - - # No satellite corrections - benchmark_no_sat = AccuracyBenchmark( - eskf_config=self.eskf_config, - sat_correction_fn=lambda _: None, # suppress all satellite updates - ) - result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999) - # Return final-frame error - return result.errors_m[-1] if result.errors_m else float("inf") +"""Legacy import path. Phase 1 shim — code lives in testing/benchmark.py.""" +from gps_denied.testing.benchmark import ( # noqa: F401 + AccuracyBenchmark, + BenchmarkResult, + SyntheticTrajectory, + SyntheticTrajectoryConfig, + TrajectoryFrame, +) diff --git a/src/gps_denied/core/chunk_manager.py b/src/gps_denied/core/chunk_manager.py index 5a5a8ea..488d839 100644 --- a/src/gps_denied/core/chunk_manager.py +++ b/src/gps_denied/core/chunk_manager.py @@ -2,8 +2,7 @@ import logging import uuid -from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Protocol, runtime_checkable from gps_denied.core.graph import IFactorGraphOptimizer from gps_denied.schemas.chunk import ChunkHandle, ChunkStatus @@ -12,30 +11,25 @@ from gps_denied.schemas.metric import Sim3Transform logger = logging.getLogger(__name__) -class IRouteChunkManager(ABC): - @abstractmethod +@runtime_checkable +class IRouteChunkManager(Protocol): def create_new_chunk(self, flight_id: str, start_frame_id: int) -> ChunkHandle: - pass + ... - @abstractmethod def get_active_chunk(self, flight_id: str) -> Optional[ChunkHandle]: - pass + ... - @abstractmethod def get_all_chunks(self, flight_id: str) -> List[ChunkHandle]: - pass + ... - @abstractmethod def add_frame_to_chunk(self, flight_id: str, frame_id: int) -> bool: - pass + ... - @abstractmethod def update_chunk_status(self, flight_id: str, chunk_id: str, status: ChunkStatus) -> bool: - pass + ... - @abstractmethod def merge_chunks(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool: - pass + ... class RouteChunkManager(IRouteChunkManager): diff --git a/src/gps_denied/core/factor_graph.py b/src/gps_denied/core/factor_graph.py new file mode 100644 index 0000000..2f0725c --- /dev/null +++ b/src/gps_denied/core/factor_graph.py @@ -0,0 +1,350 @@ +"""Factor Graph Optimizer (Component F10).""" + +import logging +from datetime import datetime, timezone +from typing import Dict, Protocol, runtime_checkable + +import numpy as np + +try: + import gtsam + HAS_GTSAM = True +except ImportError: + HAS_GTSAM = False + +from gps_denied.schemas import GPSPoint +from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose +from gps_denied.schemas.metric import Sim3Transform +from gps_denied.schemas.vo import RelativePose + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class IFactorGraphOptimizer(Protocol): + """GTSAM-based factor graph optimizer.""" + + def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: + ... + + def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool: + ... + + def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool: + ... + + def optimize(self, flight_id: str, iterations: int) -> OptimizationResult: + ... + + def get_trajectory(self, flight_id: str) -> Dict[int, Pose]: + ... + + def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray: + ... + + def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool: + ... + + def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: + ... + + def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool: + ... + + def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool: + ... + + def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]: + ... + + def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult: + ... + + def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult: + ... + + def delete_flight_graph(self, flight_id: str) -> bool: + ... + + +class FactorGraphOptimizer(IFactorGraphOptimizer): + """Implementation of F10 Factor Graph using GTSAM or Mock.""" + + def __init__(self, config: FactorGraphConfig): + self.config = config + # Keyed by flight_id + self._flights_state: Dict[str, dict] = {} + # Keyed by chunk_id + self._chunks_state: Dict[str, dict] = {} + # Per-flight ENU origin (set from first absolute GPS factor) + self._enu_origins: Dict[str, GPSPoint] = {} + + def _init_flight(self, flight_id: str): + if flight_id not in self._flights_state: + self._flights_state[flight_id] = { + "graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None, + "initial": gtsam.Values() if HAS_GTSAM else None, + "isam": gtsam.ISAM2() if HAS_GTSAM else None, + "poses": {}, + "dirty": False + } + + def _init_chunk(self, chunk_id: str): + if chunk_id not in self._chunks_state: + self._chunks_state[chunk_id] = { + "graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None, + "initial": gtsam.Values() if HAS_GTSAM else None, + "isam": gtsam.ISAM2() if HAS_GTSAM else None, + "poses": {}, + "dirty": False + } + + # ================== MOCK IMPLEMENTATION ==================== + # As GTSAM Python bindings can be extremely context-dependent and + # require proper ENU translation logic, we use an advanced Mock + # that satisfies the architectural design and typing for the backend. + + def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: + self._init_flight(flight_id) + state = self._flights_state[flight_id] + + # --- Mock: propagate position chain --- + if frame_i in state["poses"]: + prev_pose = state["poses"][frame_i] + new_pos = prev_pose.position + relative_pose.translation + state["poses"][frame_j] = Pose( + frame_id=frame_j, + position=new_pos, + orientation=np.eye(3), + timestamp=datetime.now(timezone.utc), + covariance=np.eye(6), + ) + state["dirty"] = True + else: + return False + + # --- GTSAM: add BetweenFactorPose3 --- + if HAS_GTSAM and state["graph"] is not None: + try: + cov6 = covariance if covariance.shape == (6, 6) else np.eye(6) + noise = gtsam.noiseModel.Gaussian.Covariance(cov6) + key_i = gtsam.symbol("x", frame_i) + key_j = gtsam.symbol("x", frame_j) + t = relative_pose.translation + between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))) + state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise)) + if not state["initial"].exists(key_j): + if state["initial"].exists(key_i): + prev = state["initial"].atPose3(key_i) + pt = prev.translation() + new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2]) + else: + new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2])) + state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t)) + except Exception as exc: + logger.debug("GTSAM add_relative_factor failed: %s", exc) + + return True + + def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray: + """Convert GPS to local ENU using per-flight origin.""" + origin = self._enu_origins.get(flight_id) + if origin is None: + # First GPS factor sets the origin + self._enu_origins[flight_id] = gps + return np.zeros(3) + enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat)) + enu_y = (gps.lat - origin.lat) * 111000 + return np.array([enu_x, enu_y, 0.0]) + + def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool: + self._init_flight(flight_id) + state = self._flights_state[flight_id] + + enu = self._gps_to_enu(flight_id, gps) + + # --- Mock: update pose position --- + if frame_id in state["poses"]: + state["poses"][frame_id].position = enu + state["dirty"] = True + else: + return False + + # --- GTSAM: add PriorFactorPose3 --- + if HAS_GTSAM and state["graph"] is not None: + try: + cov6 = covariance if covariance.shape == (6, 6) else np.eye(6) + noise = gtsam.noiseModel.Gaussian.Covariance(cov6) + key = gtsam.symbol("x", frame_id) + prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2]))) + state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise)) + if not state["initial"].exists(key): + state["initial"].insert(key, prior) + except Exception as exc: + logger.debug("GTSAM add_absolute_factor failed: %s", exc) + + return True + + def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool: + self._init_flight(flight_id) + state = self._flights_state[flight_id] + + if frame_id in state["poses"]: + state["poses"][frame_id].position = np.array([ + state["poses"][frame_id].position[0], + state["poses"][frame_id].position[1], + altitude, + ]) + state["dirty"] = True + return True + return False + + def optimize(self, flight_id: str, iterations: int) -> OptimizationResult: + self._init_flight(flight_id) + state = self._flights_state[flight_id] + + # --- PIPE-03: Real GTSAM ISAM2 update when available --- + if HAS_GTSAM and state["dirty"] and state["graph"] is not None: + try: + state["isam"].update(state["graph"], state["initial"]) + estimate = state["isam"].calculateEstimate() + for fid in list(state["poses"].keys()): + key = gtsam.symbol("x", fid) + if estimate.exists(key): + pose = estimate.atPose3(key) + t = pose.translation() + state["poses"][fid].position = np.array([t[0], t[1], t[2]]) + state["poses"][fid].orientation = np.array(pose.rotation().matrix()) + # Reset for next incremental batch + state["graph"] = gtsam.NonlinearFactorGraph() + state["initial"] = gtsam.Values() + except Exception as exc: + logger.warning("GTSAM ISAM2 update failed: %s", exc) + + state["dirty"] = False + return OptimizationResult( + converged=True, + final_error=0.1, + iterations_used=iterations, + optimized_frames=list(state["poses"].keys()), + mean_reprojection_error=0.5, + ) + + def get_trajectory(self, flight_id: str) -> Dict[int, Pose]: + if flight_id not in self._flights_state: + return {} + return self._flights_state[flight_id]["poses"] + + def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray: + return np.eye(6) + + # ================== CHUNK OPERATIONS ======================= + + def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool: + self._init_chunk(chunk_id) + state = self._chunks_state[chunk_id] + + state["poses"][start_frame_id] = Pose( + frame_id=start_frame_id, + position=np.zeros(3), + orientation=np.eye(3), + timestamp=datetime.now(timezone.utc), + covariance=np.eye(6) + ) + return True + + def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: + if chunk_id not in self._chunks_state: + return False + + state = self._chunks_state[chunk_id] + if frame_i in state["poses"]: + prev_pose = state["poses"][frame_i] + new_pos = prev_pose.position + relative_pose.translation + + state["poses"][frame_j] = Pose( + frame_id=frame_j, + position=new_pos, + orientation=np.eye(3), + timestamp=datetime.now(timezone.utc), + covariance=np.eye(6) + ) + state["dirty"] = True + return True + return False + + def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool: + if chunk_id not in self._chunks_state: + return False + + state = self._chunks_state[chunk_id] + if frame_id in state["poses"]: + enu = self._gps_to_enu(flight_id, gps) + state["poses"][frame_id].position = enu + state["dirty"] = True + return True + return False + + def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool: + if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state: + return False + + new_state = self._chunks_state[new_chunk_id] + main_state = self._chunks_state[main_chunk_id] + + # Apply Sim(3) transform effectively by copying poses + for f_id, p in new_state["poses"].items(): + # mock sim3 transform + idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation + + main_state["poses"][f_id] = Pose( + frame_id=f_id, + position=idx_pos, + orientation=np.eye(3), + timestamp=p.timestamp, + covariance=p.covariance + ) + + return True + + def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]: + if chunk_id not in self._chunks_state: + return {} + return self._chunks_state[chunk_id]["poses"] + + def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult: + if chunk_id not in self._chunks_state: + return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0) + + state = self._chunks_state[chunk_id] + state["dirty"] = False + + return OptimizationResult( + converged=True, + final_error=0.1, + iterations_used=iterations, + optimized_frames=list(state["poses"].keys()), + mean_reprojection_error=0.5 + ) + + def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult: + # Optimizes everything + self._init_flight(flight_id) + state = self._flights_state[flight_id] + state["dirty"] = False + + return OptimizationResult( + converged=True, + final_error=0.1, + iterations_used=iterations, + optimized_frames=list(state["poses"].keys()), + mean_reprojection_error=0.5 + ) + + def delete_flight_graph(self, flight_id: str) -> bool: + removed = False + if flight_id in self._flights_state: + del self._flights_state[flight_id] + removed = True + self._enu_origins.pop(flight_id, None) + return removed diff --git a/src/gps_denied/core/graph.py b/src/gps_denied/core/graph.py index 044232c..1974dd5 100644 --- a/src/gps_denied/core/graph.py +++ b/src/gps_denied/core/graph.py @@ -1,364 +1,5 @@ -"""Factor Graph Optimizer (Component F10).""" - -import logging -from abc import ABC, abstractmethod -from datetime import datetime, timezone -from typing import Dict - -import numpy as np - -try: - import gtsam - HAS_GTSAM = True -except ImportError: - HAS_GTSAM = False - -from gps_denied.schemas import GPSPoint -from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose -from gps_denied.schemas.metric import Sim3Transform -from gps_denied.schemas.vo import RelativePose - -logger = logging.getLogger(__name__) - - -class IFactorGraphOptimizer(ABC): - """GTSAM-based factor graph optimizer.""" - - @abstractmethod - def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: - pass - - @abstractmethod - def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool: - pass - - @abstractmethod - def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool: - pass - - @abstractmethod - def optimize(self, flight_id: str, iterations: int) -> OptimizationResult: - pass - - @abstractmethod - def get_trajectory(self, flight_id: str) -> Dict[int, Pose]: - pass - - @abstractmethod - def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray: - pass - - @abstractmethod - def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool: - pass - - @abstractmethod - def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: - pass - - @abstractmethod - def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool: - pass - - @abstractmethod - def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool: - pass - - @abstractmethod - def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]: - pass - - @abstractmethod - def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult: - pass - - @abstractmethod - def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult: - pass - - @abstractmethod - def delete_flight_graph(self, flight_id: str) -> bool: - pass - - -class FactorGraphOptimizer(IFactorGraphOptimizer): - """Implementation of F10 Factor Graph using GTSAM or Mock.""" - - def __init__(self, config: FactorGraphConfig): - self.config = config - # Keyed by flight_id - self._flights_state: Dict[str, dict] = {} - # Keyed by chunk_id - self._chunks_state: Dict[str, dict] = {} - # Per-flight ENU origin (set from first absolute GPS factor) - self._enu_origins: Dict[str, GPSPoint] = {} - - def _init_flight(self, flight_id: str): - if flight_id not in self._flights_state: - self._flights_state[flight_id] = { - "graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None, - "initial": gtsam.Values() if HAS_GTSAM else None, - "isam": gtsam.ISAM2() if HAS_GTSAM else None, - "poses": {}, - "dirty": False - } - - def _init_chunk(self, chunk_id: str): - if chunk_id not in self._chunks_state: - self._chunks_state[chunk_id] = { - "graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None, - "initial": gtsam.Values() if HAS_GTSAM else None, - "isam": gtsam.ISAM2() if HAS_GTSAM else None, - "poses": {}, - "dirty": False - } - - # ================== MOCK IMPLEMENTATION ==================== - # As GTSAM Python bindings can be extremely context-dependent and - # require proper ENU translation logic, we use an advanced Mock - # that satisfies the architectural design and typing for the backend. - - def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: - self._init_flight(flight_id) - state = self._flights_state[flight_id] - - # --- Mock: propagate position chain --- - if frame_i in state["poses"]: - prev_pose = state["poses"][frame_i] - new_pos = prev_pose.position + relative_pose.translation - state["poses"][frame_j] = Pose( - frame_id=frame_j, - position=new_pos, - orientation=np.eye(3), - timestamp=datetime.now(timezone.utc), - covariance=np.eye(6), - ) - state["dirty"] = True - else: - return False - - # --- GTSAM: add BetweenFactorPose3 --- - if HAS_GTSAM and state["graph"] is not None: - try: - cov6 = covariance if covariance.shape == (6, 6) else np.eye(6) - noise = gtsam.noiseModel.Gaussian.Covariance(cov6) - key_i = gtsam.symbol("x", frame_i) - key_j = gtsam.symbol("x", frame_j) - t = relative_pose.translation - between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))) - state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise)) - if not state["initial"].exists(key_j): - if state["initial"].exists(key_i): - prev = state["initial"].atPose3(key_i) - pt = prev.translation() - new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2]) - else: - new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2])) - state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t)) - except Exception as exc: - logger.debug("GTSAM add_relative_factor failed: %s", exc) - - return True - - def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray: - """Convert GPS to local ENU using per-flight origin.""" - origin = self._enu_origins.get(flight_id) - if origin is None: - # First GPS factor sets the origin - self._enu_origins[flight_id] = gps - return np.zeros(3) - enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat)) - enu_y = (gps.lat - origin.lat) * 111000 - return np.array([enu_x, enu_y, 0.0]) - - def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool: - self._init_flight(flight_id) - state = self._flights_state[flight_id] - - enu = self._gps_to_enu(flight_id, gps) - - # --- Mock: update pose position --- - if frame_id in state["poses"]: - state["poses"][frame_id].position = enu - state["dirty"] = True - else: - return False - - # --- GTSAM: add PriorFactorPose3 --- - if HAS_GTSAM and state["graph"] is not None: - try: - cov6 = covariance if covariance.shape == (6, 6) else np.eye(6) - noise = gtsam.noiseModel.Gaussian.Covariance(cov6) - key = gtsam.symbol("x", frame_id) - prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2]))) - state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise)) - if not state["initial"].exists(key): - state["initial"].insert(key, prior) - except Exception as exc: - logger.debug("GTSAM add_absolute_factor failed: %s", exc) - - return True - - def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool: - self._init_flight(flight_id) - state = self._flights_state[flight_id] - - if frame_id in state["poses"]: - state["poses"][frame_id].position = np.array([ - state["poses"][frame_id].position[0], - state["poses"][frame_id].position[1], - altitude, - ]) - state["dirty"] = True - return True - return False - - def optimize(self, flight_id: str, iterations: int) -> OptimizationResult: - self._init_flight(flight_id) - state = self._flights_state[flight_id] - - # --- PIPE-03: Real GTSAM ISAM2 update when available --- - if HAS_GTSAM and state["dirty"] and state["graph"] is not None: - try: - state["isam"].update(state["graph"], state["initial"]) - estimate = state["isam"].calculateEstimate() - for fid in list(state["poses"].keys()): - key = gtsam.symbol("x", fid) - if estimate.exists(key): - pose = estimate.atPose3(key) - t = pose.translation() - state["poses"][fid].position = np.array([t[0], t[1], t[2]]) - state["poses"][fid].orientation = np.array(pose.rotation().matrix()) - # Reset for next incremental batch - state["graph"] = gtsam.NonlinearFactorGraph() - state["initial"] = gtsam.Values() - except Exception as exc: - logger.warning("GTSAM ISAM2 update failed: %s", exc) - - state["dirty"] = False - return OptimizationResult( - converged=True, - final_error=0.1, - iterations_used=iterations, - optimized_frames=list(state["poses"].keys()), - mean_reprojection_error=0.5, - ) - - def get_trajectory(self, flight_id: str) -> Dict[int, Pose]: - if flight_id not in self._flights_state: - return {} - return self._flights_state[flight_id]["poses"] - - def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray: - return np.eye(6) - - # ================== CHUNK OPERATIONS ======================= - - def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool: - self._init_chunk(chunk_id) - state = self._chunks_state[chunk_id] - - state["poses"][start_frame_id] = Pose( - frame_id=start_frame_id, - position=np.zeros(3), - orientation=np.eye(3), - timestamp=datetime.now(timezone.utc), - covariance=np.eye(6) - ) - return True - - def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool: - if chunk_id not in self._chunks_state: - return False - - state = self._chunks_state[chunk_id] - if frame_i in state["poses"]: - prev_pose = state["poses"][frame_i] - new_pos = prev_pose.position + relative_pose.translation - - state["poses"][frame_j] = Pose( - frame_id=frame_j, - position=new_pos, - orientation=np.eye(3), - timestamp=datetime.now(timezone.utc), - covariance=np.eye(6) - ) - state["dirty"] = True - return True - return False - - def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool: - if chunk_id not in self._chunks_state: - return False - - state = self._chunks_state[chunk_id] - if frame_id in state["poses"]: - enu = self._gps_to_enu(flight_id, gps) - state["poses"][frame_id].position = enu - state["dirty"] = True - return True - return False - - def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool: - if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state: - return False - - new_state = self._chunks_state[new_chunk_id] - main_state = self._chunks_state[main_chunk_id] - - # Apply Sim(3) transform effectively by copying poses - for f_id, p in new_state["poses"].items(): - # mock sim3 transform - idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation - - main_state["poses"][f_id] = Pose( - frame_id=f_id, - position=idx_pos, - orientation=np.eye(3), - timestamp=p.timestamp, - covariance=p.covariance - ) - - return True - - def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]: - if chunk_id not in self._chunks_state: - return {} - return self._chunks_state[chunk_id]["poses"] - - def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult: - if chunk_id not in self._chunks_state: - return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0) - - state = self._chunks_state[chunk_id] - state["dirty"] = False - - return OptimizationResult( - converged=True, - final_error=0.1, - iterations_used=iterations, - optimized_frames=list(state["poses"].keys()), - mean_reprojection_error=0.5 - ) - - def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult: - # Optimizes everything - self._init_flight(flight_id) - state = self._flights_state[flight_id] - state["dirty"] = False - - return OptimizationResult( - converged=True, - final_error=0.1, - iterations_used=iterations, - optimized_frames=list(state["poses"].keys()), - mean_reprojection_error=0.5 - ) - - def delete_flight_graph(self, flight_id: str) -> bool: - removed = False - if flight_id in self._flights_state: - del self._flights_state[flight_id] - removed = True - self._enu_origins.pop(flight_id, None) - return removed +"""Legacy import path. Phase 1 shim — code lives in core/factor_graph.py.""" +from gps_denied.core.factor_graph import ( # noqa: F401 + IFactorGraphOptimizer, + FactorGraphOptimizer, +) diff --git a/src/gps_denied/core/models.py b/src/gps_denied/core/models.py index b38a373..17e653c 100644 --- a/src/gps_denied/core/models.py +++ b/src/gps_denied/core/models.py @@ -10,8 +10,7 @@ file is available, otherwise falls back to Mock. import logging import os -from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Protocol, runtime_checkable import numpy as np @@ -31,26 +30,22 @@ def _is_jetson() -> bool: return os.path.exists("/sys/bus/platform/drivers/tegra-se-nvhost") -class IModelManager(ABC): - @abstractmethod +@runtime_checkable +class IModelManager(Protocol): def load_model(self, model_name: str, model_format: str) -> bool: - pass + ... - @abstractmethod def get_inference_engine(self, model_name: str) -> InferenceEngine: - pass + ... - @abstractmethod def optimize_to_tensorrt(self, model_name: str, onnx_path: str) -> str: - pass + ... - @abstractmethod def fallback_to_onnx(self, model_name: str) -> bool: - pass + ... - @abstractmethod def warmup_model(self, model_name: str) -> bool: - pass + ... class MockInferenceEngine(InferenceEngine): diff --git a/src/gps_denied/core/pipeline.py b/src/gps_denied/core/pipeline.py index d7379bc..5da1abd 100644 --- a/src/gps_denied/core/pipeline.py +++ b/src/gps_denied/core/pipeline.py @@ -1,227 +1,6 @@ -"""Image Input Pipeline (Component F05).""" - -import asyncio -import os -import re -from datetime import datetime, timezone - -import cv2 -import numpy as np - -from gps_denied.schemas.image import ( - ImageBatch, - ImageData, - ImageMetadata, - ProcessedBatch, - ProcessingStatus, - ValidationResult, +"""Legacy import path. Phase 1 shim — code lives in pipeline/image_input.py.""" +from gps_denied.pipeline.image_input import ( # noqa: F401 + ImageInputPipeline, + QueueFullError, + ValidationError, ) - - -class QueueFullError(Exception): - pass - -class ValidationError(Exception): - pass - - -class ImageInputPipeline: - """Manages ingestion, disk storage, and queuing of UAV image batches.""" - - def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50): - self.storage_dir = storage_dir - # flight_id -> asyncio.Queue of ImageBatch - self._queues: dict[str, asyncio.Queue] = {} - self.max_queue_size = max_queue_size - - # In-memory tracking (in a real system, sync this with DB) - self._status: dict[str, dict] = {} - # Exact sequence → filename mapping (VO-05: no substring collision) - self._sequence_map: dict[str, dict[int, str]] = {} - - def _get_queue(self, flight_id: str) -> asyncio.Queue: - if flight_id not in self._queues: - self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size) - return self._queues[flight_id] - - def _init_status(self, flight_id: str): - if flight_id not in self._status: - self._status[flight_id] = { - "total_images": 0, - "processed_images": 0, - "current_sequence": 1, - } - - def validate_batch(self, batch: ImageBatch) -> ValidationResult: - """Validates batch integrity and sequence continuity.""" - errors = [] - - num_images = len(batch.images) - if num_images < 1: - errors.append("Batch is empty") - elif num_images > 100: - errors.append("Batch too large") - - if len(batch.filenames) != num_images: - errors.append("Mismatch between filenames and images count") - - # Naming convention ADxxxxxx.jpg or similar - pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE) - for fn in batch.filenames: - if not pattern.match(fn): - errors.append(f"Invalid filename: {fn}") - break - - if batch.start_sequence > batch.end_sequence: - errors.append("Start sequence greater than end sequence") - - return ValidationResult(valid=len(errors) == 0, errors=errors) - - def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool: - """Queues a batch of images for processing.""" - val = self.validate_batch(batch) - if not val.valid: - raise ValidationError(f"Batch validation failed: {val.errors}") - - q = self._get_queue(flight_id) - if q.full(): - raise QueueFullError(f"Queue for flight {flight_id} is full") - - q.put_nowait(batch) - - self._init_status(flight_id) - self._status[flight_id]["total_images"] += len(batch.images) - - return True - - async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None: - """Dequeues and processing the next batch.""" - q = self._get_queue(flight_id) - if q.empty(): - return None - - batch: ImageBatch = await q.get() - - processed_images = [] - for i, raw_bytes in enumerate(batch.images): - # Decode - nparr = np.frombuffer(raw_bytes, np.uint8) - img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) - - if img is None: - continue # skip corrupted - - seq = batch.start_sequence + i - fn = batch.filenames[i] - - h, w = img.shape[:2] - meta = ImageMetadata( - sequence=seq, - filename=fn, - dimensions=(w, h), - file_size=len(raw_bytes), - timestamp=datetime.now(timezone.utc), - ) - - img_data = ImageData( - flight_id=flight_id, - sequence=seq, - filename=fn, - image=img, - metadata=meta - ) - processed_images.append(img_data) - # VO-05: record exact sequence→filename mapping - self._sequence_map.setdefault(flight_id, {})[seq] = fn - - # Store to disk - self.store_images(flight_id, processed_images) - - self._status[flight_id]["processed_images"] += len(processed_images) - q.task_done() - - return ProcessedBatch( - images=processed_images, - batch_id=f"batch_{batch.batch_number}", - start_sequence=batch.start_sequence, - end_sequence=batch.end_sequence - ) - - def store_images(self, flight_id: str, images: list[ImageData]) -> bool: - """Persists images to disk.""" - flight_dir = os.path.join(self.storage_dir, flight_id) - os.makedirs(flight_dir, exist_ok=True) - - for img in images: - path = os.path.join(flight_dir, img.filename) - cv2.imwrite(path, img.image) - - return True - - def get_next_image(self, flight_id: str) -> ImageData | None: - """Gets the next image in sequence for processing.""" - self._init_status(flight_id) - seq = self._status[flight_id]["current_sequence"] - - img = self.get_image_by_sequence(flight_id, seq) - if img: - self._status[flight_id]["current_sequence"] += 1 - - return img - - def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None: - """Retrieves a specific image by sequence number (exact match — VO-05).""" - flight_dir = os.path.join(self.storage_dir, flight_id) - if not os.path.exists(flight_dir): - return None - - # Prefer the exact mapping built during process_next_batch - fn = self._sequence_map.get(flight_id, {}).get(sequence) - if fn: - path = os.path.join(flight_dir, fn) - img = cv2.imread(path) - if img is not None: - h, w = img.shape[:2] - meta = ImageMetadata( - sequence=sequence, - filename=fn, - dimensions=(w, h), - file_size=os.path.getsize(path), - timestamp=datetime.now(timezone.utc), - ) - return ImageData(flight_id, sequence, fn, img, meta) - - # Fallback: scan directory for exact filename patterns - # (handles images stored before this process started) - for fn in os.listdir(flight_dir): - base, _ = os.path.splitext(fn) - # Accept only if the base name ends with exactly the padded sequence number - if base.endswith(f"{sequence:06d}") or base == str(sequence): - path = os.path.join(flight_dir, fn) - img = cv2.imread(path) - if img is not None: - h, w = img.shape[:2] - meta = ImageMetadata( - sequence=sequence, - filename=fn, - dimensions=(w, h), - file_size=os.path.getsize(path), - timestamp=datetime.now(timezone.utc), - ) - return ImageData(flight_id, sequence, fn, img, meta) - - return None - - def get_processing_status(self, flight_id: str) -> ProcessingStatus: - self._init_status(flight_id) - s = self._status[flight_id] - q = self._get_queue(flight_id) - - return ProcessingStatus( - flight_id=flight_id, - total_images=s["total_images"], - processed_images=s["processed_images"], - current_sequence=s["current_sequence"], - queued_batches=q.qsize(), - processing_rate=0.0 # mock - ) diff --git a/src/gps_denied/core/processor.py b/src/gps_denied/core/processor.py index c0ba248..6a61f85 100644 --- a/src/gps_denied/core/processor.py +++ b/src/gps_denied/core/processor.py @@ -1,599 +1,6 @@ -"""Core Flight Processor — Full Processing Pipeline (Stage 10). - -Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE. -State Machine: NORMAL → LOST → RECOVERY → NORMAL. -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from enum import Enum -from typing import Optional - -import numpy as np - -from gps_denied.core.eskf import ESKF -from gps_denied.core.pipeline import ImageInputPipeline -from gps_denied.core.results import ResultManager -from gps_denied.core.sse import SSEEventStreamer -from gps_denied.db.repository import FlightRepository -from gps_denied.schemas import CameraParameters, GPSPoint -from gps_denied.schemas.flight import ( - BatchMetadata, - BatchResponse, - BatchUpdateResponse, - DeleteResponse, - FlightCreateRequest, - FlightDetailResponse, - FlightResponse, - FlightStatusResponse, - ObjectGPSResponse, - UpdateResponse, - UserFixRequest, - UserFixResponse, - Waypoint, +"""Legacy import path. Phase 1 shim — code lives in pipeline/orchestrator.py.""" +from gps_denied.pipeline.orchestrator import ( # noqa: F401 + FlightProcessor, + TrackingState, + FrameResult, ) - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# State Machine -# --------------------------------------------------------------------------- -class TrackingState(str, Enum): - """Processing state for a flight.""" - NORMAL = "normal" - LOST = "lost" - RECOVERY = "recovery" - - -class FrameResult: - """Intermediate result of processing a single frame.""" - - def __init__(self, frame_id: int): - self.frame_id = frame_id - self.gps: Optional[GPSPoint] = None - self.confidence: float = 0.0 - self.tracking_state: TrackingState = TrackingState.NORMAL - self.vo_success: bool = False - self.alignment_success: bool = False - - -# --------------------------------------------------------------------------- -# FlightProcessor -# --------------------------------------------------------------------------- -class FlightProcessor: - """Manages business logic, background processing, and frame orchestration.""" - - def __init__( - self, - repository: FlightRepository, - streamer: SSEEventStreamer, - eskf_config=None, - ) -> None: - self.repository = repository - self.streamer = streamer - self.result_manager = ResultManager(repository, streamer) - self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50) - self._eskf_config = eskf_config # ESKFConfig or None → default - - # Per-flight processing state - self._flight_states: dict[str, TrackingState] = {} - self._prev_images: dict[str, np.ndarray] = {} # previous frame cache - self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera - self._altitudes: dict[str, float] = {} # per-flight altitude (m) - self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter - - # Per-flight ESKF instances (PIPE-01/07) - self._eskf: dict[str, ESKF] = {} - - # Lazy-initialised component references (set via `attach_components`) - self._vo = None # ISequentialVisualOdometry - self._gpr = None # IGlobalPlaceRecognition - self._metric = None # IMetricRefinement - self._graph = None # IFactorGraphOptimizer - self._recovery = None # IFailureRecoveryCoordinator - self._chunk_mgr = None # IRouteChunkManager - self._rotation = None # ImageRotationManager - self._satellite = None # SatelliteDataManager (PIPE-02) - self._coord = None # CoordinateTransformer (PIPE-02/06) - self._mavlink = None # MAVLinkBridge (PIPE-07) - - # ------ Dependency injection for core components --------- - def attach_components( - self, - vo=None, - gpr=None, - metric=None, - graph=None, - recovery=None, - chunk_mgr=None, - rotation=None, - satellite=None, - coord=None, - mavlink=None, - ): - """Attach pipeline components after construction (avoids circular deps).""" - self._vo = vo - self._gpr = gpr - self._metric = metric - self._graph = graph - self._recovery = recovery - self._chunk_mgr = chunk_mgr - self._rotation = rotation - self._satellite = satellite # PIPE-02: SatelliteDataManager - self._coord = coord # PIPE-02/06: CoordinateTransformer - self._mavlink = mavlink # PIPE-07: MAVLinkBridge - - # ------ ESKF lifecycle helpers ---------------------------- - def _init_eskf_for_flight( - self, flight_id: str, start_gps: GPSPoint, altitude: float - ) -> None: - """Create and initialize a per-flight ESKF instance.""" - if flight_id in self._eskf: - return - eskf = ESKF(config=self._eskf_config) - if self._coord: - try: - e, n, _ = self._coord.gps_to_enu(flight_id, start_gps) - eskf.initialize(np.array([e, n, altitude]), time.time()) - except Exception: - eskf.initialize(np.zeros(3), time.time()) - else: - eskf.initialize(np.zeros(3), time.time()) - self._eskf[flight_id] = eskf - - def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]: - """Convert current ESKF ENU position to WGS84 GPS.""" - if not eskf.initialized or self._coord is None: - return None - try: - pos = eskf.position - return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))) - except Exception: - return None - - # ========================================================= - # process_frame — central orchestration - # ========================================================= - async def process_frame( - self, - flight_id: str, - frame_id: int, - image: np.ndarray, - ) -> FrameResult: - """ - Process a single UAV frame through the full pipeline. - - State transitions: - NORMAL — VO succeeds → ESKF VO update, attempt satellite fix - LOST — VO failed → create new chunk, enter RECOVERY - RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL - - PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT - PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection - PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator - PIPE-05: ImageRotationManager initialised on first frame - PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state - """ - result = FrameResult(frame_id) - state = self._flight_states.get(flight_id, TrackingState.NORMAL) - eskf = self._eskf.get(flight_id) - - _default_cam = CameraParameters( - focal_length=4.5, sensor_width=6.17, sensor_height=4.55, - resolution_width=640, resolution_height=480, - ) - - # ---- PIPE-05: Initialise heading tracking on first frame ---- - if self._rotation and frame_id == 0: - self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory - - # ---- 1. Visual Odometry (frame-to-frame) ---- - vo_ok = False - if self._vo and flight_id in self._prev_images: - try: - cam = self._flight_cameras.get(flight_id, _default_cam) - rel_pose = self._vo.compute_relative_pose( - self._prev_images[flight_id], image, cam - ) - if rel_pose and rel_pose.tracking_good: - vo_ok = True - result.vo_success = True - - if self._graph: - self._graph.add_relative_factor( - flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6) - ) - - # PIPE-01: Feed VO relative displacement into ESKF - if eskf and eskf.initialized: - now = time.time() - dt_vo = max(0.01, now - (eskf.last_timestamp or now)) - eskf.update_vo(rel_pose.translation, dt_vo) - except Exception as exc: - logger.warning("VO failed for frame %d: %s", frame_id, exc) - - # Store current image for next frame - self._prev_images[flight_id] = image - - # ---- PIPE-04: Consecutive failure counter ---- - if not vo_ok and frame_id > 0: - self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1 - else: - self._failure_counts[flight_id] = 0 - - # ---- 2. State Machine transitions ---- - if state == TrackingState.NORMAL: - if not vo_ok and frame_id > 0: - state = TrackingState.LOST - logger.info("Flight %s → LOST at frame %d", flight_id, frame_id) - if self._recovery: - self._recovery.handle_tracking_lost(flight_id, frame_id) - - if state == TrackingState.LOST: - state = TrackingState.RECOVERY - - if state == TrackingState.RECOVERY: - recovered = False - if self._recovery and self._chunk_mgr: - active_chunk = self._chunk_mgr.get_active_chunk(flight_id) - if active_chunk: - recovered = self._recovery.process_chunk_recovery( - flight_id, active_chunk.chunk_id, [image] - ) - if recovered: - state = TrackingState.NORMAL - result.alignment_success = True - # PIPE-04: Reset failure count on successful recovery - self._failure_counts[flight_id] = 0 - logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id) - - # ---- 3. Satellite position fix (PIPE-01/02) ---- - if state == TrackingState.NORMAL and self._metric: - sat_tile: Optional[np.ndarray] = None - tile_bounds = None - - # PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection) - if self._satellite and eskf and eskf.initialized: - gps_est = self._eskf_to_gps(flight_id, eskf) - if gps_est: - cov = eskf.covariance - sigma_h = float( - np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0) - ) if cov is not None else 30.0 - sigma_h = max(sigma_h, 5.0) - try: - tile_result = await asyncio.get_event_loop().run_in_executor( - None, - self._satellite.fetch_tiles_for_position, - gps_est, sigma_h, 18, - ) - if tile_result: - sat_tile, tile_bounds = tile_result - except Exception as exc: - logger.debug("Satellite tile fetch failed: %s", exc) - - # Fallback: GPR candidate tile (mock image, real bounds) - if sat_tile is None and self._gpr: - try: - candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1) - if candidates: - sat_tile = np.zeros((256, 256, 3), dtype=np.uint8) - tile_bounds = candidates[0].bounds - except Exception as exc: - logger.debug("GPR tile fallback failed: %s", exc) - - if sat_tile is not None and tile_bounds is not None: - try: - align = self._metric.align_to_satellite(image, sat_tile, tile_bounds) - if align and align.matched: - result.gps = align.gps_center - result.confidence = align.confidence - result.alignment_success = True - - if self._graph: - self._graph.add_absolute_factor( - flight_id, frame_id, - align.gps_center, np.eye(6), - is_user_anchor=False, - ) - - # PIPE-01: ESKF satellite update — noise from RANSAC confidence - if eskf and eskf.initialized and self._coord: - try: - e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center) - alt = self._altitudes.get(flight_id, 100.0) - pos_enu = np.array([e, n, alt]) - noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence)) - eskf.update_satellite(pos_enu, noise_m) - except Exception as exc: - logger.debug("ESKF satellite update failed: %s", exc) - except Exception as exc: - logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc) - - # ---- 4. Graph optimization (incremental) ---- - if self._graph: - opt_result = self._graph.optimize(flight_id, iterations=5) - logger.debug( - "Optimization: converged=%s, error=%.4f", - opt_result.converged, opt_result.final_error, - ) - - # ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ---- - if self._mavlink and eskf and eskf.initialized: - try: - eskf_state = eskf.get_state() - alt = self._altitudes.get(flight_id, 100.0) - self._mavlink.update_state(eskf_state, altitude_m=alt) - except Exception as exc: - logger.debug("MAVLink state push failed: %s", exc) - - # ---- 5. Publish via SSE ---- - result.tracking_state = state - self._flight_states[flight_id] = state - await self._publish_frame_result(flight_id, result) - return result - - async def _publish_frame_result(self, flight_id: str, result: FrameResult): - """Emit SSE event for processed frame.""" - event_data = { - "frame_id": result.frame_id, - "tracking_state": result.tracking_state.value, - "vo_success": result.vo_success, - "alignment_success": result.alignment_success, - "confidence": result.confidence, - } - if result.gps: - event_data["lat"] = result.gps.lat - event_data["lon"] = result.gps.lon - - await self.streamer.push_event( - flight_id, event_type="frame_result", data=event_data - ) - - # ========================================================= - # Existing CRUD / REST helpers (unchanged from Stage 3-4) - # ========================================================= - async def create_flight(self, req: FlightCreateRequest) -> FlightResponse: - flight = await self.repository.insert_flight( - name=req.name, - description=req.description, - start_lat=req.start_gps.lat, - start_lon=req.start_gps.lon, - altitude=req.altitude, - camera_params=req.camera_params.model_dump(), - ) - # P0#2: Store camera params for process_frame VO calls - self._flight_cameras[flight.id] = req.camera_params - - for poly in req.geofences.polygons: - await self.repository.insert_geofence( - flight.id, - nw_lat=poly.north_west.lat, - nw_lon=poly.north_west.lon, - se_lat=poly.south_east.lat, - se_lon=poly.south_east.lon, - ) - for w in req.rough_waypoints: - await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon) - - # Store per-flight altitude for ESKF/pixel projection - self._altitudes[flight.id] = req.altitude or 100.0 - - # PIPE-02: Set ENU origin and initialise ESKF for this flight - if self._coord: - self._coord.set_enu_origin(flight.id, req.start_gps) - self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0) - - # Start MAVLink bridge for this flight (origin required for GPS_INPUT) - if self._mavlink and not self._mavlink._running: - try: - asyncio.create_task(self._mavlink.start(req.start_gps)) - except Exception as exc: - logger.warning("MAVLink bridge start failed: %s", exc) - - return FlightResponse( - flight_id=flight.id, - status="prefetching", - message="Flight created and prefetching started.", - created_at=flight.created_at, - ) - - async def get_flight(self, flight_id: str) -> FlightDetailResponse | None: - flight = await self.repository.get_flight(flight_id) - if not flight: - return None - wps = await self.repository.get_waypoints(flight_id) - state = await self.repository.load_flight_state(flight_id) - - waypoints = [ - Waypoint( - id=w.id, - lat=w.lat, - lon=w.lon, - altitude=w.altitude, - confidence=w.confidence, - timestamp=w.timestamp, - refined=w.refined, - ) - for w in wps - ] - - status = state.status if state else "unknown" - frames_processed = state.frames_processed if state else 0 - frames_total = state.frames_total if state else 0 - - from gps_denied.schemas import Geofences - - return FlightDetailResponse( - flight_id=flight.id, - name=flight.name, - description=flight.description, - start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon), - waypoints=waypoints, - geofences=Geofences(polygons=[]), - camera_params=flight.camera_params, - altitude=flight.altitude, - status=status, - frames_processed=frames_processed, - frames_total=frames_total, - created_at=flight.created_at, - updated_at=flight.updated_at, - ) - - async def delete_flight(self, flight_id: str) -> DeleteResponse: - deleted = await self.repository.delete_flight(flight_id) - # P0#1: Cleanup in-memory state to prevent memory leaks - self._cleanup_flight(flight_id) - return DeleteResponse(deleted=deleted, flight_id=flight_id) - - def _cleanup_flight(self, flight_id: str) -> None: - """Remove all in-memory state for a flight (prevents memory leaks).""" - self._prev_images.pop(flight_id, None) - self._flight_states.pop(flight_id, None) - self._flight_cameras.pop(flight_id, None) - self._altitudes.pop(flight_id, None) - self._failure_counts.pop(flight_id, None) - self._eskf.pop(flight_id, None) - if self._graph: - self._graph.delete_flight_graph(flight_id) - - async def update_waypoint( - self, flight_id: str, waypoint_id: str, waypoint: Waypoint - ) -> UpdateResponse: - ok = await self.repository.update_waypoint( - flight_id, - waypoint_id, - lat=waypoint.lat, - lon=waypoint.lon, - altitude=waypoint.altitude, - confidence=waypoint.confidence, - refined=waypoint.refined, - ) - return UpdateResponse(updated=ok, waypoint_id=waypoint_id) - - async def batch_update_waypoints( - self, flight_id: str, waypoints: list[Waypoint] - ) -> BatchUpdateResponse: - failed = [] - updated = 0 - for wp in waypoints: - ok = await self.repository.update_waypoint( - flight_id, - wp.id, - lat=wp.lat, - lon=wp.lon, - altitude=wp.altitude, - confidence=wp.confidence, - refined=wp.refined, - ) - if ok: - updated += 1 - else: - failed.append(wp.id) - return BatchUpdateResponse( - success=(len(failed) == 0), updated_count=updated, failed_ids=failed - ) - - async def queue_images( - self, flight_id: str, metadata: BatchMetadata, file_count: int - ) -> BatchResponse: - state = await self.repository.load_flight_state(flight_id) - if state: - total = state.frames_total + file_count - await self.repository.save_flight_state( - flight_id, frames_total=total, status="processing" - ) - - next_seq = metadata.end_sequence + 1 - seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1)) - return BatchResponse( - accepted=True, - sequences=seqs, - next_expected=next_seq, - message=f"Queued {file_count} images.", - ) - - async def handle_user_fix( - self, flight_id: str, req: UserFixRequest - ) -> UserFixResponse: - await self.repository.save_flight_state( - flight_id, blocked=False, status="processing" - ) - - # Inject operator position into ESKF with high uncertainty (500m) - eskf = self._eskf.get(flight_id) - if eskf and eskf.initialized and self._coord: - try: - e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps) - alt = self._altitudes.get(flight_id, 100.0) - eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0) - self._failure_counts[flight_id] = 0 - logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps) - except Exception as exc: - logger.warning("User fix ESKF injection failed: %s", exc) - - return UserFixResponse( - accepted=True, processing_resumed=True, message="Fix applied." - ) - - async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None: - state = await self.repository.load_flight_state(flight_id) - if not state: - return None - return FlightStatusResponse( - status=state.status, - frames_processed=state.frames_processed, - frames_total=state.frames_total, - current_frame=state.current_frame, - current_heading=None, - blocked=state.blocked, - search_grid_size=state.search_grid_size, - created_at=state.created_at, - updated_at=state.updated_at, - ) - - async def convert_object_to_gps( - self, flight_id: str, frame_id: int, pixel: tuple[float, float] - ) -> ObjectGPSResponse: - # PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection - gps: Optional[GPSPoint] = None - eskf = self._eskf.get(flight_id) - if self._coord and eskf and eskf.initialized: - pos = eskf.position - quat = eskf.quaternion - cam = self._flight_cameras.get(flight_id, CameraParameters( - focal_length=4.5, sensor_width=6.17, sensor_height=4.55, - resolution_width=640, resolution_height=480, - )) - alt = self._altitudes.get(flight_id, 100.0) - try: - gps = self._coord.pixel_to_gps( - flight_id, - pixel, - frame_pose={"position": pos}, - camera_params=cam, - altitude=float(alt), - quaternion=quat, - ) - except Exception as exc: - logger.debug("pixel_to_gps failed: %s", exc) - - # Fallback: return ESKF position projected to ground (no pixel shift) - if gps is None and eskf: - gps = self._eskf_to_gps(flight_id, eskf) - - return ObjectGPSResponse( - gps=gps or GPSPoint(lat=0.0, lon=0.0), - accuracy_meters=5.0, - frame_id=frame_id, - pixel=pixel, - ) - - async def stream_events(self, flight_id: str, client_id: str): - """Async generator for SSE stream.""" - async for event in self.streamer.stream_generator(flight_id, client_id): - yield event diff --git a/src/gps_denied/core/recovery.py b/src/gps_denied/core/recovery.py index 0eca2d6..436f3d2 100644 --- a/src/gps_denied/core/recovery.py +++ b/src/gps_denied/core/recovery.py @@ -1,8 +1,7 @@ """Failure Recovery Coordinator (Component F11).""" import logging -from abc import ABC, abstractmethod -from typing import List +from typing import List, Protocol, runtime_checkable import numpy as np @@ -14,14 +13,13 @@ from gps_denied.schemas.chunk import ChunkStatus logger = logging.getLogger(__name__) -class IFailureRecoveryCoordinator(ABC): - @abstractmethod +@runtime_checkable +class IFailureRecoveryCoordinator(Protocol): def handle_tracking_lost(self, flight_id: str, current_frame_id: int) -> bool: - pass + ... - @abstractmethod def process_chunk_recovery(self, flight_id: str, chunk_id: str, images: List[np.ndarray]) -> bool: - pass + ... class FailureRecoveryCoordinator(IFailureRecoveryCoordinator): diff --git a/src/gps_denied/core/results.py b/src/gps_denied/core/results.py index 99e3bca..c2361c4 100644 --- a/src/gps_denied/core/results.py +++ b/src/gps_denied/core/results.py @@ -1,73 +1,4 @@ -"""Result Manager (Component F14).""" - -from __future__ import annotations - -from datetime import datetime - -from gps_denied.core.sse import SSEEventStreamer -from gps_denied.db.repository import FlightRepository -from gps_denied.schemas import GPSPoint -from gps_denied.schemas.events import FrameProcessedEvent - - -class ResultManager: - """Result consistency and publishing.""" - - def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None: - self.repo = repo - self.sse = sse - - async def update_frame_result( - self, - flight_id: str, - frame_id: int, - gps_lat: float, - gps_lon: float, - altitude: float, - heading: float, - confidence: float, - timestamp: datetime, - refined: bool = False, - ) -> bool: - """Atomic DB update + SSE event publish.""" - - # 1. Update DB (in the repository these are auto-committing via flush, - # but normally F03 would wrap in a single transaction). - await self.repo.save_frame_result( - flight_id, - frame_id=frame_id, - gps_lat=gps_lat, - gps_lon=gps_lon, - altitude=altitude, - heading=heading, - confidence=confidence, - refined=refined, - ) - - # Wait, the spec also wants Waypoints to be updated. - # But image frames != waypoints. Waypoints are the planned route. - # Actually in the spec it says: "Updates waypoint in waypoints table." - # This implies updating the closest waypoint or a generated waypoint path. - # We will follow the simplest form for now: update the waypoint if there is one corresponding. - # Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map, - # or we just skip unless specified. - - # 2. Trigger SSE event - evt = FrameProcessedEvent( - frame_id=frame_id, - gps=GPSPoint(lat=gps_lat, lon=gps_lon), - altitude=altitude, - confidence=confidence, - heading=heading, - timestamp=timestamp, - ) - if refined: - self.sse.send_refinement(flight_id, evt) - else: - self.sse.send_frame_result(flight_id, evt) - - return True - - async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool: - # Just delegates to SSE for waypoint updates, which is basically the frame result for UI - pass +"""Legacy import path. Phase 1 shim — code lives in pipeline/result_manager.py.""" +from gps_denied.pipeline.result_manager import ( # noqa: F401 + ResultManager, +) diff --git a/src/gps_denied/core/rotation.py b/src/gps_denied/core/rotation.py index b4dfe9d..d75e495 100644 --- a/src/gps_denied/core/rotation.py +++ b/src/gps_denied/core/rotation.py @@ -2,8 +2,8 @@ import dataclasses import math -from abc import ABC, abstractmethod from datetime import datetime +from typing import Protocol, runtime_checkable import cv2 import numpy as np @@ -12,14 +12,14 @@ from gps_denied.schemas.rotation import HeadingHistory, RotationResult from gps_denied.schemas.satellite import TileBounds -class IImageMatcher(ABC): +@runtime_checkable +class IImageMatcher(Protocol): """Dependency injection interface for Metric Refinement.""" - @abstractmethod def align_to_satellite( self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds, ) -> RotationResult: - pass + ... class ImageRotationManager: diff --git a/src/gps_denied/core/sse.py b/src/gps_denied/core/sse.py index 10a0699..a822f30 100644 --- a/src/gps_denied/core/sse.py +++ b/src/gps_denied/core/sse.py @@ -1,164 +1,4 @@ -"""SSE Event Streamer (Component F15).""" - -from __future__ import annotations - -import asyncio -import json -from collections import defaultdict - -from gps_denied.schemas.events import ( - FlightCompletedEvent, - FrameProcessedEvent, - SearchExpandedEvent, - SSEEventType, - SSEMessage, - UserInputNeededEvent, +"""Legacy import path. Phase 1 shim — code lives in pipeline/sse_streamer.py.""" +from gps_denied.pipeline.sse_streamer import ( # noqa: F401 + SSEEventStreamer, ) - - -class SSEEventStreamer: - """Manages real-time SSE connections and event broadcasting.""" - - def __init__(self) -> None: - # Map: flight_id -> Dict[client_id, asyncio.Queue] - self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict) - - def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]: - """Create a new event queue for a client.""" - q: asyncio.Queue[SSEMessage | None] = asyncio.Queue() - self._streams[flight_id][client_id] = q - return q - - def close_stream(self, flight_id: str, client_id: str) -> None: - """Close a client stream by putting a sentinel and removing the queue.""" - if flight_id in self._streams and client_id in self._streams[flight_id]: - q = self._streams[flight_id].pop(client_id) - if not self._streams[flight_id]: - del self._streams[flight_id] - # Put None to signal generator exit - try: - q.put_nowait(None) - except asyncio.QueueFull: - pass - - def get_active_connections(self, flight_id: str) -> int: - return len(self._streams.get(flight_id, {})) - - def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool: - """Broadcast a message to all clients subscribed to flight_id.""" - if flight_id not in self._streams or not self._streams[flight_id]: - return False - - for q in self._streams[flight_id].values(): - try: - q.put_nowait(msg) - except asyncio.QueueFull: - pass # Drop if queue is full rather than blocking - return True - - # ── Business Event Senders ──────────────────────────────────────────────── - - def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool: - msg = SSEMessage( - event=SSEEventType.FRAME_PROCESSED, - data=event_data.model_dump(mode="json"), - id=f"frame_{event_data.frame_id}", - ) - return self._broadcast(flight_id, msg) - - def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool: - msg = SSEMessage( - event=SSEEventType.FRAME_REFINED, - data=event_data.model_dump(mode="json"), - id=f"refine_{event_data.frame_id}", - ) - return self._broadcast(flight_id, msg) - - def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool: - msg = SSEMessage( - event=SSEEventType.SEARCH_EXPANDED, - data=event_data.model_dump(mode="json"), - ) - return self._broadcast(flight_id, msg) - - def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool: - msg = SSEMessage( - event=SSEEventType.USER_INPUT_NEEDED, - data=event_data.model_dump(mode="json"), - ) - return self._broadcast(flight_id, msg) - - def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool: - msg = SSEMessage( - event=SSEEventType.FLIGHT_COMPLETED, - data=event_data.model_dump(mode="json"), - ) - return self._broadcast(flight_id, msg) - - def send_heartbeat(self, flight_id: str) -> bool: - # sse_starlette uses empty string or comment for heartbeat, - # but we can just send an SSEMessage object that parses as empty event - if flight_id not in self._streams: - return False - - # Manually sending a comment via the generator is tricky with strict SSEMessage schema - # but we'll handle this in the stream generator directly - return True - - # ── Generic event dispatcher (used by processor.process_frame) ────────── - - async def push_event(self, flight_id: str, event_type: str, data: dict) -> None: - """Dispatch a generic event to all clients for a flight. - - Maps event_type strings to typed SSE events: - "frame_result" → FrameProcessedEvent - "refinement" → FrameProcessedEvent (refined) - Other → raw broadcast via SSEMessage - """ - if event_type == "frame_result": - evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data - self.send_frame_result(flight_id, evt) - elif event_type == "refinement": - evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data - self.send_refinement(flight_id, evt) - else: - msg = SSEMessage( - event=SSEEventType.FRAME_PROCESSED, - data=data, - id=str(data.get("frame_id", "")), - ) - self._broadcast(flight_id, msg) - - # ── Stream Generator ────────────────────────────────────────────────────── - - async def stream_generator(self, flight_id: str, client_id: str): - """Yields dicts for sse_starlette EventSourceResponse.""" - q = self.create_stream(flight_id, client_id) - - # Send an immediate connection accepted ping - yield {"event": "connected", "data": "connected"} - - try: - while True: - # Wait for next event or send heartbeat every 15s - try: - msg = await asyncio.wait_for(q.get(), timeout=15.0) - if msg is None: - # Sentinel for clean shutdown - break - - # Yield dict format for sse_starlette - yield { - "event": msg.event.value, - "id": msg.id if msg.id else "", - "data": json.dumps(msg.data) - } - - except asyncio.TimeoutError: - # Heartbeat format for sse_starlette (empty string generates a comment) - yield {"event": "heartbeat", "data": "ping"} - - except asyncio.CancelledError: - pass # Client disconnected - finally: - self.close_stream(flight_id, client_id) diff --git a/src/gps_denied/pipeline/__init__.py b/src/gps_denied/pipeline/__init__.py new file mode 100644 index 0000000..71598bb --- /dev/null +++ b/src/gps_denied/pipeline/__init__.py @@ -0,0 +1,15 @@ +"""Pipeline package: orchestrator + IO + composition root.""" +from .orchestrator import FlightProcessor +from .image_input import ImageInputPipeline +from .result_manager import ResultManager +from .sse_streamer import SSEEventStreamer + +Pipeline = FlightProcessor + +__all__ = [ + "FlightProcessor", + "Pipeline", + "ImageInputPipeline", + "ResultManager", + "SSEEventStreamer", +] diff --git a/src/gps_denied/pipeline/image_input.py b/src/gps_denied/pipeline/image_input.py new file mode 100644 index 0000000..d7379bc --- /dev/null +++ b/src/gps_denied/pipeline/image_input.py @@ -0,0 +1,227 @@ +"""Image Input Pipeline (Component F05).""" + +import asyncio +import os +import re +from datetime import datetime, timezone + +import cv2 +import numpy as np + +from gps_denied.schemas.image import ( + ImageBatch, + ImageData, + ImageMetadata, + ProcessedBatch, + ProcessingStatus, + ValidationResult, +) + + +class QueueFullError(Exception): + pass + +class ValidationError(Exception): + pass + + +class ImageInputPipeline: + """Manages ingestion, disk storage, and queuing of UAV image batches.""" + + def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50): + self.storage_dir = storage_dir + # flight_id -> asyncio.Queue of ImageBatch + self._queues: dict[str, asyncio.Queue] = {} + self.max_queue_size = max_queue_size + + # In-memory tracking (in a real system, sync this with DB) + self._status: dict[str, dict] = {} + # Exact sequence → filename mapping (VO-05: no substring collision) + self._sequence_map: dict[str, dict[int, str]] = {} + + def _get_queue(self, flight_id: str) -> asyncio.Queue: + if flight_id not in self._queues: + self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size) + return self._queues[flight_id] + + def _init_status(self, flight_id: str): + if flight_id not in self._status: + self._status[flight_id] = { + "total_images": 0, + "processed_images": 0, + "current_sequence": 1, + } + + def validate_batch(self, batch: ImageBatch) -> ValidationResult: + """Validates batch integrity and sequence continuity.""" + errors = [] + + num_images = len(batch.images) + if num_images < 1: + errors.append("Batch is empty") + elif num_images > 100: + errors.append("Batch too large") + + if len(batch.filenames) != num_images: + errors.append("Mismatch between filenames and images count") + + # Naming convention ADxxxxxx.jpg or similar + pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE) + for fn in batch.filenames: + if not pattern.match(fn): + errors.append(f"Invalid filename: {fn}") + break + + if batch.start_sequence > batch.end_sequence: + errors.append("Start sequence greater than end sequence") + + return ValidationResult(valid=len(errors) == 0, errors=errors) + + def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool: + """Queues a batch of images for processing.""" + val = self.validate_batch(batch) + if not val.valid: + raise ValidationError(f"Batch validation failed: {val.errors}") + + q = self._get_queue(flight_id) + if q.full(): + raise QueueFullError(f"Queue for flight {flight_id} is full") + + q.put_nowait(batch) + + self._init_status(flight_id) + self._status[flight_id]["total_images"] += len(batch.images) + + return True + + async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None: + """Dequeues and processing the next batch.""" + q = self._get_queue(flight_id) + if q.empty(): + return None + + batch: ImageBatch = await q.get() + + processed_images = [] + for i, raw_bytes in enumerate(batch.images): + # Decode + nparr = np.frombuffer(raw_bytes, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if img is None: + continue # skip corrupted + + seq = batch.start_sequence + i + fn = batch.filenames[i] + + h, w = img.shape[:2] + meta = ImageMetadata( + sequence=seq, + filename=fn, + dimensions=(w, h), + file_size=len(raw_bytes), + timestamp=datetime.now(timezone.utc), + ) + + img_data = ImageData( + flight_id=flight_id, + sequence=seq, + filename=fn, + image=img, + metadata=meta + ) + processed_images.append(img_data) + # VO-05: record exact sequence→filename mapping + self._sequence_map.setdefault(flight_id, {})[seq] = fn + + # Store to disk + self.store_images(flight_id, processed_images) + + self._status[flight_id]["processed_images"] += len(processed_images) + q.task_done() + + return ProcessedBatch( + images=processed_images, + batch_id=f"batch_{batch.batch_number}", + start_sequence=batch.start_sequence, + end_sequence=batch.end_sequence + ) + + def store_images(self, flight_id: str, images: list[ImageData]) -> bool: + """Persists images to disk.""" + flight_dir = os.path.join(self.storage_dir, flight_id) + os.makedirs(flight_dir, exist_ok=True) + + for img in images: + path = os.path.join(flight_dir, img.filename) + cv2.imwrite(path, img.image) + + return True + + def get_next_image(self, flight_id: str) -> ImageData | None: + """Gets the next image in sequence for processing.""" + self._init_status(flight_id) + seq = self._status[flight_id]["current_sequence"] + + img = self.get_image_by_sequence(flight_id, seq) + if img: + self._status[flight_id]["current_sequence"] += 1 + + return img + + def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None: + """Retrieves a specific image by sequence number (exact match — VO-05).""" + flight_dir = os.path.join(self.storage_dir, flight_id) + if not os.path.exists(flight_dir): + return None + + # Prefer the exact mapping built during process_next_batch + fn = self._sequence_map.get(flight_id, {}).get(sequence) + if fn: + path = os.path.join(flight_dir, fn) + img = cv2.imread(path) + if img is not None: + h, w = img.shape[:2] + meta = ImageMetadata( + sequence=sequence, + filename=fn, + dimensions=(w, h), + file_size=os.path.getsize(path), + timestamp=datetime.now(timezone.utc), + ) + return ImageData(flight_id, sequence, fn, img, meta) + + # Fallback: scan directory for exact filename patterns + # (handles images stored before this process started) + for fn in os.listdir(flight_dir): + base, _ = os.path.splitext(fn) + # Accept only if the base name ends with exactly the padded sequence number + if base.endswith(f"{sequence:06d}") or base == str(sequence): + path = os.path.join(flight_dir, fn) + img = cv2.imread(path) + if img is not None: + h, w = img.shape[:2] + meta = ImageMetadata( + sequence=sequence, + filename=fn, + dimensions=(w, h), + file_size=os.path.getsize(path), + timestamp=datetime.now(timezone.utc), + ) + return ImageData(flight_id, sequence, fn, img, meta) + + return None + + def get_processing_status(self, flight_id: str) -> ProcessingStatus: + self._init_status(flight_id) + s = self._status[flight_id] + q = self._get_queue(flight_id) + + return ProcessingStatus( + flight_id=flight_id, + total_images=s["total_images"], + processed_images=s["processed_images"], + current_sequence=s["current_sequence"], + queued_batches=q.qsize(), + processing_rate=0.0 # mock + ) diff --git a/src/gps_denied/pipeline/orchestrator.py b/src/gps_denied/pipeline/orchestrator.py new file mode 100644 index 0000000..0038996 --- /dev/null +++ b/src/gps_denied/pipeline/orchestrator.py @@ -0,0 +1,599 @@ +"""Core Flight Processor — Full Processing Pipeline (Stage 10). + +Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE. +State Machine: NORMAL → LOST → RECOVERY → NORMAL. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from enum import Enum +from typing import Optional + +import numpy as np + +from gps_denied.core.eskf import ESKF +from gps_denied.pipeline.image_input import ImageInputPipeline +from gps_denied.pipeline.result_manager import ResultManager +from gps_denied.pipeline.sse_streamer import SSEEventStreamer +from gps_denied.db.repository import FlightRepository +from gps_denied.schemas import CameraParameters, GPSPoint +from gps_denied.schemas.flight import ( + BatchMetadata, + BatchResponse, + BatchUpdateResponse, + DeleteResponse, + FlightCreateRequest, + FlightDetailResponse, + FlightResponse, + FlightStatusResponse, + ObjectGPSResponse, + UpdateResponse, + UserFixRequest, + UserFixResponse, + Waypoint, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# State Machine +# --------------------------------------------------------------------------- +class TrackingState(str, Enum): + """Processing state for a flight.""" + NORMAL = "normal" + LOST = "lost" + RECOVERY = "recovery" + + +class FrameResult: + """Intermediate result of processing a single frame.""" + + def __init__(self, frame_id: int): + self.frame_id = frame_id + self.gps: Optional[GPSPoint] = None + self.confidence: float = 0.0 + self.tracking_state: TrackingState = TrackingState.NORMAL + self.vo_success: bool = False + self.alignment_success: bool = False + + +# --------------------------------------------------------------------------- +# FlightProcessor +# --------------------------------------------------------------------------- +class FlightProcessor: + """Manages business logic, background processing, and frame orchestration.""" + + def __init__( + self, + repository: FlightRepository, + streamer: SSEEventStreamer, + eskf_config=None, + ) -> None: + self.repository = repository + self.streamer = streamer + self.result_manager = ResultManager(repository, streamer) + self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50) + self._eskf_config = eskf_config # ESKFConfig or None → default + + # Per-flight processing state + self._flight_states: dict[str, TrackingState] = {} + self._prev_images: dict[str, np.ndarray] = {} # previous frame cache + self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera + self._altitudes: dict[str, float] = {} # per-flight altitude (m) + self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter + + # Per-flight ESKF instances (PIPE-01/07) + self._eskf: dict[str, ESKF] = {} + + # Lazy-initialised component references (set via `attach_components`) + self._vo = None # ISequentialVisualOdometry + self._gpr = None # IGlobalPlaceRecognition + self._metric = None # IMetricRefinement + self._graph = None # IFactorGraphOptimizer + self._recovery = None # IFailureRecoveryCoordinator + self._chunk_mgr = None # IRouteChunkManager + self._rotation = None # ImageRotationManager + self._satellite = None # SatelliteDataManager (PIPE-02) + self._coord = None # CoordinateTransformer (PIPE-02/06) + self._mavlink = None # MAVLinkBridge (PIPE-07) + + # ------ Dependency injection for core components --------- + def attach_components( + self, + vo=None, + gpr=None, + metric=None, + graph=None, + recovery=None, + chunk_mgr=None, + rotation=None, + satellite=None, + coord=None, + mavlink=None, + ): + """Attach pipeline components after construction (avoids circular deps).""" + self._vo = vo + self._gpr = gpr + self._metric = metric + self._graph = graph + self._recovery = recovery + self._chunk_mgr = chunk_mgr + self._rotation = rotation + self._satellite = satellite # PIPE-02: SatelliteDataManager + self._coord = coord # PIPE-02/06: CoordinateTransformer + self._mavlink = mavlink # PIPE-07: MAVLinkBridge + + # ------ ESKF lifecycle helpers ---------------------------- + def _init_eskf_for_flight( + self, flight_id: str, start_gps: GPSPoint, altitude: float + ) -> None: + """Create and initialize a per-flight ESKF instance.""" + if flight_id in self._eskf: + return + eskf = ESKF(config=self._eskf_config) + if self._coord: + try: + e, n, _ = self._coord.gps_to_enu(flight_id, start_gps) + eskf.initialize(np.array([e, n, altitude]), time.time()) + except Exception: + eskf.initialize(np.zeros(3), time.time()) + else: + eskf.initialize(np.zeros(3), time.time()) + self._eskf[flight_id] = eskf + + def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]: + """Convert current ESKF ENU position to WGS84 GPS.""" + if not eskf.initialized or self._coord is None: + return None + try: + pos = eskf.position + return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))) + except Exception: + return None + + # ========================================================= + # process_frame — central orchestration + # ========================================================= + async def process_frame( + self, + flight_id: str, + frame_id: int, + image: np.ndarray, + ) -> FrameResult: + """ + Process a single UAV frame through the full pipeline. + + State transitions: + NORMAL — VO succeeds → ESKF VO update, attempt satellite fix + LOST — VO failed → create new chunk, enter RECOVERY + RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL + + PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT + PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection + PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator + PIPE-05: ImageRotationManager initialised on first frame + PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state + """ + result = FrameResult(frame_id) + state = self._flight_states.get(flight_id, TrackingState.NORMAL) + eskf = self._eskf.get(flight_id) + + _default_cam = CameraParameters( + focal_length=4.5, sensor_width=6.17, sensor_height=4.55, + resolution_width=640, resolution_height=480, + ) + + # ---- PIPE-05: Initialise heading tracking on first frame ---- + if self._rotation and frame_id == 0: + self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory + + # ---- 1. Visual Odometry (frame-to-frame) ---- + vo_ok = False + if self._vo and flight_id in self._prev_images: + try: + cam = self._flight_cameras.get(flight_id, _default_cam) + rel_pose = self._vo.compute_relative_pose( + self._prev_images[flight_id], image, cam + ) + if rel_pose and rel_pose.tracking_good: + vo_ok = True + result.vo_success = True + + if self._graph: + self._graph.add_relative_factor( + flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6) + ) + + # PIPE-01: Feed VO relative displacement into ESKF + if eskf and eskf.initialized: + now = time.time() + dt_vo = max(0.01, now - (eskf.last_timestamp or now)) + eskf.update_vo(rel_pose.translation, dt_vo) + except Exception as exc: + logger.warning("VO failed for frame %d: %s", frame_id, exc) + + # Store current image for next frame + self._prev_images[flight_id] = image + + # ---- PIPE-04: Consecutive failure counter ---- + if not vo_ok and frame_id > 0: + self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1 + else: + self._failure_counts[flight_id] = 0 + + # ---- 2. State Machine transitions ---- + if state == TrackingState.NORMAL: + if not vo_ok and frame_id > 0: + state = TrackingState.LOST + logger.info("Flight %s → LOST at frame %d", flight_id, frame_id) + if self._recovery: + self._recovery.handle_tracking_lost(flight_id, frame_id) + + if state == TrackingState.LOST: + state = TrackingState.RECOVERY + + if state == TrackingState.RECOVERY: + recovered = False + if self._recovery and self._chunk_mgr: + active_chunk = self._chunk_mgr.get_active_chunk(flight_id) + if active_chunk: + recovered = self._recovery.process_chunk_recovery( + flight_id, active_chunk.chunk_id, [image] + ) + if recovered: + state = TrackingState.NORMAL + result.alignment_success = True + # PIPE-04: Reset failure count on successful recovery + self._failure_counts[flight_id] = 0 + logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id) + + # ---- 3. Satellite position fix (PIPE-01/02) ---- + if state == TrackingState.NORMAL and self._metric: + sat_tile: Optional[np.ndarray] = None + tile_bounds = None + + # PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection) + if self._satellite and eskf and eskf.initialized: + gps_est = self._eskf_to_gps(flight_id, eskf) + if gps_est: + cov = eskf.covariance + sigma_h = float( + np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0) + ) if cov is not None else 30.0 + sigma_h = max(sigma_h, 5.0) + try: + tile_result = await asyncio.get_event_loop().run_in_executor( + None, + self._satellite.fetch_tiles_for_position, + gps_est, sigma_h, 18, + ) + if tile_result: + sat_tile, tile_bounds = tile_result + except Exception as exc: + logger.debug("Satellite tile fetch failed: %s", exc) + + # Fallback: GPR candidate tile (mock image, real bounds) + if sat_tile is None and self._gpr: + try: + candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1) + if candidates: + sat_tile = np.zeros((256, 256, 3), dtype=np.uint8) + tile_bounds = candidates[0].bounds + except Exception as exc: + logger.debug("GPR tile fallback failed: %s", exc) + + if sat_tile is not None and tile_bounds is not None: + try: + align = self._metric.align_to_satellite(image, sat_tile, tile_bounds) + if align and align.matched: + result.gps = align.gps_center + result.confidence = align.confidence + result.alignment_success = True + + if self._graph: + self._graph.add_absolute_factor( + flight_id, frame_id, + align.gps_center, np.eye(6), + is_user_anchor=False, + ) + + # PIPE-01: ESKF satellite update — noise from RANSAC confidence + if eskf and eskf.initialized and self._coord: + try: + e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center) + alt = self._altitudes.get(flight_id, 100.0) + pos_enu = np.array([e, n, alt]) + noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence)) + eskf.update_satellite(pos_enu, noise_m) + except Exception as exc: + logger.debug("ESKF satellite update failed: %s", exc) + except Exception as exc: + logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc) + + # ---- 4. Graph optimization (incremental) ---- + if self._graph: + opt_result = self._graph.optimize(flight_id, iterations=5) + logger.debug( + "Optimization: converged=%s, error=%.4f", + opt_result.converged, opt_result.final_error, + ) + + # ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ---- + if self._mavlink and eskf and eskf.initialized: + try: + eskf_state = eskf.get_state() + alt = self._altitudes.get(flight_id, 100.0) + self._mavlink.update_state(eskf_state, altitude_m=alt) + except Exception as exc: + logger.debug("MAVLink state push failed: %s", exc) + + # ---- 5. Publish via SSE ---- + result.tracking_state = state + self._flight_states[flight_id] = state + await self._publish_frame_result(flight_id, result) + return result + + async def _publish_frame_result(self, flight_id: str, result: FrameResult): + """Emit SSE event for processed frame.""" + event_data = { + "frame_id": result.frame_id, + "tracking_state": result.tracking_state.value, + "vo_success": result.vo_success, + "alignment_success": result.alignment_success, + "confidence": result.confidence, + } + if result.gps: + event_data["lat"] = result.gps.lat + event_data["lon"] = result.gps.lon + + await self.streamer.push_event( + flight_id, event_type="frame_result", data=event_data + ) + + # ========================================================= + # Existing CRUD / REST helpers (unchanged from Stage 3-4) + # ========================================================= + async def create_flight(self, req: FlightCreateRequest) -> FlightResponse: + flight = await self.repository.insert_flight( + name=req.name, + description=req.description, + start_lat=req.start_gps.lat, + start_lon=req.start_gps.lon, + altitude=req.altitude, + camera_params=req.camera_params.model_dump(), + ) + # P0#2: Store camera params for process_frame VO calls + self._flight_cameras[flight.id] = req.camera_params + + for poly in req.geofences.polygons: + await self.repository.insert_geofence( + flight.id, + nw_lat=poly.north_west.lat, + nw_lon=poly.north_west.lon, + se_lat=poly.south_east.lat, + se_lon=poly.south_east.lon, + ) + for w in req.rough_waypoints: + await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon) + + # Store per-flight altitude for ESKF/pixel projection + self._altitudes[flight.id] = req.altitude or 100.0 + + # PIPE-02: Set ENU origin and initialise ESKF for this flight + if self._coord: + self._coord.set_enu_origin(flight.id, req.start_gps) + self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0) + + # Start MAVLink bridge for this flight (origin required for GPS_INPUT) + if self._mavlink and not self._mavlink._running: + try: + asyncio.create_task(self._mavlink.start(req.start_gps)) + except Exception as exc: + logger.warning("MAVLink bridge start failed: %s", exc) + + return FlightResponse( + flight_id=flight.id, + status="prefetching", + message="Flight created and prefetching started.", + created_at=flight.created_at, + ) + + async def get_flight(self, flight_id: str) -> FlightDetailResponse | None: + flight = await self.repository.get_flight(flight_id) + if not flight: + return None + wps = await self.repository.get_waypoints(flight_id) + state = await self.repository.load_flight_state(flight_id) + + waypoints = [ + Waypoint( + id=w.id, + lat=w.lat, + lon=w.lon, + altitude=w.altitude, + confidence=w.confidence, + timestamp=w.timestamp, + refined=w.refined, + ) + for w in wps + ] + + status = state.status if state else "unknown" + frames_processed = state.frames_processed if state else 0 + frames_total = state.frames_total if state else 0 + + from gps_denied.schemas import Geofences + + return FlightDetailResponse( + flight_id=flight.id, + name=flight.name, + description=flight.description, + start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon), + waypoints=waypoints, + geofences=Geofences(polygons=[]), + camera_params=flight.camera_params, + altitude=flight.altitude, + status=status, + frames_processed=frames_processed, + frames_total=frames_total, + created_at=flight.created_at, + updated_at=flight.updated_at, + ) + + async def delete_flight(self, flight_id: str) -> DeleteResponse: + deleted = await self.repository.delete_flight(flight_id) + # P0#1: Cleanup in-memory state to prevent memory leaks + self._cleanup_flight(flight_id) + return DeleteResponse(deleted=deleted, flight_id=flight_id) + + def _cleanup_flight(self, flight_id: str) -> None: + """Remove all in-memory state for a flight (prevents memory leaks).""" + self._prev_images.pop(flight_id, None) + self._flight_states.pop(flight_id, None) + self._flight_cameras.pop(flight_id, None) + self._altitudes.pop(flight_id, None) + self._failure_counts.pop(flight_id, None) + self._eskf.pop(flight_id, None) + if self._graph: + self._graph.delete_flight_graph(flight_id) + + async def update_waypoint( + self, flight_id: str, waypoint_id: str, waypoint: Waypoint + ) -> UpdateResponse: + ok = await self.repository.update_waypoint( + flight_id, + waypoint_id, + lat=waypoint.lat, + lon=waypoint.lon, + altitude=waypoint.altitude, + confidence=waypoint.confidence, + refined=waypoint.refined, + ) + return UpdateResponse(updated=ok, waypoint_id=waypoint_id) + + async def batch_update_waypoints( + self, flight_id: str, waypoints: list[Waypoint] + ) -> BatchUpdateResponse: + failed = [] + updated = 0 + for wp in waypoints: + ok = await self.repository.update_waypoint( + flight_id, + wp.id, + lat=wp.lat, + lon=wp.lon, + altitude=wp.altitude, + confidence=wp.confidence, + refined=wp.refined, + ) + if ok: + updated += 1 + else: + failed.append(wp.id) + return BatchUpdateResponse( + success=(len(failed) == 0), updated_count=updated, failed_ids=failed + ) + + async def queue_images( + self, flight_id: str, metadata: BatchMetadata, file_count: int + ) -> BatchResponse: + state = await self.repository.load_flight_state(flight_id) + if state: + total = state.frames_total + file_count + await self.repository.save_flight_state( + flight_id, frames_total=total, status="processing" + ) + + next_seq = metadata.end_sequence + 1 + seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1)) + return BatchResponse( + accepted=True, + sequences=seqs, + next_expected=next_seq, + message=f"Queued {file_count} images.", + ) + + async def handle_user_fix( + self, flight_id: str, req: UserFixRequest + ) -> UserFixResponse: + await self.repository.save_flight_state( + flight_id, blocked=False, status="processing" + ) + + # Inject operator position into ESKF with high uncertainty (500m) + eskf = self._eskf.get(flight_id) + if eskf and eskf.initialized and self._coord: + try: + e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps) + alt = self._altitudes.get(flight_id, 100.0) + eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0) + self._failure_counts[flight_id] = 0 + logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps) + except Exception as exc: + logger.warning("User fix ESKF injection failed: %s", exc) + + return UserFixResponse( + accepted=True, processing_resumed=True, message="Fix applied." + ) + + async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None: + state = await self.repository.load_flight_state(flight_id) + if not state: + return None + return FlightStatusResponse( + status=state.status, + frames_processed=state.frames_processed, + frames_total=state.frames_total, + current_frame=state.current_frame, + current_heading=None, + blocked=state.blocked, + search_grid_size=state.search_grid_size, + created_at=state.created_at, + updated_at=state.updated_at, + ) + + async def convert_object_to_gps( + self, flight_id: str, frame_id: int, pixel: tuple[float, float] + ) -> ObjectGPSResponse: + # PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection + gps: Optional[GPSPoint] = None + eskf = self._eskf.get(flight_id) + if self._coord and eskf and eskf.initialized: + pos = eskf.position + quat = eskf.quaternion + cam = self._flight_cameras.get(flight_id, CameraParameters( + focal_length=4.5, sensor_width=6.17, sensor_height=4.55, + resolution_width=640, resolution_height=480, + )) + alt = self._altitudes.get(flight_id, 100.0) + try: + gps = self._coord.pixel_to_gps( + flight_id, + pixel, + frame_pose={"position": pos}, + camera_params=cam, + altitude=float(alt), + quaternion=quat, + ) + except Exception as exc: + logger.debug("pixel_to_gps failed: %s", exc) + + # Fallback: return ESKF position projected to ground (no pixel shift) + if gps is None and eskf: + gps = self._eskf_to_gps(flight_id, eskf) + + return ObjectGPSResponse( + gps=gps or GPSPoint(lat=0.0, lon=0.0), + accuracy_meters=5.0, + frame_id=frame_id, + pixel=pixel, + ) + + async def stream_events(self, flight_id: str, client_id: str): + """Async generator for SSE stream.""" + async for event in self.streamer.stream_generator(flight_id, client_id): + yield event diff --git a/src/gps_denied/pipeline/result_manager.py b/src/gps_denied/pipeline/result_manager.py new file mode 100644 index 0000000..b53cd59 --- /dev/null +++ b/src/gps_denied/pipeline/result_manager.py @@ -0,0 +1,73 @@ +"""Result Manager (Component F14).""" + +from __future__ import annotations + +from datetime import datetime + +from gps_denied.pipeline.sse_streamer import SSEEventStreamer +from gps_denied.db.repository import FlightRepository +from gps_denied.schemas import GPSPoint +from gps_denied.schemas.events import FrameProcessedEvent + + +class ResultManager: + """Result consistency and publishing.""" + + def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None: + self.repo = repo + self.sse = sse + + async def update_frame_result( + self, + flight_id: str, + frame_id: int, + gps_lat: float, + gps_lon: float, + altitude: float, + heading: float, + confidence: float, + timestamp: datetime, + refined: bool = False, + ) -> bool: + """Atomic DB update + SSE event publish.""" + + # 1. Update DB (in the repository these are auto-committing via flush, + # but normally F03 would wrap in a single transaction). + await self.repo.save_frame_result( + flight_id, + frame_id=frame_id, + gps_lat=gps_lat, + gps_lon=gps_lon, + altitude=altitude, + heading=heading, + confidence=confidence, + refined=refined, + ) + + # Wait, the spec also wants Waypoints to be updated. + # But image frames != waypoints. Waypoints are the planned route. + # Actually in the spec it says: "Updates waypoint in waypoints table." + # This implies updating the closest waypoint or a generated waypoint path. + # We will follow the simplest form for now: update the waypoint if there is one corresponding. + # Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map, + # or we just skip unless specified. + + # 2. Trigger SSE event + evt = FrameProcessedEvent( + frame_id=frame_id, + gps=GPSPoint(lat=gps_lat, lon=gps_lon), + altitude=altitude, + confidence=confidence, + heading=heading, + timestamp=timestamp, + ) + if refined: + self.sse.send_refinement(flight_id, evt) + else: + self.sse.send_frame_result(flight_id, evt) + + return True + + async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool: + # Just delegates to SSE for waypoint updates, which is basically the frame result for UI + pass diff --git a/src/gps_denied/pipeline/sse_streamer.py b/src/gps_denied/pipeline/sse_streamer.py new file mode 100644 index 0000000..10a0699 --- /dev/null +++ b/src/gps_denied/pipeline/sse_streamer.py @@ -0,0 +1,164 @@ +"""SSE Event Streamer (Component F15).""" + +from __future__ import annotations + +import asyncio +import json +from collections import defaultdict + +from gps_denied.schemas.events import ( + FlightCompletedEvent, + FrameProcessedEvent, + SearchExpandedEvent, + SSEEventType, + SSEMessage, + UserInputNeededEvent, +) + + +class SSEEventStreamer: + """Manages real-time SSE connections and event broadcasting.""" + + def __init__(self) -> None: + # Map: flight_id -> Dict[client_id, asyncio.Queue] + self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict) + + def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]: + """Create a new event queue for a client.""" + q: asyncio.Queue[SSEMessage | None] = asyncio.Queue() + self._streams[flight_id][client_id] = q + return q + + def close_stream(self, flight_id: str, client_id: str) -> None: + """Close a client stream by putting a sentinel and removing the queue.""" + if flight_id in self._streams and client_id in self._streams[flight_id]: + q = self._streams[flight_id].pop(client_id) + if not self._streams[flight_id]: + del self._streams[flight_id] + # Put None to signal generator exit + try: + q.put_nowait(None) + except asyncio.QueueFull: + pass + + def get_active_connections(self, flight_id: str) -> int: + return len(self._streams.get(flight_id, {})) + + def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool: + """Broadcast a message to all clients subscribed to flight_id.""" + if flight_id not in self._streams or not self._streams[flight_id]: + return False + + for q in self._streams[flight_id].values(): + try: + q.put_nowait(msg) + except asyncio.QueueFull: + pass # Drop if queue is full rather than blocking + return True + + # ── Business Event Senders ──────────────────────────────────────────────── + + def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool: + msg = SSEMessage( + event=SSEEventType.FRAME_PROCESSED, + data=event_data.model_dump(mode="json"), + id=f"frame_{event_data.frame_id}", + ) + return self._broadcast(flight_id, msg) + + def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool: + msg = SSEMessage( + event=SSEEventType.FRAME_REFINED, + data=event_data.model_dump(mode="json"), + id=f"refine_{event_data.frame_id}", + ) + return self._broadcast(flight_id, msg) + + def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool: + msg = SSEMessage( + event=SSEEventType.SEARCH_EXPANDED, + data=event_data.model_dump(mode="json"), + ) + return self._broadcast(flight_id, msg) + + def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool: + msg = SSEMessage( + event=SSEEventType.USER_INPUT_NEEDED, + data=event_data.model_dump(mode="json"), + ) + return self._broadcast(flight_id, msg) + + def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool: + msg = SSEMessage( + event=SSEEventType.FLIGHT_COMPLETED, + data=event_data.model_dump(mode="json"), + ) + return self._broadcast(flight_id, msg) + + def send_heartbeat(self, flight_id: str) -> bool: + # sse_starlette uses empty string or comment for heartbeat, + # but we can just send an SSEMessage object that parses as empty event + if flight_id not in self._streams: + return False + + # Manually sending a comment via the generator is tricky with strict SSEMessage schema + # but we'll handle this in the stream generator directly + return True + + # ── Generic event dispatcher (used by processor.process_frame) ────────── + + async def push_event(self, flight_id: str, event_type: str, data: dict) -> None: + """Dispatch a generic event to all clients for a flight. + + Maps event_type strings to typed SSE events: + "frame_result" → FrameProcessedEvent + "refinement" → FrameProcessedEvent (refined) + Other → raw broadcast via SSEMessage + """ + if event_type == "frame_result": + evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data + self.send_frame_result(flight_id, evt) + elif event_type == "refinement": + evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data + self.send_refinement(flight_id, evt) + else: + msg = SSEMessage( + event=SSEEventType.FRAME_PROCESSED, + data=data, + id=str(data.get("frame_id", "")), + ) + self._broadcast(flight_id, msg) + + # ── Stream Generator ────────────────────────────────────────────────────── + + async def stream_generator(self, flight_id: str, client_id: str): + """Yields dicts for sse_starlette EventSourceResponse.""" + q = self.create_stream(flight_id, client_id) + + # Send an immediate connection accepted ping + yield {"event": "connected", "data": "connected"} + + try: + while True: + # Wait for next event or send heartbeat every 15s + try: + msg = await asyncio.wait_for(q.get(), timeout=15.0) + if msg is None: + # Sentinel for clean shutdown + break + + # Yield dict format for sse_starlette + yield { + "event": msg.event.value, + "id": msg.id if msg.id else "", + "data": json.dumps(msg.data) + } + + except asyncio.TimeoutError: + # Heartbeat format for sse_starlette (empty string generates a comment) + yield {"event": "heartbeat", "data": "ping"} + + except asyncio.CancelledError: + pass # Client disconnected + finally: + self.close_stream(flight_id, client_id) diff --git a/src/gps_denied/testing/benchmark.py b/src/gps_denied/testing/benchmark.py new file mode 100644 index 0000000..0645bc3 --- /dev/null +++ b/src/gps_denied/testing/benchmark.py @@ -0,0 +1,371 @@ +"""Accuracy Benchmark (Phase 7). + +Provides: + - SyntheticTrajectory — generates a realistic fixed-wing UAV flight path + with ground-truth GPS + noisy sensor data. + - AccuracyBenchmark — replays a trajectory through the ESKF pipeline + and computes position-error statistics. + +Acceptance criteria (from solution.md): + AC-PERF-1: 80 % of frames within 50 m of ground truth. + AC-PERF-2: 60 % of frames within 20 m of ground truth. + AC-PERF-3: End-to-end per-frame latency < 400 ms. + AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m. +""" + +from __future__ import annotations + +import math +import time +from dataclasses import dataclass, field +from typing import Callable, Optional + +import numpy as np + +from gps_denied.core.coordinates import CoordinateTransformer +from gps_denied.core.eskf import ESKF +from gps_denied.schemas import GPSPoint +from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement + +# --------------------------------------------------------------------------- +# Synthetic trajectory +# --------------------------------------------------------------------------- + +@dataclass +class TrajectoryFrame: + """One simulated camera frame with ground-truth and noisy sensor data.""" + frame_id: int + timestamp: float + true_position_enu: np.ndarray # (3,) East, North, Up in metres + true_gps: GPSPoint # WGS84 from true ENU + imu_measurements: list[IMUMeasurement] # High-rate IMU between frames + vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,) + vo_tracking_good: bool = True + + +@dataclass +class SyntheticTrajectoryConfig: + """Parameters for trajectory generation.""" + # Origin (mission start) + origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0)) + altitude_m: float = 600.0 # Constant AGL altitude (m) + # UAV speed and heading + speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing) + heading_deg: float = 45.0 # Initial heading (degrees CW from North) + camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz) + imu_hz: float = 200.0 # IMU sample rate + num_frames: int = 50 # Number of camera frames to simulate + # Noise parameters + vo_noise_m: float = 0.5 # VO translation noise (sigma, metres) + imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²) + imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s) + # Failure injection + vo_failure_frames: list[int] = field(default_factory=list) + # Waypoints for heading changes (ENU East, North metres from origin) + waypoints_enu: list[tuple[float, float]] = field(default_factory=list) + + +class SyntheticTrajectory: + """Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors.""" + + def __init__(self, config: SyntheticTrajectoryConfig | None = None): + self.config = config or SyntheticTrajectoryConfig() + self._coord = CoordinateTransformer() + self._flight_id = "__synthetic__" + self._coord.set_enu_origin(self._flight_id, self.config.origin) + + def generate(self) -> list[TrajectoryFrame]: + """Generate all trajectory frames.""" + cfg = self.config + dt_camera = 1.0 / cfg.camera_fps + dt_imu = 1.0 / cfg.imu_hz + imu_steps = int(dt_camera * cfg.imu_hz) + + frames: list[TrajectoryFrame] = [] + pos = np.array([0.0, 0.0, cfg.altitude_m]) + vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps) + prev_pos = pos.copy() + t = time.time() + + waypoints = list(cfg.waypoints_enu) # copy + + for fid in range(cfg.num_frames): + # --- Waypoint steering --- + if waypoints: + wp_e, wp_n = waypoints[0] + to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0]) + dist_wp = np.linalg.norm(to_wp[:2]) + if dist_wp < cfg.speed_mps * dt_camera: + waypoints.pop(0) + else: + heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y + vel = np.array([ + cfg.speed_mps * math.sin(heading_rad), + cfg.speed_mps * math.cos(heading_rad), + 0.0, + ]) + + # --- Simulate IMU between frames --- + imu_list: list[IMUMeasurement] = [] + for step in range(imu_steps): + ts = t + step * dt_imu + # Body-frame acceleration (mostly gravity correction, small forward accel) + accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation + gyro_true = np.zeros(3) + imu = IMUMeasurement( + accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise, + gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise, + timestamp=ts, + ) + imu_list.append(imu) + + # --- Propagate position --- + prev_pos = pos.copy() + pos = pos + vel * dt_camera + t += dt_camera + + # --- True GPS from ENU position --- + true_gps = self._coord.enu_to_gps( + self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2])) + ) + + # --- VO measurement (relative displacement + noise) --- + true_displacement = pos - prev_pos + vo_tracking_good = fid not in cfg.vo_failure_frames + if vo_tracking_good: + noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m + noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z + else: + noisy_displacement = None + + frames.append(TrajectoryFrame( + frame_id=fid, + timestamp=t, + true_position_enu=pos.copy(), + true_gps=true_gps, + imu_measurements=imu_list, + vo_translation=noisy_displacement, + vo_tracking_good=vo_tracking_good, + )) + + return frames + + @staticmethod + def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray: + """Convert heading (degrees CW from North) to ENU velocity vector.""" + rad = math.radians(heading_deg) + return np.array([ + speed_mps * math.sin(rad), # East + speed_mps * math.cos(rad), # North + 0.0, # Up + ]) + + +# --------------------------------------------------------------------------- +# Accuracy Benchmark +# --------------------------------------------------------------------------- + +@dataclass +class BenchmarkResult: + """Position error statistics over a trajectory replay.""" + errors_m: list[float] # Per-frame horizontal error in metres + latencies_ms: list[float] # Per-frame process time in ms + frames_total: int + frames_with_good_estimate: int + + @property + def p80_error_m(self) -> float: + """80th percentile position error (metres).""" + return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf") + + @property + def p60_error_m(self) -> float: + """60th percentile position error (metres).""" + return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf") + + @property + def median_error_m(self) -> float: + """Median position error (metres).""" + return float(np.median(self.errors_m)) if self.errors_m else float("inf") + + @property + def max_error_m(self) -> float: + return float(max(self.errors_m)) if self.errors_m else float("inf") + + @property + def p95_latency_ms(self) -> float: + """95th percentile frame latency (ms).""" + return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf") + + @property + def pct_within_50m(self) -> float: + """Fraction of frames within 50 m error.""" + if not self.errors_m: + return 0.0 + return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m) + + @property + def pct_within_20m(self) -> float: + """Fraction of frames within 20 m error.""" + if not self.errors_m: + return 0.0 + return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m) + + def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]: + """Check all solution.md acceptance criteria. + + Returns (overall_pass, per_criterion_dict). + """ + checks = { + "AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80, + "AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60, + "AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0, + } + overall = all(checks.values()) + return overall, checks + + def summary(self) -> str: + overall, checks = self.passes_acceptance_criteria() + lines = [ + f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}", + f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m " + f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m", + f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%", + f"Latency p95: {self.p95_latency_ms:.1f}ms", + "", + "Acceptance criteria:", + ] + for criterion, passed in checks.items(): + lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}") + lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}") + return "\n".join(lines) + + +class AccuracyBenchmark: + """Replays a SyntheticTrajectory through the ESKF and measures accuracy. + + The benchmark uses only the ESKF (no full FlightProcessor) for speed. + Satellite corrections are injected optionally via sat_correction_fn. + """ + + def __init__( + self, + eskf_config: ESKFConfig | None = None, + sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None, + ): + """ + Args: + eskf_config: ESKF tuning parameters. + sat_correction_fn: Optional callback(frame) → ENU position or None. + Called on keyframes to inject satellite corrections. + If None, no satellite corrections are applied. + """ + self.eskf_config = eskf_config or ESKFConfig() + self.sat_correction_fn = sat_correction_fn + + def run( + self, + trajectory: list[TrajectoryFrame], + origin: GPSPoint, + satellite_keyframe_interval: int = 7, + ) -> BenchmarkResult: + """Replay trajectory frames through ESKF, collect errors and latencies. + + Args: + trajectory: List of TrajectoryFrame (from SyntheticTrajectory). + origin: WGS84 reference origin for ENU. + satellite_keyframe_interval: Apply satellite correction every N frames. + """ + coord = CoordinateTransformer() + flight_id = "__benchmark__" + coord.set_enu_origin(flight_id, origin) + + eskf = ESKF(self.eskf_config) + # Init at origin with HIGH uncertainty + eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]), + trajectory[0].timestamp) + + errors_m: list[float] = [] + latencies_ms: list[float] = [] + frames_with_estimate = 0 + + for frame in trajectory: + t_frame_start = time.perf_counter() + + # --- IMU prediction --- + for imu in frame.imu_measurements: + eskf.predict(imu) + + # --- VO update --- + if frame.vo_tracking_good and frame.vo_translation is not None: + dt_vo = 1.0 / 0.7 # camera interval + eskf.update_vo(frame.vo_translation, dt_vo) + + # --- Satellite update (keyframes) --- + if frame.frame_id % satellite_keyframe_interval == 0: + sat_pos_enu: Optional[np.ndarray] = None + if self.sat_correction_fn is not None: + sat_pos_enu = self.sat_correction_fn(frame) + else: + # Default: inject ground-truth position + realistic noise + noise_m = 10.0 + sat_pos_enu = ( + frame.true_position_enu[:3] + + np.random.randn(3) * noise_m + ) + sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude + + if sat_pos_enu is not None: + # Tell ESKF the measurement noise matches what we inject + eskf.update_satellite(sat_pos_enu, noise_meters=noise_m) + + latency_ms = (time.perf_counter() - t_frame_start) * 1000.0 + latencies_ms.append(latency_ms) + + # --- Compute horizontal error vs ground truth --- + if eskf.initialized and eskf._nominal_state is not None: + est_pos = eskf._nominal_state["position"] + true_pos = frame.true_position_enu + horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2])) + errors_m.append(horiz_error) + frames_with_estimate += 1 + else: + errors_m.append(float("inf")) + + return BenchmarkResult( + errors_m=errors_m, + latencies_ms=latencies_ms, + frames_total=len(trajectory), + frames_with_good_estimate=frames_with_estimate, + ) + + def run_vo_drift_test( + self, + trajectory_length_m: float = 1000.0, + speed_mps: float = 20.0, + ) -> float: + """Measure VO drift over a straight segment with NO satellite correction. + + Returns final horizontal position error in metres. + Per solution.md, this should be < 100m over 1km. + """ + fps = 0.7 + num_frames = max(10, int(trajectory_length_m / speed_mps * fps)) + cfg = SyntheticTrajectoryConfig( + speed_mps=speed_mps, + heading_deg=0.0, # straight North + camera_fps=fps, + num_frames=num_frames, + vo_noise_m=0.3, # cuVSLAM-grade VO noise + ) + traj_gen = SyntheticTrajectory(cfg) + frames = traj_gen.generate() + + # No satellite corrections + benchmark_no_sat = AccuracyBenchmark( + eskf_config=self.eskf_config, + sat_correction_fn=lambda _: None, # suppress all satellite updates + ) + result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999) + # Return final-frame error + return result.errors_m[-1] if result.errors_m else float("inf")