mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-21 08:31:13 +00:00
refactor(01-07): factor_graph, pipeline pkg, testing/benchmark, Protocol ABCs
- Create core/factor_graph.py: IFactorGraphOptimizer converted to Protocol
- Shim core/graph.py to re-export from core/factor_graph
- Create pipeline/ package: orchestrator, image_input, result_manager, sse_streamer
- Shim core/{processor,pipeline,results,sse}.py to re-export from pipeline/
- Create testing/benchmark.py; shim core/benchmark.py
- Convert IRouteChunkManager, IFailureRecoveryCoordinator, IModelManager, IImageMatcher to Protocol
- Update pyproject.toml ruff per-file-ignores to new paths
- All 216 tests pass (regression floor maintained)
This commit is contained in:
+2
-2
@@ -42,8 +42,8 @@ line-length = 120
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# Abstract interfaces have long method signatures — allow up to 170
|
||||
"src/gps_denied/core/graph.py" = ["E501"]
|
||||
"src/gps_denied/core/metric.py" = ["E501"]
|
||||
"src/gps_denied/core/factor_graph.py" = ["E501"]
|
||||
"src/gps_denied/components/satellite_matcher/metric_refinement.py" = ["E501"]
|
||||
"src/gps_denied/core/chunk_manager.py" = ["E501"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
@@ -1,371 +1,8 @@
|
||||
"""Accuracy Benchmark (Phase 7).
|
||||
|
||||
Provides:
|
||||
- SyntheticTrajectory — generates a realistic fixed-wing UAV flight path
|
||||
with ground-truth GPS + noisy sensor data.
|
||||
- AccuracyBenchmark — replays a trajectory through the ESKF pipeline
|
||||
and computes position-error statistics.
|
||||
|
||||
Acceptance criteria (from solution.md):
|
||||
AC-PERF-1: 80 % of frames within 50 m of ground truth.
|
||||
AC-PERF-2: 60 % of frames within 20 m of ground truth.
|
||||
AC-PERF-3: End-to-end per-frame latency < 400 ms.
|
||||
AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.coordinates import CoordinateTransformer
|
||||
from gps_denied.core.eskf import ESKF
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthetic trajectory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TrajectoryFrame:
|
||||
"""One simulated camera frame with ground-truth and noisy sensor data."""
|
||||
frame_id: int
|
||||
timestamp: float
|
||||
true_position_enu: np.ndarray # (3,) East, North, Up in metres
|
||||
true_gps: GPSPoint # WGS84 from true ENU
|
||||
imu_measurements: list[IMUMeasurement] # High-rate IMU between frames
|
||||
vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,)
|
||||
vo_tracking_good: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyntheticTrajectoryConfig:
|
||||
"""Parameters for trajectory generation."""
|
||||
# Origin (mission start)
|
||||
origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0))
|
||||
altitude_m: float = 600.0 # Constant AGL altitude (m)
|
||||
# UAV speed and heading
|
||||
speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing)
|
||||
heading_deg: float = 45.0 # Initial heading (degrees CW from North)
|
||||
camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz)
|
||||
imu_hz: float = 200.0 # IMU sample rate
|
||||
num_frames: int = 50 # Number of camera frames to simulate
|
||||
# Noise parameters
|
||||
vo_noise_m: float = 0.5 # VO translation noise (sigma, metres)
|
||||
imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²)
|
||||
imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s)
|
||||
# Failure injection
|
||||
vo_failure_frames: list[int] = field(default_factory=list)
|
||||
# Waypoints for heading changes (ENU East, North metres from origin)
|
||||
waypoints_enu: list[tuple[float, float]] = field(default_factory=list)
|
||||
|
||||
|
||||
class SyntheticTrajectory:
|
||||
"""Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors."""
|
||||
|
||||
def __init__(self, config: SyntheticTrajectoryConfig | None = None):
|
||||
self.config = config or SyntheticTrajectoryConfig()
|
||||
self._coord = CoordinateTransformer()
|
||||
self._flight_id = "__synthetic__"
|
||||
self._coord.set_enu_origin(self._flight_id, self.config.origin)
|
||||
|
||||
def generate(self) -> list[TrajectoryFrame]:
|
||||
"""Generate all trajectory frames."""
|
||||
cfg = self.config
|
||||
dt_camera = 1.0 / cfg.camera_fps
|
||||
dt_imu = 1.0 / cfg.imu_hz
|
||||
imu_steps = int(dt_camera * cfg.imu_hz)
|
||||
|
||||
frames: list[TrajectoryFrame] = []
|
||||
pos = np.array([0.0, 0.0, cfg.altitude_m])
|
||||
vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps)
|
||||
prev_pos = pos.copy()
|
||||
t = time.time()
|
||||
|
||||
waypoints = list(cfg.waypoints_enu) # copy
|
||||
|
||||
for fid in range(cfg.num_frames):
|
||||
# --- Waypoint steering ---
|
||||
if waypoints:
|
||||
wp_e, wp_n = waypoints[0]
|
||||
to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0])
|
||||
dist_wp = np.linalg.norm(to_wp[:2])
|
||||
if dist_wp < cfg.speed_mps * dt_camera:
|
||||
waypoints.pop(0)
|
||||
else:
|
||||
heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y
|
||||
vel = np.array([
|
||||
cfg.speed_mps * math.sin(heading_rad),
|
||||
cfg.speed_mps * math.cos(heading_rad),
|
||||
0.0,
|
||||
])
|
||||
|
||||
# --- Simulate IMU between frames ---
|
||||
imu_list: list[IMUMeasurement] = []
|
||||
for step in range(imu_steps):
|
||||
ts = t + step * dt_imu
|
||||
# Body-frame acceleration (mostly gravity correction, small forward accel)
|
||||
accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation
|
||||
gyro_true = np.zeros(3)
|
||||
imu = IMUMeasurement(
|
||||
accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise,
|
||||
gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise,
|
||||
timestamp=ts,
|
||||
)
|
||||
imu_list.append(imu)
|
||||
|
||||
# --- Propagate position ---
|
||||
prev_pos = pos.copy()
|
||||
pos = pos + vel * dt_camera
|
||||
t += dt_camera
|
||||
|
||||
# --- True GPS from ENU position ---
|
||||
true_gps = self._coord.enu_to_gps(
|
||||
self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))
|
||||
)
|
||||
|
||||
# --- VO measurement (relative displacement + noise) ---
|
||||
true_displacement = pos - prev_pos
|
||||
vo_tracking_good = fid not in cfg.vo_failure_frames
|
||||
if vo_tracking_good:
|
||||
noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m
|
||||
noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z
|
||||
else:
|
||||
noisy_displacement = None
|
||||
|
||||
frames.append(TrajectoryFrame(
|
||||
frame_id=fid,
|
||||
timestamp=t,
|
||||
true_position_enu=pos.copy(),
|
||||
true_gps=true_gps,
|
||||
imu_measurements=imu_list,
|
||||
vo_translation=noisy_displacement,
|
||||
vo_tracking_good=vo_tracking_good,
|
||||
))
|
||||
|
||||
return frames
|
||||
|
||||
@staticmethod
|
||||
def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray:
|
||||
"""Convert heading (degrees CW from North) to ENU velocity vector."""
|
||||
rad = math.radians(heading_deg)
|
||||
return np.array([
|
||||
speed_mps * math.sin(rad), # East
|
||||
speed_mps * math.cos(rad), # North
|
||||
0.0, # Up
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Accuracy Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Position error statistics over a trajectory replay."""
|
||||
errors_m: list[float] # Per-frame horizontal error in metres
|
||||
latencies_ms: list[float] # Per-frame process time in ms
|
||||
frames_total: int
|
||||
frames_with_good_estimate: int
|
||||
|
||||
@property
|
||||
def p80_error_m(self) -> float:
|
||||
"""80th percentile position error (metres)."""
|
||||
return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def p60_error_m(self) -> float:
|
||||
"""60th percentile position error (metres)."""
|
||||
return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def median_error_m(self) -> float:
|
||||
"""Median position error (metres)."""
|
||||
return float(np.median(self.errors_m)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def max_error_m(self) -> float:
|
||||
return float(max(self.errors_m)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def p95_latency_ms(self) -> float:
|
||||
"""95th percentile frame latency (ms)."""
|
||||
return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf")
|
||||
|
||||
@property
|
||||
def pct_within_50m(self) -> float:
|
||||
"""Fraction of frames within 50 m error."""
|
||||
if not self.errors_m:
|
||||
return 0.0
|
||||
return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m)
|
||||
|
||||
@property
|
||||
def pct_within_20m(self) -> float:
|
||||
"""Fraction of frames within 20 m error."""
|
||||
if not self.errors_m:
|
||||
return 0.0
|
||||
return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m)
|
||||
|
||||
def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]:
|
||||
"""Check all solution.md acceptance criteria.
|
||||
|
||||
Returns (overall_pass, per_criterion_dict).
|
||||
"""
|
||||
checks = {
|
||||
"AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80,
|
||||
"AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60,
|
||||
"AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0,
|
||||
}
|
||||
overall = all(checks.values())
|
||||
return overall, checks
|
||||
|
||||
def summary(self) -> str:
|
||||
overall, checks = self.passes_acceptance_criteria()
|
||||
lines = [
|
||||
f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}",
|
||||
f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m "
|
||||
f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m",
|
||||
f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%",
|
||||
f"Latency p95: {self.p95_latency_ms:.1f}ms",
|
||||
"",
|
||||
"Acceptance criteria:",
|
||||
]
|
||||
for criterion, passed in checks.items():
|
||||
lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}")
|
||||
lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class AccuracyBenchmark:
|
||||
"""Replays a SyntheticTrajectory through the ESKF and measures accuracy.
|
||||
|
||||
The benchmark uses only the ESKF (no full FlightProcessor) for speed.
|
||||
Satellite corrections are injected optionally via sat_correction_fn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eskf_config: ESKFConfig | None = None,
|
||||
sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
eskf_config: ESKF tuning parameters.
|
||||
sat_correction_fn: Optional callback(frame) → ENU position or None.
|
||||
Called on keyframes to inject satellite corrections.
|
||||
If None, no satellite corrections are applied.
|
||||
"""
|
||||
self.eskf_config = eskf_config or ESKFConfig()
|
||||
self.sat_correction_fn = sat_correction_fn
|
||||
|
||||
def run(
|
||||
self,
|
||||
trajectory: list[TrajectoryFrame],
|
||||
origin: GPSPoint,
|
||||
satellite_keyframe_interval: int = 7,
|
||||
) -> BenchmarkResult:
|
||||
"""Replay trajectory frames through ESKF, collect errors and latencies.
|
||||
|
||||
Args:
|
||||
trajectory: List of TrajectoryFrame (from SyntheticTrajectory).
|
||||
origin: WGS84 reference origin for ENU.
|
||||
satellite_keyframe_interval: Apply satellite correction every N frames.
|
||||
"""
|
||||
coord = CoordinateTransformer()
|
||||
flight_id = "__benchmark__"
|
||||
coord.set_enu_origin(flight_id, origin)
|
||||
|
||||
eskf = ESKF(self.eskf_config)
|
||||
# Init at origin with HIGH uncertainty
|
||||
eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]),
|
||||
trajectory[0].timestamp)
|
||||
|
||||
errors_m: list[float] = []
|
||||
latencies_ms: list[float] = []
|
||||
frames_with_estimate = 0
|
||||
|
||||
for frame in trajectory:
|
||||
t_frame_start = time.perf_counter()
|
||||
|
||||
# --- IMU prediction ---
|
||||
for imu in frame.imu_measurements:
|
||||
eskf.predict(imu)
|
||||
|
||||
# --- VO update ---
|
||||
if frame.vo_tracking_good and frame.vo_translation is not None:
|
||||
dt_vo = 1.0 / 0.7 # camera interval
|
||||
eskf.update_vo(frame.vo_translation, dt_vo)
|
||||
|
||||
# --- Satellite update (keyframes) ---
|
||||
if frame.frame_id % satellite_keyframe_interval == 0:
|
||||
sat_pos_enu: Optional[np.ndarray] = None
|
||||
if self.sat_correction_fn is not None:
|
||||
sat_pos_enu = self.sat_correction_fn(frame)
|
||||
else:
|
||||
# Default: inject ground-truth position + realistic noise
|
||||
noise_m = 10.0
|
||||
sat_pos_enu = (
|
||||
frame.true_position_enu[:3]
|
||||
+ np.random.randn(3) * noise_m
|
||||
)
|
||||
sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude
|
||||
|
||||
if sat_pos_enu is not None:
|
||||
# Tell ESKF the measurement noise matches what we inject
|
||||
eskf.update_satellite(sat_pos_enu, noise_meters=noise_m)
|
||||
|
||||
latency_ms = (time.perf_counter() - t_frame_start) * 1000.0
|
||||
latencies_ms.append(latency_ms)
|
||||
|
||||
# --- Compute horizontal error vs ground truth ---
|
||||
if eskf.initialized and eskf._nominal_state is not None:
|
||||
est_pos = eskf._nominal_state["position"]
|
||||
true_pos = frame.true_position_enu
|
||||
horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2]))
|
||||
errors_m.append(horiz_error)
|
||||
frames_with_estimate += 1
|
||||
else:
|
||||
errors_m.append(float("inf"))
|
||||
|
||||
return BenchmarkResult(
|
||||
errors_m=errors_m,
|
||||
latencies_ms=latencies_ms,
|
||||
frames_total=len(trajectory),
|
||||
frames_with_good_estimate=frames_with_estimate,
|
||||
)
|
||||
|
||||
def run_vo_drift_test(
|
||||
self,
|
||||
trajectory_length_m: float = 1000.0,
|
||||
speed_mps: float = 20.0,
|
||||
) -> float:
|
||||
"""Measure VO drift over a straight segment with NO satellite correction.
|
||||
|
||||
Returns final horizontal position error in metres.
|
||||
Per solution.md, this should be < 100m over 1km.
|
||||
"""
|
||||
fps = 0.7
|
||||
num_frames = max(10, int(trajectory_length_m / speed_mps * fps))
|
||||
cfg = SyntheticTrajectoryConfig(
|
||||
speed_mps=speed_mps,
|
||||
heading_deg=0.0, # straight North
|
||||
camera_fps=fps,
|
||||
num_frames=num_frames,
|
||||
vo_noise_m=0.3, # cuVSLAM-grade VO noise
|
||||
)
|
||||
traj_gen = SyntheticTrajectory(cfg)
|
||||
frames = traj_gen.generate()
|
||||
|
||||
# No satellite corrections
|
||||
benchmark_no_sat = AccuracyBenchmark(
|
||||
eskf_config=self.eskf_config,
|
||||
sat_correction_fn=lambda _: None, # suppress all satellite updates
|
||||
)
|
||||
result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999)
|
||||
# Return final-frame error
|
||||
return result.errors_m[-1] if result.errors_m else float("inf")
|
||||
"""Legacy import path. Phase 1 shim — code lives in testing/benchmark.py."""
|
||||
from gps_denied.testing.benchmark import ( # noqa: F401
|
||||
AccuracyBenchmark,
|
||||
BenchmarkResult,
|
||||
SyntheticTrajectory,
|
||||
SyntheticTrajectoryConfig,
|
||||
TrajectoryFrame,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from gps_denied.core.graph import IFactorGraphOptimizer
|
||||
from gps_denied.schemas.chunk import ChunkHandle, ChunkStatus
|
||||
@@ -12,30 +11,25 @@ from gps_denied.schemas.metric import Sim3Transform
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IRouteChunkManager(ABC):
|
||||
@abstractmethod
|
||||
@runtime_checkable
|
||||
class IRouteChunkManager(Protocol):
|
||||
def create_new_chunk(self, flight_id: str, start_frame_id: int) -> ChunkHandle:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_active_chunk(self, flight_id: str) -> Optional[ChunkHandle]:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_all_chunks(self, flight_id: str) -> List[ChunkHandle]:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_frame_to_chunk(self, flight_id: str, frame_id: int) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update_chunk_status(self, flight_id: str, chunk_id: str, status: ChunkStatus) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def merge_chunks(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class RouteChunkManager(IRouteChunkManager):
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
"""Factor Graph Optimizer (Component F10)."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Protocol, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import gtsam
|
||||
HAS_GTSAM = True
|
||||
except ImportError:
|
||||
HAS_GTSAM = False
|
||||
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose
|
||||
from gps_denied.schemas.metric import Sim3Transform
|
||||
from gps_denied.schemas.vo import RelativePose
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IFactorGraphOptimizer(Protocol):
|
||||
"""GTSAM-based factor graph optimizer."""
|
||||
|
||||
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
...
|
||||
|
||||
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
||||
...
|
||||
|
||||
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
||||
...
|
||||
|
||||
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
...
|
||||
|
||||
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
||||
...
|
||||
|
||||
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
||||
...
|
||||
|
||||
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
||||
...
|
||||
|
||||
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
...
|
||||
|
||||
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
||||
...
|
||||
|
||||
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||
...
|
||||
|
||||
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
||||
...
|
||||
|
||||
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
||||
...
|
||||
|
||||
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
...
|
||||
|
||||
def delete_flight_graph(self, flight_id: str) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class FactorGraphOptimizer(IFactorGraphOptimizer):
|
||||
"""Implementation of F10 Factor Graph using GTSAM or Mock."""
|
||||
|
||||
def __init__(self, config: FactorGraphConfig):
|
||||
self.config = config
|
||||
# Keyed by flight_id
|
||||
self._flights_state: Dict[str, dict] = {}
|
||||
# Keyed by chunk_id
|
||||
self._chunks_state: Dict[str, dict] = {}
|
||||
# Per-flight ENU origin (set from first absolute GPS factor)
|
||||
self._enu_origins: Dict[str, GPSPoint] = {}
|
||||
|
||||
def _init_flight(self, flight_id: str):
|
||||
if flight_id not in self._flights_state:
|
||||
self._flights_state[flight_id] = {
|
||||
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
||||
"initial": gtsam.Values() if HAS_GTSAM else None,
|
||||
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
||||
"poses": {},
|
||||
"dirty": False
|
||||
}
|
||||
|
||||
def _init_chunk(self, chunk_id: str):
|
||||
if chunk_id not in self._chunks_state:
|
||||
self._chunks_state[chunk_id] = {
|
||||
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
||||
"initial": gtsam.Values() if HAS_GTSAM else None,
|
||||
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
||||
"poses": {},
|
||||
"dirty": False
|
||||
}
|
||||
|
||||
# ================== MOCK IMPLEMENTATION ====================
|
||||
# As GTSAM Python bindings can be extremely context-dependent and
|
||||
# require proper ENU translation logic, we use an advanced Mock
|
||||
# that satisfies the architectural design and typing for the backend.
|
||||
|
||||
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
# --- Mock: propagate position chain ---
|
||||
if frame_i in state["poses"]:
|
||||
prev_pose = state["poses"][frame_i]
|
||||
new_pos = prev_pose.position + relative_pose.translation
|
||||
state["poses"][frame_j] = Pose(
|
||||
frame_id=frame_j,
|
||||
position=new_pos,
|
||||
orientation=np.eye(3),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
covariance=np.eye(6),
|
||||
)
|
||||
state["dirty"] = True
|
||||
else:
|
||||
return False
|
||||
|
||||
# --- GTSAM: add BetweenFactorPose3 ---
|
||||
if HAS_GTSAM and state["graph"] is not None:
|
||||
try:
|
||||
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
||||
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
||||
key_i = gtsam.symbol("x", frame_i)
|
||||
key_j = gtsam.symbol("x", frame_j)
|
||||
t = relative_pose.translation
|
||||
between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2])))
|
||||
state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise))
|
||||
if not state["initial"].exists(key_j):
|
||||
if state["initial"].exists(key_i):
|
||||
prev = state["initial"].atPose3(key_i)
|
||||
pt = prev.translation()
|
||||
new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2])
|
||||
else:
|
||||
new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))
|
||||
state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t))
|
||||
except Exception as exc:
|
||||
logger.debug("GTSAM add_relative_factor failed: %s", exc)
|
||||
|
||||
return True
|
||||
|
||||
def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray:
|
||||
"""Convert GPS to local ENU using per-flight origin."""
|
||||
origin = self._enu_origins.get(flight_id)
|
||||
if origin is None:
|
||||
# First GPS factor sets the origin
|
||||
self._enu_origins[flight_id] = gps
|
||||
return np.zeros(3)
|
||||
enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat))
|
||||
enu_y = (gps.lat - origin.lat) * 111000
|
||||
return np.array([enu_x, enu_y, 0.0])
|
||||
|
||||
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
enu = self._gps_to_enu(flight_id, gps)
|
||||
|
||||
# --- Mock: update pose position ---
|
||||
if frame_id in state["poses"]:
|
||||
state["poses"][frame_id].position = enu
|
||||
state["dirty"] = True
|
||||
else:
|
||||
return False
|
||||
|
||||
# --- GTSAM: add PriorFactorPose3 ---
|
||||
if HAS_GTSAM and state["graph"] is not None:
|
||||
try:
|
||||
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
||||
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
||||
key = gtsam.symbol("x", frame_id)
|
||||
prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2])))
|
||||
state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise))
|
||||
if not state["initial"].exists(key):
|
||||
state["initial"].insert(key, prior)
|
||||
except Exception as exc:
|
||||
logger.debug("GTSAM add_absolute_factor failed: %s", exc)
|
||||
|
||||
return True
|
||||
|
||||
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
if frame_id in state["poses"]:
|
||||
state["poses"][frame_id].position = np.array([
|
||||
state["poses"][frame_id].position[0],
|
||||
state["poses"][frame_id].position[1],
|
||||
altitude,
|
||||
])
|
||||
state["dirty"] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
# --- PIPE-03: Real GTSAM ISAM2 update when available ---
|
||||
if HAS_GTSAM and state["dirty"] and state["graph"] is not None:
|
||||
try:
|
||||
state["isam"].update(state["graph"], state["initial"])
|
||||
estimate = state["isam"].calculateEstimate()
|
||||
for fid in list(state["poses"].keys()):
|
||||
key = gtsam.symbol("x", fid)
|
||||
if estimate.exists(key):
|
||||
pose = estimate.atPose3(key)
|
||||
t = pose.translation()
|
||||
state["poses"][fid].position = np.array([t[0], t[1], t[2]])
|
||||
state["poses"][fid].orientation = np.array(pose.rotation().matrix())
|
||||
# Reset for next incremental batch
|
||||
state["graph"] = gtsam.NonlinearFactorGraph()
|
||||
state["initial"] = gtsam.Values()
|
||||
except Exception as exc:
|
||||
logger.warning("GTSAM ISAM2 update failed: %s", exc)
|
||||
|
||||
state["dirty"] = False
|
||||
return OptimizationResult(
|
||||
converged=True,
|
||||
final_error=0.1,
|
||||
iterations_used=iterations,
|
||||
optimized_frames=list(state["poses"].keys()),
|
||||
mean_reprojection_error=0.5,
|
||||
)
|
||||
|
||||
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
||||
if flight_id not in self._flights_state:
|
||||
return {}
|
||||
return self._flights_state[flight_id]["poses"]
|
||||
|
||||
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
||||
return np.eye(6)
|
||||
|
||||
# ================== CHUNK OPERATIONS =======================
|
||||
|
||||
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
||||
self._init_chunk(chunk_id)
|
||||
state = self._chunks_state[chunk_id]
|
||||
|
||||
state["poses"][start_frame_id] = Pose(
|
||||
frame_id=start_frame_id,
|
||||
position=np.zeros(3),
|
||||
orientation=np.eye(3),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
covariance=np.eye(6)
|
||||
)
|
||||
return True
|
||||
|
||||
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return False
|
||||
|
||||
state = self._chunks_state[chunk_id]
|
||||
if frame_i in state["poses"]:
|
||||
prev_pose = state["poses"][frame_i]
|
||||
new_pos = prev_pose.position + relative_pose.translation
|
||||
|
||||
state["poses"][frame_j] = Pose(
|
||||
frame_id=frame_j,
|
||||
position=new_pos,
|
||||
orientation=np.eye(3),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
covariance=np.eye(6)
|
||||
)
|
||||
state["dirty"] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return False
|
||||
|
||||
state = self._chunks_state[chunk_id]
|
||||
if frame_id in state["poses"]:
|
||||
enu = self._gps_to_enu(flight_id, gps)
|
||||
state["poses"][frame_id].position = enu
|
||||
state["dirty"] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||
if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state:
|
||||
return False
|
||||
|
||||
new_state = self._chunks_state[new_chunk_id]
|
||||
main_state = self._chunks_state[main_chunk_id]
|
||||
|
||||
# Apply Sim(3) transform effectively by copying poses
|
||||
for f_id, p in new_state["poses"].items():
|
||||
# mock sim3 transform
|
||||
idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation
|
||||
|
||||
main_state["poses"][f_id] = Pose(
|
||||
frame_id=f_id,
|
||||
position=idx_pos,
|
||||
orientation=np.eye(3),
|
||||
timestamp=p.timestamp,
|
||||
covariance=p.covariance
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return {}
|
||||
return self._chunks_state[chunk_id]["poses"]
|
||||
|
||||
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0)
|
||||
|
||||
state = self._chunks_state[chunk_id]
|
||||
state["dirty"] = False
|
||||
|
||||
return OptimizationResult(
|
||||
converged=True,
|
||||
final_error=0.1,
|
||||
iterations_used=iterations,
|
||||
optimized_frames=list(state["poses"].keys()),
|
||||
mean_reprojection_error=0.5
|
||||
)
|
||||
|
||||
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
# Optimizes everything
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
state["dirty"] = False
|
||||
|
||||
return OptimizationResult(
|
||||
converged=True,
|
||||
final_error=0.1,
|
||||
iterations_used=iterations,
|
||||
optimized_frames=list(state["poses"].keys()),
|
||||
mean_reprojection_error=0.5
|
||||
)
|
||||
|
||||
def delete_flight_graph(self, flight_id: str) -> bool:
|
||||
removed = False
|
||||
if flight_id in self._flights_state:
|
||||
del self._flights_state[flight_id]
|
||||
removed = True
|
||||
self._enu_origins.pop(flight_id, None)
|
||||
return removed
|
||||
@@ -1,364 +1,5 @@
|
||||
"""Factor Graph Optimizer (Component F10)."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import gtsam
|
||||
HAS_GTSAM = True
|
||||
except ImportError:
|
||||
HAS_GTSAM = False
|
||||
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose
|
||||
from gps_denied.schemas.metric import Sim3Transform
|
||||
from gps_denied.schemas.vo import RelativePose
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IFactorGraphOptimizer(ABC):
|
||||
"""GTSAM-based factor graph optimizer."""
|
||||
|
||||
@abstractmethod
|
||||
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_flight_graph(self, flight_id: str) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class FactorGraphOptimizer(IFactorGraphOptimizer):
|
||||
"""Implementation of F10 Factor Graph using GTSAM or Mock."""
|
||||
|
||||
def __init__(self, config: FactorGraphConfig):
|
||||
self.config = config
|
||||
# Keyed by flight_id
|
||||
self._flights_state: Dict[str, dict] = {}
|
||||
# Keyed by chunk_id
|
||||
self._chunks_state: Dict[str, dict] = {}
|
||||
# Per-flight ENU origin (set from first absolute GPS factor)
|
||||
self._enu_origins: Dict[str, GPSPoint] = {}
|
||||
|
||||
def _init_flight(self, flight_id: str):
|
||||
if flight_id not in self._flights_state:
|
||||
self._flights_state[flight_id] = {
|
||||
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
||||
"initial": gtsam.Values() if HAS_GTSAM else None,
|
||||
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
||||
"poses": {},
|
||||
"dirty": False
|
||||
}
|
||||
|
||||
def _init_chunk(self, chunk_id: str):
|
||||
if chunk_id not in self._chunks_state:
|
||||
self._chunks_state[chunk_id] = {
|
||||
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
||||
"initial": gtsam.Values() if HAS_GTSAM else None,
|
||||
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
||||
"poses": {},
|
||||
"dirty": False
|
||||
}
|
||||
|
||||
# ================== MOCK IMPLEMENTATION ====================
|
||||
# As GTSAM Python bindings can be extremely context-dependent and
|
||||
# require proper ENU translation logic, we use an advanced Mock
|
||||
# that satisfies the architectural design and typing for the backend.
|
||||
|
||||
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
# --- Mock: propagate position chain ---
|
||||
if frame_i in state["poses"]:
|
||||
prev_pose = state["poses"][frame_i]
|
||||
new_pos = prev_pose.position + relative_pose.translation
|
||||
state["poses"][frame_j] = Pose(
|
||||
frame_id=frame_j,
|
||||
position=new_pos,
|
||||
orientation=np.eye(3),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
covariance=np.eye(6),
|
||||
)
|
||||
state["dirty"] = True
|
||||
else:
|
||||
return False
|
||||
|
||||
# --- GTSAM: add BetweenFactorPose3 ---
|
||||
if HAS_GTSAM and state["graph"] is not None:
|
||||
try:
|
||||
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
||||
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
||||
key_i = gtsam.symbol("x", frame_i)
|
||||
key_j = gtsam.symbol("x", frame_j)
|
||||
t = relative_pose.translation
|
||||
between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2])))
|
||||
state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise))
|
||||
if not state["initial"].exists(key_j):
|
||||
if state["initial"].exists(key_i):
|
||||
prev = state["initial"].atPose3(key_i)
|
||||
pt = prev.translation()
|
||||
new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2])
|
||||
else:
|
||||
new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))
|
||||
state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t))
|
||||
except Exception as exc:
|
||||
logger.debug("GTSAM add_relative_factor failed: %s", exc)
|
||||
|
||||
return True
|
||||
|
||||
def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray:
|
||||
"""Convert GPS to local ENU using per-flight origin."""
|
||||
origin = self._enu_origins.get(flight_id)
|
||||
if origin is None:
|
||||
# First GPS factor sets the origin
|
||||
self._enu_origins[flight_id] = gps
|
||||
return np.zeros(3)
|
||||
enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat))
|
||||
enu_y = (gps.lat - origin.lat) * 111000
|
||||
return np.array([enu_x, enu_y, 0.0])
|
||||
|
||||
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
enu = self._gps_to_enu(flight_id, gps)
|
||||
|
||||
# --- Mock: update pose position ---
|
||||
if frame_id in state["poses"]:
|
||||
state["poses"][frame_id].position = enu
|
||||
state["dirty"] = True
|
||||
else:
|
||||
return False
|
||||
|
||||
# --- GTSAM: add PriorFactorPose3 ---
|
||||
if HAS_GTSAM and state["graph"] is not None:
|
||||
try:
|
||||
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
||||
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
||||
key = gtsam.symbol("x", frame_id)
|
||||
prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2])))
|
||||
state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise))
|
||||
if not state["initial"].exists(key):
|
||||
state["initial"].insert(key, prior)
|
||||
except Exception as exc:
|
||||
logger.debug("GTSAM add_absolute_factor failed: %s", exc)
|
||||
|
||||
return True
|
||||
|
||||
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
if frame_id in state["poses"]:
|
||||
state["poses"][frame_id].position = np.array([
|
||||
state["poses"][frame_id].position[0],
|
||||
state["poses"][frame_id].position[1],
|
||||
altitude,
|
||||
])
|
||||
state["dirty"] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
|
||||
# --- PIPE-03: Real GTSAM ISAM2 update when available ---
|
||||
if HAS_GTSAM and state["dirty"] and state["graph"] is not None:
|
||||
try:
|
||||
state["isam"].update(state["graph"], state["initial"])
|
||||
estimate = state["isam"].calculateEstimate()
|
||||
for fid in list(state["poses"].keys()):
|
||||
key = gtsam.symbol("x", fid)
|
||||
if estimate.exists(key):
|
||||
pose = estimate.atPose3(key)
|
||||
t = pose.translation()
|
||||
state["poses"][fid].position = np.array([t[0], t[1], t[2]])
|
||||
state["poses"][fid].orientation = np.array(pose.rotation().matrix())
|
||||
# Reset for next incremental batch
|
||||
state["graph"] = gtsam.NonlinearFactorGraph()
|
||||
state["initial"] = gtsam.Values()
|
||||
except Exception as exc:
|
||||
logger.warning("GTSAM ISAM2 update failed: %s", exc)
|
||||
|
||||
state["dirty"] = False
|
||||
return OptimizationResult(
|
||||
converged=True,
|
||||
final_error=0.1,
|
||||
iterations_used=iterations,
|
||||
optimized_frames=list(state["poses"].keys()),
|
||||
mean_reprojection_error=0.5,
|
||||
)
|
||||
|
||||
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
||||
if flight_id not in self._flights_state:
|
||||
return {}
|
||||
return self._flights_state[flight_id]["poses"]
|
||||
|
||||
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
||||
return np.eye(6)
|
||||
|
||||
# ================== CHUNK OPERATIONS =======================
|
||||
|
||||
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
||||
self._init_chunk(chunk_id)
|
||||
state = self._chunks_state[chunk_id]
|
||||
|
||||
state["poses"][start_frame_id] = Pose(
|
||||
frame_id=start_frame_id,
|
||||
position=np.zeros(3),
|
||||
orientation=np.eye(3),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
covariance=np.eye(6)
|
||||
)
|
||||
return True
|
||||
|
||||
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return False
|
||||
|
||||
state = self._chunks_state[chunk_id]
|
||||
if frame_i in state["poses"]:
|
||||
prev_pose = state["poses"][frame_i]
|
||||
new_pos = prev_pose.position + relative_pose.translation
|
||||
|
||||
state["poses"][frame_j] = Pose(
|
||||
frame_id=frame_j,
|
||||
position=new_pos,
|
||||
orientation=np.eye(3),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
covariance=np.eye(6)
|
||||
)
|
||||
state["dirty"] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return False
|
||||
|
||||
state = self._chunks_state[chunk_id]
|
||||
if frame_id in state["poses"]:
|
||||
enu = self._gps_to_enu(flight_id, gps)
|
||||
state["poses"][frame_id].position = enu
|
||||
state["dirty"] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||
if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state:
|
||||
return False
|
||||
|
||||
new_state = self._chunks_state[new_chunk_id]
|
||||
main_state = self._chunks_state[main_chunk_id]
|
||||
|
||||
# Apply Sim(3) transform effectively by copying poses
|
||||
for f_id, p in new_state["poses"].items():
|
||||
# mock sim3 transform
|
||||
idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation
|
||||
|
||||
main_state["poses"][f_id] = Pose(
|
||||
frame_id=f_id,
|
||||
position=idx_pos,
|
||||
orientation=np.eye(3),
|
||||
timestamp=p.timestamp,
|
||||
covariance=p.covariance
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return {}
|
||||
return self._chunks_state[chunk_id]["poses"]
|
||||
|
||||
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
||||
if chunk_id not in self._chunks_state:
|
||||
return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0)
|
||||
|
||||
state = self._chunks_state[chunk_id]
|
||||
state["dirty"] = False
|
||||
|
||||
return OptimizationResult(
|
||||
converged=True,
|
||||
final_error=0.1,
|
||||
iterations_used=iterations,
|
||||
optimized_frames=list(state["poses"].keys()),
|
||||
mean_reprojection_error=0.5
|
||||
)
|
||||
|
||||
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||
# Optimizes everything
|
||||
self._init_flight(flight_id)
|
||||
state = self._flights_state[flight_id]
|
||||
state["dirty"] = False
|
||||
|
||||
return OptimizationResult(
|
||||
converged=True,
|
||||
final_error=0.1,
|
||||
iterations_used=iterations,
|
||||
optimized_frames=list(state["poses"].keys()),
|
||||
mean_reprojection_error=0.5
|
||||
)
|
||||
|
||||
def delete_flight_graph(self, flight_id: str) -> bool:
|
||||
removed = False
|
||||
if flight_id in self._flights_state:
|
||||
del self._flights_state[flight_id]
|
||||
removed = True
|
||||
self._enu_origins.pop(flight_id, None)
|
||||
return removed
|
||||
"""Legacy import path. Phase 1 shim — code lives in core/factor_graph.py."""
|
||||
from gps_denied.core.factor_graph import ( # noqa: F401
|
||||
IFactorGraphOptimizer,
|
||||
FactorGraphOptimizer,
|
||||
)
|
||||
|
||||
@@ -10,8 +10,7 @@ file is available, otherwise falls back to Mock.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -31,26 +30,22 @@ def _is_jetson() -> bool:
|
||||
return os.path.exists("/sys/bus/platform/drivers/tegra-se-nvhost")
|
||||
|
||||
|
||||
class IModelManager(ABC):
|
||||
@abstractmethod
|
||||
@runtime_checkable
|
||||
class IModelManager(Protocol):
|
||||
def load_model(self, model_name: str, model_format: str) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_inference_engine(self, model_name: str) -> InferenceEngine:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def optimize_to_tensorrt(self, model_name: str, onnx_path: str) -> str:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def fallback_to_onnx(self, model_name: str) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def warmup_model(self, model_name: str) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class MockInferenceEngine(InferenceEngine):
|
||||
|
||||
@@ -1,227 +1,6 @@
|
||||
"""Image Input Pipeline (Component F05)."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.schemas.image import (
|
||||
ImageBatch,
|
||||
ImageData,
|
||||
ImageMetadata,
|
||||
ProcessedBatch,
|
||||
ProcessingStatus,
|
||||
ValidationResult,
|
||||
"""Legacy import path. Phase 1 shim — code lives in pipeline/image_input.py."""
|
||||
from gps_denied.pipeline.image_input import ( # noqa: F401
|
||||
ImageInputPipeline,
|
||||
QueueFullError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
|
||||
class QueueFullError(Exception):
|
||||
pass
|
||||
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ImageInputPipeline:
|
||||
"""Manages ingestion, disk storage, and queuing of UAV image batches."""
|
||||
|
||||
def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50):
|
||||
self.storage_dir = storage_dir
|
||||
# flight_id -> asyncio.Queue of ImageBatch
|
||||
self._queues: dict[str, asyncio.Queue] = {}
|
||||
self.max_queue_size = max_queue_size
|
||||
|
||||
# In-memory tracking (in a real system, sync this with DB)
|
||||
self._status: dict[str, dict] = {}
|
||||
# Exact sequence → filename mapping (VO-05: no substring collision)
|
||||
self._sequence_map: dict[str, dict[int, str]] = {}
|
||||
|
||||
def _get_queue(self, flight_id: str) -> asyncio.Queue:
|
||||
if flight_id not in self._queues:
|
||||
self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size)
|
||||
return self._queues[flight_id]
|
||||
|
||||
def _init_status(self, flight_id: str):
|
||||
if flight_id not in self._status:
|
||||
self._status[flight_id] = {
|
||||
"total_images": 0,
|
||||
"processed_images": 0,
|
||||
"current_sequence": 1,
|
||||
}
|
||||
|
||||
def validate_batch(self, batch: ImageBatch) -> ValidationResult:
|
||||
"""Validates batch integrity and sequence continuity."""
|
||||
errors = []
|
||||
|
||||
num_images = len(batch.images)
|
||||
if num_images < 1:
|
||||
errors.append("Batch is empty")
|
||||
elif num_images > 100:
|
||||
errors.append("Batch too large")
|
||||
|
||||
if len(batch.filenames) != num_images:
|
||||
errors.append("Mismatch between filenames and images count")
|
||||
|
||||
# Naming convention ADxxxxxx.jpg or similar
|
||||
pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE)
|
||||
for fn in batch.filenames:
|
||||
if not pattern.match(fn):
|
||||
errors.append(f"Invalid filename: {fn}")
|
||||
break
|
||||
|
||||
if batch.start_sequence > batch.end_sequence:
|
||||
errors.append("Start sequence greater than end sequence")
|
||||
|
||||
return ValidationResult(valid=len(errors) == 0, errors=errors)
|
||||
|
||||
def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool:
|
||||
"""Queues a batch of images for processing."""
|
||||
val = self.validate_batch(batch)
|
||||
if not val.valid:
|
||||
raise ValidationError(f"Batch validation failed: {val.errors}")
|
||||
|
||||
q = self._get_queue(flight_id)
|
||||
if q.full():
|
||||
raise QueueFullError(f"Queue for flight {flight_id} is full")
|
||||
|
||||
q.put_nowait(batch)
|
||||
|
||||
self._init_status(flight_id)
|
||||
self._status[flight_id]["total_images"] += len(batch.images)
|
||||
|
||||
return True
|
||||
|
||||
async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None:
|
||||
"""Dequeues and processing the next batch."""
|
||||
q = self._get_queue(flight_id)
|
||||
if q.empty():
|
||||
return None
|
||||
|
||||
batch: ImageBatch = await q.get()
|
||||
|
||||
processed_images = []
|
||||
for i, raw_bytes in enumerate(batch.images):
|
||||
# Decode
|
||||
nparr = np.frombuffer(raw_bytes, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is None:
|
||||
continue # skip corrupted
|
||||
|
||||
seq = batch.start_sequence + i
|
||||
fn = batch.filenames[i]
|
||||
|
||||
h, w = img.shape[:2]
|
||||
meta = ImageMetadata(
|
||||
sequence=seq,
|
||||
filename=fn,
|
||||
dimensions=(w, h),
|
||||
file_size=len(raw_bytes),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
img_data = ImageData(
|
||||
flight_id=flight_id,
|
||||
sequence=seq,
|
||||
filename=fn,
|
||||
image=img,
|
||||
metadata=meta
|
||||
)
|
||||
processed_images.append(img_data)
|
||||
# VO-05: record exact sequence→filename mapping
|
||||
self._sequence_map.setdefault(flight_id, {})[seq] = fn
|
||||
|
||||
# Store to disk
|
||||
self.store_images(flight_id, processed_images)
|
||||
|
||||
self._status[flight_id]["processed_images"] += len(processed_images)
|
||||
q.task_done()
|
||||
|
||||
return ProcessedBatch(
|
||||
images=processed_images,
|
||||
batch_id=f"batch_{batch.batch_number}",
|
||||
start_sequence=batch.start_sequence,
|
||||
end_sequence=batch.end_sequence
|
||||
)
|
||||
|
||||
def store_images(self, flight_id: str, images: list[ImageData]) -> bool:
|
||||
"""Persists images to disk."""
|
||||
flight_dir = os.path.join(self.storage_dir, flight_id)
|
||||
os.makedirs(flight_dir, exist_ok=True)
|
||||
|
||||
for img in images:
|
||||
path = os.path.join(flight_dir, img.filename)
|
||||
cv2.imwrite(path, img.image)
|
||||
|
||||
return True
|
||||
|
||||
def get_next_image(self, flight_id: str) -> ImageData | None:
|
||||
"""Gets the next image in sequence for processing."""
|
||||
self._init_status(flight_id)
|
||||
seq = self._status[flight_id]["current_sequence"]
|
||||
|
||||
img = self.get_image_by_sequence(flight_id, seq)
|
||||
if img:
|
||||
self._status[flight_id]["current_sequence"] += 1
|
||||
|
||||
return img
|
||||
|
||||
def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None:
|
||||
"""Retrieves a specific image by sequence number (exact match — VO-05)."""
|
||||
flight_dir = os.path.join(self.storage_dir, flight_id)
|
||||
if not os.path.exists(flight_dir):
|
||||
return None
|
||||
|
||||
# Prefer the exact mapping built during process_next_batch
|
||||
fn = self._sequence_map.get(flight_id, {}).get(sequence)
|
||||
if fn:
|
||||
path = os.path.join(flight_dir, fn)
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
h, w = img.shape[:2]
|
||||
meta = ImageMetadata(
|
||||
sequence=sequence,
|
||||
filename=fn,
|
||||
dimensions=(w, h),
|
||||
file_size=os.path.getsize(path),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
return ImageData(flight_id, sequence, fn, img, meta)
|
||||
|
||||
# Fallback: scan directory for exact filename patterns
|
||||
# (handles images stored before this process started)
|
||||
for fn in os.listdir(flight_dir):
|
||||
base, _ = os.path.splitext(fn)
|
||||
# Accept only if the base name ends with exactly the padded sequence number
|
||||
if base.endswith(f"{sequence:06d}") or base == str(sequence):
|
||||
path = os.path.join(flight_dir, fn)
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
h, w = img.shape[:2]
|
||||
meta = ImageMetadata(
|
||||
sequence=sequence,
|
||||
filename=fn,
|
||||
dimensions=(w, h),
|
||||
file_size=os.path.getsize(path),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
return ImageData(flight_id, sequence, fn, img, meta)
|
||||
|
||||
return None
|
||||
|
||||
def get_processing_status(self, flight_id: str) -> ProcessingStatus:
|
||||
self._init_status(flight_id)
|
||||
s = self._status[flight_id]
|
||||
q = self._get_queue(flight_id)
|
||||
|
||||
return ProcessingStatus(
|
||||
flight_id=flight_id,
|
||||
total_images=s["total_images"],
|
||||
processed_images=s["processed_images"],
|
||||
current_sequence=s["current_sequence"],
|
||||
queued_batches=q.qsize(),
|
||||
processing_rate=0.0 # mock
|
||||
)
|
||||
|
||||
@@ -1,599 +1,6 @@
|
||||
"""Core Flight Processor — Full Processing Pipeline (Stage 10).
|
||||
|
||||
Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE.
|
||||
State Machine: NORMAL → LOST → RECOVERY → NORMAL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.eskf import ESKF
|
||||
from gps_denied.core.pipeline import ImageInputPipeline
|
||||
from gps_denied.core.results import ResultManager
|
||||
from gps_denied.core.sse import SSEEventStreamer
|
||||
from gps_denied.db.repository import FlightRepository
|
||||
from gps_denied.schemas import CameraParameters, GPSPoint
|
||||
from gps_denied.schemas.flight import (
|
||||
BatchMetadata,
|
||||
BatchResponse,
|
||||
BatchUpdateResponse,
|
||||
DeleteResponse,
|
||||
FlightCreateRequest,
|
||||
FlightDetailResponse,
|
||||
FlightResponse,
|
||||
FlightStatusResponse,
|
||||
ObjectGPSResponse,
|
||||
UpdateResponse,
|
||||
UserFixRequest,
|
||||
UserFixResponse,
|
||||
Waypoint,
|
||||
"""Legacy import path. Phase 1 shim — code lives in pipeline/orchestrator.py."""
|
||||
from gps_denied.pipeline.orchestrator import ( # noqa: F401
|
||||
FlightProcessor,
|
||||
TrackingState,
|
||||
FrameResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State Machine
|
||||
# ---------------------------------------------------------------------------
|
||||
class TrackingState(str, Enum):
|
||||
"""Processing state for a flight."""
|
||||
NORMAL = "normal"
|
||||
LOST = "lost"
|
||||
RECOVERY = "recovery"
|
||||
|
||||
|
||||
class FrameResult:
|
||||
"""Intermediate result of processing a single frame."""
|
||||
|
||||
def __init__(self, frame_id: int):
|
||||
self.frame_id = frame_id
|
||||
self.gps: Optional[GPSPoint] = None
|
||||
self.confidence: float = 0.0
|
||||
self.tracking_state: TrackingState = TrackingState.NORMAL
|
||||
self.vo_success: bool = False
|
||||
self.alignment_success: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlightProcessor
|
||||
# ---------------------------------------------------------------------------
|
||||
class FlightProcessor:
|
||||
"""Manages business logic, background processing, and frame orchestration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: FlightRepository,
|
||||
streamer: SSEEventStreamer,
|
||||
eskf_config=None,
|
||||
) -> None:
|
||||
self.repository = repository
|
||||
self.streamer = streamer
|
||||
self.result_manager = ResultManager(repository, streamer)
|
||||
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
|
||||
self._eskf_config = eskf_config # ESKFConfig or None → default
|
||||
|
||||
# Per-flight processing state
|
||||
self._flight_states: dict[str, TrackingState] = {}
|
||||
self._prev_images: dict[str, np.ndarray] = {} # previous frame cache
|
||||
self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera
|
||||
self._altitudes: dict[str, float] = {} # per-flight altitude (m)
|
||||
self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter
|
||||
|
||||
# Per-flight ESKF instances (PIPE-01/07)
|
||||
self._eskf: dict[str, ESKF] = {}
|
||||
|
||||
# Lazy-initialised component references (set via `attach_components`)
|
||||
self._vo = None # ISequentialVisualOdometry
|
||||
self._gpr = None # IGlobalPlaceRecognition
|
||||
self._metric = None # IMetricRefinement
|
||||
self._graph = None # IFactorGraphOptimizer
|
||||
self._recovery = None # IFailureRecoveryCoordinator
|
||||
self._chunk_mgr = None # IRouteChunkManager
|
||||
self._rotation = None # ImageRotationManager
|
||||
self._satellite = None # SatelliteDataManager (PIPE-02)
|
||||
self._coord = None # CoordinateTransformer (PIPE-02/06)
|
||||
self._mavlink = None # MAVLinkBridge (PIPE-07)
|
||||
|
||||
# ------ Dependency injection for core components ---------
|
||||
def attach_components(
|
||||
self,
|
||||
vo=None,
|
||||
gpr=None,
|
||||
metric=None,
|
||||
graph=None,
|
||||
recovery=None,
|
||||
chunk_mgr=None,
|
||||
rotation=None,
|
||||
satellite=None,
|
||||
coord=None,
|
||||
mavlink=None,
|
||||
):
|
||||
"""Attach pipeline components after construction (avoids circular deps)."""
|
||||
self._vo = vo
|
||||
self._gpr = gpr
|
||||
self._metric = metric
|
||||
self._graph = graph
|
||||
self._recovery = recovery
|
||||
self._chunk_mgr = chunk_mgr
|
||||
self._rotation = rotation
|
||||
self._satellite = satellite # PIPE-02: SatelliteDataManager
|
||||
self._coord = coord # PIPE-02/06: CoordinateTransformer
|
||||
self._mavlink = mavlink # PIPE-07: MAVLinkBridge
|
||||
|
||||
# ------ ESKF lifecycle helpers ----------------------------
|
||||
def _init_eskf_for_flight(
|
||||
self, flight_id: str, start_gps: GPSPoint, altitude: float
|
||||
) -> None:
|
||||
"""Create and initialize a per-flight ESKF instance."""
|
||||
if flight_id in self._eskf:
|
||||
return
|
||||
eskf = ESKF(config=self._eskf_config)
|
||||
if self._coord:
|
||||
try:
|
||||
e, n, _ = self._coord.gps_to_enu(flight_id, start_gps)
|
||||
eskf.initialize(np.array([e, n, altitude]), time.time())
|
||||
except Exception:
|
||||
eskf.initialize(np.zeros(3), time.time())
|
||||
else:
|
||||
eskf.initialize(np.zeros(3), time.time())
|
||||
self._eskf[flight_id] = eskf
|
||||
|
||||
def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]:
|
||||
"""Convert current ESKF ENU position to WGS84 GPS."""
|
||||
if not eskf.initialized or self._coord is None:
|
||||
return None
|
||||
try:
|
||||
pos = eskf.position
|
||||
return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2])))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# =========================================================
|
||||
# process_frame — central orchestration
|
||||
# =========================================================
|
||||
async def process_frame(
|
||||
self,
|
||||
flight_id: str,
|
||||
frame_id: int,
|
||||
image: np.ndarray,
|
||||
) -> FrameResult:
|
||||
"""
|
||||
Process a single UAV frame through the full pipeline.
|
||||
|
||||
State transitions:
|
||||
NORMAL — VO succeeds → ESKF VO update, attempt satellite fix
|
||||
LOST — VO failed → create new chunk, enter RECOVERY
|
||||
RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL
|
||||
|
||||
PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT
|
||||
PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection
|
||||
PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator
|
||||
PIPE-05: ImageRotationManager initialised on first frame
|
||||
PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state
|
||||
"""
|
||||
result = FrameResult(frame_id)
|
||||
state = self._flight_states.get(flight_id, TrackingState.NORMAL)
|
||||
eskf = self._eskf.get(flight_id)
|
||||
|
||||
_default_cam = CameraParameters(
|
||||
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
||||
resolution_width=640, resolution_height=480,
|
||||
)
|
||||
|
||||
# ---- PIPE-05: Initialise heading tracking on first frame ----
|
||||
if self._rotation and frame_id == 0:
|
||||
self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory
|
||||
|
||||
# ---- 1. Visual Odometry (frame-to-frame) ----
|
||||
vo_ok = False
|
||||
if self._vo and flight_id in self._prev_images:
|
||||
try:
|
||||
cam = self._flight_cameras.get(flight_id, _default_cam)
|
||||
rel_pose = self._vo.compute_relative_pose(
|
||||
self._prev_images[flight_id], image, cam
|
||||
)
|
||||
if rel_pose and rel_pose.tracking_good:
|
||||
vo_ok = True
|
||||
result.vo_success = True
|
||||
|
||||
if self._graph:
|
||||
self._graph.add_relative_factor(
|
||||
flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6)
|
||||
)
|
||||
|
||||
# PIPE-01: Feed VO relative displacement into ESKF
|
||||
if eskf and eskf.initialized:
|
||||
now = time.time()
|
||||
dt_vo = max(0.01, now - (eskf.last_timestamp or now))
|
||||
eskf.update_vo(rel_pose.translation, dt_vo)
|
||||
except Exception as exc:
|
||||
logger.warning("VO failed for frame %d: %s", frame_id, exc)
|
||||
|
||||
# Store current image for next frame
|
||||
self._prev_images[flight_id] = image
|
||||
|
||||
# ---- PIPE-04: Consecutive failure counter ----
|
||||
if not vo_ok and frame_id > 0:
|
||||
self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1
|
||||
else:
|
||||
self._failure_counts[flight_id] = 0
|
||||
|
||||
# ---- 2. State Machine transitions ----
|
||||
if state == TrackingState.NORMAL:
|
||||
if not vo_ok and frame_id > 0:
|
||||
state = TrackingState.LOST
|
||||
logger.info("Flight %s → LOST at frame %d", flight_id, frame_id)
|
||||
if self._recovery:
|
||||
self._recovery.handle_tracking_lost(flight_id, frame_id)
|
||||
|
||||
if state == TrackingState.LOST:
|
||||
state = TrackingState.RECOVERY
|
||||
|
||||
if state == TrackingState.RECOVERY:
|
||||
recovered = False
|
||||
if self._recovery and self._chunk_mgr:
|
||||
active_chunk = self._chunk_mgr.get_active_chunk(flight_id)
|
||||
if active_chunk:
|
||||
recovered = self._recovery.process_chunk_recovery(
|
||||
flight_id, active_chunk.chunk_id, [image]
|
||||
)
|
||||
if recovered:
|
||||
state = TrackingState.NORMAL
|
||||
result.alignment_success = True
|
||||
# PIPE-04: Reset failure count on successful recovery
|
||||
self._failure_counts[flight_id] = 0
|
||||
logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id)
|
||||
|
||||
# ---- 3. Satellite position fix (PIPE-01/02) ----
|
||||
if state == TrackingState.NORMAL and self._metric:
|
||||
sat_tile: Optional[np.ndarray] = None
|
||||
tile_bounds = None
|
||||
|
||||
# PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection)
|
||||
if self._satellite and eskf and eskf.initialized:
|
||||
gps_est = self._eskf_to_gps(flight_id, eskf)
|
||||
if gps_est:
|
||||
cov = eskf.covariance
|
||||
sigma_h = float(
|
||||
np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0)
|
||||
) if cov is not None else 30.0
|
||||
sigma_h = max(sigma_h, 5.0)
|
||||
try:
|
||||
tile_result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._satellite.fetch_tiles_for_position,
|
||||
gps_est, sigma_h, 18,
|
||||
)
|
||||
if tile_result:
|
||||
sat_tile, tile_bounds = tile_result
|
||||
except Exception as exc:
|
||||
logger.debug("Satellite tile fetch failed: %s", exc)
|
||||
|
||||
# Fallback: GPR candidate tile (mock image, real bounds)
|
||||
if sat_tile is None and self._gpr:
|
||||
try:
|
||||
candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1)
|
||||
if candidates:
|
||||
sat_tile = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||
tile_bounds = candidates[0].bounds
|
||||
except Exception as exc:
|
||||
logger.debug("GPR tile fallback failed: %s", exc)
|
||||
|
||||
if sat_tile is not None and tile_bounds is not None:
|
||||
try:
|
||||
align = self._metric.align_to_satellite(image, sat_tile, tile_bounds)
|
||||
if align and align.matched:
|
||||
result.gps = align.gps_center
|
||||
result.confidence = align.confidence
|
||||
result.alignment_success = True
|
||||
|
||||
if self._graph:
|
||||
self._graph.add_absolute_factor(
|
||||
flight_id, frame_id,
|
||||
align.gps_center, np.eye(6),
|
||||
is_user_anchor=False,
|
||||
)
|
||||
|
||||
# PIPE-01: ESKF satellite update — noise from RANSAC confidence
|
||||
if eskf and eskf.initialized and self._coord:
|
||||
try:
|
||||
e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center)
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
pos_enu = np.array([e, n, alt])
|
||||
noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence))
|
||||
eskf.update_satellite(pos_enu, noise_m)
|
||||
except Exception as exc:
|
||||
logger.debug("ESKF satellite update failed: %s", exc)
|
||||
except Exception as exc:
|
||||
logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc)
|
||||
|
||||
# ---- 4. Graph optimization (incremental) ----
|
||||
if self._graph:
|
||||
opt_result = self._graph.optimize(flight_id, iterations=5)
|
||||
logger.debug(
|
||||
"Optimization: converged=%s, error=%.4f",
|
||||
opt_result.converged, opt_result.final_error,
|
||||
)
|
||||
|
||||
# ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ----
|
||||
if self._mavlink and eskf and eskf.initialized:
|
||||
try:
|
||||
eskf_state = eskf.get_state()
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
self._mavlink.update_state(eskf_state, altitude_m=alt)
|
||||
except Exception as exc:
|
||||
logger.debug("MAVLink state push failed: %s", exc)
|
||||
|
||||
# ---- 5. Publish via SSE ----
|
||||
result.tracking_state = state
|
||||
self._flight_states[flight_id] = state
|
||||
await self._publish_frame_result(flight_id, result)
|
||||
return result
|
||||
|
||||
async def _publish_frame_result(self, flight_id: str, result: FrameResult):
|
||||
"""Emit SSE event for processed frame."""
|
||||
event_data = {
|
||||
"frame_id": result.frame_id,
|
||||
"tracking_state": result.tracking_state.value,
|
||||
"vo_success": result.vo_success,
|
||||
"alignment_success": result.alignment_success,
|
||||
"confidence": result.confidence,
|
||||
}
|
||||
if result.gps:
|
||||
event_data["lat"] = result.gps.lat
|
||||
event_data["lon"] = result.gps.lon
|
||||
|
||||
await self.streamer.push_event(
|
||||
flight_id, event_type="frame_result", data=event_data
|
||||
)
|
||||
|
||||
# =========================================================
|
||||
# Existing CRUD / REST helpers (unchanged from Stage 3-4)
|
||||
# =========================================================
|
||||
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
|
||||
flight = await self.repository.insert_flight(
|
||||
name=req.name,
|
||||
description=req.description,
|
||||
start_lat=req.start_gps.lat,
|
||||
start_lon=req.start_gps.lon,
|
||||
altitude=req.altitude,
|
||||
camera_params=req.camera_params.model_dump(),
|
||||
)
|
||||
# P0#2: Store camera params for process_frame VO calls
|
||||
self._flight_cameras[flight.id] = req.camera_params
|
||||
|
||||
for poly in req.geofences.polygons:
|
||||
await self.repository.insert_geofence(
|
||||
flight.id,
|
||||
nw_lat=poly.north_west.lat,
|
||||
nw_lon=poly.north_west.lon,
|
||||
se_lat=poly.south_east.lat,
|
||||
se_lon=poly.south_east.lon,
|
||||
)
|
||||
for w in req.rough_waypoints:
|
||||
await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon)
|
||||
|
||||
# Store per-flight altitude for ESKF/pixel projection
|
||||
self._altitudes[flight.id] = req.altitude or 100.0
|
||||
|
||||
# PIPE-02: Set ENU origin and initialise ESKF for this flight
|
||||
if self._coord:
|
||||
self._coord.set_enu_origin(flight.id, req.start_gps)
|
||||
self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0)
|
||||
|
||||
# Start MAVLink bridge for this flight (origin required for GPS_INPUT)
|
||||
if self._mavlink and not self._mavlink._running:
|
||||
try:
|
||||
asyncio.create_task(self._mavlink.start(req.start_gps))
|
||||
except Exception as exc:
|
||||
logger.warning("MAVLink bridge start failed: %s", exc)
|
||||
|
||||
return FlightResponse(
|
||||
flight_id=flight.id,
|
||||
status="prefetching",
|
||||
message="Flight created and prefetching started.",
|
||||
created_at=flight.created_at,
|
||||
)
|
||||
|
||||
async def get_flight(self, flight_id: str) -> FlightDetailResponse | None:
|
||||
flight = await self.repository.get_flight(flight_id)
|
||||
if not flight:
|
||||
return None
|
||||
wps = await self.repository.get_waypoints(flight_id)
|
||||
state = await self.repository.load_flight_state(flight_id)
|
||||
|
||||
waypoints = [
|
||||
Waypoint(
|
||||
id=w.id,
|
||||
lat=w.lat,
|
||||
lon=w.lon,
|
||||
altitude=w.altitude,
|
||||
confidence=w.confidence,
|
||||
timestamp=w.timestamp,
|
||||
refined=w.refined,
|
||||
)
|
||||
for w in wps
|
||||
]
|
||||
|
||||
status = state.status if state else "unknown"
|
||||
frames_processed = state.frames_processed if state else 0
|
||||
frames_total = state.frames_total if state else 0
|
||||
|
||||
from gps_denied.schemas import Geofences
|
||||
|
||||
return FlightDetailResponse(
|
||||
flight_id=flight.id,
|
||||
name=flight.name,
|
||||
description=flight.description,
|
||||
start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon),
|
||||
waypoints=waypoints,
|
||||
geofences=Geofences(polygons=[]),
|
||||
camera_params=flight.camera_params,
|
||||
altitude=flight.altitude,
|
||||
status=status,
|
||||
frames_processed=frames_processed,
|
||||
frames_total=frames_total,
|
||||
created_at=flight.created_at,
|
||||
updated_at=flight.updated_at,
|
||||
)
|
||||
|
||||
async def delete_flight(self, flight_id: str) -> DeleteResponse:
|
||||
deleted = await self.repository.delete_flight(flight_id)
|
||||
# P0#1: Cleanup in-memory state to prevent memory leaks
|
||||
self._cleanup_flight(flight_id)
|
||||
return DeleteResponse(deleted=deleted, flight_id=flight_id)
|
||||
|
||||
def _cleanup_flight(self, flight_id: str) -> None:
|
||||
"""Remove all in-memory state for a flight (prevents memory leaks)."""
|
||||
self._prev_images.pop(flight_id, None)
|
||||
self._flight_states.pop(flight_id, None)
|
||||
self._flight_cameras.pop(flight_id, None)
|
||||
self._altitudes.pop(flight_id, None)
|
||||
self._failure_counts.pop(flight_id, None)
|
||||
self._eskf.pop(flight_id, None)
|
||||
if self._graph:
|
||||
self._graph.delete_flight_graph(flight_id)
|
||||
|
||||
async def update_waypoint(
|
||||
self, flight_id: str, waypoint_id: str, waypoint: Waypoint
|
||||
) -> UpdateResponse:
|
||||
ok = await self.repository.update_waypoint(
|
||||
flight_id,
|
||||
waypoint_id,
|
||||
lat=waypoint.lat,
|
||||
lon=waypoint.lon,
|
||||
altitude=waypoint.altitude,
|
||||
confidence=waypoint.confidence,
|
||||
refined=waypoint.refined,
|
||||
)
|
||||
return UpdateResponse(updated=ok, waypoint_id=waypoint_id)
|
||||
|
||||
async def batch_update_waypoints(
|
||||
self, flight_id: str, waypoints: list[Waypoint]
|
||||
) -> BatchUpdateResponse:
|
||||
failed = []
|
||||
updated = 0
|
||||
for wp in waypoints:
|
||||
ok = await self.repository.update_waypoint(
|
||||
flight_id,
|
||||
wp.id,
|
||||
lat=wp.lat,
|
||||
lon=wp.lon,
|
||||
altitude=wp.altitude,
|
||||
confidence=wp.confidence,
|
||||
refined=wp.refined,
|
||||
)
|
||||
if ok:
|
||||
updated += 1
|
||||
else:
|
||||
failed.append(wp.id)
|
||||
return BatchUpdateResponse(
|
||||
success=(len(failed) == 0), updated_count=updated, failed_ids=failed
|
||||
)
|
||||
|
||||
async def queue_images(
|
||||
self, flight_id: str, metadata: BatchMetadata, file_count: int
|
||||
) -> BatchResponse:
|
||||
state = await self.repository.load_flight_state(flight_id)
|
||||
if state:
|
||||
total = state.frames_total + file_count
|
||||
await self.repository.save_flight_state(
|
||||
flight_id, frames_total=total, status="processing"
|
||||
)
|
||||
|
||||
next_seq = metadata.end_sequence + 1
|
||||
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
|
||||
return BatchResponse(
|
||||
accepted=True,
|
||||
sequences=seqs,
|
||||
next_expected=next_seq,
|
||||
message=f"Queued {file_count} images.",
|
||||
)
|
||||
|
||||
async def handle_user_fix(
|
||||
self, flight_id: str, req: UserFixRequest
|
||||
) -> UserFixResponse:
|
||||
await self.repository.save_flight_state(
|
||||
flight_id, blocked=False, status="processing"
|
||||
)
|
||||
|
||||
# Inject operator position into ESKF with high uncertainty (500m)
|
||||
eskf = self._eskf.get(flight_id)
|
||||
if eskf and eskf.initialized and self._coord:
|
||||
try:
|
||||
e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps)
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0)
|
||||
self._failure_counts[flight_id] = 0
|
||||
logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps)
|
||||
except Exception as exc:
|
||||
logger.warning("User fix ESKF injection failed: %s", exc)
|
||||
|
||||
return UserFixResponse(
|
||||
accepted=True, processing_resumed=True, message="Fix applied."
|
||||
)
|
||||
|
||||
async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None:
|
||||
state = await self.repository.load_flight_state(flight_id)
|
||||
if not state:
|
||||
return None
|
||||
return FlightStatusResponse(
|
||||
status=state.status,
|
||||
frames_processed=state.frames_processed,
|
||||
frames_total=state.frames_total,
|
||||
current_frame=state.current_frame,
|
||||
current_heading=None,
|
||||
blocked=state.blocked,
|
||||
search_grid_size=state.search_grid_size,
|
||||
created_at=state.created_at,
|
||||
updated_at=state.updated_at,
|
||||
)
|
||||
|
||||
async def convert_object_to_gps(
|
||||
self, flight_id: str, frame_id: int, pixel: tuple[float, float]
|
||||
) -> ObjectGPSResponse:
|
||||
# PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection
|
||||
gps: Optional[GPSPoint] = None
|
||||
eskf = self._eskf.get(flight_id)
|
||||
if self._coord and eskf and eskf.initialized:
|
||||
pos = eskf.position
|
||||
quat = eskf.quaternion
|
||||
cam = self._flight_cameras.get(flight_id, CameraParameters(
|
||||
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
||||
resolution_width=640, resolution_height=480,
|
||||
))
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
try:
|
||||
gps = self._coord.pixel_to_gps(
|
||||
flight_id,
|
||||
pixel,
|
||||
frame_pose={"position": pos},
|
||||
camera_params=cam,
|
||||
altitude=float(alt),
|
||||
quaternion=quat,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("pixel_to_gps failed: %s", exc)
|
||||
|
||||
# Fallback: return ESKF position projected to ground (no pixel shift)
|
||||
if gps is None and eskf:
|
||||
gps = self._eskf_to_gps(flight_id, eskf)
|
||||
|
||||
return ObjectGPSResponse(
|
||||
gps=gps or GPSPoint(lat=0.0, lon=0.0),
|
||||
accuracy_meters=5.0,
|
||||
frame_id=frame_id,
|
||||
pixel=pixel,
|
||||
)
|
||||
|
||||
async def stream_events(self, flight_id: str, client_id: str):
|
||||
"""Async generator for SSE stream."""
|
||||
async for event in self.streamer.stream_generator(flight_id, client_id):
|
||||
yield event
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Failure Recovery Coordinator (Component F11)."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Protocol, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -14,14 +13,13 @@ from gps_denied.schemas.chunk import ChunkStatus
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IFailureRecoveryCoordinator(ABC):
|
||||
@abstractmethod
|
||||
@runtime_checkable
|
||||
class IFailureRecoveryCoordinator(Protocol):
|
||||
def handle_tracking_lost(self, flight_id: str, current_frame_id: int) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def process_chunk_recovery(self, flight_id: str, chunk_id: str, images: List[np.ndarray]) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class FailureRecoveryCoordinator(IFailureRecoveryCoordinator):
|
||||
|
||||
@@ -1,73 +1,4 @@
|
||||
"""Result Manager (Component F14)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from gps_denied.core.sse import SSEEventStreamer
|
||||
from gps_denied.db.repository import FlightRepository
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.events import FrameProcessedEvent
|
||||
|
||||
|
||||
class ResultManager:
|
||||
"""Result consistency and publishing."""
|
||||
|
||||
def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None:
|
||||
self.repo = repo
|
||||
self.sse = sse
|
||||
|
||||
async def update_frame_result(
|
||||
self,
|
||||
flight_id: str,
|
||||
frame_id: int,
|
||||
gps_lat: float,
|
||||
gps_lon: float,
|
||||
altitude: float,
|
||||
heading: float,
|
||||
confidence: float,
|
||||
timestamp: datetime,
|
||||
refined: bool = False,
|
||||
) -> bool:
|
||||
"""Atomic DB update + SSE event publish."""
|
||||
|
||||
# 1. Update DB (in the repository these are auto-committing via flush,
|
||||
# but normally F03 would wrap in a single transaction).
|
||||
await self.repo.save_frame_result(
|
||||
flight_id,
|
||||
frame_id=frame_id,
|
||||
gps_lat=gps_lat,
|
||||
gps_lon=gps_lon,
|
||||
altitude=altitude,
|
||||
heading=heading,
|
||||
confidence=confidence,
|
||||
refined=refined,
|
||||
)
|
||||
|
||||
# Wait, the spec also wants Waypoints to be updated.
|
||||
# But image frames != waypoints. Waypoints are the planned route.
|
||||
# Actually in the spec it says: "Updates waypoint in waypoints table."
|
||||
# This implies updating the closest waypoint or a generated waypoint path.
|
||||
# We will follow the simplest form for now: update the waypoint if there is one corresponding.
|
||||
# Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map,
|
||||
# or we just skip unless specified.
|
||||
|
||||
# 2. Trigger SSE event
|
||||
evt = FrameProcessedEvent(
|
||||
frame_id=frame_id,
|
||||
gps=GPSPoint(lat=gps_lat, lon=gps_lon),
|
||||
altitude=altitude,
|
||||
confidence=confidence,
|
||||
heading=heading,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
if refined:
|
||||
self.sse.send_refinement(flight_id, evt)
|
||||
else:
|
||||
self.sse.send_frame_result(flight_id, evt)
|
||||
|
||||
return True
|
||||
|
||||
async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool:
|
||||
# Just delegates to SSE for waypoint updates, which is basically the frame result for UI
|
||||
pass
|
||||
"""Legacy import path. Phase 1 shim — code lives in pipeline/result_manager.py."""
|
||||
from gps_denied.pipeline.result_manager import ( # noqa: F401
|
||||
ResultManager,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import dataclasses
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -12,14 +12,14 @@ from gps_denied.schemas.rotation import HeadingHistory, RotationResult
|
||||
from gps_denied.schemas.satellite import TileBounds
|
||||
|
||||
|
||||
class IImageMatcher(ABC):
|
||||
@runtime_checkable
|
||||
class IImageMatcher(Protocol):
|
||||
"""Dependency injection interface for Metric Refinement."""
|
||||
@abstractmethod
|
||||
def align_to_satellite(
|
||||
self, uav_image: np.ndarray, satellite_tile: np.ndarray,
|
||||
tile_bounds: TileBounds,
|
||||
) -> RotationResult:
|
||||
pass
|
||||
...
|
||||
|
||||
|
||||
class ImageRotationManager:
|
||||
|
||||
+3
-163
@@ -1,164 +1,4 @@
|
||||
"""SSE Event Streamer (Component F15)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
from gps_denied.schemas.events import (
|
||||
FlightCompletedEvent,
|
||||
FrameProcessedEvent,
|
||||
SearchExpandedEvent,
|
||||
SSEEventType,
|
||||
SSEMessage,
|
||||
UserInputNeededEvent,
|
||||
"""Legacy import path. Phase 1 shim — code lives in pipeline/sse_streamer.py."""
|
||||
from gps_denied.pipeline.sse_streamer import ( # noqa: F401
|
||||
SSEEventStreamer,
|
||||
)
|
||||
|
||||
|
||||
class SSEEventStreamer:
|
||||
"""Manages real-time SSE connections and event broadcasting."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Map: flight_id -> Dict[client_id, asyncio.Queue]
|
||||
self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict)
|
||||
|
||||
def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]:
|
||||
"""Create a new event queue for a client."""
|
||||
q: asyncio.Queue[SSEMessage | None] = asyncio.Queue()
|
||||
self._streams[flight_id][client_id] = q
|
||||
return q
|
||||
|
||||
def close_stream(self, flight_id: str, client_id: str) -> None:
|
||||
"""Close a client stream by putting a sentinel and removing the queue."""
|
||||
if flight_id in self._streams and client_id in self._streams[flight_id]:
|
||||
q = self._streams[flight_id].pop(client_id)
|
||||
if not self._streams[flight_id]:
|
||||
del self._streams[flight_id]
|
||||
# Put None to signal generator exit
|
||||
try:
|
||||
q.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
def get_active_connections(self, flight_id: str) -> int:
|
||||
return len(self._streams.get(flight_id, {}))
|
||||
|
||||
def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool:
|
||||
"""Broadcast a message to all clients subscribed to flight_id."""
|
||||
if flight_id not in self._streams or not self._streams[flight_id]:
|
||||
return False
|
||||
|
||||
for q in self._streams[flight_id].values():
|
||||
try:
|
||||
q.put_nowait(msg)
|
||||
except asyncio.QueueFull:
|
||||
pass # Drop if queue is full rather than blocking
|
||||
return True
|
||||
|
||||
# ── Business Event Senders ────────────────────────────────────────────────
|
||||
|
||||
def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FRAME_PROCESSED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
id=f"frame_{event_data.frame_id}",
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FRAME_REFINED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
id=f"refine_{event_data.frame_id}",
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.SEARCH_EXPANDED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.USER_INPUT_NEEDED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FLIGHT_COMPLETED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_heartbeat(self, flight_id: str) -> bool:
|
||||
# sse_starlette uses empty string or comment for heartbeat,
|
||||
# but we can just send an SSEMessage object that parses as empty event
|
||||
if flight_id not in self._streams:
|
||||
return False
|
||||
|
||||
# Manually sending a comment via the generator is tricky with strict SSEMessage schema
|
||||
# but we'll handle this in the stream generator directly
|
||||
return True
|
||||
|
||||
# ── Generic event dispatcher (used by processor.process_frame) ──────────
|
||||
|
||||
async def push_event(self, flight_id: str, event_type: str, data: dict) -> None:
|
||||
"""Dispatch a generic event to all clients for a flight.
|
||||
|
||||
Maps event_type strings to typed SSE events:
|
||||
"frame_result" → FrameProcessedEvent
|
||||
"refinement" → FrameProcessedEvent (refined)
|
||||
Other → raw broadcast via SSEMessage
|
||||
"""
|
||||
if event_type == "frame_result":
|
||||
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
||||
self.send_frame_result(flight_id, evt)
|
||||
elif event_type == "refinement":
|
||||
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
||||
self.send_refinement(flight_id, evt)
|
||||
else:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FRAME_PROCESSED,
|
||||
data=data,
|
||||
id=str(data.get("frame_id", "")),
|
||||
)
|
||||
self._broadcast(flight_id, msg)
|
||||
|
||||
# ── Stream Generator ──────────────────────────────────────────────────────
|
||||
|
||||
async def stream_generator(self, flight_id: str, client_id: str):
|
||||
"""Yields dicts for sse_starlette EventSourceResponse."""
|
||||
q = self.create_stream(flight_id, client_id)
|
||||
|
||||
# Send an immediate connection accepted ping
|
||||
yield {"event": "connected", "data": "connected"}
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for next event or send heartbeat every 15s
|
||||
try:
|
||||
msg = await asyncio.wait_for(q.get(), timeout=15.0)
|
||||
if msg is None:
|
||||
# Sentinel for clean shutdown
|
||||
break
|
||||
|
||||
# Yield dict format for sse_starlette
|
||||
yield {
|
||||
"event": msg.event.value,
|
||||
"id": msg.id if msg.id else "",
|
||||
"data": json.dumps(msg.data)
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Heartbeat format for sse_starlette (empty string generates a comment)
|
||||
yield {"event": "heartbeat", "data": "ping"}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass # Client disconnected
|
||||
finally:
|
||||
self.close_stream(flight_id, client_id)
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Pipeline package: orchestrator + IO + composition root."""
|
||||
from .orchestrator import FlightProcessor
|
||||
from .image_input import ImageInputPipeline
|
||||
from .result_manager import ResultManager
|
||||
from .sse_streamer import SSEEventStreamer
|
||||
|
||||
Pipeline = FlightProcessor
|
||||
|
||||
__all__ = [
|
||||
"FlightProcessor",
|
||||
"Pipeline",
|
||||
"ImageInputPipeline",
|
||||
"ResultManager",
|
||||
"SSEEventStreamer",
|
||||
]
|
||||
@@ -0,0 +1,227 @@
|
||||
"""Image Input Pipeline (Component F05)."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.schemas.image import (
|
||||
ImageBatch,
|
||||
ImageData,
|
||||
ImageMetadata,
|
||||
ProcessedBatch,
|
||||
ProcessingStatus,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
|
||||
class QueueFullError(Exception):
|
||||
pass
|
||||
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ImageInputPipeline:
|
||||
"""Manages ingestion, disk storage, and queuing of UAV image batches."""
|
||||
|
||||
def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50):
|
||||
self.storage_dir = storage_dir
|
||||
# flight_id -> asyncio.Queue of ImageBatch
|
||||
self._queues: dict[str, asyncio.Queue] = {}
|
||||
self.max_queue_size = max_queue_size
|
||||
|
||||
# In-memory tracking (in a real system, sync this with DB)
|
||||
self._status: dict[str, dict] = {}
|
||||
# Exact sequence → filename mapping (VO-05: no substring collision)
|
||||
self._sequence_map: dict[str, dict[int, str]] = {}
|
||||
|
||||
def _get_queue(self, flight_id: str) -> asyncio.Queue:
|
||||
if flight_id not in self._queues:
|
||||
self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size)
|
||||
return self._queues[flight_id]
|
||||
|
||||
def _init_status(self, flight_id: str):
|
||||
if flight_id not in self._status:
|
||||
self._status[flight_id] = {
|
||||
"total_images": 0,
|
||||
"processed_images": 0,
|
||||
"current_sequence": 1,
|
||||
}
|
||||
|
||||
def validate_batch(self, batch: ImageBatch) -> ValidationResult:
|
||||
"""Validates batch integrity and sequence continuity."""
|
||||
errors = []
|
||||
|
||||
num_images = len(batch.images)
|
||||
if num_images < 1:
|
||||
errors.append("Batch is empty")
|
||||
elif num_images > 100:
|
||||
errors.append("Batch too large")
|
||||
|
||||
if len(batch.filenames) != num_images:
|
||||
errors.append("Mismatch between filenames and images count")
|
||||
|
||||
# Naming convention ADxxxxxx.jpg or similar
|
||||
pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE)
|
||||
for fn in batch.filenames:
|
||||
if not pattern.match(fn):
|
||||
errors.append(f"Invalid filename: {fn}")
|
||||
break
|
||||
|
||||
if batch.start_sequence > batch.end_sequence:
|
||||
errors.append("Start sequence greater than end sequence")
|
||||
|
||||
return ValidationResult(valid=len(errors) == 0, errors=errors)
|
||||
|
||||
def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool:
|
||||
"""Queues a batch of images for processing."""
|
||||
val = self.validate_batch(batch)
|
||||
if not val.valid:
|
||||
raise ValidationError(f"Batch validation failed: {val.errors}")
|
||||
|
||||
q = self._get_queue(flight_id)
|
||||
if q.full():
|
||||
raise QueueFullError(f"Queue for flight {flight_id} is full")
|
||||
|
||||
q.put_nowait(batch)
|
||||
|
||||
self._init_status(flight_id)
|
||||
self._status[flight_id]["total_images"] += len(batch.images)
|
||||
|
||||
return True
|
||||
|
||||
async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None:
|
||||
"""Dequeues and processing the next batch."""
|
||||
q = self._get_queue(flight_id)
|
||||
if q.empty():
|
||||
return None
|
||||
|
||||
batch: ImageBatch = await q.get()
|
||||
|
||||
processed_images = []
|
||||
for i, raw_bytes in enumerate(batch.images):
|
||||
# Decode
|
||||
nparr = np.frombuffer(raw_bytes, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is None:
|
||||
continue # skip corrupted
|
||||
|
||||
seq = batch.start_sequence + i
|
||||
fn = batch.filenames[i]
|
||||
|
||||
h, w = img.shape[:2]
|
||||
meta = ImageMetadata(
|
||||
sequence=seq,
|
||||
filename=fn,
|
||||
dimensions=(w, h),
|
||||
file_size=len(raw_bytes),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
img_data = ImageData(
|
||||
flight_id=flight_id,
|
||||
sequence=seq,
|
||||
filename=fn,
|
||||
image=img,
|
||||
metadata=meta
|
||||
)
|
||||
processed_images.append(img_data)
|
||||
# VO-05: record exact sequence→filename mapping
|
||||
self._sequence_map.setdefault(flight_id, {})[seq] = fn
|
||||
|
||||
# Store to disk
|
||||
self.store_images(flight_id, processed_images)
|
||||
|
||||
self._status[flight_id]["processed_images"] += len(processed_images)
|
||||
q.task_done()
|
||||
|
||||
return ProcessedBatch(
|
||||
images=processed_images,
|
||||
batch_id=f"batch_{batch.batch_number}",
|
||||
start_sequence=batch.start_sequence,
|
||||
end_sequence=batch.end_sequence
|
||||
)
|
||||
|
||||
def store_images(self, flight_id: str, images: list[ImageData]) -> bool:
|
||||
"""Persists images to disk."""
|
||||
flight_dir = os.path.join(self.storage_dir, flight_id)
|
||||
os.makedirs(flight_dir, exist_ok=True)
|
||||
|
||||
for img in images:
|
||||
path = os.path.join(flight_dir, img.filename)
|
||||
cv2.imwrite(path, img.image)
|
||||
|
||||
return True
|
||||
|
||||
def get_next_image(self, flight_id: str) -> ImageData | None:
|
||||
"""Gets the next image in sequence for processing."""
|
||||
self._init_status(flight_id)
|
||||
seq = self._status[flight_id]["current_sequence"]
|
||||
|
||||
img = self.get_image_by_sequence(flight_id, seq)
|
||||
if img:
|
||||
self._status[flight_id]["current_sequence"] += 1
|
||||
|
||||
return img
|
||||
|
||||
def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None:
|
||||
"""Retrieves a specific image by sequence number (exact match — VO-05)."""
|
||||
flight_dir = os.path.join(self.storage_dir, flight_id)
|
||||
if not os.path.exists(flight_dir):
|
||||
return None
|
||||
|
||||
# Prefer the exact mapping built during process_next_batch
|
||||
fn = self._sequence_map.get(flight_id, {}).get(sequence)
|
||||
if fn:
|
||||
path = os.path.join(flight_dir, fn)
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
h, w = img.shape[:2]
|
||||
meta = ImageMetadata(
|
||||
sequence=sequence,
|
||||
filename=fn,
|
||||
dimensions=(w, h),
|
||||
file_size=os.path.getsize(path),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
return ImageData(flight_id, sequence, fn, img, meta)
|
||||
|
||||
# Fallback: scan directory for exact filename patterns
|
||||
# (handles images stored before this process started)
|
||||
for fn in os.listdir(flight_dir):
|
||||
base, _ = os.path.splitext(fn)
|
||||
# Accept only if the base name ends with exactly the padded sequence number
|
||||
if base.endswith(f"{sequence:06d}") or base == str(sequence):
|
||||
path = os.path.join(flight_dir, fn)
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
h, w = img.shape[:2]
|
||||
meta = ImageMetadata(
|
||||
sequence=sequence,
|
||||
filename=fn,
|
||||
dimensions=(w, h),
|
||||
file_size=os.path.getsize(path),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
return ImageData(flight_id, sequence, fn, img, meta)
|
||||
|
||||
return None
|
||||
|
||||
def get_processing_status(self, flight_id: str) -> ProcessingStatus:
|
||||
self._init_status(flight_id)
|
||||
s = self._status[flight_id]
|
||||
q = self._get_queue(flight_id)
|
||||
|
||||
return ProcessingStatus(
|
||||
flight_id=flight_id,
|
||||
total_images=s["total_images"],
|
||||
processed_images=s["processed_images"],
|
||||
current_sequence=s["current_sequence"],
|
||||
queued_batches=q.qsize(),
|
||||
processing_rate=0.0 # mock
|
||||
)
|
||||
@@ -0,0 +1,599 @@
|
||||
"""Core Flight Processor — Full Processing Pipeline (Stage 10).
|
||||
|
||||
Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE.
|
||||
State Machine: NORMAL → LOST → RECOVERY → NORMAL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.eskf import ESKF
|
||||
from gps_denied.pipeline.image_input import ImageInputPipeline
|
||||
from gps_denied.pipeline.result_manager import ResultManager
|
||||
from gps_denied.pipeline.sse_streamer import SSEEventStreamer
|
||||
from gps_denied.db.repository import FlightRepository
|
||||
from gps_denied.schemas import CameraParameters, GPSPoint
|
||||
from gps_denied.schemas.flight import (
|
||||
BatchMetadata,
|
||||
BatchResponse,
|
||||
BatchUpdateResponse,
|
||||
DeleteResponse,
|
||||
FlightCreateRequest,
|
||||
FlightDetailResponse,
|
||||
FlightResponse,
|
||||
FlightStatusResponse,
|
||||
ObjectGPSResponse,
|
||||
UpdateResponse,
|
||||
UserFixRequest,
|
||||
UserFixResponse,
|
||||
Waypoint,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State Machine
|
||||
# ---------------------------------------------------------------------------
|
||||
class TrackingState(str, Enum):
|
||||
"""Processing state for a flight."""
|
||||
NORMAL = "normal"
|
||||
LOST = "lost"
|
||||
RECOVERY = "recovery"
|
||||
|
||||
|
||||
class FrameResult:
|
||||
"""Intermediate result of processing a single frame."""
|
||||
|
||||
def __init__(self, frame_id: int):
|
||||
self.frame_id = frame_id
|
||||
self.gps: Optional[GPSPoint] = None
|
||||
self.confidence: float = 0.0
|
||||
self.tracking_state: TrackingState = TrackingState.NORMAL
|
||||
self.vo_success: bool = False
|
||||
self.alignment_success: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlightProcessor
|
||||
# ---------------------------------------------------------------------------
|
||||
class FlightProcessor:
|
||||
"""Manages business logic, background processing, and frame orchestration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: FlightRepository,
|
||||
streamer: SSEEventStreamer,
|
||||
eskf_config=None,
|
||||
) -> None:
|
||||
self.repository = repository
|
||||
self.streamer = streamer
|
||||
self.result_manager = ResultManager(repository, streamer)
|
||||
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
|
||||
self._eskf_config = eskf_config # ESKFConfig or None → default
|
||||
|
||||
# Per-flight processing state
|
||||
self._flight_states: dict[str, TrackingState] = {}
|
||||
self._prev_images: dict[str, np.ndarray] = {} # previous frame cache
|
||||
self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera
|
||||
self._altitudes: dict[str, float] = {} # per-flight altitude (m)
|
||||
self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter
|
||||
|
||||
# Per-flight ESKF instances (PIPE-01/07)
|
||||
self._eskf: dict[str, ESKF] = {}
|
||||
|
||||
# Lazy-initialised component references (set via `attach_components`)
|
||||
self._vo = None # ISequentialVisualOdometry
|
||||
self._gpr = None # IGlobalPlaceRecognition
|
||||
self._metric = None # IMetricRefinement
|
||||
self._graph = None # IFactorGraphOptimizer
|
||||
self._recovery = None # IFailureRecoveryCoordinator
|
||||
self._chunk_mgr = None # IRouteChunkManager
|
||||
self._rotation = None # ImageRotationManager
|
||||
self._satellite = None # SatelliteDataManager (PIPE-02)
|
||||
self._coord = None # CoordinateTransformer (PIPE-02/06)
|
||||
self._mavlink = None # MAVLinkBridge (PIPE-07)
|
||||
|
||||
# ------ Dependency injection for core components ---------
|
||||
def attach_components(
|
||||
self,
|
||||
vo=None,
|
||||
gpr=None,
|
||||
metric=None,
|
||||
graph=None,
|
||||
recovery=None,
|
||||
chunk_mgr=None,
|
||||
rotation=None,
|
||||
satellite=None,
|
||||
coord=None,
|
||||
mavlink=None,
|
||||
):
|
||||
"""Attach pipeline components after construction (avoids circular deps)."""
|
||||
self._vo = vo
|
||||
self._gpr = gpr
|
||||
self._metric = metric
|
||||
self._graph = graph
|
||||
self._recovery = recovery
|
||||
self._chunk_mgr = chunk_mgr
|
||||
self._rotation = rotation
|
||||
self._satellite = satellite # PIPE-02: SatelliteDataManager
|
||||
self._coord = coord # PIPE-02/06: CoordinateTransformer
|
||||
self._mavlink = mavlink # PIPE-07: MAVLinkBridge
|
||||
|
||||
# ------ ESKF lifecycle helpers ----------------------------
|
||||
def _init_eskf_for_flight(
|
||||
self, flight_id: str, start_gps: GPSPoint, altitude: float
|
||||
) -> None:
|
||||
"""Create and initialize a per-flight ESKF instance."""
|
||||
if flight_id in self._eskf:
|
||||
return
|
||||
eskf = ESKF(config=self._eskf_config)
|
||||
if self._coord:
|
||||
try:
|
||||
e, n, _ = self._coord.gps_to_enu(flight_id, start_gps)
|
||||
eskf.initialize(np.array([e, n, altitude]), time.time())
|
||||
except Exception:
|
||||
eskf.initialize(np.zeros(3), time.time())
|
||||
else:
|
||||
eskf.initialize(np.zeros(3), time.time())
|
||||
self._eskf[flight_id] = eskf
|
||||
|
||||
def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]:
|
||||
"""Convert current ESKF ENU position to WGS84 GPS."""
|
||||
if not eskf.initialized or self._coord is None:
|
||||
return None
|
||||
try:
|
||||
pos = eskf.position
|
||||
return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2])))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# =========================================================
|
||||
# process_frame — central orchestration
|
||||
# =========================================================
|
||||
async def process_frame(
|
||||
self,
|
||||
flight_id: str,
|
||||
frame_id: int,
|
||||
image: np.ndarray,
|
||||
) -> FrameResult:
|
||||
"""
|
||||
Process a single UAV frame through the full pipeline.
|
||||
|
||||
State transitions:
|
||||
NORMAL — VO succeeds → ESKF VO update, attempt satellite fix
|
||||
LOST — VO failed → create new chunk, enter RECOVERY
|
||||
RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL
|
||||
|
||||
PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT
|
||||
PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection
|
||||
PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator
|
||||
PIPE-05: ImageRotationManager initialised on first frame
|
||||
PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state
|
||||
"""
|
||||
result = FrameResult(frame_id)
|
||||
state = self._flight_states.get(flight_id, TrackingState.NORMAL)
|
||||
eskf = self._eskf.get(flight_id)
|
||||
|
||||
_default_cam = CameraParameters(
|
||||
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
||||
resolution_width=640, resolution_height=480,
|
||||
)
|
||||
|
||||
# ---- PIPE-05: Initialise heading tracking on first frame ----
|
||||
if self._rotation and frame_id == 0:
|
||||
self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory
|
||||
|
||||
# ---- 1. Visual Odometry (frame-to-frame) ----
|
||||
vo_ok = False
|
||||
if self._vo and flight_id in self._prev_images:
|
||||
try:
|
||||
cam = self._flight_cameras.get(flight_id, _default_cam)
|
||||
rel_pose = self._vo.compute_relative_pose(
|
||||
self._prev_images[flight_id], image, cam
|
||||
)
|
||||
if rel_pose and rel_pose.tracking_good:
|
||||
vo_ok = True
|
||||
result.vo_success = True
|
||||
|
||||
if self._graph:
|
||||
self._graph.add_relative_factor(
|
||||
flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6)
|
||||
)
|
||||
|
||||
# PIPE-01: Feed VO relative displacement into ESKF
|
||||
if eskf and eskf.initialized:
|
||||
now = time.time()
|
||||
dt_vo = max(0.01, now - (eskf.last_timestamp or now))
|
||||
eskf.update_vo(rel_pose.translation, dt_vo)
|
||||
except Exception as exc:
|
||||
logger.warning("VO failed for frame %d: %s", frame_id, exc)
|
||||
|
||||
# Store current image for next frame
|
||||
self._prev_images[flight_id] = image
|
||||
|
||||
# ---- PIPE-04: Consecutive failure counter ----
|
||||
if not vo_ok and frame_id > 0:
|
||||
self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1
|
||||
else:
|
||||
self._failure_counts[flight_id] = 0
|
||||
|
||||
# ---- 2. State Machine transitions ----
|
||||
if state == TrackingState.NORMAL:
|
||||
if not vo_ok and frame_id > 0:
|
||||
state = TrackingState.LOST
|
||||
logger.info("Flight %s → LOST at frame %d", flight_id, frame_id)
|
||||
if self._recovery:
|
||||
self._recovery.handle_tracking_lost(flight_id, frame_id)
|
||||
|
||||
if state == TrackingState.LOST:
|
||||
state = TrackingState.RECOVERY
|
||||
|
||||
if state == TrackingState.RECOVERY:
|
||||
recovered = False
|
||||
if self._recovery and self._chunk_mgr:
|
||||
active_chunk = self._chunk_mgr.get_active_chunk(flight_id)
|
||||
if active_chunk:
|
||||
recovered = self._recovery.process_chunk_recovery(
|
||||
flight_id, active_chunk.chunk_id, [image]
|
||||
)
|
||||
if recovered:
|
||||
state = TrackingState.NORMAL
|
||||
result.alignment_success = True
|
||||
# PIPE-04: Reset failure count on successful recovery
|
||||
self._failure_counts[flight_id] = 0
|
||||
logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id)
|
||||
|
||||
# ---- 3. Satellite position fix (PIPE-01/02) ----
|
||||
if state == TrackingState.NORMAL and self._metric:
|
||||
sat_tile: Optional[np.ndarray] = None
|
||||
tile_bounds = None
|
||||
|
||||
# PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection)
|
||||
if self._satellite and eskf and eskf.initialized:
|
||||
gps_est = self._eskf_to_gps(flight_id, eskf)
|
||||
if gps_est:
|
||||
cov = eskf.covariance
|
||||
sigma_h = float(
|
||||
np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0)
|
||||
) if cov is not None else 30.0
|
||||
sigma_h = max(sigma_h, 5.0)
|
||||
try:
|
||||
tile_result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._satellite.fetch_tiles_for_position,
|
||||
gps_est, sigma_h, 18,
|
||||
)
|
||||
if tile_result:
|
||||
sat_tile, tile_bounds = tile_result
|
||||
except Exception as exc:
|
||||
logger.debug("Satellite tile fetch failed: %s", exc)
|
||||
|
||||
# Fallback: GPR candidate tile (mock image, real bounds)
|
||||
if sat_tile is None and self._gpr:
|
||||
try:
|
||||
candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1)
|
||||
if candidates:
|
||||
sat_tile = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||
tile_bounds = candidates[0].bounds
|
||||
except Exception as exc:
|
||||
logger.debug("GPR tile fallback failed: %s", exc)
|
||||
|
||||
if sat_tile is not None and tile_bounds is not None:
|
||||
try:
|
||||
align = self._metric.align_to_satellite(image, sat_tile, tile_bounds)
|
||||
if align and align.matched:
|
||||
result.gps = align.gps_center
|
||||
result.confidence = align.confidence
|
||||
result.alignment_success = True
|
||||
|
||||
if self._graph:
|
||||
self._graph.add_absolute_factor(
|
||||
flight_id, frame_id,
|
||||
align.gps_center, np.eye(6),
|
||||
is_user_anchor=False,
|
||||
)
|
||||
|
||||
# PIPE-01: ESKF satellite update — noise from RANSAC confidence
|
||||
if eskf and eskf.initialized and self._coord:
|
||||
try:
|
||||
e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center)
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
pos_enu = np.array([e, n, alt])
|
||||
noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence))
|
||||
eskf.update_satellite(pos_enu, noise_m)
|
||||
except Exception as exc:
|
||||
logger.debug("ESKF satellite update failed: %s", exc)
|
||||
except Exception as exc:
|
||||
logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc)
|
||||
|
||||
# ---- 4. Graph optimization (incremental) ----
|
||||
if self._graph:
|
||||
opt_result = self._graph.optimize(flight_id, iterations=5)
|
||||
logger.debug(
|
||||
"Optimization: converged=%s, error=%.4f",
|
||||
opt_result.converged, opt_result.final_error,
|
||||
)
|
||||
|
||||
# ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ----
|
||||
if self._mavlink and eskf and eskf.initialized:
|
||||
try:
|
||||
eskf_state = eskf.get_state()
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
self._mavlink.update_state(eskf_state, altitude_m=alt)
|
||||
except Exception as exc:
|
||||
logger.debug("MAVLink state push failed: %s", exc)
|
||||
|
||||
# ---- 5. Publish via SSE ----
|
||||
result.tracking_state = state
|
||||
self._flight_states[flight_id] = state
|
||||
await self._publish_frame_result(flight_id, result)
|
||||
return result
|
||||
|
||||
async def _publish_frame_result(self, flight_id: str, result: FrameResult):
|
||||
"""Emit SSE event for processed frame."""
|
||||
event_data = {
|
||||
"frame_id": result.frame_id,
|
||||
"tracking_state": result.tracking_state.value,
|
||||
"vo_success": result.vo_success,
|
||||
"alignment_success": result.alignment_success,
|
||||
"confidence": result.confidence,
|
||||
}
|
||||
if result.gps:
|
||||
event_data["lat"] = result.gps.lat
|
||||
event_data["lon"] = result.gps.lon
|
||||
|
||||
await self.streamer.push_event(
|
||||
flight_id, event_type="frame_result", data=event_data
|
||||
)
|
||||
|
||||
# =========================================================
|
||||
# Existing CRUD / REST helpers (unchanged from Stage 3-4)
|
||||
# =========================================================
|
||||
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
|
||||
flight = await self.repository.insert_flight(
|
||||
name=req.name,
|
||||
description=req.description,
|
||||
start_lat=req.start_gps.lat,
|
||||
start_lon=req.start_gps.lon,
|
||||
altitude=req.altitude,
|
||||
camera_params=req.camera_params.model_dump(),
|
||||
)
|
||||
# P0#2: Store camera params for process_frame VO calls
|
||||
self._flight_cameras[flight.id] = req.camera_params
|
||||
|
||||
for poly in req.geofences.polygons:
|
||||
await self.repository.insert_geofence(
|
||||
flight.id,
|
||||
nw_lat=poly.north_west.lat,
|
||||
nw_lon=poly.north_west.lon,
|
||||
se_lat=poly.south_east.lat,
|
||||
se_lon=poly.south_east.lon,
|
||||
)
|
||||
for w in req.rough_waypoints:
|
||||
await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon)
|
||||
|
||||
# Store per-flight altitude for ESKF/pixel projection
|
||||
self._altitudes[flight.id] = req.altitude or 100.0
|
||||
|
||||
# PIPE-02: Set ENU origin and initialise ESKF for this flight
|
||||
if self._coord:
|
||||
self._coord.set_enu_origin(flight.id, req.start_gps)
|
||||
self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0)
|
||||
|
||||
# Start MAVLink bridge for this flight (origin required for GPS_INPUT)
|
||||
if self._mavlink and not self._mavlink._running:
|
||||
try:
|
||||
asyncio.create_task(self._mavlink.start(req.start_gps))
|
||||
except Exception as exc:
|
||||
logger.warning("MAVLink bridge start failed: %s", exc)
|
||||
|
||||
return FlightResponse(
|
||||
flight_id=flight.id,
|
||||
status="prefetching",
|
||||
message="Flight created and prefetching started.",
|
||||
created_at=flight.created_at,
|
||||
)
|
||||
|
||||
async def get_flight(self, flight_id: str) -> FlightDetailResponse | None:
|
||||
flight = await self.repository.get_flight(flight_id)
|
||||
if not flight:
|
||||
return None
|
||||
wps = await self.repository.get_waypoints(flight_id)
|
||||
state = await self.repository.load_flight_state(flight_id)
|
||||
|
||||
waypoints = [
|
||||
Waypoint(
|
||||
id=w.id,
|
||||
lat=w.lat,
|
||||
lon=w.lon,
|
||||
altitude=w.altitude,
|
||||
confidence=w.confidence,
|
||||
timestamp=w.timestamp,
|
||||
refined=w.refined,
|
||||
)
|
||||
for w in wps
|
||||
]
|
||||
|
||||
status = state.status if state else "unknown"
|
||||
frames_processed = state.frames_processed if state else 0
|
||||
frames_total = state.frames_total if state else 0
|
||||
|
||||
from gps_denied.schemas import Geofences
|
||||
|
||||
return FlightDetailResponse(
|
||||
flight_id=flight.id,
|
||||
name=flight.name,
|
||||
description=flight.description,
|
||||
start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon),
|
||||
waypoints=waypoints,
|
||||
geofences=Geofences(polygons=[]),
|
||||
camera_params=flight.camera_params,
|
||||
altitude=flight.altitude,
|
||||
status=status,
|
||||
frames_processed=frames_processed,
|
||||
frames_total=frames_total,
|
||||
created_at=flight.created_at,
|
||||
updated_at=flight.updated_at,
|
||||
)
|
||||
|
||||
async def delete_flight(self, flight_id: str) -> DeleteResponse:
|
||||
deleted = await self.repository.delete_flight(flight_id)
|
||||
# P0#1: Cleanup in-memory state to prevent memory leaks
|
||||
self._cleanup_flight(flight_id)
|
||||
return DeleteResponse(deleted=deleted, flight_id=flight_id)
|
||||
|
||||
def _cleanup_flight(self, flight_id: str) -> None:
|
||||
"""Remove all in-memory state for a flight (prevents memory leaks)."""
|
||||
self._prev_images.pop(flight_id, None)
|
||||
self._flight_states.pop(flight_id, None)
|
||||
self._flight_cameras.pop(flight_id, None)
|
||||
self._altitudes.pop(flight_id, None)
|
||||
self._failure_counts.pop(flight_id, None)
|
||||
self._eskf.pop(flight_id, None)
|
||||
if self._graph:
|
||||
self._graph.delete_flight_graph(flight_id)
|
||||
|
||||
async def update_waypoint(
|
||||
self, flight_id: str, waypoint_id: str, waypoint: Waypoint
|
||||
) -> UpdateResponse:
|
||||
ok = await self.repository.update_waypoint(
|
||||
flight_id,
|
||||
waypoint_id,
|
||||
lat=waypoint.lat,
|
||||
lon=waypoint.lon,
|
||||
altitude=waypoint.altitude,
|
||||
confidence=waypoint.confidence,
|
||||
refined=waypoint.refined,
|
||||
)
|
||||
return UpdateResponse(updated=ok, waypoint_id=waypoint_id)
|
||||
|
||||
async def batch_update_waypoints(
|
||||
self, flight_id: str, waypoints: list[Waypoint]
|
||||
) -> BatchUpdateResponse:
|
||||
failed = []
|
||||
updated = 0
|
||||
for wp in waypoints:
|
||||
ok = await self.repository.update_waypoint(
|
||||
flight_id,
|
||||
wp.id,
|
||||
lat=wp.lat,
|
||||
lon=wp.lon,
|
||||
altitude=wp.altitude,
|
||||
confidence=wp.confidence,
|
||||
refined=wp.refined,
|
||||
)
|
||||
if ok:
|
||||
updated += 1
|
||||
else:
|
||||
failed.append(wp.id)
|
||||
return BatchUpdateResponse(
|
||||
success=(len(failed) == 0), updated_count=updated, failed_ids=failed
|
||||
)
|
||||
|
||||
async def queue_images(
|
||||
self, flight_id: str, metadata: BatchMetadata, file_count: int
|
||||
) -> BatchResponse:
|
||||
state = await self.repository.load_flight_state(flight_id)
|
||||
if state:
|
||||
total = state.frames_total + file_count
|
||||
await self.repository.save_flight_state(
|
||||
flight_id, frames_total=total, status="processing"
|
||||
)
|
||||
|
||||
next_seq = metadata.end_sequence + 1
|
||||
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
|
||||
return BatchResponse(
|
||||
accepted=True,
|
||||
sequences=seqs,
|
||||
next_expected=next_seq,
|
||||
message=f"Queued {file_count} images.",
|
||||
)
|
||||
|
||||
async def handle_user_fix(
|
||||
self, flight_id: str, req: UserFixRequest
|
||||
) -> UserFixResponse:
|
||||
await self.repository.save_flight_state(
|
||||
flight_id, blocked=False, status="processing"
|
||||
)
|
||||
|
||||
# Inject operator position into ESKF with high uncertainty (500m)
|
||||
eskf = self._eskf.get(flight_id)
|
||||
if eskf and eskf.initialized and self._coord:
|
||||
try:
|
||||
e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps)
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0)
|
||||
self._failure_counts[flight_id] = 0
|
||||
logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps)
|
||||
except Exception as exc:
|
||||
logger.warning("User fix ESKF injection failed: %s", exc)
|
||||
|
||||
return UserFixResponse(
|
||||
accepted=True, processing_resumed=True, message="Fix applied."
|
||||
)
|
||||
|
||||
async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None:
|
||||
state = await self.repository.load_flight_state(flight_id)
|
||||
if not state:
|
||||
return None
|
||||
return FlightStatusResponse(
|
||||
status=state.status,
|
||||
frames_processed=state.frames_processed,
|
||||
frames_total=state.frames_total,
|
||||
current_frame=state.current_frame,
|
||||
current_heading=None,
|
||||
blocked=state.blocked,
|
||||
search_grid_size=state.search_grid_size,
|
||||
created_at=state.created_at,
|
||||
updated_at=state.updated_at,
|
||||
)
|
||||
|
||||
async def convert_object_to_gps(
|
||||
self, flight_id: str, frame_id: int, pixel: tuple[float, float]
|
||||
) -> ObjectGPSResponse:
|
||||
# PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection
|
||||
gps: Optional[GPSPoint] = None
|
||||
eskf = self._eskf.get(flight_id)
|
||||
if self._coord and eskf and eskf.initialized:
|
||||
pos = eskf.position
|
||||
quat = eskf.quaternion
|
||||
cam = self._flight_cameras.get(flight_id, CameraParameters(
|
||||
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
||||
resolution_width=640, resolution_height=480,
|
||||
))
|
||||
alt = self._altitudes.get(flight_id, 100.0)
|
||||
try:
|
||||
gps = self._coord.pixel_to_gps(
|
||||
flight_id,
|
||||
pixel,
|
||||
frame_pose={"position": pos},
|
||||
camera_params=cam,
|
||||
altitude=float(alt),
|
||||
quaternion=quat,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("pixel_to_gps failed: %s", exc)
|
||||
|
||||
# Fallback: return ESKF position projected to ground (no pixel shift)
|
||||
if gps is None and eskf:
|
||||
gps = self._eskf_to_gps(flight_id, eskf)
|
||||
|
||||
return ObjectGPSResponse(
|
||||
gps=gps or GPSPoint(lat=0.0, lon=0.0),
|
||||
accuracy_meters=5.0,
|
||||
frame_id=frame_id,
|
||||
pixel=pixel,
|
||||
)
|
||||
|
||||
async def stream_events(self, flight_id: str, client_id: str):
|
||||
"""Async generator for SSE stream."""
|
||||
async for event in self.streamer.stream_generator(flight_id, client_id):
|
||||
yield event
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Result Manager (Component F14)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from gps_denied.pipeline.sse_streamer import SSEEventStreamer
|
||||
from gps_denied.db.repository import FlightRepository
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.events import FrameProcessedEvent
|
||||
|
||||
|
||||
class ResultManager:
|
||||
"""Result consistency and publishing."""
|
||||
|
||||
def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None:
|
||||
self.repo = repo
|
||||
self.sse = sse
|
||||
|
||||
async def update_frame_result(
|
||||
self,
|
||||
flight_id: str,
|
||||
frame_id: int,
|
||||
gps_lat: float,
|
||||
gps_lon: float,
|
||||
altitude: float,
|
||||
heading: float,
|
||||
confidence: float,
|
||||
timestamp: datetime,
|
||||
refined: bool = False,
|
||||
) -> bool:
|
||||
"""Atomic DB update + SSE event publish."""
|
||||
|
||||
# 1. Update DB (in the repository these are auto-committing via flush,
|
||||
# but normally F03 would wrap in a single transaction).
|
||||
await self.repo.save_frame_result(
|
||||
flight_id,
|
||||
frame_id=frame_id,
|
||||
gps_lat=gps_lat,
|
||||
gps_lon=gps_lon,
|
||||
altitude=altitude,
|
||||
heading=heading,
|
||||
confidence=confidence,
|
||||
refined=refined,
|
||||
)
|
||||
|
||||
# Wait, the spec also wants Waypoints to be updated.
|
||||
# But image frames != waypoints. Waypoints are the planned route.
|
||||
# Actually in the spec it says: "Updates waypoint in waypoints table."
|
||||
# This implies updating the closest waypoint or a generated waypoint path.
|
||||
# We will follow the simplest form for now: update the waypoint if there is one corresponding.
|
||||
# Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map,
|
||||
# or we just skip unless specified.
|
||||
|
||||
# 2. Trigger SSE event
|
||||
evt = FrameProcessedEvent(
|
||||
frame_id=frame_id,
|
||||
gps=GPSPoint(lat=gps_lat, lon=gps_lon),
|
||||
altitude=altitude,
|
||||
confidence=confidence,
|
||||
heading=heading,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
if refined:
|
||||
self.sse.send_refinement(flight_id, evt)
|
||||
else:
|
||||
self.sse.send_frame_result(flight_id, evt)
|
||||
|
||||
return True
|
||||
|
||||
async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool:
|
||||
# Just delegates to SSE for waypoint updates, which is basically the frame result for UI
|
||||
pass
|
||||
@@ -0,0 +1,164 @@
|
||||
"""SSE Event Streamer (Component F15)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
from gps_denied.schemas.events import (
|
||||
FlightCompletedEvent,
|
||||
FrameProcessedEvent,
|
||||
SearchExpandedEvent,
|
||||
SSEEventType,
|
||||
SSEMessage,
|
||||
UserInputNeededEvent,
|
||||
)
|
||||
|
||||
|
||||
class SSEEventStreamer:
|
||||
"""Manages real-time SSE connections and event broadcasting."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Map: flight_id -> Dict[client_id, asyncio.Queue]
|
||||
self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict)
|
||||
|
||||
def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]:
|
||||
"""Create a new event queue for a client."""
|
||||
q: asyncio.Queue[SSEMessage | None] = asyncio.Queue()
|
||||
self._streams[flight_id][client_id] = q
|
||||
return q
|
||||
|
||||
def close_stream(self, flight_id: str, client_id: str) -> None:
|
||||
"""Close a client stream by putting a sentinel and removing the queue."""
|
||||
if flight_id in self._streams and client_id in self._streams[flight_id]:
|
||||
q = self._streams[flight_id].pop(client_id)
|
||||
if not self._streams[flight_id]:
|
||||
del self._streams[flight_id]
|
||||
# Put None to signal generator exit
|
||||
try:
|
||||
q.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
def get_active_connections(self, flight_id: str) -> int:
|
||||
return len(self._streams.get(flight_id, {}))
|
||||
|
||||
def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool:
|
||||
"""Broadcast a message to all clients subscribed to flight_id."""
|
||||
if flight_id not in self._streams or not self._streams[flight_id]:
|
||||
return False
|
||||
|
||||
for q in self._streams[flight_id].values():
|
||||
try:
|
||||
q.put_nowait(msg)
|
||||
except asyncio.QueueFull:
|
||||
pass # Drop if queue is full rather than blocking
|
||||
return True
|
||||
|
||||
# ── Business Event Senders ────────────────────────────────────────────────
|
||||
|
||||
def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FRAME_PROCESSED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
id=f"frame_{event_data.frame_id}",
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FRAME_REFINED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
id=f"refine_{event_data.frame_id}",
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.SEARCH_EXPANDED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.USER_INPUT_NEEDED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FLIGHT_COMPLETED,
|
||||
data=event_data.model_dump(mode="json"),
|
||||
)
|
||||
return self._broadcast(flight_id, msg)
|
||||
|
||||
def send_heartbeat(self, flight_id: str) -> bool:
|
||||
# sse_starlette uses empty string or comment for heartbeat,
|
||||
# but we can just send an SSEMessage object that parses as empty event
|
||||
if flight_id not in self._streams:
|
||||
return False
|
||||
|
||||
# Manually sending a comment via the generator is tricky with strict SSEMessage schema
|
||||
# but we'll handle this in the stream generator directly
|
||||
return True
|
||||
|
||||
# ── Generic event dispatcher (used by processor.process_frame) ──────────
|
||||
|
||||
async def push_event(self, flight_id: str, event_type: str, data: dict) -> None:
|
||||
"""Dispatch a generic event to all clients for a flight.
|
||||
|
||||
Maps event_type strings to typed SSE events:
|
||||
"frame_result" → FrameProcessedEvent
|
||||
"refinement" → FrameProcessedEvent (refined)
|
||||
Other → raw broadcast via SSEMessage
|
||||
"""
|
||||
if event_type == "frame_result":
|
||||
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
||||
self.send_frame_result(flight_id, evt)
|
||||
elif event_type == "refinement":
|
||||
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
||||
self.send_refinement(flight_id, evt)
|
||||
else:
|
||||
msg = SSEMessage(
|
||||
event=SSEEventType.FRAME_PROCESSED,
|
||||
data=data,
|
||||
id=str(data.get("frame_id", "")),
|
||||
)
|
||||
self._broadcast(flight_id, msg)
|
||||
|
||||
# ── Stream Generator ──────────────────────────────────────────────────────
|
||||
|
||||
async def stream_generator(self, flight_id: str, client_id: str):
|
||||
"""Yields dicts for sse_starlette EventSourceResponse."""
|
||||
q = self.create_stream(flight_id, client_id)
|
||||
|
||||
# Send an immediate connection accepted ping
|
||||
yield {"event": "connected", "data": "connected"}
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for next event or send heartbeat every 15s
|
||||
try:
|
||||
msg = await asyncio.wait_for(q.get(), timeout=15.0)
|
||||
if msg is None:
|
||||
# Sentinel for clean shutdown
|
||||
break
|
||||
|
||||
# Yield dict format for sse_starlette
|
||||
yield {
|
||||
"event": msg.event.value,
|
||||
"id": msg.id if msg.id else "",
|
||||
"data": json.dumps(msg.data)
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Heartbeat format for sse_starlette (empty string generates a comment)
|
||||
yield {"event": "heartbeat", "data": "ping"}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass # Client disconnected
|
||||
finally:
|
||||
self.close_stream(flight_id, client_id)
|
||||
@@ -0,0 +1,371 @@
|
||||
"""Accuracy Benchmark (Phase 7).
|
||||
|
||||
Provides:
|
||||
- SyntheticTrajectory — generates a realistic fixed-wing UAV flight path
|
||||
with ground-truth GPS + noisy sensor data.
|
||||
- AccuracyBenchmark — replays a trajectory through the ESKF pipeline
|
||||
and computes position-error statistics.
|
||||
|
||||
Acceptance criteria (from solution.md):
|
||||
AC-PERF-1: 80 % of frames within 50 m of ground truth.
|
||||
AC-PERF-2: 60 % of frames within 20 m of ground truth.
|
||||
AC-PERF-3: End-to-end per-frame latency < 400 ms.
|
||||
AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gps_denied.core.coordinates import CoordinateTransformer
|
||||
from gps_denied.core.eskf import ESKF
|
||||
from gps_denied.schemas import GPSPoint
|
||||
from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthetic trajectory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class TrajectoryFrame:
|
||||
"""One simulated camera frame with ground-truth and noisy sensor data."""
|
||||
frame_id: int
|
||||
timestamp: float
|
||||
true_position_enu: np.ndarray # (3,) East, North, Up in metres
|
||||
true_gps: GPSPoint # WGS84 from true ENU
|
||||
imu_measurements: list[IMUMeasurement] # High-rate IMU between frames
|
||||
vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,)
|
||||
vo_tracking_good: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyntheticTrajectoryConfig:
|
||||
"""Parameters for trajectory generation."""
|
||||
# Origin (mission start)
|
||||
origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0))
|
||||
altitude_m: float = 600.0 # Constant AGL altitude (m)
|
||||
# UAV speed and heading
|
||||
speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing)
|
||||
heading_deg: float = 45.0 # Initial heading (degrees CW from North)
|
||||
camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz)
|
||||
imu_hz: float = 200.0 # IMU sample rate
|
||||
num_frames: int = 50 # Number of camera frames to simulate
|
||||
# Noise parameters
|
||||
vo_noise_m: float = 0.5 # VO translation noise (sigma, metres)
|
||||
imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²)
|
||||
imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s)
|
||||
# Failure injection
|
||||
vo_failure_frames: list[int] = field(default_factory=list)
|
||||
# Waypoints for heading changes (ENU East, North metres from origin)
|
||||
waypoints_enu: list[tuple[float, float]] = field(default_factory=list)
|
||||
|
||||
|
||||
class SyntheticTrajectory:
|
||||
"""Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors."""
|
||||
|
||||
def __init__(self, config: SyntheticTrajectoryConfig | None = None):
|
||||
self.config = config or SyntheticTrajectoryConfig()
|
||||
self._coord = CoordinateTransformer()
|
||||
self._flight_id = "__synthetic__"
|
||||
self._coord.set_enu_origin(self._flight_id, self.config.origin)
|
||||
|
||||
def generate(self) -> list[TrajectoryFrame]:
|
||||
"""Generate all trajectory frames."""
|
||||
cfg = self.config
|
||||
dt_camera = 1.0 / cfg.camera_fps
|
||||
dt_imu = 1.0 / cfg.imu_hz
|
||||
imu_steps = int(dt_camera * cfg.imu_hz)
|
||||
|
||||
frames: list[TrajectoryFrame] = []
|
||||
pos = np.array([0.0, 0.0, cfg.altitude_m])
|
||||
vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps)
|
||||
prev_pos = pos.copy()
|
||||
t = time.time()
|
||||
|
||||
waypoints = list(cfg.waypoints_enu) # copy
|
||||
|
||||
for fid in range(cfg.num_frames):
|
||||
# --- Waypoint steering ---
|
||||
if waypoints:
|
||||
wp_e, wp_n = waypoints[0]
|
||||
to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0])
|
||||
dist_wp = np.linalg.norm(to_wp[:2])
|
||||
if dist_wp < cfg.speed_mps * dt_camera:
|
||||
waypoints.pop(0)
|
||||
else:
|
||||
heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y
|
||||
vel = np.array([
|
||||
cfg.speed_mps * math.sin(heading_rad),
|
||||
cfg.speed_mps * math.cos(heading_rad),
|
||||
0.0,
|
||||
])
|
||||
|
||||
# --- Simulate IMU between frames ---
|
||||
imu_list: list[IMUMeasurement] = []
|
||||
for step in range(imu_steps):
|
||||
ts = t + step * dt_imu
|
||||
# Body-frame acceleration (mostly gravity correction, small forward accel)
|
||||
accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation
|
||||
gyro_true = np.zeros(3)
|
||||
imu = IMUMeasurement(
|
||||
accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise,
|
||||
gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise,
|
||||
timestamp=ts,
|
||||
)
|
||||
imu_list.append(imu)
|
||||
|
||||
# --- Propagate position ---
|
||||
prev_pos = pos.copy()
|
||||
pos = pos + vel * dt_camera
|
||||
t += dt_camera
|
||||
|
||||
# --- True GPS from ENU position ---
|
||||
true_gps = self._coord.enu_to_gps(
|
||||
self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))
|
||||
)
|
||||
|
||||
# --- VO measurement (relative displacement + noise) ---
|
||||
true_displacement = pos - prev_pos
|
||||
vo_tracking_good = fid not in cfg.vo_failure_frames
|
||||
if vo_tracking_good:
|
||||
noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m
|
||||
noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z
|
||||
else:
|
||||
noisy_displacement = None
|
||||
|
||||
frames.append(TrajectoryFrame(
|
||||
frame_id=fid,
|
||||
timestamp=t,
|
||||
true_position_enu=pos.copy(),
|
||||
true_gps=true_gps,
|
||||
imu_measurements=imu_list,
|
||||
vo_translation=noisy_displacement,
|
||||
vo_tracking_good=vo_tracking_good,
|
||||
))
|
||||
|
||||
return frames
|
||||
|
||||
@staticmethod
|
||||
def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray:
|
||||
"""Convert heading (degrees CW from North) to ENU velocity vector."""
|
||||
rad = math.radians(heading_deg)
|
||||
return np.array([
|
||||
speed_mps * math.sin(rad), # East
|
||||
speed_mps * math.cos(rad), # North
|
||||
0.0, # Up
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Accuracy Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResult:
|
||||
"""Position error statistics over a trajectory replay."""
|
||||
errors_m: list[float] # Per-frame horizontal error in metres
|
||||
latencies_ms: list[float] # Per-frame process time in ms
|
||||
frames_total: int
|
||||
frames_with_good_estimate: int
|
||||
|
||||
@property
|
||||
def p80_error_m(self) -> float:
|
||||
"""80th percentile position error (metres)."""
|
||||
return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def p60_error_m(self) -> float:
|
||||
"""60th percentile position error (metres)."""
|
||||
return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def median_error_m(self) -> float:
|
||||
"""Median position error (metres)."""
|
||||
return float(np.median(self.errors_m)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def max_error_m(self) -> float:
|
||||
return float(max(self.errors_m)) if self.errors_m else float("inf")
|
||||
|
||||
@property
|
||||
def p95_latency_ms(self) -> float:
|
||||
"""95th percentile frame latency (ms)."""
|
||||
return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf")
|
||||
|
||||
@property
|
||||
def pct_within_50m(self) -> float:
|
||||
"""Fraction of frames within 50 m error."""
|
||||
if not self.errors_m:
|
||||
return 0.0
|
||||
return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m)
|
||||
|
||||
@property
|
||||
def pct_within_20m(self) -> float:
|
||||
"""Fraction of frames within 20 m error."""
|
||||
if not self.errors_m:
|
||||
return 0.0
|
||||
return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m)
|
||||
|
||||
def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]:
|
||||
"""Check all solution.md acceptance criteria.
|
||||
|
||||
Returns (overall_pass, per_criterion_dict).
|
||||
"""
|
||||
checks = {
|
||||
"AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80,
|
||||
"AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60,
|
||||
"AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0,
|
||||
}
|
||||
overall = all(checks.values())
|
||||
return overall, checks
|
||||
|
||||
def summary(self) -> str:
|
||||
overall, checks = self.passes_acceptance_criteria()
|
||||
lines = [
|
||||
f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}",
|
||||
f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m "
|
||||
f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m",
|
||||
f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%",
|
||||
f"Latency p95: {self.p95_latency_ms:.1f}ms",
|
||||
"",
|
||||
"Acceptance criteria:",
|
||||
]
|
||||
for criterion, passed in checks.items():
|
||||
lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}")
|
||||
lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class AccuracyBenchmark:
|
||||
"""Replays a SyntheticTrajectory through the ESKF and measures accuracy.
|
||||
|
||||
The benchmark uses only the ESKF (no full FlightProcessor) for speed.
|
||||
Satellite corrections are injected optionally via sat_correction_fn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eskf_config: ESKFConfig | None = None,
|
||||
sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
eskf_config: ESKF tuning parameters.
|
||||
sat_correction_fn: Optional callback(frame) → ENU position or None.
|
||||
Called on keyframes to inject satellite corrections.
|
||||
If None, no satellite corrections are applied.
|
||||
"""
|
||||
self.eskf_config = eskf_config or ESKFConfig()
|
||||
self.sat_correction_fn = sat_correction_fn
|
||||
|
||||
def run(
|
||||
self,
|
||||
trajectory: list[TrajectoryFrame],
|
||||
origin: GPSPoint,
|
||||
satellite_keyframe_interval: int = 7,
|
||||
) -> BenchmarkResult:
|
||||
"""Replay trajectory frames through ESKF, collect errors and latencies.
|
||||
|
||||
Args:
|
||||
trajectory: List of TrajectoryFrame (from SyntheticTrajectory).
|
||||
origin: WGS84 reference origin for ENU.
|
||||
satellite_keyframe_interval: Apply satellite correction every N frames.
|
||||
"""
|
||||
coord = CoordinateTransformer()
|
||||
flight_id = "__benchmark__"
|
||||
coord.set_enu_origin(flight_id, origin)
|
||||
|
||||
eskf = ESKF(self.eskf_config)
|
||||
# Init at origin with HIGH uncertainty
|
||||
eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]),
|
||||
trajectory[0].timestamp)
|
||||
|
||||
errors_m: list[float] = []
|
||||
latencies_ms: list[float] = []
|
||||
frames_with_estimate = 0
|
||||
|
||||
for frame in trajectory:
|
||||
t_frame_start = time.perf_counter()
|
||||
|
||||
# --- IMU prediction ---
|
||||
for imu in frame.imu_measurements:
|
||||
eskf.predict(imu)
|
||||
|
||||
# --- VO update ---
|
||||
if frame.vo_tracking_good and frame.vo_translation is not None:
|
||||
dt_vo = 1.0 / 0.7 # camera interval
|
||||
eskf.update_vo(frame.vo_translation, dt_vo)
|
||||
|
||||
# --- Satellite update (keyframes) ---
|
||||
if frame.frame_id % satellite_keyframe_interval == 0:
|
||||
sat_pos_enu: Optional[np.ndarray] = None
|
||||
if self.sat_correction_fn is not None:
|
||||
sat_pos_enu = self.sat_correction_fn(frame)
|
||||
else:
|
||||
# Default: inject ground-truth position + realistic noise
|
||||
noise_m = 10.0
|
||||
sat_pos_enu = (
|
||||
frame.true_position_enu[:3]
|
||||
+ np.random.randn(3) * noise_m
|
||||
)
|
||||
sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude
|
||||
|
||||
if sat_pos_enu is not None:
|
||||
# Tell ESKF the measurement noise matches what we inject
|
||||
eskf.update_satellite(sat_pos_enu, noise_meters=noise_m)
|
||||
|
||||
latency_ms = (time.perf_counter() - t_frame_start) * 1000.0
|
||||
latencies_ms.append(latency_ms)
|
||||
|
||||
# --- Compute horizontal error vs ground truth ---
|
||||
if eskf.initialized and eskf._nominal_state is not None:
|
||||
est_pos = eskf._nominal_state["position"]
|
||||
true_pos = frame.true_position_enu
|
||||
horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2]))
|
||||
errors_m.append(horiz_error)
|
||||
frames_with_estimate += 1
|
||||
else:
|
||||
errors_m.append(float("inf"))
|
||||
|
||||
return BenchmarkResult(
|
||||
errors_m=errors_m,
|
||||
latencies_ms=latencies_ms,
|
||||
frames_total=len(trajectory),
|
||||
frames_with_good_estimate=frames_with_estimate,
|
||||
)
|
||||
|
||||
def run_vo_drift_test(
|
||||
self,
|
||||
trajectory_length_m: float = 1000.0,
|
||||
speed_mps: float = 20.0,
|
||||
) -> float:
|
||||
"""Measure VO drift over a straight segment with NO satellite correction.
|
||||
|
||||
Returns final horizontal position error in metres.
|
||||
Per solution.md, this should be < 100m over 1km.
|
||||
"""
|
||||
fps = 0.7
|
||||
num_frames = max(10, int(trajectory_length_m / speed_mps * fps))
|
||||
cfg = SyntheticTrajectoryConfig(
|
||||
speed_mps=speed_mps,
|
||||
heading_deg=0.0, # straight North
|
||||
camera_fps=fps,
|
||||
num_frames=num_frames,
|
||||
vo_noise_m=0.3, # cuVSLAM-grade VO noise
|
||||
)
|
||||
traj_gen = SyntheticTrajectory(cfg)
|
||||
frames = traj_gen.generate()
|
||||
|
||||
# No satellite corrections
|
||||
benchmark_no_sat = AccuracyBenchmark(
|
||||
eskf_config=self.eskf_config,
|
||||
sat_correction_fn=lambda _: None, # suppress all satellite updates
|
||||
)
|
||||
result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999)
|
||||
# Return final-frame error
|
||||
return result.errors_m[-1] if result.errors_m else float("inf")
|
||||
Reference in New Issue
Block a user