refactor(01-07): factor_graph, pipeline pkg, testing/benchmark, Protocol ABCs

- Create core/factor_graph.py: IFactorGraphOptimizer converted to Protocol
- Shim core/graph.py to re-export from core/factor_graph
- Create pipeline/ package: orchestrator, image_input, result_manager, sse_streamer
- Shim core/{processor,pipeline,results,sse}.py to re-export from pipeline/
- Create testing/benchmark.py; shim core/benchmark.py
- Convert IRouteChunkManager, IFailureRecoveryCoordinator, IModelManager, IImageMatcher to Protocol
- Update pyproject.toml ruff per-file-ignores to new paths
- All 216 tests pass (regression floor maintained)
This commit is contained in:
Yuzviak
2026-05-11 08:59:07 +03:00
parent 275c7b4642
commit 5a60c1ee2c
18 changed files with 1857 additions and 1836 deletions
+2 -2
View File
@@ -42,8 +42,8 @@ line-length = 120
[tool.ruff.lint.per-file-ignores]
# Abstract interfaces have long method signatures — allow up to 170
"src/gps_denied/core/graph.py" = ["E501"]
"src/gps_denied/core/metric.py" = ["E501"]
"src/gps_denied/core/factor_graph.py" = ["E501"]
"src/gps_denied/components/satellite_matcher/metric_refinement.py" = ["E501"]
"src/gps_denied/core/chunk_manager.py" = ["E501"]
[tool.ruff.lint]
+8 -371
View File
@@ -1,371 +1,8 @@
"""Accuracy Benchmark (Phase 7).
Provides:
- SyntheticTrajectory — generates a realistic fixed-wing UAV flight path
with ground-truth GPS + noisy sensor data.
- AccuracyBenchmark — replays a trajectory through the ESKF pipeline
and computes position-error statistics.
Acceptance criteria (from solution.md):
AC-PERF-1: 80 % of frames within 50 m of ground truth.
AC-PERF-2: 60 % of frames within 20 m of ground truth.
AC-PERF-3: End-to-end per-frame latency < 400 ms.
AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m.
"""
from __future__ import annotations
import math
import time
from dataclasses import dataclass, field
from typing import Callable, Optional
import numpy as np
from gps_denied.core.coordinates import CoordinateTransformer
from gps_denied.core.eskf import ESKF
from gps_denied.schemas import GPSPoint
from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement
# ---------------------------------------------------------------------------
# Synthetic trajectory
# ---------------------------------------------------------------------------
@dataclass
class TrajectoryFrame:
"""One simulated camera frame with ground-truth and noisy sensor data."""
frame_id: int
timestamp: float
true_position_enu: np.ndarray # (3,) East, North, Up in metres
true_gps: GPSPoint # WGS84 from true ENU
imu_measurements: list[IMUMeasurement] # High-rate IMU between frames
vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,)
vo_tracking_good: bool = True
@dataclass
class SyntheticTrajectoryConfig:
"""Parameters for trajectory generation."""
# Origin (mission start)
origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0))
altitude_m: float = 600.0 # Constant AGL altitude (m)
# UAV speed and heading
speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing)
heading_deg: float = 45.0 # Initial heading (degrees CW from North)
camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz)
imu_hz: float = 200.0 # IMU sample rate
num_frames: int = 50 # Number of camera frames to simulate
# Noise parameters
vo_noise_m: float = 0.5 # VO translation noise (sigma, metres)
imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²)
imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s)
# Failure injection
vo_failure_frames: list[int] = field(default_factory=list)
# Waypoints for heading changes (ENU East, North metres from origin)
waypoints_enu: list[tuple[float, float]] = field(default_factory=list)
class SyntheticTrajectory:
"""Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors."""
def __init__(self, config: SyntheticTrajectoryConfig | None = None):
self.config = config or SyntheticTrajectoryConfig()
self._coord = CoordinateTransformer()
self._flight_id = "__synthetic__"
self._coord.set_enu_origin(self._flight_id, self.config.origin)
def generate(self) -> list[TrajectoryFrame]:
"""Generate all trajectory frames."""
cfg = self.config
dt_camera = 1.0 / cfg.camera_fps
dt_imu = 1.0 / cfg.imu_hz
imu_steps = int(dt_camera * cfg.imu_hz)
frames: list[TrajectoryFrame] = []
pos = np.array([0.0, 0.0, cfg.altitude_m])
vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps)
prev_pos = pos.copy()
t = time.time()
waypoints = list(cfg.waypoints_enu) # copy
for fid in range(cfg.num_frames):
# --- Waypoint steering ---
if waypoints:
wp_e, wp_n = waypoints[0]
to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0])
dist_wp = np.linalg.norm(to_wp[:2])
if dist_wp < cfg.speed_mps * dt_camera:
waypoints.pop(0)
else:
heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y
vel = np.array([
cfg.speed_mps * math.sin(heading_rad),
cfg.speed_mps * math.cos(heading_rad),
0.0,
])
# --- Simulate IMU between frames ---
imu_list: list[IMUMeasurement] = []
for step in range(imu_steps):
ts = t + step * dt_imu
# Body-frame acceleration (mostly gravity correction, small forward accel)
accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation
gyro_true = np.zeros(3)
imu = IMUMeasurement(
accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise,
gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise,
timestamp=ts,
)
imu_list.append(imu)
# --- Propagate position ---
prev_pos = pos.copy()
pos = pos + vel * dt_camera
t += dt_camera
# --- True GPS from ENU position ---
true_gps = self._coord.enu_to_gps(
self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))
)
# --- VO measurement (relative displacement + noise) ---
true_displacement = pos - prev_pos
vo_tracking_good = fid not in cfg.vo_failure_frames
if vo_tracking_good:
noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m
noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z
else:
noisy_displacement = None
frames.append(TrajectoryFrame(
frame_id=fid,
timestamp=t,
true_position_enu=pos.copy(),
true_gps=true_gps,
imu_measurements=imu_list,
vo_translation=noisy_displacement,
vo_tracking_good=vo_tracking_good,
))
return frames
@staticmethod
def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray:
"""Convert heading (degrees CW from North) to ENU velocity vector."""
rad = math.radians(heading_deg)
return np.array([
speed_mps * math.sin(rad), # East
speed_mps * math.cos(rad), # North
0.0, # Up
])
# ---------------------------------------------------------------------------
# Accuracy Benchmark
# ---------------------------------------------------------------------------
@dataclass
class BenchmarkResult:
"""Position error statistics over a trajectory replay."""
errors_m: list[float] # Per-frame horizontal error in metres
latencies_ms: list[float] # Per-frame process time in ms
frames_total: int
frames_with_good_estimate: int
@property
def p80_error_m(self) -> float:
"""80th percentile position error (metres)."""
return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf")
@property
def p60_error_m(self) -> float:
"""60th percentile position error (metres)."""
return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf")
@property
def median_error_m(self) -> float:
"""Median position error (metres)."""
return float(np.median(self.errors_m)) if self.errors_m else float("inf")
@property
def max_error_m(self) -> float:
return float(max(self.errors_m)) if self.errors_m else float("inf")
@property
def p95_latency_ms(self) -> float:
"""95th percentile frame latency (ms)."""
return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf")
@property
def pct_within_50m(self) -> float:
"""Fraction of frames within 50 m error."""
if not self.errors_m:
return 0.0
return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m)
@property
def pct_within_20m(self) -> float:
"""Fraction of frames within 20 m error."""
if not self.errors_m:
return 0.0
return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m)
def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]:
"""Check all solution.md acceptance criteria.
Returns (overall_pass, per_criterion_dict).
"""
checks = {
"AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80,
"AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60,
"AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0,
}
overall = all(checks.values())
return overall, checks
def summary(self) -> str:
overall, checks = self.passes_acceptance_criteria()
lines = [
f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}",
f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m "
f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m",
f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%",
f"Latency p95: {self.p95_latency_ms:.1f}ms",
"",
"Acceptance criteria:",
]
for criterion, passed in checks.items():
lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}")
lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}")
return "\n".join(lines)
class AccuracyBenchmark:
"""Replays a SyntheticTrajectory through the ESKF and measures accuracy.
The benchmark uses only the ESKF (no full FlightProcessor) for speed.
Satellite corrections are injected optionally via sat_correction_fn.
"""
def __init__(
self,
eskf_config: ESKFConfig | None = None,
sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None,
):
"""
Args:
eskf_config: ESKF tuning parameters.
sat_correction_fn: Optional callback(frame) → ENU position or None.
Called on keyframes to inject satellite corrections.
If None, no satellite corrections are applied.
"""
self.eskf_config = eskf_config or ESKFConfig()
self.sat_correction_fn = sat_correction_fn
def run(
self,
trajectory: list[TrajectoryFrame],
origin: GPSPoint,
satellite_keyframe_interval: int = 7,
) -> BenchmarkResult:
"""Replay trajectory frames through ESKF, collect errors and latencies.
Args:
trajectory: List of TrajectoryFrame (from SyntheticTrajectory).
origin: WGS84 reference origin for ENU.
satellite_keyframe_interval: Apply satellite correction every N frames.
"""
coord = CoordinateTransformer()
flight_id = "__benchmark__"
coord.set_enu_origin(flight_id, origin)
eskf = ESKF(self.eskf_config)
# Init at origin with HIGH uncertainty
eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]),
trajectory[0].timestamp)
errors_m: list[float] = []
latencies_ms: list[float] = []
frames_with_estimate = 0
for frame in trajectory:
t_frame_start = time.perf_counter()
# --- IMU prediction ---
for imu in frame.imu_measurements:
eskf.predict(imu)
# --- VO update ---
if frame.vo_tracking_good and frame.vo_translation is not None:
dt_vo = 1.0 / 0.7 # camera interval
eskf.update_vo(frame.vo_translation, dt_vo)
# --- Satellite update (keyframes) ---
if frame.frame_id % satellite_keyframe_interval == 0:
sat_pos_enu: Optional[np.ndarray] = None
if self.sat_correction_fn is not None:
sat_pos_enu = self.sat_correction_fn(frame)
else:
# Default: inject ground-truth position + realistic noise
noise_m = 10.0
sat_pos_enu = (
frame.true_position_enu[:3]
+ np.random.randn(3) * noise_m
)
sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude
if sat_pos_enu is not None:
# Tell ESKF the measurement noise matches what we inject
eskf.update_satellite(sat_pos_enu, noise_meters=noise_m)
latency_ms = (time.perf_counter() - t_frame_start) * 1000.0
latencies_ms.append(latency_ms)
# --- Compute horizontal error vs ground truth ---
if eskf.initialized and eskf._nominal_state is not None:
est_pos = eskf._nominal_state["position"]
true_pos = frame.true_position_enu
horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2]))
errors_m.append(horiz_error)
frames_with_estimate += 1
else:
errors_m.append(float("inf"))
return BenchmarkResult(
errors_m=errors_m,
latencies_ms=latencies_ms,
frames_total=len(trajectory),
frames_with_good_estimate=frames_with_estimate,
)
def run_vo_drift_test(
self,
trajectory_length_m: float = 1000.0,
speed_mps: float = 20.0,
) -> float:
"""Measure VO drift over a straight segment with NO satellite correction.
Returns final horizontal position error in metres.
Per solution.md, this should be < 100m over 1km.
"""
fps = 0.7
num_frames = max(10, int(trajectory_length_m / speed_mps * fps))
cfg = SyntheticTrajectoryConfig(
speed_mps=speed_mps,
heading_deg=0.0, # straight North
camera_fps=fps,
num_frames=num_frames,
vo_noise_m=0.3, # cuVSLAM-grade VO noise
)
traj_gen = SyntheticTrajectory(cfg)
frames = traj_gen.generate()
# No satellite corrections
benchmark_no_sat = AccuracyBenchmark(
eskf_config=self.eskf_config,
sat_correction_fn=lambda _: None, # suppress all satellite updates
)
result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999)
# Return final-frame error
return result.errors_m[-1] if result.errors_m else float("inf")
"""Legacy import path. Phase 1 shim — code lives in testing/benchmark.py."""
from gps_denied.testing.benchmark import ( # noqa: F401
AccuracyBenchmark,
BenchmarkResult,
SyntheticTrajectory,
SyntheticTrajectoryConfig,
TrajectoryFrame,
)
+9 -15
View File
@@ -2,8 +2,7 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Protocol, runtime_checkable
from gps_denied.core.graph import IFactorGraphOptimizer
from gps_denied.schemas.chunk import ChunkHandle, ChunkStatus
@@ -12,30 +11,25 @@ from gps_denied.schemas.metric import Sim3Transform
logger = logging.getLogger(__name__)
class IRouteChunkManager(ABC):
@abstractmethod
@runtime_checkable
class IRouteChunkManager(Protocol):
def create_new_chunk(self, flight_id: str, start_frame_id: int) -> ChunkHandle:
pass
...
@abstractmethod
def get_active_chunk(self, flight_id: str) -> Optional[ChunkHandle]:
pass
...
@abstractmethod
def get_all_chunks(self, flight_id: str) -> List[ChunkHandle]:
pass
...
@abstractmethod
def add_frame_to_chunk(self, flight_id: str, frame_id: int) -> bool:
pass
...
@abstractmethod
def update_chunk_status(self, flight_id: str, chunk_id: str, status: ChunkStatus) -> bool:
pass
...
@abstractmethod
def merge_chunks(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
pass
...
class RouteChunkManager(IRouteChunkManager):
+350
View File
@@ -0,0 +1,350 @@
"""Factor Graph Optimizer (Component F10)."""
import logging
from datetime import datetime, timezone
from typing import Dict, Protocol, runtime_checkable
import numpy as np
try:
import gtsam
HAS_GTSAM = True
except ImportError:
HAS_GTSAM = False
from gps_denied.schemas import GPSPoint
from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose
from gps_denied.schemas.metric import Sim3Transform
from gps_denied.schemas.vo import RelativePose
logger = logging.getLogger(__name__)
@runtime_checkable
class IFactorGraphOptimizer(Protocol):
"""GTSAM-based factor graph optimizer."""
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
...
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
...
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
...
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
...
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
...
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
...
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
...
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
...
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
...
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
...
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
...
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
...
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
...
def delete_flight_graph(self, flight_id: str) -> bool:
...
class FactorGraphOptimizer(IFactorGraphOptimizer):
"""Implementation of F10 Factor Graph using GTSAM or Mock."""
def __init__(self, config: FactorGraphConfig):
self.config = config
# Keyed by flight_id
self._flights_state: Dict[str, dict] = {}
# Keyed by chunk_id
self._chunks_state: Dict[str, dict] = {}
# Per-flight ENU origin (set from first absolute GPS factor)
self._enu_origins: Dict[str, GPSPoint] = {}
def _init_flight(self, flight_id: str):
if flight_id not in self._flights_state:
self._flights_state[flight_id] = {
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
"initial": gtsam.Values() if HAS_GTSAM else None,
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
"poses": {},
"dirty": False
}
def _init_chunk(self, chunk_id: str):
if chunk_id not in self._chunks_state:
self._chunks_state[chunk_id] = {
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
"initial": gtsam.Values() if HAS_GTSAM else None,
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
"poses": {},
"dirty": False
}
# ================== MOCK IMPLEMENTATION ====================
# As GTSAM Python bindings can be extremely context-dependent and
# require proper ENU translation logic, we use an advanced Mock
# that satisfies the architectural design and typing for the backend.
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
# --- Mock: propagate position chain ---
if frame_i in state["poses"]:
prev_pose = state["poses"][frame_i]
new_pos = prev_pose.position + relative_pose.translation
state["poses"][frame_j] = Pose(
frame_id=frame_j,
position=new_pos,
orientation=np.eye(3),
timestamp=datetime.now(timezone.utc),
covariance=np.eye(6),
)
state["dirty"] = True
else:
return False
# --- GTSAM: add BetweenFactorPose3 ---
if HAS_GTSAM and state["graph"] is not None:
try:
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
key_i = gtsam.symbol("x", frame_i)
key_j = gtsam.symbol("x", frame_j)
t = relative_pose.translation
between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2])))
state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise))
if not state["initial"].exists(key_j):
if state["initial"].exists(key_i):
prev = state["initial"].atPose3(key_i)
pt = prev.translation()
new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2])
else:
new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))
state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t))
except Exception as exc:
logger.debug("GTSAM add_relative_factor failed: %s", exc)
return True
def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray:
"""Convert GPS to local ENU using per-flight origin."""
origin = self._enu_origins.get(flight_id)
if origin is None:
# First GPS factor sets the origin
self._enu_origins[flight_id] = gps
return np.zeros(3)
enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat))
enu_y = (gps.lat - origin.lat) * 111000
return np.array([enu_x, enu_y, 0.0])
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
enu = self._gps_to_enu(flight_id, gps)
# --- Mock: update pose position ---
if frame_id in state["poses"]:
state["poses"][frame_id].position = enu
state["dirty"] = True
else:
return False
# --- GTSAM: add PriorFactorPose3 ---
if HAS_GTSAM and state["graph"] is not None:
try:
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
key = gtsam.symbol("x", frame_id)
prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2])))
state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise))
if not state["initial"].exists(key):
state["initial"].insert(key, prior)
except Exception as exc:
logger.debug("GTSAM add_absolute_factor failed: %s", exc)
return True
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
if frame_id in state["poses"]:
state["poses"][frame_id].position = np.array([
state["poses"][frame_id].position[0],
state["poses"][frame_id].position[1],
altitude,
])
state["dirty"] = True
return True
return False
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
# --- PIPE-03: Real GTSAM ISAM2 update when available ---
if HAS_GTSAM and state["dirty"] and state["graph"] is not None:
try:
state["isam"].update(state["graph"], state["initial"])
estimate = state["isam"].calculateEstimate()
for fid in list(state["poses"].keys()):
key = gtsam.symbol("x", fid)
if estimate.exists(key):
pose = estimate.atPose3(key)
t = pose.translation()
state["poses"][fid].position = np.array([t[0], t[1], t[2]])
state["poses"][fid].orientation = np.array(pose.rotation().matrix())
# Reset for next incremental batch
state["graph"] = gtsam.NonlinearFactorGraph()
state["initial"] = gtsam.Values()
except Exception as exc:
logger.warning("GTSAM ISAM2 update failed: %s", exc)
state["dirty"] = False
return OptimizationResult(
converged=True,
final_error=0.1,
iterations_used=iterations,
optimized_frames=list(state["poses"].keys()),
mean_reprojection_error=0.5,
)
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
if flight_id not in self._flights_state:
return {}
return self._flights_state[flight_id]["poses"]
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
return np.eye(6)
# ================== CHUNK OPERATIONS =======================
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
self._init_chunk(chunk_id)
state = self._chunks_state[chunk_id]
state["poses"][start_frame_id] = Pose(
frame_id=start_frame_id,
position=np.zeros(3),
orientation=np.eye(3),
timestamp=datetime.now(timezone.utc),
covariance=np.eye(6)
)
return True
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
if chunk_id not in self._chunks_state:
return False
state = self._chunks_state[chunk_id]
if frame_i in state["poses"]:
prev_pose = state["poses"][frame_i]
new_pos = prev_pose.position + relative_pose.translation
state["poses"][frame_j] = Pose(
frame_id=frame_j,
position=new_pos,
orientation=np.eye(3),
timestamp=datetime.now(timezone.utc),
covariance=np.eye(6)
)
state["dirty"] = True
return True
return False
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
if chunk_id not in self._chunks_state:
return False
state = self._chunks_state[chunk_id]
if frame_id in state["poses"]:
enu = self._gps_to_enu(flight_id, gps)
state["poses"][frame_id].position = enu
state["dirty"] = True
return True
return False
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state:
return False
new_state = self._chunks_state[new_chunk_id]
main_state = self._chunks_state[main_chunk_id]
# Apply Sim(3) transform effectively by copying poses
for f_id, p in new_state["poses"].items():
# mock sim3 transform
idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation
main_state["poses"][f_id] = Pose(
frame_id=f_id,
position=idx_pos,
orientation=np.eye(3),
timestamp=p.timestamp,
covariance=p.covariance
)
return True
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
if chunk_id not in self._chunks_state:
return {}
return self._chunks_state[chunk_id]["poses"]
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
if chunk_id not in self._chunks_state:
return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0)
state = self._chunks_state[chunk_id]
state["dirty"] = False
return OptimizationResult(
converged=True,
final_error=0.1,
iterations_used=iterations,
optimized_frames=list(state["poses"].keys()),
mean_reprojection_error=0.5
)
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
# Optimizes everything
self._init_flight(flight_id)
state = self._flights_state[flight_id]
state["dirty"] = False
return OptimizationResult(
converged=True,
final_error=0.1,
iterations_used=iterations,
optimized_frames=list(state["poses"].keys()),
mean_reprojection_error=0.5
)
def delete_flight_graph(self, flight_id: str) -> bool:
removed = False
if flight_id in self._flights_state:
del self._flights_state[flight_id]
removed = True
self._enu_origins.pop(flight_id, None)
return removed
+5 -364
View File
@@ -1,364 +1,5 @@
"""Factor Graph Optimizer (Component F10)."""
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Dict
import numpy as np
try:
import gtsam
HAS_GTSAM = True
except ImportError:
HAS_GTSAM = False
from gps_denied.schemas import GPSPoint
from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose
from gps_denied.schemas.metric import Sim3Transform
from gps_denied.schemas.vo import RelativePose
logger = logging.getLogger(__name__)
class IFactorGraphOptimizer(ABC):
"""GTSAM-based factor graph optimizer."""
@abstractmethod
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
pass
@abstractmethod
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
pass
@abstractmethod
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
pass
@abstractmethod
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
pass
@abstractmethod
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
pass
@abstractmethod
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
pass
@abstractmethod
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
pass
@abstractmethod
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
pass
@abstractmethod
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
pass
@abstractmethod
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
pass
@abstractmethod
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
pass
@abstractmethod
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
pass
@abstractmethod
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
pass
@abstractmethod
def delete_flight_graph(self, flight_id: str) -> bool:
pass
class FactorGraphOptimizer(IFactorGraphOptimizer):
"""Implementation of F10 Factor Graph using GTSAM or Mock."""
def __init__(self, config: FactorGraphConfig):
self.config = config
# Keyed by flight_id
self._flights_state: Dict[str, dict] = {}
# Keyed by chunk_id
self._chunks_state: Dict[str, dict] = {}
# Per-flight ENU origin (set from first absolute GPS factor)
self._enu_origins: Dict[str, GPSPoint] = {}
def _init_flight(self, flight_id: str):
if flight_id not in self._flights_state:
self._flights_state[flight_id] = {
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
"initial": gtsam.Values() if HAS_GTSAM else None,
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
"poses": {},
"dirty": False
}
def _init_chunk(self, chunk_id: str):
if chunk_id not in self._chunks_state:
self._chunks_state[chunk_id] = {
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
"initial": gtsam.Values() if HAS_GTSAM else None,
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
"poses": {},
"dirty": False
}
# ================== MOCK IMPLEMENTATION ====================
# As GTSAM Python bindings can be extremely context-dependent and
# require proper ENU translation logic, we use an advanced Mock
# that satisfies the architectural design and typing for the backend.
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
# --- Mock: propagate position chain ---
if frame_i in state["poses"]:
prev_pose = state["poses"][frame_i]
new_pos = prev_pose.position + relative_pose.translation
state["poses"][frame_j] = Pose(
frame_id=frame_j,
position=new_pos,
orientation=np.eye(3),
timestamp=datetime.now(timezone.utc),
covariance=np.eye(6),
)
state["dirty"] = True
else:
return False
# --- GTSAM: add BetweenFactorPose3 ---
if HAS_GTSAM and state["graph"] is not None:
try:
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
key_i = gtsam.symbol("x", frame_i)
key_j = gtsam.symbol("x", frame_j)
t = relative_pose.translation
between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2])))
state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise))
if not state["initial"].exists(key_j):
if state["initial"].exists(key_i):
prev = state["initial"].atPose3(key_i)
pt = prev.translation()
new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2])
else:
new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))
state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t))
except Exception as exc:
logger.debug("GTSAM add_relative_factor failed: %s", exc)
return True
def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray:
"""Convert GPS to local ENU using per-flight origin."""
origin = self._enu_origins.get(flight_id)
if origin is None:
# First GPS factor sets the origin
self._enu_origins[flight_id] = gps
return np.zeros(3)
enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat))
enu_y = (gps.lat - origin.lat) * 111000
return np.array([enu_x, enu_y, 0.0])
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
enu = self._gps_to_enu(flight_id, gps)
# --- Mock: update pose position ---
if frame_id in state["poses"]:
state["poses"][frame_id].position = enu
state["dirty"] = True
else:
return False
# --- GTSAM: add PriorFactorPose3 ---
if HAS_GTSAM and state["graph"] is not None:
try:
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
key = gtsam.symbol("x", frame_id)
prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2])))
state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise))
if not state["initial"].exists(key):
state["initial"].insert(key, prior)
except Exception as exc:
logger.debug("GTSAM add_absolute_factor failed: %s", exc)
return True
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
if frame_id in state["poses"]:
state["poses"][frame_id].position = np.array([
state["poses"][frame_id].position[0],
state["poses"][frame_id].position[1],
altitude,
])
state["dirty"] = True
return True
return False
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
self._init_flight(flight_id)
state = self._flights_state[flight_id]
# --- PIPE-03: Real GTSAM ISAM2 update when available ---
if HAS_GTSAM and state["dirty"] and state["graph"] is not None:
try:
state["isam"].update(state["graph"], state["initial"])
estimate = state["isam"].calculateEstimate()
for fid in list(state["poses"].keys()):
key = gtsam.symbol("x", fid)
if estimate.exists(key):
pose = estimate.atPose3(key)
t = pose.translation()
state["poses"][fid].position = np.array([t[0], t[1], t[2]])
state["poses"][fid].orientation = np.array(pose.rotation().matrix())
# Reset for next incremental batch
state["graph"] = gtsam.NonlinearFactorGraph()
state["initial"] = gtsam.Values()
except Exception as exc:
logger.warning("GTSAM ISAM2 update failed: %s", exc)
state["dirty"] = False
return OptimizationResult(
converged=True,
final_error=0.1,
iterations_used=iterations,
optimized_frames=list(state["poses"].keys()),
mean_reprojection_error=0.5,
)
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
if flight_id not in self._flights_state:
return {}
return self._flights_state[flight_id]["poses"]
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
return np.eye(6)
# ================== CHUNK OPERATIONS =======================
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
self._init_chunk(chunk_id)
state = self._chunks_state[chunk_id]
state["poses"][start_frame_id] = Pose(
frame_id=start_frame_id,
position=np.zeros(3),
orientation=np.eye(3),
timestamp=datetime.now(timezone.utc),
covariance=np.eye(6)
)
return True
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
if chunk_id not in self._chunks_state:
return False
state = self._chunks_state[chunk_id]
if frame_i in state["poses"]:
prev_pose = state["poses"][frame_i]
new_pos = prev_pose.position + relative_pose.translation
state["poses"][frame_j] = Pose(
frame_id=frame_j,
position=new_pos,
orientation=np.eye(3),
timestamp=datetime.now(timezone.utc),
covariance=np.eye(6)
)
state["dirty"] = True
return True
return False
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
if chunk_id not in self._chunks_state:
return False
state = self._chunks_state[chunk_id]
if frame_id in state["poses"]:
enu = self._gps_to_enu(flight_id, gps)
state["poses"][frame_id].position = enu
state["dirty"] = True
return True
return False
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state:
return False
new_state = self._chunks_state[new_chunk_id]
main_state = self._chunks_state[main_chunk_id]
# Apply Sim(3) transform effectively by copying poses
for f_id, p in new_state["poses"].items():
# mock sim3 transform
idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation
main_state["poses"][f_id] = Pose(
frame_id=f_id,
position=idx_pos,
orientation=np.eye(3),
timestamp=p.timestamp,
covariance=p.covariance
)
return True
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
if chunk_id not in self._chunks_state:
return {}
return self._chunks_state[chunk_id]["poses"]
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
if chunk_id not in self._chunks_state:
return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0)
state = self._chunks_state[chunk_id]
state["dirty"] = False
return OptimizationResult(
converged=True,
final_error=0.1,
iterations_used=iterations,
optimized_frames=list(state["poses"].keys()),
mean_reprojection_error=0.5
)
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
# Optimizes everything
self._init_flight(flight_id)
state = self._flights_state[flight_id]
state["dirty"] = False
return OptimizationResult(
converged=True,
final_error=0.1,
iterations_used=iterations,
optimized_frames=list(state["poses"].keys()),
mean_reprojection_error=0.5
)
def delete_flight_graph(self, flight_id: str) -> bool:
removed = False
if flight_id in self._flights_state:
del self._flights_state[flight_id]
removed = True
self._enu_origins.pop(flight_id, None)
return removed
"""Legacy import path. Phase 1 shim — code lives in core/factor_graph.py."""
from gps_denied.core.factor_graph import ( # noqa: F401
IFactorGraphOptimizer,
FactorGraphOptimizer,
)
+8 -13
View File
@@ -10,8 +10,7 @@ file is available, otherwise falls back to Mock.
import logging
import os
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Protocol, runtime_checkable
import numpy as np
@@ -31,26 +30,22 @@ def _is_jetson() -> bool:
return os.path.exists("/sys/bus/platform/drivers/tegra-se-nvhost")
class IModelManager(ABC):
@abstractmethod
@runtime_checkable
class IModelManager(Protocol):
def load_model(self, model_name: str, model_format: str) -> bool:
pass
...
@abstractmethod
def get_inference_engine(self, model_name: str) -> InferenceEngine:
pass
...
@abstractmethod
def optimize_to_tensorrt(self, model_name: str, onnx_path: str) -> str:
pass
...
@abstractmethod
def fallback_to_onnx(self, model_name: str) -> bool:
pass
...
@abstractmethod
def warmup_model(self, model_name: str) -> bool:
pass
...
class MockInferenceEngine(InferenceEngine):
+5 -226
View File
@@ -1,227 +1,6 @@
"""Image Input Pipeline (Component F05)."""
import asyncio
import os
import re
from datetime import datetime, timezone
import cv2
import numpy as np
from gps_denied.schemas.image import (
ImageBatch,
ImageData,
ImageMetadata,
ProcessedBatch,
ProcessingStatus,
ValidationResult,
"""Legacy import path. Phase 1 shim — code lives in pipeline/image_input.py."""
from gps_denied.pipeline.image_input import ( # noqa: F401
ImageInputPipeline,
QueueFullError,
ValidationError,
)
class QueueFullError(Exception):
pass
class ValidationError(Exception):
pass
class ImageInputPipeline:
"""Manages ingestion, disk storage, and queuing of UAV image batches."""
def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50):
self.storage_dir = storage_dir
# flight_id -> asyncio.Queue of ImageBatch
self._queues: dict[str, asyncio.Queue] = {}
self.max_queue_size = max_queue_size
# In-memory tracking (in a real system, sync this with DB)
self._status: dict[str, dict] = {}
# Exact sequence → filename mapping (VO-05: no substring collision)
self._sequence_map: dict[str, dict[int, str]] = {}
def _get_queue(self, flight_id: str) -> asyncio.Queue:
if flight_id not in self._queues:
self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size)
return self._queues[flight_id]
def _init_status(self, flight_id: str):
if flight_id not in self._status:
self._status[flight_id] = {
"total_images": 0,
"processed_images": 0,
"current_sequence": 1,
}
def validate_batch(self, batch: ImageBatch) -> ValidationResult:
"""Validates batch integrity and sequence continuity."""
errors = []
num_images = len(batch.images)
if num_images < 1:
errors.append("Batch is empty")
elif num_images > 100:
errors.append("Batch too large")
if len(batch.filenames) != num_images:
errors.append("Mismatch between filenames and images count")
# Naming convention ADxxxxxx.jpg or similar
pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE)
for fn in batch.filenames:
if not pattern.match(fn):
errors.append(f"Invalid filename: {fn}")
break
if batch.start_sequence > batch.end_sequence:
errors.append("Start sequence greater than end sequence")
return ValidationResult(valid=len(errors) == 0, errors=errors)
def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool:
"""Queues a batch of images for processing."""
val = self.validate_batch(batch)
if not val.valid:
raise ValidationError(f"Batch validation failed: {val.errors}")
q = self._get_queue(flight_id)
if q.full():
raise QueueFullError(f"Queue for flight {flight_id} is full")
q.put_nowait(batch)
self._init_status(flight_id)
self._status[flight_id]["total_images"] += len(batch.images)
return True
async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None:
"""Dequeues and processing the next batch."""
q = self._get_queue(flight_id)
if q.empty():
return None
batch: ImageBatch = await q.get()
processed_images = []
for i, raw_bytes in enumerate(batch.images):
# Decode
nparr = np.frombuffer(raw_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
continue # skip corrupted
seq = batch.start_sequence + i
fn = batch.filenames[i]
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=seq,
filename=fn,
dimensions=(w, h),
file_size=len(raw_bytes),
timestamp=datetime.now(timezone.utc),
)
img_data = ImageData(
flight_id=flight_id,
sequence=seq,
filename=fn,
image=img,
metadata=meta
)
processed_images.append(img_data)
# VO-05: record exact sequence→filename mapping
self._sequence_map.setdefault(flight_id, {})[seq] = fn
# Store to disk
self.store_images(flight_id, processed_images)
self._status[flight_id]["processed_images"] += len(processed_images)
q.task_done()
return ProcessedBatch(
images=processed_images,
batch_id=f"batch_{batch.batch_number}",
start_sequence=batch.start_sequence,
end_sequence=batch.end_sequence
)
def store_images(self, flight_id: str, images: list[ImageData]) -> bool:
"""Persists images to disk."""
flight_dir = os.path.join(self.storage_dir, flight_id)
os.makedirs(flight_dir, exist_ok=True)
for img in images:
path = os.path.join(flight_dir, img.filename)
cv2.imwrite(path, img.image)
return True
def get_next_image(self, flight_id: str) -> ImageData | None:
"""Gets the next image in sequence for processing."""
self._init_status(flight_id)
seq = self._status[flight_id]["current_sequence"]
img = self.get_image_by_sequence(flight_id, seq)
if img:
self._status[flight_id]["current_sequence"] += 1
return img
def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None:
"""Retrieves a specific image by sequence number (exact match — VO-05)."""
flight_dir = os.path.join(self.storage_dir, flight_id)
if not os.path.exists(flight_dir):
return None
# Prefer the exact mapping built during process_next_batch
fn = self._sequence_map.get(flight_id, {}).get(sequence)
if fn:
path = os.path.join(flight_dir, fn)
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=sequence,
filename=fn,
dimensions=(w, h),
file_size=os.path.getsize(path),
timestamp=datetime.now(timezone.utc),
)
return ImageData(flight_id, sequence, fn, img, meta)
# Fallback: scan directory for exact filename patterns
# (handles images stored before this process started)
for fn in os.listdir(flight_dir):
base, _ = os.path.splitext(fn)
# Accept only if the base name ends with exactly the padded sequence number
if base.endswith(f"{sequence:06d}") or base == str(sequence):
path = os.path.join(flight_dir, fn)
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=sequence,
filename=fn,
dimensions=(w, h),
file_size=os.path.getsize(path),
timestamp=datetime.now(timezone.utc),
)
return ImageData(flight_id, sequence, fn, img, meta)
return None
def get_processing_status(self, flight_id: str) -> ProcessingStatus:
self._init_status(flight_id)
s = self._status[flight_id]
q = self._get_queue(flight_id)
return ProcessingStatus(
flight_id=flight_id,
total_images=s["total_images"],
processed_images=s["processed_images"],
current_sequence=s["current_sequence"],
queued_batches=q.qsize(),
processing_rate=0.0 # mock
)
+5 -598
View File
@@ -1,599 +1,6 @@
"""Core Flight Processor — Full Processing Pipeline (Stage 10).
Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE.
State Machine: NORMAL → LOST → RECOVERY → NORMAL.
"""
from __future__ import annotations
import asyncio
import logging
import time
from enum import Enum
from typing import Optional
import numpy as np
from gps_denied.core.eskf import ESKF
from gps_denied.core.pipeline import ImageInputPipeline
from gps_denied.core.results import ResultManager
from gps_denied.core.sse import SSEEventStreamer
from gps_denied.db.repository import FlightRepository
from gps_denied.schemas import CameraParameters, GPSPoint
from gps_denied.schemas.flight import (
BatchMetadata,
BatchResponse,
BatchUpdateResponse,
DeleteResponse,
FlightCreateRequest,
FlightDetailResponse,
FlightResponse,
FlightStatusResponse,
ObjectGPSResponse,
UpdateResponse,
UserFixRequest,
UserFixResponse,
Waypoint,
"""Legacy import path. Phase 1 shim — code lives in pipeline/orchestrator.py."""
from gps_denied.pipeline.orchestrator import ( # noqa: F401
FlightProcessor,
TrackingState,
FrameResult,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# State Machine
# ---------------------------------------------------------------------------
class TrackingState(str, Enum):
"""Processing state for a flight."""
NORMAL = "normal"
LOST = "lost"
RECOVERY = "recovery"
class FrameResult:
"""Intermediate result of processing a single frame."""
def __init__(self, frame_id: int):
self.frame_id = frame_id
self.gps: Optional[GPSPoint] = None
self.confidence: float = 0.0
self.tracking_state: TrackingState = TrackingState.NORMAL
self.vo_success: bool = False
self.alignment_success: bool = False
# ---------------------------------------------------------------------------
# FlightProcessor
# ---------------------------------------------------------------------------
class FlightProcessor:
"""Manages business logic, background processing, and frame orchestration."""
def __init__(
self,
repository: FlightRepository,
streamer: SSEEventStreamer,
eskf_config=None,
) -> None:
self.repository = repository
self.streamer = streamer
self.result_manager = ResultManager(repository, streamer)
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
self._eskf_config = eskf_config # ESKFConfig or None → default
# Per-flight processing state
self._flight_states: dict[str, TrackingState] = {}
self._prev_images: dict[str, np.ndarray] = {} # previous frame cache
self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera
self._altitudes: dict[str, float] = {} # per-flight altitude (m)
self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter
# Per-flight ESKF instances (PIPE-01/07)
self._eskf: dict[str, ESKF] = {}
# Lazy-initialised component references (set via `attach_components`)
self._vo = None # ISequentialVisualOdometry
self._gpr = None # IGlobalPlaceRecognition
self._metric = None # IMetricRefinement
self._graph = None # IFactorGraphOptimizer
self._recovery = None # IFailureRecoveryCoordinator
self._chunk_mgr = None # IRouteChunkManager
self._rotation = None # ImageRotationManager
self._satellite = None # SatelliteDataManager (PIPE-02)
self._coord = None # CoordinateTransformer (PIPE-02/06)
self._mavlink = None # MAVLinkBridge (PIPE-07)
# ------ Dependency injection for core components ---------
def attach_components(
self,
vo=None,
gpr=None,
metric=None,
graph=None,
recovery=None,
chunk_mgr=None,
rotation=None,
satellite=None,
coord=None,
mavlink=None,
):
"""Attach pipeline components after construction (avoids circular deps)."""
self._vo = vo
self._gpr = gpr
self._metric = metric
self._graph = graph
self._recovery = recovery
self._chunk_mgr = chunk_mgr
self._rotation = rotation
self._satellite = satellite # PIPE-02: SatelliteDataManager
self._coord = coord # PIPE-02/06: CoordinateTransformer
self._mavlink = mavlink # PIPE-07: MAVLinkBridge
# ------ ESKF lifecycle helpers ----------------------------
def _init_eskf_for_flight(
self, flight_id: str, start_gps: GPSPoint, altitude: float
) -> None:
"""Create and initialize a per-flight ESKF instance."""
if flight_id in self._eskf:
return
eskf = ESKF(config=self._eskf_config)
if self._coord:
try:
e, n, _ = self._coord.gps_to_enu(flight_id, start_gps)
eskf.initialize(np.array([e, n, altitude]), time.time())
except Exception:
eskf.initialize(np.zeros(3), time.time())
else:
eskf.initialize(np.zeros(3), time.time())
self._eskf[flight_id] = eskf
def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]:
"""Convert current ESKF ENU position to WGS84 GPS."""
if not eskf.initialized or self._coord is None:
return None
try:
pos = eskf.position
return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2])))
except Exception:
return None
# =========================================================
# process_frame — central orchestration
# =========================================================
async def process_frame(
self,
flight_id: str,
frame_id: int,
image: np.ndarray,
) -> FrameResult:
"""
Process a single UAV frame through the full pipeline.
State transitions:
NORMAL — VO succeeds → ESKF VO update, attempt satellite fix
LOST — VO failed → create new chunk, enter RECOVERY
RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL
PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT
PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection
PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator
PIPE-05: ImageRotationManager initialised on first frame
PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state
"""
result = FrameResult(frame_id)
state = self._flight_states.get(flight_id, TrackingState.NORMAL)
eskf = self._eskf.get(flight_id)
_default_cam = CameraParameters(
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
resolution_width=640, resolution_height=480,
)
# ---- PIPE-05: Initialise heading tracking on first frame ----
if self._rotation and frame_id == 0:
self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory
# ---- 1. Visual Odometry (frame-to-frame) ----
vo_ok = False
if self._vo and flight_id in self._prev_images:
try:
cam = self._flight_cameras.get(flight_id, _default_cam)
rel_pose = self._vo.compute_relative_pose(
self._prev_images[flight_id], image, cam
)
if rel_pose and rel_pose.tracking_good:
vo_ok = True
result.vo_success = True
if self._graph:
self._graph.add_relative_factor(
flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6)
)
# PIPE-01: Feed VO relative displacement into ESKF
if eskf and eskf.initialized:
now = time.time()
dt_vo = max(0.01, now - (eskf.last_timestamp or now))
eskf.update_vo(rel_pose.translation, dt_vo)
except Exception as exc:
logger.warning("VO failed for frame %d: %s", frame_id, exc)
# Store current image for next frame
self._prev_images[flight_id] = image
# ---- PIPE-04: Consecutive failure counter ----
if not vo_ok and frame_id > 0:
self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1
else:
self._failure_counts[flight_id] = 0
# ---- 2. State Machine transitions ----
if state == TrackingState.NORMAL:
if not vo_ok and frame_id > 0:
state = TrackingState.LOST
logger.info("Flight %s → LOST at frame %d", flight_id, frame_id)
if self._recovery:
self._recovery.handle_tracking_lost(flight_id, frame_id)
if state == TrackingState.LOST:
state = TrackingState.RECOVERY
if state == TrackingState.RECOVERY:
recovered = False
if self._recovery and self._chunk_mgr:
active_chunk = self._chunk_mgr.get_active_chunk(flight_id)
if active_chunk:
recovered = self._recovery.process_chunk_recovery(
flight_id, active_chunk.chunk_id, [image]
)
if recovered:
state = TrackingState.NORMAL
result.alignment_success = True
# PIPE-04: Reset failure count on successful recovery
self._failure_counts[flight_id] = 0
logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id)
# ---- 3. Satellite position fix (PIPE-01/02) ----
if state == TrackingState.NORMAL and self._metric:
sat_tile: Optional[np.ndarray] = None
tile_bounds = None
# PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection)
if self._satellite and eskf and eskf.initialized:
gps_est = self._eskf_to_gps(flight_id, eskf)
if gps_est:
cov = eskf.covariance
sigma_h = float(
np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0)
) if cov is not None else 30.0
sigma_h = max(sigma_h, 5.0)
try:
tile_result = await asyncio.get_event_loop().run_in_executor(
None,
self._satellite.fetch_tiles_for_position,
gps_est, sigma_h, 18,
)
if tile_result:
sat_tile, tile_bounds = tile_result
except Exception as exc:
logger.debug("Satellite tile fetch failed: %s", exc)
# Fallback: GPR candidate tile (mock image, real bounds)
if sat_tile is None and self._gpr:
try:
candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1)
if candidates:
sat_tile = np.zeros((256, 256, 3), dtype=np.uint8)
tile_bounds = candidates[0].bounds
except Exception as exc:
logger.debug("GPR tile fallback failed: %s", exc)
if sat_tile is not None and tile_bounds is not None:
try:
align = self._metric.align_to_satellite(image, sat_tile, tile_bounds)
if align and align.matched:
result.gps = align.gps_center
result.confidence = align.confidence
result.alignment_success = True
if self._graph:
self._graph.add_absolute_factor(
flight_id, frame_id,
align.gps_center, np.eye(6),
is_user_anchor=False,
)
# PIPE-01: ESKF satellite update — noise from RANSAC confidence
if eskf and eskf.initialized and self._coord:
try:
e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center)
alt = self._altitudes.get(flight_id, 100.0)
pos_enu = np.array([e, n, alt])
noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence))
eskf.update_satellite(pos_enu, noise_m)
except Exception as exc:
logger.debug("ESKF satellite update failed: %s", exc)
except Exception as exc:
logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc)
# ---- 4. Graph optimization (incremental) ----
if self._graph:
opt_result = self._graph.optimize(flight_id, iterations=5)
logger.debug(
"Optimization: converged=%s, error=%.4f",
opt_result.converged, opt_result.final_error,
)
# ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ----
if self._mavlink and eskf and eskf.initialized:
try:
eskf_state = eskf.get_state()
alt = self._altitudes.get(flight_id, 100.0)
self._mavlink.update_state(eskf_state, altitude_m=alt)
except Exception as exc:
logger.debug("MAVLink state push failed: %s", exc)
# ---- 5. Publish via SSE ----
result.tracking_state = state
self._flight_states[flight_id] = state
await self._publish_frame_result(flight_id, result)
return result
async def _publish_frame_result(self, flight_id: str, result: FrameResult):
"""Emit SSE event for processed frame."""
event_data = {
"frame_id": result.frame_id,
"tracking_state": result.tracking_state.value,
"vo_success": result.vo_success,
"alignment_success": result.alignment_success,
"confidence": result.confidence,
}
if result.gps:
event_data["lat"] = result.gps.lat
event_data["lon"] = result.gps.lon
await self.streamer.push_event(
flight_id, event_type="frame_result", data=event_data
)
# =========================================================
# Existing CRUD / REST helpers (unchanged from Stage 3-4)
# =========================================================
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
flight = await self.repository.insert_flight(
name=req.name,
description=req.description,
start_lat=req.start_gps.lat,
start_lon=req.start_gps.lon,
altitude=req.altitude,
camera_params=req.camera_params.model_dump(),
)
# P0#2: Store camera params for process_frame VO calls
self._flight_cameras[flight.id] = req.camera_params
for poly in req.geofences.polygons:
await self.repository.insert_geofence(
flight.id,
nw_lat=poly.north_west.lat,
nw_lon=poly.north_west.lon,
se_lat=poly.south_east.lat,
se_lon=poly.south_east.lon,
)
for w in req.rough_waypoints:
await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon)
# Store per-flight altitude for ESKF/pixel projection
self._altitudes[flight.id] = req.altitude or 100.0
# PIPE-02: Set ENU origin and initialise ESKF for this flight
if self._coord:
self._coord.set_enu_origin(flight.id, req.start_gps)
self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0)
# Start MAVLink bridge for this flight (origin required for GPS_INPUT)
if self._mavlink and not self._mavlink._running:
try:
asyncio.create_task(self._mavlink.start(req.start_gps))
except Exception as exc:
logger.warning("MAVLink bridge start failed: %s", exc)
return FlightResponse(
flight_id=flight.id,
status="prefetching",
message="Flight created and prefetching started.",
created_at=flight.created_at,
)
async def get_flight(self, flight_id: str) -> FlightDetailResponse | None:
flight = await self.repository.get_flight(flight_id)
if not flight:
return None
wps = await self.repository.get_waypoints(flight_id)
state = await self.repository.load_flight_state(flight_id)
waypoints = [
Waypoint(
id=w.id,
lat=w.lat,
lon=w.lon,
altitude=w.altitude,
confidence=w.confidence,
timestamp=w.timestamp,
refined=w.refined,
)
for w in wps
]
status = state.status if state else "unknown"
frames_processed = state.frames_processed if state else 0
frames_total = state.frames_total if state else 0
from gps_denied.schemas import Geofences
return FlightDetailResponse(
flight_id=flight.id,
name=flight.name,
description=flight.description,
start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon),
waypoints=waypoints,
geofences=Geofences(polygons=[]),
camera_params=flight.camera_params,
altitude=flight.altitude,
status=status,
frames_processed=frames_processed,
frames_total=frames_total,
created_at=flight.created_at,
updated_at=flight.updated_at,
)
async def delete_flight(self, flight_id: str) -> DeleteResponse:
deleted = await self.repository.delete_flight(flight_id)
# P0#1: Cleanup in-memory state to prevent memory leaks
self._cleanup_flight(flight_id)
return DeleteResponse(deleted=deleted, flight_id=flight_id)
def _cleanup_flight(self, flight_id: str) -> None:
"""Remove all in-memory state for a flight (prevents memory leaks)."""
self._prev_images.pop(flight_id, None)
self._flight_states.pop(flight_id, None)
self._flight_cameras.pop(flight_id, None)
self._altitudes.pop(flight_id, None)
self._failure_counts.pop(flight_id, None)
self._eskf.pop(flight_id, None)
if self._graph:
self._graph.delete_flight_graph(flight_id)
async def update_waypoint(
self, flight_id: str, waypoint_id: str, waypoint: Waypoint
) -> UpdateResponse:
ok = await self.repository.update_waypoint(
flight_id,
waypoint_id,
lat=waypoint.lat,
lon=waypoint.lon,
altitude=waypoint.altitude,
confidence=waypoint.confidence,
refined=waypoint.refined,
)
return UpdateResponse(updated=ok, waypoint_id=waypoint_id)
async def batch_update_waypoints(
self, flight_id: str, waypoints: list[Waypoint]
) -> BatchUpdateResponse:
failed = []
updated = 0
for wp in waypoints:
ok = await self.repository.update_waypoint(
flight_id,
wp.id,
lat=wp.lat,
lon=wp.lon,
altitude=wp.altitude,
confidence=wp.confidence,
refined=wp.refined,
)
if ok:
updated += 1
else:
failed.append(wp.id)
return BatchUpdateResponse(
success=(len(failed) == 0), updated_count=updated, failed_ids=failed
)
async def queue_images(
self, flight_id: str, metadata: BatchMetadata, file_count: int
) -> BatchResponse:
state = await self.repository.load_flight_state(flight_id)
if state:
total = state.frames_total + file_count
await self.repository.save_flight_state(
flight_id, frames_total=total, status="processing"
)
next_seq = metadata.end_sequence + 1
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
return BatchResponse(
accepted=True,
sequences=seqs,
next_expected=next_seq,
message=f"Queued {file_count} images.",
)
async def handle_user_fix(
self, flight_id: str, req: UserFixRequest
) -> UserFixResponse:
await self.repository.save_flight_state(
flight_id, blocked=False, status="processing"
)
# Inject operator position into ESKF with high uncertainty (500m)
eskf = self._eskf.get(flight_id)
if eskf and eskf.initialized and self._coord:
try:
e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps)
alt = self._altitudes.get(flight_id, 100.0)
eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0)
self._failure_counts[flight_id] = 0
logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps)
except Exception as exc:
logger.warning("User fix ESKF injection failed: %s", exc)
return UserFixResponse(
accepted=True, processing_resumed=True, message="Fix applied."
)
async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None:
state = await self.repository.load_flight_state(flight_id)
if not state:
return None
return FlightStatusResponse(
status=state.status,
frames_processed=state.frames_processed,
frames_total=state.frames_total,
current_frame=state.current_frame,
current_heading=None,
blocked=state.blocked,
search_grid_size=state.search_grid_size,
created_at=state.created_at,
updated_at=state.updated_at,
)
async def convert_object_to_gps(
self, flight_id: str, frame_id: int, pixel: tuple[float, float]
) -> ObjectGPSResponse:
# PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection
gps: Optional[GPSPoint] = None
eskf = self._eskf.get(flight_id)
if self._coord and eskf and eskf.initialized:
pos = eskf.position
quat = eskf.quaternion
cam = self._flight_cameras.get(flight_id, CameraParameters(
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
resolution_width=640, resolution_height=480,
))
alt = self._altitudes.get(flight_id, 100.0)
try:
gps = self._coord.pixel_to_gps(
flight_id,
pixel,
frame_pose={"position": pos},
camera_params=cam,
altitude=float(alt),
quaternion=quat,
)
except Exception as exc:
logger.debug("pixel_to_gps failed: %s", exc)
# Fallback: return ESKF position projected to ground (no pixel shift)
if gps is None and eskf:
gps = self._eskf_to_gps(flight_id, eskf)
return ObjectGPSResponse(
gps=gps or GPSPoint(lat=0.0, lon=0.0),
accuracy_meters=5.0,
frame_id=frame_id,
pixel=pixel,
)
async def stream_events(self, flight_id: str, client_id: str):
"""Async generator for SSE stream."""
async for event in self.streamer.stream_generator(flight_id, client_id):
yield event
+5 -7
View File
@@ -1,8 +1,7 @@
"""Failure Recovery Coordinator (Component F11)."""
import logging
from abc import ABC, abstractmethod
from typing import List
from typing import List, Protocol, runtime_checkable
import numpy as np
@@ -14,14 +13,13 @@ from gps_denied.schemas.chunk import ChunkStatus
logger = logging.getLogger(__name__)
class IFailureRecoveryCoordinator(ABC):
@abstractmethod
@runtime_checkable
class IFailureRecoveryCoordinator(Protocol):
def handle_tracking_lost(self, flight_id: str, current_frame_id: int) -> bool:
pass
...
@abstractmethod
def process_chunk_recovery(self, flight_id: str, chunk_id: str, images: List[np.ndarray]) -> bool:
pass
...
class FailureRecoveryCoordinator(IFailureRecoveryCoordinator):
+4 -73
View File
@@ -1,73 +1,4 @@
"""Result Manager (Component F14)."""
from __future__ import annotations
from datetime import datetime
from gps_denied.core.sse import SSEEventStreamer
from gps_denied.db.repository import FlightRepository
from gps_denied.schemas import GPSPoint
from gps_denied.schemas.events import FrameProcessedEvent
class ResultManager:
"""Result consistency and publishing."""
def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None:
self.repo = repo
self.sse = sse
async def update_frame_result(
self,
flight_id: str,
frame_id: int,
gps_lat: float,
gps_lon: float,
altitude: float,
heading: float,
confidence: float,
timestamp: datetime,
refined: bool = False,
) -> bool:
"""Atomic DB update + SSE event publish."""
# 1. Update DB (in the repository these are auto-committing via flush,
# but normally F03 would wrap in a single transaction).
await self.repo.save_frame_result(
flight_id,
frame_id=frame_id,
gps_lat=gps_lat,
gps_lon=gps_lon,
altitude=altitude,
heading=heading,
confidence=confidence,
refined=refined,
)
# Wait, the spec also wants Waypoints to be updated.
# But image frames != waypoints. Waypoints are the planned route.
# Actually in the spec it says: "Updates waypoint in waypoints table."
# This implies updating the closest waypoint or a generated waypoint path.
# We will follow the simplest form for now: update the waypoint if there is one corresponding.
# Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map,
# or we just skip unless specified.
# 2. Trigger SSE event
evt = FrameProcessedEvent(
frame_id=frame_id,
gps=GPSPoint(lat=gps_lat, lon=gps_lon),
altitude=altitude,
confidence=confidence,
heading=heading,
timestamp=timestamp,
)
if refined:
self.sse.send_refinement(flight_id, evt)
else:
self.sse.send_frame_result(flight_id, evt)
return True
async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool:
# Just delegates to SSE for waypoint updates, which is basically the frame result for UI
pass
"""Legacy import path. Phase 1 shim — code lives in pipeline/result_manager.py."""
from gps_denied.pipeline.result_manager import ( # noqa: F401
ResultManager,
)
+4 -4
View File
@@ -2,8 +2,8 @@
import dataclasses
import math
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Protocol, runtime_checkable
import cv2
import numpy as np
@@ -12,14 +12,14 @@ from gps_denied.schemas.rotation import HeadingHistory, RotationResult
from gps_denied.schemas.satellite import TileBounds
class IImageMatcher(ABC):
@runtime_checkable
class IImageMatcher(Protocol):
"""Dependency injection interface for Metric Refinement."""
@abstractmethod
def align_to_satellite(
self, uav_image: np.ndarray, satellite_tile: np.ndarray,
tile_bounds: TileBounds,
) -> RotationResult:
pass
...
class ImageRotationManager:
+3 -163
View File
@@ -1,164 +1,4 @@
"""SSE Event Streamer (Component F15)."""
from __future__ import annotations
import asyncio
import json
from collections import defaultdict
from gps_denied.schemas.events import (
FlightCompletedEvent,
FrameProcessedEvent,
SearchExpandedEvent,
SSEEventType,
SSEMessage,
UserInputNeededEvent,
"""Legacy import path. Phase 1 shim — code lives in pipeline/sse_streamer.py."""
from gps_denied.pipeline.sse_streamer import ( # noqa: F401
SSEEventStreamer,
)
class SSEEventStreamer:
"""Manages real-time SSE connections and event broadcasting."""
def __init__(self) -> None:
# Map: flight_id -> Dict[client_id, asyncio.Queue]
self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict)
def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]:
"""Create a new event queue for a client."""
q: asyncio.Queue[SSEMessage | None] = asyncio.Queue()
self._streams[flight_id][client_id] = q
return q
def close_stream(self, flight_id: str, client_id: str) -> None:
"""Close a client stream by putting a sentinel and removing the queue."""
if flight_id in self._streams and client_id in self._streams[flight_id]:
q = self._streams[flight_id].pop(client_id)
if not self._streams[flight_id]:
del self._streams[flight_id]
# Put None to signal generator exit
try:
q.put_nowait(None)
except asyncio.QueueFull:
pass
def get_active_connections(self, flight_id: str) -> int:
return len(self._streams.get(flight_id, {}))
def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool:
"""Broadcast a message to all clients subscribed to flight_id."""
if flight_id not in self._streams or not self._streams[flight_id]:
return False
for q in self._streams[flight_id].values():
try:
q.put_nowait(msg)
except asyncio.QueueFull:
pass # Drop if queue is full rather than blocking
return True
# ── Business Event Senders ────────────────────────────────────────────────
def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.FRAME_PROCESSED,
data=event_data.model_dump(mode="json"),
id=f"frame_{event_data.frame_id}",
)
return self._broadcast(flight_id, msg)
def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.FRAME_REFINED,
data=event_data.model_dump(mode="json"),
id=f"refine_{event_data.frame_id}",
)
return self._broadcast(flight_id, msg)
def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.SEARCH_EXPANDED,
data=event_data.model_dump(mode="json"),
)
return self._broadcast(flight_id, msg)
def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.USER_INPUT_NEEDED,
data=event_data.model_dump(mode="json"),
)
return self._broadcast(flight_id, msg)
def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.FLIGHT_COMPLETED,
data=event_data.model_dump(mode="json"),
)
return self._broadcast(flight_id, msg)
def send_heartbeat(self, flight_id: str) -> bool:
# sse_starlette uses empty string or comment for heartbeat,
# but we can just send an SSEMessage object that parses as empty event
if flight_id not in self._streams:
return False
# Manually sending a comment via the generator is tricky with strict SSEMessage schema
# but we'll handle this in the stream generator directly
return True
# ── Generic event dispatcher (used by processor.process_frame) ──────────
async def push_event(self, flight_id: str, event_type: str, data: dict) -> None:
"""Dispatch a generic event to all clients for a flight.
Maps event_type strings to typed SSE events:
"frame_result" → FrameProcessedEvent
"refinement" → FrameProcessedEvent (refined)
Other → raw broadcast via SSEMessage
"""
if event_type == "frame_result":
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
self.send_frame_result(flight_id, evt)
elif event_type == "refinement":
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
self.send_refinement(flight_id, evt)
else:
msg = SSEMessage(
event=SSEEventType.FRAME_PROCESSED,
data=data,
id=str(data.get("frame_id", "")),
)
self._broadcast(flight_id, msg)
# ── Stream Generator ──────────────────────────────────────────────────────
async def stream_generator(self, flight_id: str, client_id: str):
"""Yields dicts for sse_starlette EventSourceResponse."""
q = self.create_stream(flight_id, client_id)
# Send an immediate connection accepted ping
yield {"event": "connected", "data": "connected"}
try:
while True:
# Wait for next event or send heartbeat every 15s
try:
msg = await asyncio.wait_for(q.get(), timeout=15.0)
if msg is None:
# Sentinel for clean shutdown
break
# Yield dict format for sse_starlette
yield {
"event": msg.event.value,
"id": msg.id if msg.id else "",
"data": json.dumps(msg.data)
}
except asyncio.TimeoutError:
# Heartbeat format for sse_starlette (empty string generates a comment)
yield {"event": "heartbeat", "data": "ping"}
except asyncio.CancelledError:
pass # Client disconnected
finally:
self.close_stream(flight_id, client_id)
+15
View File
@@ -0,0 +1,15 @@
"""Pipeline package: orchestrator + IO + composition root."""
from .orchestrator import FlightProcessor
from .image_input import ImageInputPipeline
from .result_manager import ResultManager
from .sse_streamer import SSEEventStreamer
Pipeline = FlightProcessor
__all__ = [
"FlightProcessor",
"Pipeline",
"ImageInputPipeline",
"ResultManager",
"SSEEventStreamer",
]
+227
View File
@@ -0,0 +1,227 @@
"""Image Input Pipeline (Component F05)."""
import asyncio
import os
import re
from datetime import datetime, timezone
import cv2
import numpy as np
from gps_denied.schemas.image import (
ImageBatch,
ImageData,
ImageMetadata,
ProcessedBatch,
ProcessingStatus,
ValidationResult,
)
class QueueFullError(Exception):
pass
class ValidationError(Exception):
pass
class ImageInputPipeline:
"""Manages ingestion, disk storage, and queuing of UAV image batches."""
def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50):
self.storage_dir = storage_dir
# flight_id -> asyncio.Queue of ImageBatch
self._queues: dict[str, asyncio.Queue] = {}
self.max_queue_size = max_queue_size
# In-memory tracking (in a real system, sync this with DB)
self._status: dict[str, dict] = {}
# Exact sequence → filename mapping (VO-05: no substring collision)
self._sequence_map: dict[str, dict[int, str]] = {}
def _get_queue(self, flight_id: str) -> asyncio.Queue:
if flight_id not in self._queues:
self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size)
return self._queues[flight_id]
def _init_status(self, flight_id: str):
if flight_id not in self._status:
self._status[flight_id] = {
"total_images": 0,
"processed_images": 0,
"current_sequence": 1,
}
def validate_batch(self, batch: ImageBatch) -> ValidationResult:
"""Validates batch integrity and sequence continuity."""
errors = []
num_images = len(batch.images)
if num_images < 1:
errors.append("Batch is empty")
elif num_images > 100:
errors.append("Batch too large")
if len(batch.filenames) != num_images:
errors.append("Mismatch between filenames and images count")
# Naming convention ADxxxxxx.jpg or similar
pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE)
for fn in batch.filenames:
if not pattern.match(fn):
errors.append(f"Invalid filename: {fn}")
break
if batch.start_sequence > batch.end_sequence:
errors.append("Start sequence greater than end sequence")
return ValidationResult(valid=len(errors) == 0, errors=errors)
def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool:
"""Queues a batch of images for processing."""
val = self.validate_batch(batch)
if not val.valid:
raise ValidationError(f"Batch validation failed: {val.errors}")
q = self._get_queue(flight_id)
if q.full():
raise QueueFullError(f"Queue for flight {flight_id} is full")
q.put_nowait(batch)
self._init_status(flight_id)
self._status[flight_id]["total_images"] += len(batch.images)
return True
async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None:
"""Dequeues and processing the next batch."""
q = self._get_queue(flight_id)
if q.empty():
return None
batch: ImageBatch = await q.get()
processed_images = []
for i, raw_bytes in enumerate(batch.images):
# Decode
nparr = np.frombuffer(raw_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
continue # skip corrupted
seq = batch.start_sequence + i
fn = batch.filenames[i]
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=seq,
filename=fn,
dimensions=(w, h),
file_size=len(raw_bytes),
timestamp=datetime.now(timezone.utc),
)
img_data = ImageData(
flight_id=flight_id,
sequence=seq,
filename=fn,
image=img,
metadata=meta
)
processed_images.append(img_data)
# VO-05: record exact sequence→filename mapping
self._sequence_map.setdefault(flight_id, {})[seq] = fn
# Store to disk
self.store_images(flight_id, processed_images)
self._status[flight_id]["processed_images"] += len(processed_images)
q.task_done()
return ProcessedBatch(
images=processed_images,
batch_id=f"batch_{batch.batch_number}",
start_sequence=batch.start_sequence,
end_sequence=batch.end_sequence
)
def store_images(self, flight_id: str, images: list[ImageData]) -> bool:
"""Persists images to disk."""
flight_dir = os.path.join(self.storage_dir, flight_id)
os.makedirs(flight_dir, exist_ok=True)
for img in images:
path = os.path.join(flight_dir, img.filename)
cv2.imwrite(path, img.image)
return True
def get_next_image(self, flight_id: str) -> ImageData | None:
"""Gets the next image in sequence for processing."""
self._init_status(flight_id)
seq = self._status[flight_id]["current_sequence"]
img = self.get_image_by_sequence(flight_id, seq)
if img:
self._status[flight_id]["current_sequence"] += 1
return img
def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None:
"""Retrieves a specific image by sequence number (exact match — VO-05)."""
flight_dir = os.path.join(self.storage_dir, flight_id)
if not os.path.exists(flight_dir):
return None
# Prefer the exact mapping built during process_next_batch
fn = self._sequence_map.get(flight_id, {}).get(sequence)
if fn:
path = os.path.join(flight_dir, fn)
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=sequence,
filename=fn,
dimensions=(w, h),
file_size=os.path.getsize(path),
timestamp=datetime.now(timezone.utc),
)
return ImageData(flight_id, sequence, fn, img, meta)
# Fallback: scan directory for exact filename patterns
# (handles images stored before this process started)
for fn in os.listdir(flight_dir):
base, _ = os.path.splitext(fn)
# Accept only if the base name ends with exactly the padded sequence number
if base.endswith(f"{sequence:06d}") or base == str(sequence):
path = os.path.join(flight_dir, fn)
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=sequence,
filename=fn,
dimensions=(w, h),
file_size=os.path.getsize(path),
timestamp=datetime.now(timezone.utc),
)
return ImageData(flight_id, sequence, fn, img, meta)
return None
def get_processing_status(self, flight_id: str) -> ProcessingStatus:
self._init_status(flight_id)
s = self._status[flight_id]
q = self._get_queue(flight_id)
return ProcessingStatus(
flight_id=flight_id,
total_images=s["total_images"],
processed_images=s["processed_images"],
current_sequence=s["current_sequence"],
queued_batches=q.qsize(),
processing_rate=0.0 # mock
)
+599
View File
@@ -0,0 +1,599 @@
"""Core Flight Processor — Full Processing Pipeline (Stage 10).
Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE.
State Machine: NORMAL → LOST → RECOVERY → NORMAL.
"""
from __future__ import annotations
import asyncio
import logging
import time
from enum import Enum
from typing import Optional
import numpy as np
from gps_denied.core.eskf import ESKF
from gps_denied.pipeline.image_input import ImageInputPipeline
from gps_denied.pipeline.result_manager import ResultManager
from gps_denied.pipeline.sse_streamer import SSEEventStreamer
from gps_denied.db.repository import FlightRepository
from gps_denied.schemas import CameraParameters, GPSPoint
from gps_denied.schemas.flight import (
BatchMetadata,
BatchResponse,
BatchUpdateResponse,
DeleteResponse,
FlightCreateRequest,
FlightDetailResponse,
FlightResponse,
FlightStatusResponse,
ObjectGPSResponse,
UpdateResponse,
UserFixRequest,
UserFixResponse,
Waypoint,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# State Machine
# ---------------------------------------------------------------------------
class TrackingState(str, Enum):
"""Processing state for a flight."""
NORMAL = "normal"
LOST = "lost"
RECOVERY = "recovery"
class FrameResult:
"""Intermediate result of processing a single frame."""
def __init__(self, frame_id: int):
self.frame_id = frame_id
self.gps: Optional[GPSPoint] = None
self.confidence: float = 0.0
self.tracking_state: TrackingState = TrackingState.NORMAL
self.vo_success: bool = False
self.alignment_success: bool = False
# ---------------------------------------------------------------------------
# FlightProcessor
# ---------------------------------------------------------------------------
class FlightProcessor:
"""Manages business logic, background processing, and frame orchestration."""
def __init__(
self,
repository: FlightRepository,
streamer: SSEEventStreamer,
eskf_config=None,
) -> None:
self.repository = repository
self.streamer = streamer
self.result_manager = ResultManager(repository, streamer)
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
self._eskf_config = eskf_config # ESKFConfig or None → default
# Per-flight processing state
self._flight_states: dict[str, TrackingState] = {}
self._prev_images: dict[str, np.ndarray] = {} # previous frame cache
self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera
self._altitudes: dict[str, float] = {} # per-flight altitude (m)
self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter
# Per-flight ESKF instances (PIPE-01/07)
self._eskf: dict[str, ESKF] = {}
# Lazy-initialised component references (set via `attach_components`)
self._vo = None # ISequentialVisualOdometry
self._gpr = None # IGlobalPlaceRecognition
self._metric = None # IMetricRefinement
self._graph = None # IFactorGraphOptimizer
self._recovery = None # IFailureRecoveryCoordinator
self._chunk_mgr = None # IRouteChunkManager
self._rotation = None # ImageRotationManager
self._satellite = None # SatelliteDataManager (PIPE-02)
self._coord = None # CoordinateTransformer (PIPE-02/06)
self._mavlink = None # MAVLinkBridge (PIPE-07)
# ------ Dependency injection for core components ---------
def attach_components(
self,
vo=None,
gpr=None,
metric=None,
graph=None,
recovery=None,
chunk_mgr=None,
rotation=None,
satellite=None,
coord=None,
mavlink=None,
):
"""Attach pipeline components after construction (avoids circular deps)."""
self._vo = vo
self._gpr = gpr
self._metric = metric
self._graph = graph
self._recovery = recovery
self._chunk_mgr = chunk_mgr
self._rotation = rotation
self._satellite = satellite # PIPE-02: SatelliteDataManager
self._coord = coord # PIPE-02/06: CoordinateTransformer
self._mavlink = mavlink # PIPE-07: MAVLinkBridge
# ------ ESKF lifecycle helpers ----------------------------
def _init_eskf_for_flight(
self, flight_id: str, start_gps: GPSPoint, altitude: float
) -> None:
"""Create and initialize a per-flight ESKF instance."""
if flight_id in self._eskf:
return
eskf = ESKF(config=self._eskf_config)
if self._coord:
try:
e, n, _ = self._coord.gps_to_enu(flight_id, start_gps)
eskf.initialize(np.array([e, n, altitude]), time.time())
except Exception:
eskf.initialize(np.zeros(3), time.time())
else:
eskf.initialize(np.zeros(3), time.time())
self._eskf[flight_id] = eskf
def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]:
"""Convert current ESKF ENU position to WGS84 GPS."""
if not eskf.initialized or self._coord is None:
return None
try:
pos = eskf.position
return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2])))
except Exception:
return None
# =========================================================
# process_frame — central orchestration
# =========================================================
async def process_frame(
self,
flight_id: str,
frame_id: int,
image: np.ndarray,
) -> FrameResult:
"""
Process a single UAV frame through the full pipeline.
State transitions:
NORMAL — VO succeeds → ESKF VO update, attempt satellite fix
LOST — VO failed → create new chunk, enter RECOVERY
RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL
PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT
PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection
PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator
PIPE-05: ImageRotationManager initialised on first frame
PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state
"""
result = FrameResult(frame_id)
state = self._flight_states.get(flight_id, TrackingState.NORMAL)
eskf = self._eskf.get(flight_id)
_default_cam = CameraParameters(
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
resolution_width=640, resolution_height=480,
)
# ---- PIPE-05: Initialise heading tracking on first frame ----
if self._rotation and frame_id == 0:
self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory
# ---- 1. Visual Odometry (frame-to-frame) ----
vo_ok = False
if self._vo and flight_id in self._prev_images:
try:
cam = self._flight_cameras.get(flight_id, _default_cam)
rel_pose = self._vo.compute_relative_pose(
self._prev_images[flight_id], image, cam
)
if rel_pose and rel_pose.tracking_good:
vo_ok = True
result.vo_success = True
if self._graph:
self._graph.add_relative_factor(
flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6)
)
# PIPE-01: Feed VO relative displacement into ESKF
if eskf and eskf.initialized:
now = time.time()
dt_vo = max(0.01, now - (eskf.last_timestamp or now))
eskf.update_vo(rel_pose.translation, dt_vo)
except Exception as exc:
logger.warning("VO failed for frame %d: %s", frame_id, exc)
# Store current image for next frame
self._prev_images[flight_id] = image
# ---- PIPE-04: Consecutive failure counter ----
if not vo_ok and frame_id > 0:
self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1
else:
self._failure_counts[flight_id] = 0
# ---- 2. State Machine transitions ----
if state == TrackingState.NORMAL:
if not vo_ok and frame_id > 0:
state = TrackingState.LOST
logger.info("Flight %s → LOST at frame %d", flight_id, frame_id)
if self._recovery:
self._recovery.handle_tracking_lost(flight_id, frame_id)
if state == TrackingState.LOST:
state = TrackingState.RECOVERY
if state == TrackingState.RECOVERY:
recovered = False
if self._recovery and self._chunk_mgr:
active_chunk = self._chunk_mgr.get_active_chunk(flight_id)
if active_chunk:
recovered = self._recovery.process_chunk_recovery(
flight_id, active_chunk.chunk_id, [image]
)
if recovered:
state = TrackingState.NORMAL
result.alignment_success = True
# PIPE-04: Reset failure count on successful recovery
self._failure_counts[flight_id] = 0
logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id)
# ---- 3. Satellite position fix (PIPE-01/02) ----
if state == TrackingState.NORMAL and self._metric:
sat_tile: Optional[np.ndarray] = None
tile_bounds = None
# PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection)
if self._satellite and eskf and eskf.initialized:
gps_est = self._eskf_to_gps(flight_id, eskf)
if gps_est:
cov = eskf.covariance
sigma_h = float(
np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0)
) if cov is not None else 30.0
sigma_h = max(sigma_h, 5.0)
try:
tile_result = await asyncio.get_event_loop().run_in_executor(
None,
self._satellite.fetch_tiles_for_position,
gps_est, sigma_h, 18,
)
if tile_result:
sat_tile, tile_bounds = tile_result
except Exception as exc:
logger.debug("Satellite tile fetch failed: %s", exc)
# Fallback: GPR candidate tile (mock image, real bounds)
if sat_tile is None and self._gpr:
try:
candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1)
if candidates:
sat_tile = np.zeros((256, 256, 3), dtype=np.uint8)
tile_bounds = candidates[0].bounds
except Exception as exc:
logger.debug("GPR tile fallback failed: %s", exc)
if sat_tile is not None and tile_bounds is not None:
try:
align = self._metric.align_to_satellite(image, sat_tile, tile_bounds)
if align and align.matched:
result.gps = align.gps_center
result.confidence = align.confidence
result.alignment_success = True
if self._graph:
self._graph.add_absolute_factor(
flight_id, frame_id,
align.gps_center, np.eye(6),
is_user_anchor=False,
)
# PIPE-01: ESKF satellite update — noise from RANSAC confidence
if eskf and eskf.initialized and self._coord:
try:
e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center)
alt = self._altitudes.get(flight_id, 100.0)
pos_enu = np.array([e, n, alt])
noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence))
eskf.update_satellite(pos_enu, noise_m)
except Exception as exc:
logger.debug("ESKF satellite update failed: %s", exc)
except Exception as exc:
logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc)
# ---- 4. Graph optimization (incremental) ----
if self._graph:
opt_result = self._graph.optimize(flight_id, iterations=5)
logger.debug(
"Optimization: converged=%s, error=%.4f",
opt_result.converged, opt_result.final_error,
)
# ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ----
if self._mavlink and eskf and eskf.initialized:
try:
eskf_state = eskf.get_state()
alt = self._altitudes.get(flight_id, 100.0)
self._mavlink.update_state(eskf_state, altitude_m=alt)
except Exception as exc:
logger.debug("MAVLink state push failed: %s", exc)
# ---- 5. Publish via SSE ----
result.tracking_state = state
self._flight_states[flight_id] = state
await self._publish_frame_result(flight_id, result)
return result
async def _publish_frame_result(self, flight_id: str, result: FrameResult):
"""Emit SSE event for processed frame."""
event_data = {
"frame_id": result.frame_id,
"tracking_state": result.tracking_state.value,
"vo_success": result.vo_success,
"alignment_success": result.alignment_success,
"confidence": result.confidence,
}
if result.gps:
event_data["lat"] = result.gps.lat
event_data["lon"] = result.gps.lon
await self.streamer.push_event(
flight_id, event_type="frame_result", data=event_data
)
# =========================================================
# Existing CRUD / REST helpers (unchanged from Stage 3-4)
# =========================================================
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
flight = await self.repository.insert_flight(
name=req.name,
description=req.description,
start_lat=req.start_gps.lat,
start_lon=req.start_gps.lon,
altitude=req.altitude,
camera_params=req.camera_params.model_dump(),
)
# P0#2: Store camera params for process_frame VO calls
self._flight_cameras[flight.id] = req.camera_params
for poly in req.geofences.polygons:
await self.repository.insert_geofence(
flight.id,
nw_lat=poly.north_west.lat,
nw_lon=poly.north_west.lon,
se_lat=poly.south_east.lat,
se_lon=poly.south_east.lon,
)
for w in req.rough_waypoints:
await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon)
# Store per-flight altitude for ESKF/pixel projection
self._altitudes[flight.id] = req.altitude or 100.0
# PIPE-02: Set ENU origin and initialise ESKF for this flight
if self._coord:
self._coord.set_enu_origin(flight.id, req.start_gps)
self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0)
# Start MAVLink bridge for this flight (origin required for GPS_INPUT)
if self._mavlink and not self._mavlink._running:
try:
asyncio.create_task(self._mavlink.start(req.start_gps))
except Exception as exc:
logger.warning("MAVLink bridge start failed: %s", exc)
return FlightResponse(
flight_id=flight.id,
status="prefetching",
message="Flight created and prefetching started.",
created_at=flight.created_at,
)
async def get_flight(self, flight_id: str) -> FlightDetailResponse | None:
flight = await self.repository.get_flight(flight_id)
if not flight:
return None
wps = await self.repository.get_waypoints(flight_id)
state = await self.repository.load_flight_state(flight_id)
waypoints = [
Waypoint(
id=w.id,
lat=w.lat,
lon=w.lon,
altitude=w.altitude,
confidence=w.confidence,
timestamp=w.timestamp,
refined=w.refined,
)
for w in wps
]
status = state.status if state else "unknown"
frames_processed = state.frames_processed if state else 0
frames_total = state.frames_total if state else 0
from gps_denied.schemas import Geofences
return FlightDetailResponse(
flight_id=flight.id,
name=flight.name,
description=flight.description,
start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon),
waypoints=waypoints,
geofences=Geofences(polygons=[]),
camera_params=flight.camera_params,
altitude=flight.altitude,
status=status,
frames_processed=frames_processed,
frames_total=frames_total,
created_at=flight.created_at,
updated_at=flight.updated_at,
)
async def delete_flight(self, flight_id: str) -> DeleteResponse:
deleted = await self.repository.delete_flight(flight_id)
# P0#1: Cleanup in-memory state to prevent memory leaks
self._cleanup_flight(flight_id)
return DeleteResponse(deleted=deleted, flight_id=flight_id)
def _cleanup_flight(self, flight_id: str) -> None:
"""Remove all in-memory state for a flight (prevents memory leaks)."""
self._prev_images.pop(flight_id, None)
self._flight_states.pop(flight_id, None)
self._flight_cameras.pop(flight_id, None)
self._altitudes.pop(flight_id, None)
self._failure_counts.pop(flight_id, None)
self._eskf.pop(flight_id, None)
if self._graph:
self._graph.delete_flight_graph(flight_id)
async def update_waypoint(
self, flight_id: str, waypoint_id: str, waypoint: Waypoint
) -> UpdateResponse:
ok = await self.repository.update_waypoint(
flight_id,
waypoint_id,
lat=waypoint.lat,
lon=waypoint.lon,
altitude=waypoint.altitude,
confidence=waypoint.confidence,
refined=waypoint.refined,
)
return UpdateResponse(updated=ok, waypoint_id=waypoint_id)
async def batch_update_waypoints(
self, flight_id: str, waypoints: list[Waypoint]
) -> BatchUpdateResponse:
failed = []
updated = 0
for wp in waypoints:
ok = await self.repository.update_waypoint(
flight_id,
wp.id,
lat=wp.lat,
lon=wp.lon,
altitude=wp.altitude,
confidence=wp.confidence,
refined=wp.refined,
)
if ok:
updated += 1
else:
failed.append(wp.id)
return BatchUpdateResponse(
success=(len(failed) == 0), updated_count=updated, failed_ids=failed
)
async def queue_images(
self, flight_id: str, metadata: BatchMetadata, file_count: int
) -> BatchResponse:
state = await self.repository.load_flight_state(flight_id)
if state:
total = state.frames_total + file_count
await self.repository.save_flight_state(
flight_id, frames_total=total, status="processing"
)
next_seq = metadata.end_sequence + 1
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
return BatchResponse(
accepted=True,
sequences=seqs,
next_expected=next_seq,
message=f"Queued {file_count} images.",
)
async def handle_user_fix(
self, flight_id: str, req: UserFixRequest
) -> UserFixResponse:
await self.repository.save_flight_state(
flight_id, blocked=False, status="processing"
)
# Inject operator position into ESKF with high uncertainty (500m)
eskf = self._eskf.get(flight_id)
if eskf and eskf.initialized and self._coord:
try:
e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps)
alt = self._altitudes.get(flight_id, 100.0)
eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0)
self._failure_counts[flight_id] = 0
logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps)
except Exception as exc:
logger.warning("User fix ESKF injection failed: %s", exc)
return UserFixResponse(
accepted=True, processing_resumed=True, message="Fix applied."
)
async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None:
state = await self.repository.load_flight_state(flight_id)
if not state:
return None
return FlightStatusResponse(
status=state.status,
frames_processed=state.frames_processed,
frames_total=state.frames_total,
current_frame=state.current_frame,
current_heading=None,
blocked=state.blocked,
search_grid_size=state.search_grid_size,
created_at=state.created_at,
updated_at=state.updated_at,
)
async def convert_object_to_gps(
self, flight_id: str, frame_id: int, pixel: tuple[float, float]
) -> ObjectGPSResponse:
# PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection
gps: Optional[GPSPoint] = None
eskf = self._eskf.get(flight_id)
if self._coord and eskf and eskf.initialized:
pos = eskf.position
quat = eskf.quaternion
cam = self._flight_cameras.get(flight_id, CameraParameters(
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
resolution_width=640, resolution_height=480,
))
alt = self._altitudes.get(flight_id, 100.0)
try:
gps = self._coord.pixel_to_gps(
flight_id,
pixel,
frame_pose={"position": pos},
camera_params=cam,
altitude=float(alt),
quaternion=quat,
)
except Exception as exc:
logger.debug("pixel_to_gps failed: %s", exc)
# Fallback: return ESKF position projected to ground (no pixel shift)
if gps is None and eskf:
gps = self._eskf_to_gps(flight_id, eskf)
return ObjectGPSResponse(
gps=gps or GPSPoint(lat=0.0, lon=0.0),
accuracy_meters=5.0,
frame_id=frame_id,
pixel=pixel,
)
async def stream_events(self, flight_id: str, client_id: str):
"""Async generator for SSE stream."""
async for event in self.streamer.stream_generator(flight_id, client_id):
yield event
+73
View File
@@ -0,0 +1,73 @@
"""Result Manager (Component F14)."""
from __future__ import annotations
from datetime import datetime
from gps_denied.pipeline.sse_streamer import SSEEventStreamer
from gps_denied.db.repository import FlightRepository
from gps_denied.schemas import GPSPoint
from gps_denied.schemas.events import FrameProcessedEvent
class ResultManager:
"""Result consistency and publishing."""
def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None:
self.repo = repo
self.sse = sse
async def update_frame_result(
self,
flight_id: str,
frame_id: int,
gps_lat: float,
gps_lon: float,
altitude: float,
heading: float,
confidence: float,
timestamp: datetime,
refined: bool = False,
) -> bool:
"""Atomic DB update + SSE event publish."""
# 1. Update DB (in the repository these are auto-committing via flush,
# but normally F03 would wrap in a single transaction).
await self.repo.save_frame_result(
flight_id,
frame_id=frame_id,
gps_lat=gps_lat,
gps_lon=gps_lon,
altitude=altitude,
heading=heading,
confidence=confidence,
refined=refined,
)
# Wait, the spec also wants Waypoints to be updated.
# But image frames != waypoints. Waypoints are the planned route.
# Actually in the spec it says: "Updates waypoint in waypoints table."
# This implies updating the closest waypoint or a generated waypoint path.
# We will follow the simplest form for now: update the waypoint if there is one corresponding.
# Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map,
# or we just skip unless specified.
# 2. Trigger SSE event
evt = FrameProcessedEvent(
frame_id=frame_id,
gps=GPSPoint(lat=gps_lat, lon=gps_lon),
altitude=altitude,
confidence=confidence,
heading=heading,
timestamp=timestamp,
)
if refined:
self.sse.send_refinement(flight_id, evt)
else:
self.sse.send_frame_result(flight_id, evt)
return True
async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool:
# Just delegates to SSE for waypoint updates, which is basically the frame result for UI
pass
+164
View File
@@ -0,0 +1,164 @@
"""SSE Event Streamer (Component F15)."""
from __future__ import annotations
import asyncio
import json
from collections import defaultdict
from gps_denied.schemas.events import (
FlightCompletedEvent,
FrameProcessedEvent,
SearchExpandedEvent,
SSEEventType,
SSEMessage,
UserInputNeededEvent,
)
class SSEEventStreamer:
"""Manages real-time SSE connections and event broadcasting."""
def __init__(self) -> None:
# Map: flight_id -> Dict[client_id, asyncio.Queue]
self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict)
def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]:
"""Create a new event queue for a client."""
q: asyncio.Queue[SSEMessage | None] = asyncio.Queue()
self._streams[flight_id][client_id] = q
return q
def close_stream(self, flight_id: str, client_id: str) -> None:
"""Close a client stream by putting a sentinel and removing the queue."""
if flight_id in self._streams and client_id in self._streams[flight_id]:
q = self._streams[flight_id].pop(client_id)
if not self._streams[flight_id]:
del self._streams[flight_id]
# Put None to signal generator exit
try:
q.put_nowait(None)
except asyncio.QueueFull:
pass
def get_active_connections(self, flight_id: str) -> int:
return len(self._streams.get(flight_id, {}))
def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool:
"""Broadcast a message to all clients subscribed to flight_id."""
if flight_id not in self._streams or not self._streams[flight_id]:
return False
for q in self._streams[flight_id].values():
try:
q.put_nowait(msg)
except asyncio.QueueFull:
pass # Drop if queue is full rather than blocking
return True
# ── Business Event Senders ────────────────────────────────────────────────
def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.FRAME_PROCESSED,
data=event_data.model_dump(mode="json"),
id=f"frame_{event_data.frame_id}",
)
return self._broadcast(flight_id, msg)
def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.FRAME_REFINED,
data=event_data.model_dump(mode="json"),
id=f"refine_{event_data.frame_id}",
)
return self._broadcast(flight_id, msg)
def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.SEARCH_EXPANDED,
data=event_data.model_dump(mode="json"),
)
return self._broadcast(flight_id, msg)
def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.USER_INPUT_NEEDED,
data=event_data.model_dump(mode="json"),
)
return self._broadcast(flight_id, msg)
def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool:
msg = SSEMessage(
event=SSEEventType.FLIGHT_COMPLETED,
data=event_data.model_dump(mode="json"),
)
return self._broadcast(flight_id, msg)
def send_heartbeat(self, flight_id: str) -> bool:
# sse_starlette uses empty string or comment for heartbeat,
# but we can just send an SSEMessage object that parses as empty event
if flight_id not in self._streams:
return False
# Manually sending a comment via the generator is tricky with strict SSEMessage schema
# but we'll handle this in the stream generator directly
return True
# ── Generic event dispatcher (used by processor.process_frame) ──────────
async def push_event(self, flight_id: str, event_type: str, data: dict) -> None:
"""Dispatch a generic event to all clients for a flight.
Maps event_type strings to typed SSE events:
"frame_result" → FrameProcessedEvent
"refinement" → FrameProcessedEvent (refined)
Other → raw broadcast via SSEMessage
"""
if event_type == "frame_result":
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
self.send_frame_result(flight_id, evt)
elif event_type == "refinement":
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
self.send_refinement(flight_id, evt)
else:
msg = SSEMessage(
event=SSEEventType.FRAME_PROCESSED,
data=data,
id=str(data.get("frame_id", "")),
)
self._broadcast(flight_id, msg)
# ── Stream Generator ──────────────────────────────────────────────────────
async def stream_generator(self, flight_id: str, client_id: str):
"""Yields dicts for sse_starlette EventSourceResponse."""
q = self.create_stream(flight_id, client_id)
# Send an immediate connection accepted ping
yield {"event": "connected", "data": "connected"}
try:
while True:
# Wait for next event or send heartbeat every 15s
try:
msg = await asyncio.wait_for(q.get(), timeout=15.0)
if msg is None:
# Sentinel for clean shutdown
break
# Yield dict format for sse_starlette
yield {
"event": msg.event.value,
"id": msg.id if msg.id else "",
"data": json.dumps(msg.data)
}
except asyncio.TimeoutError:
# Heartbeat format for sse_starlette (empty string generates a comment)
yield {"event": "heartbeat", "data": "ping"}
except asyncio.CancelledError:
pass # Client disconnected
finally:
self.close_stream(flight_id, client_id)
+371
View File
@@ -0,0 +1,371 @@
"""Accuracy Benchmark (Phase 7).
Provides:
- SyntheticTrajectory — generates a realistic fixed-wing UAV flight path
with ground-truth GPS + noisy sensor data.
- AccuracyBenchmark — replays a trajectory through the ESKF pipeline
and computes position-error statistics.
Acceptance criteria (from solution.md):
AC-PERF-1: 80 % of frames within 50 m of ground truth.
AC-PERF-2: 60 % of frames within 20 m of ground truth.
AC-PERF-3: End-to-end per-frame latency < 400 ms.
AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m.
"""
from __future__ import annotations
import math
import time
from dataclasses import dataclass, field
from typing import Callable, Optional
import numpy as np
from gps_denied.core.coordinates import CoordinateTransformer
from gps_denied.core.eskf import ESKF
from gps_denied.schemas import GPSPoint
from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement
# ---------------------------------------------------------------------------
# Synthetic trajectory
# ---------------------------------------------------------------------------
@dataclass
class TrajectoryFrame:
"""One simulated camera frame with ground-truth and noisy sensor data."""
frame_id: int
timestamp: float
true_position_enu: np.ndarray # (3,) East, North, Up in metres
true_gps: GPSPoint # WGS84 from true ENU
imu_measurements: list[IMUMeasurement] # High-rate IMU between frames
vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,)
vo_tracking_good: bool = True
@dataclass
class SyntheticTrajectoryConfig:
"""Parameters for trajectory generation."""
# Origin (mission start)
origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0))
altitude_m: float = 600.0 # Constant AGL altitude (m)
# UAV speed and heading
speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing)
heading_deg: float = 45.0 # Initial heading (degrees CW from North)
camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz)
imu_hz: float = 200.0 # IMU sample rate
num_frames: int = 50 # Number of camera frames to simulate
# Noise parameters
vo_noise_m: float = 0.5 # VO translation noise (sigma, metres)
imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²)
imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s)
# Failure injection
vo_failure_frames: list[int] = field(default_factory=list)
# Waypoints for heading changes (ENU East, North metres from origin)
waypoints_enu: list[tuple[float, float]] = field(default_factory=list)
class SyntheticTrajectory:
"""Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors."""
def __init__(self, config: SyntheticTrajectoryConfig | None = None):
self.config = config or SyntheticTrajectoryConfig()
self._coord = CoordinateTransformer()
self._flight_id = "__synthetic__"
self._coord.set_enu_origin(self._flight_id, self.config.origin)
def generate(self) -> list[TrajectoryFrame]:
"""Generate all trajectory frames."""
cfg = self.config
dt_camera = 1.0 / cfg.camera_fps
dt_imu = 1.0 / cfg.imu_hz
imu_steps = int(dt_camera * cfg.imu_hz)
frames: list[TrajectoryFrame] = []
pos = np.array([0.0, 0.0, cfg.altitude_m])
vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps)
prev_pos = pos.copy()
t = time.time()
waypoints = list(cfg.waypoints_enu) # copy
for fid in range(cfg.num_frames):
# --- Waypoint steering ---
if waypoints:
wp_e, wp_n = waypoints[0]
to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0])
dist_wp = np.linalg.norm(to_wp[:2])
if dist_wp < cfg.speed_mps * dt_camera:
waypoints.pop(0)
else:
heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y
vel = np.array([
cfg.speed_mps * math.sin(heading_rad),
cfg.speed_mps * math.cos(heading_rad),
0.0,
])
# --- Simulate IMU between frames ---
imu_list: list[IMUMeasurement] = []
for step in range(imu_steps):
ts = t + step * dt_imu
# Body-frame acceleration (mostly gravity correction, small forward accel)
accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation
gyro_true = np.zeros(3)
imu = IMUMeasurement(
accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise,
gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise,
timestamp=ts,
)
imu_list.append(imu)
# --- Propagate position ---
prev_pos = pos.copy()
pos = pos + vel * dt_camera
t += dt_camera
# --- True GPS from ENU position ---
true_gps = self._coord.enu_to_gps(
self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))
)
# --- VO measurement (relative displacement + noise) ---
true_displacement = pos - prev_pos
vo_tracking_good = fid not in cfg.vo_failure_frames
if vo_tracking_good:
noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m
noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z
else:
noisy_displacement = None
frames.append(TrajectoryFrame(
frame_id=fid,
timestamp=t,
true_position_enu=pos.copy(),
true_gps=true_gps,
imu_measurements=imu_list,
vo_translation=noisy_displacement,
vo_tracking_good=vo_tracking_good,
))
return frames
@staticmethod
def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray:
"""Convert heading (degrees CW from North) to ENU velocity vector."""
rad = math.radians(heading_deg)
return np.array([
speed_mps * math.sin(rad), # East
speed_mps * math.cos(rad), # North
0.0, # Up
])
# ---------------------------------------------------------------------------
# Accuracy Benchmark
# ---------------------------------------------------------------------------
@dataclass
class BenchmarkResult:
"""Position error statistics over a trajectory replay."""
errors_m: list[float] # Per-frame horizontal error in metres
latencies_ms: list[float] # Per-frame process time in ms
frames_total: int
frames_with_good_estimate: int
@property
def p80_error_m(self) -> float:
"""80th percentile position error (metres)."""
return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf")
@property
def p60_error_m(self) -> float:
"""60th percentile position error (metres)."""
return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf")
@property
def median_error_m(self) -> float:
"""Median position error (metres)."""
return float(np.median(self.errors_m)) if self.errors_m else float("inf")
@property
def max_error_m(self) -> float:
return float(max(self.errors_m)) if self.errors_m else float("inf")
@property
def p95_latency_ms(self) -> float:
"""95th percentile frame latency (ms)."""
return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf")
@property
def pct_within_50m(self) -> float:
"""Fraction of frames within 50 m error."""
if not self.errors_m:
return 0.0
return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m)
@property
def pct_within_20m(self) -> float:
"""Fraction of frames within 20 m error."""
if not self.errors_m:
return 0.0
return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m)
def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]:
"""Check all solution.md acceptance criteria.
Returns (overall_pass, per_criterion_dict).
"""
checks = {
"AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80,
"AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60,
"AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0,
}
overall = all(checks.values())
return overall, checks
def summary(self) -> str:
overall, checks = self.passes_acceptance_criteria()
lines = [
f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}",
f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m "
f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m",
f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%",
f"Latency p95: {self.p95_latency_ms:.1f}ms",
"",
"Acceptance criteria:",
]
for criterion, passed in checks.items():
lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}")
lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}")
return "\n".join(lines)
class AccuracyBenchmark:
"""Replays a SyntheticTrajectory through the ESKF and measures accuracy.
The benchmark uses only the ESKF (no full FlightProcessor) for speed.
Satellite corrections are injected optionally via sat_correction_fn.
"""
def __init__(
self,
eskf_config: ESKFConfig | None = None,
sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None,
):
"""
Args:
eskf_config: ESKF tuning parameters.
sat_correction_fn: Optional callback(frame) → ENU position or None.
Called on keyframes to inject satellite corrections.
If None, no satellite corrections are applied.
"""
self.eskf_config = eskf_config or ESKFConfig()
self.sat_correction_fn = sat_correction_fn
def run(
self,
trajectory: list[TrajectoryFrame],
origin: GPSPoint,
satellite_keyframe_interval: int = 7,
) -> BenchmarkResult:
"""Replay trajectory frames through ESKF, collect errors and latencies.
Args:
trajectory: List of TrajectoryFrame (from SyntheticTrajectory).
origin: WGS84 reference origin for ENU.
satellite_keyframe_interval: Apply satellite correction every N frames.
"""
coord = CoordinateTransformer()
flight_id = "__benchmark__"
coord.set_enu_origin(flight_id, origin)
eskf = ESKF(self.eskf_config)
# Init at origin with HIGH uncertainty
eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]),
trajectory[0].timestamp)
errors_m: list[float] = []
latencies_ms: list[float] = []
frames_with_estimate = 0
for frame in trajectory:
t_frame_start = time.perf_counter()
# --- IMU prediction ---
for imu in frame.imu_measurements:
eskf.predict(imu)
# --- VO update ---
if frame.vo_tracking_good and frame.vo_translation is not None:
dt_vo = 1.0 / 0.7 # camera interval
eskf.update_vo(frame.vo_translation, dt_vo)
# --- Satellite update (keyframes) ---
if frame.frame_id % satellite_keyframe_interval == 0:
sat_pos_enu: Optional[np.ndarray] = None
if self.sat_correction_fn is not None:
sat_pos_enu = self.sat_correction_fn(frame)
else:
# Default: inject ground-truth position + realistic noise
noise_m = 10.0
sat_pos_enu = (
frame.true_position_enu[:3]
+ np.random.randn(3) * noise_m
)
sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude
if sat_pos_enu is not None:
# Tell ESKF the measurement noise matches what we inject
eskf.update_satellite(sat_pos_enu, noise_meters=noise_m)
latency_ms = (time.perf_counter() - t_frame_start) * 1000.0
latencies_ms.append(latency_ms)
# --- Compute horizontal error vs ground truth ---
if eskf.initialized and eskf._nominal_state is not None:
est_pos = eskf._nominal_state["position"]
true_pos = frame.true_position_enu
horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2]))
errors_m.append(horiz_error)
frames_with_estimate += 1
else:
errors_m.append(float("inf"))
return BenchmarkResult(
errors_m=errors_m,
latencies_ms=latencies_ms,
frames_total=len(trajectory),
frames_with_good_estimate=frames_with_estimate,
)
def run_vo_drift_test(
self,
trajectory_length_m: float = 1000.0,
speed_mps: float = 20.0,
) -> float:
"""Measure VO drift over a straight segment with NO satellite correction.
Returns final horizontal position error in metres.
Per solution.md, this should be < 100m over 1km.
"""
fps = 0.7
num_frames = max(10, int(trajectory_length_m / speed_mps * fps))
cfg = SyntheticTrajectoryConfig(
speed_mps=speed_mps,
heading_deg=0.0, # straight North
camera_fps=fps,
num_frames=num_frames,
vo_noise_m=0.3, # cuVSLAM-grade VO noise
)
traj_gen = SyntheticTrajectory(cfg)
frames = traj_gen.generate()
# No satellite corrections
benchmark_no_sat = AccuracyBenchmark(
eskf_config=self.eskf_config,
sat_correction_fn=lambda _: None, # suppress all satellite updates
)
result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999)
# Return final-frame error
return result.errors_m[-1] if result.errors_m else float("inf")