mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-06-22 17:21:13 +00:00
refactor(01-07): factor_graph, pipeline pkg, testing/benchmark, Protocol ABCs
- Create core/factor_graph.py: IFactorGraphOptimizer converted to Protocol
- Shim core/graph.py to re-export from core/factor_graph
- Create pipeline/ package: orchestrator, image_input, result_manager, sse_streamer
- Shim core/{processor,pipeline,results,sse}.py to re-export from pipeline/
- Create testing/benchmark.py; shim core/benchmark.py
- Convert IRouteChunkManager, IFailureRecoveryCoordinator, IModelManager, IImageMatcher to Protocol
- Update pyproject.toml ruff per-file-ignores to new paths
- All 216 tests pass (regression floor maintained)
This commit is contained in:
+2
-2
@@ -42,8 +42,8 @@ line-length = 120
|
|||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
# Abstract interfaces have long method signatures — allow up to 170
|
# Abstract interfaces have long method signatures — allow up to 170
|
||||||
"src/gps_denied/core/graph.py" = ["E501"]
|
"src/gps_denied/core/factor_graph.py" = ["E501"]
|
||||||
"src/gps_denied/core/metric.py" = ["E501"]
|
"src/gps_denied/components/satellite_matcher/metric_refinement.py" = ["E501"]
|
||||||
"src/gps_denied/core/chunk_manager.py" = ["E501"]
|
"src/gps_denied/core/chunk_manager.py" = ["E501"]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
|
|||||||
@@ -1,371 +1,8 @@
|
|||||||
"""Accuracy Benchmark (Phase 7).
|
"""Legacy import path. Phase 1 shim — code lives in testing/benchmark.py."""
|
||||||
|
from gps_denied.testing.benchmark import ( # noqa: F401
|
||||||
Provides:
|
AccuracyBenchmark,
|
||||||
- SyntheticTrajectory — generates a realistic fixed-wing UAV flight path
|
BenchmarkResult,
|
||||||
with ground-truth GPS + noisy sensor data.
|
SyntheticTrajectory,
|
||||||
- AccuracyBenchmark — replays a trajectory through the ESKF pipeline
|
SyntheticTrajectoryConfig,
|
||||||
and computes position-error statistics.
|
TrajectoryFrame,
|
||||||
|
|
||||||
Acceptance criteria (from solution.md):
|
|
||||||
AC-PERF-1: 80 % of frames within 50 m of ground truth.
|
|
||||||
AC-PERF-2: 60 % of frames within 20 m of ground truth.
|
|
||||||
AC-PERF-3: End-to-end per-frame latency < 400 ms.
|
|
||||||
AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from gps_denied.core.coordinates import CoordinateTransformer
|
|
||||||
from gps_denied.core.eskf import ESKF
|
|
||||||
from gps_denied.schemas import GPSPoint
|
|
||||||
from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Synthetic trajectory
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrajectoryFrame:
|
|
||||||
"""One simulated camera frame with ground-truth and noisy sensor data."""
|
|
||||||
frame_id: int
|
|
||||||
timestamp: float
|
|
||||||
true_position_enu: np.ndarray # (3,) East, North, Up in metres
|
|
||||||
true_gps: GPSPoint # WGS84 from true ENU
|
|
||||||
imu_measurements: list[IMUMeasurement] # High-rate IMU between frames
|
|
||||||
vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,)
|
|
||||||
vo_tracking_good: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SyntheticTrajectoryConfig:
|
|
||||||
"""Parameters for trajectory generation."""
|
|
||||||
# Origin (mission start)
|
|
||||||
origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0))
|
|
||||||
altitude_m: float = 600.0 # Constant AGL altitude (m)
|
|
||||||
# UAV speed and heading
|
|
||||||
speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing)
|
|
||||||
heading_deg: float = 45.0 # Initial heading (degrees CW from North)
|
|
||||||
camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz)
|
|
||||||
imu_hz: float = 200.0 # IMU sample rate
|
|
||||||
num_frames: int = 50 # Number of camera frames to simulate
|
|
||||||
# Noise parameters
|
|
||||||
vo_noise_m: float = 0.5 # VO translation noise (sigma, metres)
|
|
||||||
imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²)
|
|
||||||
imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s)
|
|
||||||
# Failure injection
|
|
||||||
vo_failure_frames: list[int] = field(default_factory=list)
|
|
||||||
# Waypoints for heading changes (ENU East, North metres from origin)
|
|
||||||
waypoints_enu: list[tuple[float, float]] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class SyntheticTrajectory:
|
|
||||||
"""Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors."""
|
|
||||||
|
|
||||||
def __init__(self, config: SyntheticTrajectoryConfig | None = None):
|
|
||||||
self.config = config or SyntheticTrajectoryConfig()
|
|
||||||
self._coord = CoordinateTransformer()
|
|
||||||
self._flight_id = "__synthetic__"
|
|
||||||
self._coord.set_enu_origin(self._flight_id, self.config.origin)
|
|
||||||
|
|
||||||
def generate(self) -> list[TrajectoryFrame]:
|
|
||||||
"""Generate all trajectory frames."""
|
|
||||||
cfg = self.config
|
|
||||||
dt_camera = 1.0 / cfg.camera_fps
|
|
||||||
dt_imu = 1.0 / cfg.imu_hz
|
|
||||||
imu_steps = int(dt_camera * cfg.imu_hz)
|
|
||||||
|
|
||||||
frames: list[TrajectoryFrame] = []
|
|
||||||
pos = np.array([0.0, 0.0, cfg.altitude_m])
|
|
||||||
vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps)
|
|
||||||
prev_pos = pos.copy()
|
|
||||||
t = time.time()
|
|
||||||
|
|
||||||
waypoints = list(cfg.waypoints_enu) # copy
|
|
||||||
|
|
||||||
for fid in range(cfg.num_frames):
|
|
||||||
# --- Waypoint steering ---
|
|
||||||
if waypoints:
|
|
||||||
wp_e, wp_n = waypoints[0]
|
|
||||||
to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0])
|
|
||||||
dist_wp = np.linalg.norm(to_wp[:2])
|
|
||||||
if dist_wp < cfg.speed_mps * dt_camera:
|
|
||||||
waypoints.pop(0)
|
|
||||||
else:
|
|
||||||
heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y
|
|
||||||
vel = np.array([
|
|
||||||
cfg.speed_mps * math.sin(heading_rad),
|
|
||||||
cfg.speed_mps * math.cos(heading_rad),
|
|
||||||
0.0,
|
|
||||||
])
|
|
||||||
|
|
||||||
# --- Simulate IMU between frames ---
|
|
||||||
imu_list: list[IMUMeasurement] = []
|
|
||||||
for step in range(imu_steps):
|
|
||||||
ts = t + step * dt_imu
|
|
||||||
# Body-frame acceleration (mostly gravity correction, small forward accel)
|
|
||||||
accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation
|
|
||||||
gyro_true = np.zeros(3)
|
|
||||||
imu = IMUMeasurement(
|
|
||||||
accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise,
|
|
||||||
gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise,
|
|
||||||
timestamp=ts,
|
|
||||||
)
|
)
|
||||||
imu_list.append(imu)
|
|
||||||
|
|
||||||
# --- Propagate position ---
|
|
||||||
prev_pos = pos.copy()
|
|
||||||
pos = pos + vel * dt_camera
|
|
||||||
t += dt_camera
|
|
||||||
|
|
||||||
# --- True GPS from ENU position ---
|
|
||||||
true_gps = self._coord.enu_to_gps(
|
|
||||||
self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- VO measurement (relative displacement + noise) ---
|
|
||||||
true_displacement = pos - prev_pos
|
|
||||||
vo_tracking_good = fid not in cfg.vo_failure_frames
|
|
||||||
if vo_tracking_good:
|
|
||||||
noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m
|
|
||||||
noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z
|
|
||||||
else:
|
|
||||||
noisy_displacement = None
|
|
||||||
|
|
||||||
frames.append(TrajectoryFrame(
|
|
||||||
frame_id=fid,
|
|
||||||
timestamp=t,
|
|
||||||
true_position_enu=pos.copy(),
|
|
||||||
true_gps=true_gps,
|
|
||||||
imu_measurements=imu_list,
|
|
||||||
vo_translation=noisy_displacement,
|
|
||||||
vo_tracking_good=vo_tracking_good,
|
|
||||||
))
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray:
|
|
||||||
"""Convert heading (degrees CW from North) to ENU velocity vector."""
|
|
||||||
rad = math.radians(heading_deg)
|
|
||||||
return np.array([
|
|
||||||
speed_mps * math.sin(rad), # East
|
|
||||||
speed_mps * math.cos(rad), # North
|
|
||||||
0.0, # Up
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Accuracy Benchmark
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkResult:
|
|
||||||
"""Position error statistics over a trajectory replay."""
|
|
||||||
errors_m: list[float] # Per-frame horizontal error in metres
|
|
||||||
latencies_ms: list[float] # Per-frame process time in ms
|
|
||||||
frames_total: int
|
|
||||||
frames_with_good_estimate: int
|
|
||||||
|
|
||||||
@property
|
|
||||||
def p80_error_m(self) -> float:
|
|
||||||
"""80th percentile position error (metres)."""
|
|
||||||
return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def p60_error_m(self) -> float:
|
|
||||||
"""60th percentile position error (metres)."""
|
|
||||||
return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def median_error_m(self) -> float:
|
|
||||||
"""Median position error (metres)."""
|
|
||||||
return float(np.median(self.errors_m)) if self.errors_m else float("inf")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_error_m(self) -> float:
|
|
||||||
return float(max(self.errors_m)) if self.errors_m else float("inf")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def p95_latency_ms(self) -> float:
|
|
||||||
"""95th percentile frame latency (ms)."""
|
|
||||||
return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pct_within_50m(self) -> float:
|
|
||||||
"""Fraction of frames within 50 m error."""
|
|
||||||
if not self.errors_m:
|
|
||||||
return 0.0
|
|
||||||
return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pct_within_20m(self) -> float:
|
|
||||||
"""Fraction of frames within 20 m error."""
|
|
||||||
if not self.errors_m:
|
|
||||||
return 0.0
|
|
||||||
return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m)
|
|
||||||
|
|
||||||
def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]:
|
|
||||||
"""Check all solution.md acceptance criteria.
|
|
||||||
|
|
||||||
Returns (overall_pass, per_criterion_dict).
|
|
||||||
"""
|
|
||||||
checks = {
|
|
||||||
"AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80,
|
|
||||||
"AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60,
|
|
||||||
"AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0,
|
|
||||||
}
|
|
||||||
overall = all(checks.values())
|
|
||||||
return overall, checks
|
|
||||||
|
|
||||||
def summary(self) -> str:
|
|
||||||
overall, checks = self.passes_acceptance_criteria()
|
|
||||||
lines = [
|
|
||||||
f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}",
|
|
||||||
f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m "
|
|
||||||
f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m",
|
|
||||||
f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%",
|
|
||||||
f"Latency p95: {self.p95_latency_ms:.1f}ms",
|
|
||||||
"",
|
|
||||||
"Acceptance criteria:",
|
|
||||||
]
|
|
||||||
for criterion, passed in checks.items():
|
|
||||||
lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}")
|
|
||||||
lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}")
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
class AccuracyBenchmark:
|
|
||||||
"""Replays a SyntheticTrajectory through the ESKF and measures accuracy.
|
|
||||||
|
|
||||||
The benchmark uses only the ESKF (no full FlightProcessor) for speed.
|
|
||||||
Satellite corrections are injected optionally via sat_correction_fn.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
eskf_config: ESKFConfig | None = None,
|
|
||||||
sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
eskf_config: ESKF tuning parameters.
|
|
||||||
sat_correction_fn: Optional callback(frame) → ENU position or None.
|
|
||||||
Called on keyframes to inject satellite corrections.
|
|
||||||
If None, no satellite corrections are applied.
|
|
||||||
"""
|
|
||||||
self.eskf_config = eskf_config or ESKFConfig()
|
|
||||||
self.sat_correction_fn = sat_correction_fn
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
trajectory: list[TrajectoryFrame],
|
|
||||||
origin: GPSPoint,
|
|
||||||
satellite_keyframe_interval: int = 7,
|
|
||||||
) -> BenchmarkResult:
|
|
||||||
"""Replay trajectory frames through ESKF, collect errors and latencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trajectory: List of TrajectoryFrame (from SyntheticTrajectory).
|
|
||||||
origin: WGS84 reference origin for ENU.
|
|
||||||
satellite_keyframe_interval: Apply satellite correction every N frames.
|
|
||||||
"""
|
|
||||||
coord = CoordinateTransformer()
|
|
||||||
flight_id = "__benchmark__"
|
|
||||||
coord.set_enu_origin(flight_id, origin)
|
|
||||||
|
|
||||||
eskf = ESKF(self.eskf_config)
|
|
||||||
# Init at origin with HIGH uncertainty
|
|
||||||
eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]),
|
|
||||||
trajectory[0].timestamp)
|
|
||||||
|
|
||||||
errors_m: list[float] = []
|
|
||||||
latencies_ms: list[float] = []
|
|
||||||
frames_with_estimate = 0
|
|
||||||
|
|
||||||
for frame in trajectory:
|
|
||||||
t_frame_start = time.perf_counter()
|
|
||||||
|
|
||||||
# --- IMU prediction ---
|
|
||||||
for imu in frame.imu_measurements:
|
|
||||||
eskf.predict(imu)
|
|
||||||
|
|
||||||
# --- VO update ---
|
|
||||||
if frame.vo_tracking_good and frame.vo_translation is not None:
|
|
||||||
dt_vo = 1.0 / 0.7 # camera interval
|
|
||||||
eskf.update_vo(frame.vo_translation, dt_vo)
|
|
||||||
|
|
||||||
# --- Satellite update (keyframes) ---
|
|
||||||
if frame.frame_id % satellite_keyframe_interval == 0:
|
|
||||||
sat_pos_enu: Optional[np.ndarray] = None
|
|
||||||
if self.sat_correction_fn is not None:
|
|
||||||
sat_pos_enu = self.sat_correction_fn(frame)
|
|
||||||
else:
|
|
||||||
# Default: inject ground-truth position + realistic noise
|
|
||||||
noise_m = 10.0
|
|
||||||
sat_pos_enu = (
|
|
||||||
frame.true_position_enu[:3]
|
|
||||||
+ np.random.randn(3) * noise_m
|
|
||||||
)
|
|
||||||
sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude
|
|
||||||
|
|
||||||
if sat_pos_enu is not None:
|
|
||||||
# Tell ESKF the measurement noise matches what we inject
|
|
||||||
eskf.update_satellite(sat_pos_enu, noise_meters=noise_m)
|
|
||||||
|
|
||||||
latency_ms = (time.perf_counter() - t_frame_start) * 1000.0
|
|
||||||
latencies_ms.append(latency_ms)
|
|
||||||
|
|
||||||
# --- Compute horizontal error vs ground truth ---
|
|
||||||
if eskf.initialized and eskf._nominal_state is not None:
|
|
||||||
est_pos = eskf._nominal_state["position"]
|
|
||||||
true_pos = frame.true_position_enu
|
|
||||||
horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2]))
|
|
||||||
errors_m.append(horiz_error)
|
|
||||||
frames_with_estimate += 1
|
|
||||||
else:
|
|
||||||
errors_m.append(float("inf"))
|
|
||||||
|
|
||||||
return BenchmarkResult(
|
|
||||||
errors_m=errors_m,
|
|
||||||
latencies_ms=latencies_ms,
|
|
||||||
frames_total=len(trajectory),
|
|
||||||
frames_with_good_estimate=frames_with_estimate,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_vo_drift_test(
|
|
||||||
self,
|
|
||||||
trajectory_length_m: float = 1000.0,
|
|
||||||
speed_mps: float = 20.0,
|
|
||||||
) -> float:
|
|
||||||
"""Measure VO drift over a straight segment with NO satellite correction.
|
|
||||||
|
|
||||||
Returns final horizontal position error in metres.
|
|
||||||
Per solution.md, this should be < 100m over 1km.
|
|
||||||
"""
|
|
||||||
fps = 0.7
|
|
||||||
num_frames = max(10, int(trajectory_length_m / speed_mps * fps))
|
|
||||||
cfg = SyntheticTrajectoryConfig(
|
|
||||||
speed_mps=speed_mps,
|
|
||||||
heading_deg=0.0, # straight North
|
|
||||||
camera_fps=fps,
|
|
||||||
num_frames=num_frames,
|
|
||||||
vo_noise_m=0.3, # cuVSLAM-grade VO noise
|
|
||||||
)
|
|
||||||
traj_gen = SyntheticTrajectory(cfg)
|
|
||||||
frames = traj_gen.generate()
|
|
||||||
|
|
||||||
# No satellite corrections
|
|
||||||
benchmark_no_sat = AccuracyBenchmark(
|
|
||||||
eskf_config=self.eskf_config,
|
|
||||||
sat_correction_fn=lambda _: None, # suppress all satellite updates
|
|
||||||
)
|
|
||||||
result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999)
|
|
||||||
# Return final-frame error
|
|
||||||
return result.errors_m[-1] if result.errors_m else float("inf")
|
|
||||||
|
|||||||
@@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from typing import Dict, List, Optional, Protocol, runtime_checkable
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from gps_denied.core.graph import IFactorGraphOptimizer
|
from gps_denied.core.graph import IFactorGraphOptimizer
|
||||||
from gps_denied.schemas.chunk import ChunkHandle, ChunkStatus
|
from gps_denied.schemas.chunk import ChunkHandle, ChunkStatus
|
||||||
@@ -12,30 +11,25 @@ from gps_denied.schemas.metric import Sim3Transform
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class IRouteChunkManager(ABC):
|
@runtime_checkable
|
||||||
@abstractmethod
|
class IRouteChunkManager(Protocol):
|
||||||
def create_new_chunk(self, flight_id: str, start_frame_id: int) -> ChunkHandle:
|
def create_new_chunk(self, flight_id: str, start_frame_id: int) -> ChunkHandle:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_active_chunk(self, flight_id: str) -> Optional[ChunkHandle]:
|
def get_active_chunk(self, flight_id: str) -> Optional[ChunkHandle]:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all_chunks(self, flight_id: str) -> List[ChunkHandle]:
|
def get_all_chunks(self, flight_id: str) -> List[ChunkHandle]:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_frame_to_chunk(self, flight_id: str, frame_id: int) -> bool:
|
def add_frame_to_chunk(self, flight_id: str, frame_id: int) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_chunk_status(self, flight_id: str, chunk_id: str, status: ChunkStatus) -> bool:
|
def update_chunk_status(self, flight_id: str, chunk_id: str, status: ChunkStatus) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def merge_chunks(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
def merge_chunks(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class RouteChunkManager(IRouteChunkManager):
|
class RouteChunkManager(IRouteChunkManager):
|
||||||
|
|||||||
@@ -0,0 +1,350 @@
|
|||||||
|
"""Factor Graph Optimizer (Component F10)."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gtsam
|
||||||
|
HAS_GTSAM = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_GTSAM = False
|
||||||
|
|
||||||
|
from gps_denied.schemas import GPSPoint
|
||||||
|
from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose
|
||||||
|
from gps_denied.schemas.metric import Sim3Transform
|
||||||
|
from gps_denied.schemas.vo import RelativePose
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class IFactorGraphOptimizer(Protocol):
|
||||||
|
"""GTSAM-based factor graph optimizer."""
|
||||||
|
|
||||||
|
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
||||||
|
...
|
||||||
|
|
||||||
|
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
||||||
|
...
|
||||||
|
|
||||||
|
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete_flight_graph(self, flight_id: str) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class FactorGraphOptimizer(IFactorGraphOptimizer):
|
||||||
|
"""Implementation of F10 Factor Graph using GTSAM or Mock."""
|
||||||
|
|
||||||
|
def __init__(self, config: FactorGraphConfig):
|
||||||
|
self.config = config
|
||||||
|
# Keyed by flight_id
|
||||||
|
self._flights_state: Dict[str, dict] = {}
|
||||||
|
# Keyed by chunk_id
|
||||||
|
self._chunks_state: Dict[str, dict] = {}
|
||||||
|
# Per-flight ENU origin (set from first absolute GPS factor)
|
||||||
|
self._enu_origins: Dict[str, GPSPoint] = {}
|
||||||
|
|
||||||
|
def _init_flight(self, flight_id: str):
|
||||||
|
if flight_id not in self._flights_state:
|
||||||
|
self._flights_state[flight_id] = {
|
||||||
|
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
||||||
|
"initial": gtsam.Values() if HAS_GTSAM else None,
|
||||||
|
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
||||||
|
"poses": {},
|
||||||
|
"dirty": False
|
||||||
|
}
|
||||||
|
|
||||||
|
def _init_chunk(self, chunk_id: str):
|
||||||
|
if chunk_id not in self._chunks_state:
|
||||||
|
self._chunks_state[chunk_id] = {
|
||||||
|
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
||||||
|
"initial": gtsam.Values() if HAS_GTSAM else None,
|
||||||
|
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
||||||
|
"poses": {},
|
||||||
|
"dirty": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# ================== MOCK IMPLEMENTATION ====================
|
||||||
|
# As GTSAM Python bindings can be extremely context-dependent and
|
||||||
|
# require proper ENU translation logic, we use an advanced Mock
|
||||||
|
# that satisfies the architectural design and typing for the backend.
|
||||||
|
|
||||||
|
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||||
|
self._init_flight(flight_id)
|
||||||
|
state = self._flights_state[flight_id]
|
||||||
|
|
||||||
|
# --- Mock: propagate position chain ---
|
||||||
|
if frame_i in state["poses"]:
|
||||||
|
prev_pose = state["poses"][frame_i]
|
||||||
|
new_pos = prev_pose.position + relative_pose.translation
|
||||||
|
state["poses"][frame_j] = Pose(
|
||||||
|
frame_id=frame_j,
|
||||||
|
position=new_pos,
|
||||||
|
orientation=np.eye(3),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
covariance=np.eye(6),
|
||||||
|
)
|
||||||
|
state["dirty"] = True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# --- GTSAM: add BetweenFactorPose3 ---
|
||||||
|
if HAS_GTSAM and state["graph"] is not None:
|
||||||
|
try:
|
||||||
|
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
||||||
|
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
||||||
|
key_i = gtsam.symbol("x", frame_i)
|
||||||
|
key_j = gtsam.symbol("x", frame_j)
|
||||||
|
t = relative_pose.translation
|
||||||
|
between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2])))
|
||||||
|
state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise))
|
||||||
|
if not state["initial"].exists(key_j):
|
||||||
|
if state["initial"].exists(key_i):
|
||||||
|
prev = state["initial"].atPose3(key_i)
|
||||||
|
pt = prev.translation()
|
||||||
|
new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2])
|
||||||
|
else:
|
||||||
|
new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))
|
||||||
|
state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("GTSAM add_relative_factor failed: %s", exc)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray:
|
||||||
|
"""Convert GPS to local ENU using per-flight origin."""
|
||||||
|
origin = self._enu_origins.get(flight_id)
|
||||||
|
if origin is None:
|
||||||
|
# First GPS factor sets the origin
|
||||||
|
self._enu_origins[flight_id] = gps
|
||||||
|
return np.zeros(3)
|
||||||
|
enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat))
|
||||||
|
enu_y = (gps.lat - origin.lat) * 111000
|
||||||
|
return np.array([enu_x, enu_y, 0.0])
|
||||||
|
|
||||||
|
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
||||||
|
self._init_flight(flight_id)
|
||||||
|
state = self._flights_state[flight_id]
|
||||||
|
|
||||||
|
enu = self._gps_to_enu(flight_id, gps)
|
||||||
|
|
||||||
|
# --- Mock: update pose position ---
|
||||||
|
if frame_id in state["poses"]:
|
||||||
|
state["poses"][frame_id].position = enu
|
||||||
|
state["dirty"] = True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# --- GTSAM: add PriorFactorPose3 ---
|
||||||
|
if HAS_GTSAM and state["graph"] is not None:
|
||||||
|
try:
|
||||||
|
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
||||||
|
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
||||||
|
key = gtsam.symbol("x", frame_id)
|
||||||
|
prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2])))
|
||||||
|
state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise))
|
||||||
|
if not state["initial"].exists(key):
|
||||||
|
state["initial"].insert(key, prior)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("GTSAM add_absolute_factor failed: %s", exc)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
||||||
|
self._init_flight(flight_id)
|
||||||
|
state = self._flights_state[flight_id]
|
||||||
|
|
||||||
|
if frame_id in state["poses"]:
|
||||||
|
state["poses"][frame_id].position = np.array([
|
||||||
|
state["poses"][frame_id].position[0],
|
||||||
|
state["poses"][frame_id].position[1],
|
||||||
|
altitude,
|
||||||
|
])
|
||||||
|
state["dirty"] = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||||
|
self._init_flight(flight_id)
|
||||||
|
state = self._flights_state[flight_id]
|
||||||
|
|
||||||
|
# --- PIPE-03: Real GTSAM ISAM2 update when available ---
|
||||||
|
if HAS_GTSAM and state["dirty"] and state["graph"] is not None:
|
||||||
|
try:
|
||||||
|
state["isam"].update(state["graph"], state["initial"])
|
||||||
|
estimate = state["isam"].calculateEstimate()
|
||||||
|
for fid in list(state["poses"].keys()):
|
||||||
|
key = gtsam.symbol("x", fid)
|
||||||
|
if estimate.exists(key):
|
||||||
|
pose = estimate.atPose3(key)
|
||||||
|
t = pose.translation()
|
||||||
|
state["poses"][fid].position = np.array([t[0], t[1], t[2]])
|
||||||
|
state["poses"][fid].orientation = np.array(pose.rotation().matrix())
|
||||||
|
# Reset for next incremental batch
|
||||||
|
state["graph"] = gtsam.NonlinearFactorGraph()
|
||||||
|
state["initial"] = gtsam.Values()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("GTSAM ISAM2 update failed: %s", exc)
|
||||||
|
|
||||||
|
state["dirty"] = False
|
||||||
|
return OptimizationResult(
|
||||||
|
converged=True,
|
||||||
|
final_error=0.1,
|
||||||
|
iterations_used=iterations,
|
||||||
|
optimized_frames=list(state["poses"].keys()),
|
||||||
|
mean_reprojection_error=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
||||||
|
if flight_id not in self._flights_state:
|
||||||
|
return {}
|
||||||
|
return self._flights_state[flight_id]["poses"]
|
||||||
|
|
||||||
|
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
||||||
|
return np.eye(6)
|
||||||
|
|
||||||
|
# ================== CHUNK OPERATIONS =======================
|
||||||
|
|
||||||
|
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
||||||
|
self._init_chunk(chunk_id)
|
||||||
|
state = self._chunks_state[chunk_id]
|
||||||
|
|
||||||
|
state["poses"][start_frame_id] = Pose(
|
||||||
|
frame_id=start_frame_id,
|
||||||
|
position=np.zeros(3),
|
||||||
|
orientation=np.eye(3),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
covariance=np.eye(6)
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
||||||
|
if chunk_id not in self._chunks_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = self._chunks_state[chunk_id]
|
||||||
|
if frame_i in state["poses"]:
|
||||||
|
prev_pose = state["poses"][frame_i]
|
||||||
|
new_pos = prev_pose.position + relative_pose.translation
|
||||||
|
|
||||||
|
state["poses"][frame_j] = Pose(
|
||||||
|
frame_id=frame_j,
|
||||||
|
position=new_pos,
|
||||||
|
orientation=np.eye(3),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
covariance=np.eye(6)
|
||||||
|
)
|
||||||
|
state["dirty"] = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
||||||
|
if chunk_id not in self._chunks_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = self._chunks_state[chunk_id]
|
||||||
|
if frame_id in state["poses"]:
|
||||||
|
enu = self._gps_to_enu(flight_id, gps)
|
||||||
|
state["poses"][frame_id].position = enu
|
||||||
|
state["dirty"] = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
||||||
|
if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
new_state = self._chunks_state[new_chunk_id]
|
||||||
|
main_state = self._chunks_state[main_chunk_id]
|
||||||
|
|
||||||
|
# Apply Sim(3) transform effectively by copying poses
|
||||||
|
for f_id, p in new_state["poses"].items():
|
||||||
|
# mock sim3 transform
|
||||||
|
idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation
|
||||||
|
|
||||||
|
main_state["poses"][f_id] = Pose(
|
||||||
|
frame_id=f_id,
|
||||||
|
position=idx_pos,
|
||||||
|
orientation=np.eye(3),
|
||||||
|
timestamp=p.timestamp,
|
||||||
|
covariance=p.covariance
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
||||||
|
if chunk_id not in self._chunks_state:
|
||||||
|
return {}
|
||||||
|
return self._chunks_state[chunk_id]["poses"]
|
||||||
|
|
||||||
|
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
||||||
|
if chunk_id not in self._chunks_state:
|
||||||
|
return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0)
|
||||||
|
|
||||||
|
state = self._chunks_state[chunk_id]
|
||||||
|
state["dirty"] = False
|
||||||
|
|
||||||
|
return OptimizationResult(
|
||||||
|
converged=True,
|
||||||
|
final_error=0.1,
|
||||||
|
iterations_used=iterations,
|
||||||
|
optimized_frames=list(state["poses"].keys()),
|
||||||
|
mean_reprojection_error=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
||||||
|
# Optimizes everything
|
||||||
|
self._init_flight(flight_id)
|
||||||
|
state = self._flights_state[flight_id]
|
||||||
|
state["dirty"] = False
|
||||||
|
|
||||||
|
return OptimizationResult(
|
||||||
|
converged=True,
|
||||||
|
final_error=0.1,
|
||||||
|
iterations_used=iterations,
|
||||||
|
optimized_frames=list(state["poses"].keys()),
|
||||||
|
mean_reprojection_error=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_flight_graph(self, flight_id: str) -> bool:
|
||||||
|
removed = False
|
||||||
|
if flight_id in self._flights_state:
|
||||||
|
del self._flights_state[flight_id]
|
||||||
|
removed = True
|
||||||
|
self._enu_origins.pop(flight_id, None)
|
||||||
|
return removed
|
||||||
@@ -1,364 +1,5 @@
|
|||||||
"""Factor Graph Optimizer (Component F10)."""
|
"""Legacy import path. Phase 1 shim — code lives in core/factor_graph.py."""
|
||||||
|
from gps_denied.core.factor_graph import ( # noqa: F401
|
||||||
import logging
|
IFactorGraphOptimizer,
|
||||||
from abc import ABC, abstractmethod
|
FactorGraphOptimizer,
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
try:
|
|
||||||
import gtsam
|
|
||||||
HAS_GTSAM = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_GTSAM = False
|
|
||||||
|
|
||||||
from gps_denied.schemas import GPSPoint
|
|
||||||
from gps_denied.schemas.graph import FactorGraphConfig, OptimizationResult, Pose
|
|
||||||
from gps_denied.schemas.metric import Sim3Transform
|
|
||||||
from gps_denied.schemas.vo import RelativePose
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class IFactorGraphOptimizer(ABC):
|
|
||||||
"""GTSAM-based factor graph optimizer."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_flight_graph(self, flight_id: str) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FactorGraphOptimizer(IFactorGraphOptimizer):
|
|
||||||
"""Implementation of F10 Factor Graph using GTSAM or Mock."""
|
|
||||||
|
|
||||||
def __init__(self, config: FactorGraphConfig):
|
|
||||||
self.config = config
|
|
||||||
# Keyed by flight_id
|
|
||||||
self._flights_state: Dict[str, dict] = {}
|
|
||||||
# Keyed by chunk_id
|
|
||||||
self._chunks_state: Dict[str, dict] = {}
|
|
||||||
# Per-flight ENU origin (set from first absolute GPS factor)
|
|
||||||
self._enu_origins: Dict[str, GPSPoint] = {}
|
|
||||||
|
|
||||||
def _init_flight(self, flight_id: str):
|
|
||||||
if flight_id not in self._flights_state:
|
|
||||||
self._flights_state[flight_id] = {
|
|
||||||
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
|
||||||
"initial": gtsam.Values() if HAS_GTSAM else None,
|
|
||||||
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
|
||||||
"poses": {},
|
|
||||||
"dirty": False
|
|
||||||
}
|
|
||||||
|
|
||||||
def _init_chunk(self, chunk_id: str):
|
|
||||||
if chunk_id not in self._chunks_state:
|
|
||||||
self._chunks_state[chunk_id] = {
|
|
||||||
"graph": gtsam.NonlinearFactorGraph() if HAS_GTSAM else None,
|
|
||||||
"initial": gtsam.Values() if HAS_GTSAM else None,
|
|
||||||
"isam": gtsam.ISAM2() if HAS_GTSAM else None,
|
|
||||||
"poses": {},
|
|
||||||
"dirty": False
|
|
||||||
}
|
|
||||||
|
|
||||||
# ================== MOCK IMPLEMENTATION ====================
|
|
||||||
# As GTSAM Python bindings can be extremely context-dependent and
|
|
||||||
# require proper ENU translation logic, we use an advanced Mock
|
|
||||||
# that satisfies the architectural design and typing for the backend.
|
|
||||||
|
|
||||||
def add_relative_factor(self, flight_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
|
||||||
self._init_flight(flight_id)
|
|
||||||
state = self._flights_state[flight_id]
|
|
||||||
|
|
||||||
# --- Mock: propagate position chain ---
|
|
||||||
if frame_i in state["poses"]:
|
|
||||||
prev_pose = state["poses"][frame_i]
|
|
||||||
new_pos = prev_pose.position + relative_pose.translation
|
|
||||||
state["poses"][frame_j] = Pose(
|
|
||||||
frame_id=frame_j,
|
|
||||||
position=new_pos,
|
|
||||||
orientation=np.eye(3),
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
covariance=np.eye(6),
|
|
||||||
)
|
)
|
||||||
state["dirty"] = True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# --- GTSAM: add BetweenFactorPose3 ---
|
|
||||||
if HAS_GTSAM and state["graph"] is not None:
|
|
||||||
try:
|
|
||||||
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
|
||||||
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
|
||||||
key_i = gtsam.symbol("x", frame_i)
|
|
||||||
key_j = gtsam.symbol("x", frame_j)
|
|
||||||
t = relative_pose.translation
|
|
||||||
between = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(t[0]), float(t[1]), float(t[2])))
|
|
||||||
state["graph"].add(gtsam.BetweenFactorPose3(key_i, key_j, between, noise))
|
|
||||||
if not state["initial"].exists(key_j):
|
|
||||||
if state["initial"].exists(key_i):
|
|
||||||
prev = state["initial"].atPose3(key_i)
|
|
||||||
pt = prev.translation()
|
|
||||||
new_t = gtsam.Point3(pt[0] + t[0], pt[1] + t[1], pt[2] + t[2])
|
|
||||||
else:
|
|
||||||
new_t = gtsam.Point3(float(t[0]), float(t[1]), float(t[2]))
|
|
||||||
state["initial"].insert(key_j, gtsam.Pose3(gtsam.Rot3(), new_t))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("GTSAM add_relative_factor failed: %s", exc)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _gps_to_enu(self, flight_id: str, gps: GPSPoint) -> np.ndarray:
|
|
||||||
"""Convert GPS to local ENU using per-flight origin."""
|
|
||||||
origin = self._enu_origins.get(flight_id)
|
|
||||||
if origin is None:
|
|
||||||
# First GPS factor sets the origin
|
|
||||||
self._enu_origins[flight_id] = gps
|
|
||||||
return np.zeros(3)
|
|
||||||
enu_x = (gps.lon - origin.lon) * 111000 * np.cos(np.radians(origin.lat))
|
|
||||||
enu_y = (gps.lat - origin.lat) * 111000
|
|
||||||
return np.array([enu_x, enu_y, 0.0])
|
|
||||||
|
|
||||||
def add_absolute_factor(self, flight_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray, is_user_anchor: bool) -> bool:
|
|
||||||
self._init_flight(flight_id)
|
|
||||||
state = self._flights_state[flight_id]
|
|
||||||
|
|
||||||
enu = self._gps_to_enu(flight_id, gps)
|
|
||||||
|
|
||||||
# --- Mock: update pose position ---
|
|
||||||
if frame_id in state["poses"]:
|
|
||||||
state["poses"][frame_id].position = enu
|
|
||||||
state["dirty"] = True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# --- GTSAM: add PriorFactorPose3 ---
|
|
||||||
if HAS_GTSAM and state["graph"] is not None:
|
|
||||||
try:
|
|
||||||
cov6 = covariance if covariance.shape == (6, 6) else np.eye(6)
|
|
||||||
noise = gtsam.noiseModel.Gaussian.Covariance(cov6)
|
|
||||||
key = gtsam.symbol("x", frame_id)
|
|
||||||
prior = gtsam.Pose3(gtsam.Rot3(), gtsam.Point3(float(enu[0]), float(enu[1]), float(enu[2])))
|
|
||||||
state["graph"].add(gtsam.PriorFactorPose3(key, prior, noise))
|
|
||||||
if not state["initial"].exists(key):
|
|
||||||
state["initial"].insert(key, prior)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("GTSAM add_absolute_factor failed: %s", exc)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def add_altitude_prior(self, flight_id: str, frame_id: int, altitude: float, covariance: float) -> bool:
|
|
||||||
self._init_flight(flight_id)
|
|
||||||
state = self._flights_state[flight_id]
|
|
||||||
|
|
||||||
if frame_id in state["poses"]:
|
|
||||||
state["poses"][frame_id].position = np.array([
|
|
||||||
state["poses"][frame_id].position[0],
|
|
||||||
state["poses"][frame_id].position[1],
|
|
||||||
altitude,
|
|
||||||
])
|
|
||||||
state["dirty"] = True
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def optimize(self, flight_id: str, iterations: int) -> OptimizationResult:
|
|
||||||
self._init_flight(flight_id)
|
|
||||||
state = self._flights_state[flight_id]
|
|
||||||
|
|
||||||
# --- PIPE-03: Real GTSAM ISAM2 update when available ---
|
|
||||||
if HAS_GTSAM and state["dirty"] and state["graph"] is not None:
|
|
||||||
try:
|
|
||||||
state["isam"].update(state["graph"], state["initial"])
|
|
||||||
estimate = state["isam"].calculateEstimate()
|
|
||||||
for fid in list(state["poses"].keys()):
|
|
||||||
key = gtsam.symbol("x", fid)
|
|
||||||
if estimate.exists(key):
|
|
||||||
pose = estimate.atPose3(key)
|
|
||||||
t = pose.translation()
|
|
||||||
state["poses"][fid].position = np.array([t[0], t[1], t[2]])
|
|
||||||
state["poses"][fid].orientation = np.array(pose.rotation().matrix())
|
|
||||||
# Reset for next incremental batch
|
|
||||||
state["graph"] = gtsam.NonlinearFactorGraph()
|
|
||||||
state["initial"] = gtsam.Values()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("GTSAM ISAM2 update failed: %s", exc)
|
|
||||||
|
|
||||||
state["dirty"] = False
|
|
||||||
return OptimizationResult(
|
|
||||||
converged=True,
|
|
||||||
final_error=0.1,
|
|
||||||
iterations_used=iterations,
|
|
||||||
optimized_frames=list(state["poses"].keys()),
|
|
||||||
mean_reprojection_error=0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_trajectory(self, flight_id: str) -> Dict[int, Pose]:
|
|
||||||
if flight_id not in self._flights_state:
|
|
||||||
return {}
|
|
||||||
return self._flights_state[flight_id]["poses"]
|
|
||||||
|
|
||||||
def get_marginal_covariance(self, flight_id: str, frame_id: int) -> np.ndarray:
|
|
||||||
return np.eye(6)
|
|
||||||
|
|
||||||
# ================== CHUNK OPERATIONS =======================
|
|
||||||
|
|
||||||
def create_chunk_subgraph(self, flight_id: str, chunk_id: str, start_frame_id: int) -> bool:
|
|
||||||
self._init_chunk(chunk_id)
|
|
||||||
state = self._chunks_state[chunk_id]
|
|
||||||
|
|
||||||
state["poses"][start_frame_id] = Pose(
|
|
||||||
frame_id=start_frame_id,
|
|
||||||
position=np.zeros(3),
|
|
||||||
orientation=np.eye(3),
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
covariance=np.eye(6)
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def add_relative_factor_to_chunk(self, flight_id: str, chunk_id: str, frame_i: int, frame_j: int, relative_pose: RelativePose, covariance: np.ndarray) -> bool:
|
|
||||||
if chunk_id not in self._chunks_state:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = self._chunks_state[chunk_id]
|
|
||||||
if frame_i in state["poses"]:
|
|
||||||
prev_pose = state["poses"][frame_i]
|
|
||||||
new_pos = prev_pose.position + relative_pose.translation
|
|
||||||
|
|
||||||
state["poses"][frame_j] = Pose(
|
|
||||||
frame_id=frame_j,
|
|
||||||
position=new_pos,
|
|
||||||
orientation=np.eye(3),
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
covariance=np.eye(6)
|
|
||||||
)
|
|
||||||
state["dirty"] = True
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_chunk_anchor(self, flight_id: str, chunk_id: str, frame_id: int, gps: GPSPoint, covariance: np.ndarray) -> bool:
|
|
||||||
if chunk_id not in self._chunks_state:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = self._chunks_state[chunk_id]
|
|
||||||
if frame_id in state["poses"]:
|
|
||||||
enu = self._gps_to_enu(flight_id, gps)
|
|
||||||
state["poses"][frame_id].position = enu
|
|
||||||
state["dirty"] = True
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def merge_chunk_subgraphs(self, flight_id: str, new_chunk_id: str, main_chunk_id: str, transform: Sim3Transform) -> bool:
|
|
||||||
if new_chunk_id not in self._chunks_state or main_chunk_id not in self._chunks_state:
|
|
||||||
return False
|
|
||||||
|
|
||||||
new_state = self._chunks_state[new_chunk_id]
|
|
||||||
main_state = self._chunks_state[main_chunk_id]
|
|
||||||
|
|
||||||
# Apply Sim(3) transform effectively by copying poses
|
|
||||||
for f_id, p in new_state["poses"].items():
|
|
||||||
# mock sim3 transform
|
|
||||||
idx_pos = (transform.scale * (transform.rotation @ p.position)) + transform.translation
|
|
||||||
|
|
||||||
main_state["poses"][f_id] = Pose(
|
|
||||||
frame_id=f_id,
|
|
||||||
position=idx_pos,
|
|
||||||
orientation=np.eye(3),
|
|
||||||
timestamp=p.timestamp,
|
|
||||||
covariance=p.covariance
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_chunk_trajectory(self, flight_id: str, chunk_id: str) -> Dict[int, Pose]:
|
|
||||||
if chunk_id not in self._chunks_state:
|
|
||||||
return {}
|
|
||||||
return self._chunks_state[chunk_id]["poses"]
|
|
||||||
|
|
||||||
def optimize_chunk(self, flight_id: str, chunk_id: str, iterations: int) -> OptimizationResult:
|
|
||||||
if chunk_id not in self._chunks_state:
|
|
||||||
return OptimizationResult(converged=False, final_error=99.0, iterations_used=0, optimized_frames=[], mean_reprojection_error=99.0)
|
|
||||||
|
|
||||||
state = self._chunks_state[chunk_id]
|
|
||||||
state["dirty"] = False
|
|
||||||
|
|
||||||
return OptimizationResult(
|
|
||||||
converged=True,
|
|
||||||
final_error=0.1,
|
|
||||||
iterations_used=iterations,
|
|
||||||
optimized_frames=list(state["poses"].keys()),
|
|
||||||
mean_reprojection_error=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
def optimize_global(self, flight_id: str, iterations: int) -> OptimizationResult:
|
|
||||||
# Optimizes everything
|
|
||||||
self._init_flight(flight_id)
|
|
||||||
state = self._flights_state[flight_id]
|
|
||||||
state["dirty"] = False
|
|
||||||
|
|
||||||
return OptimizationResult(
|
|
||||||
converged=True,
|
|
||||||
final_error=0.1,
|
|
||||||
iterations_used=iterations,
|
|
||||||
optimized_frames=list(state["poses"].keys()),
|
|
||||||
mean_reprojection_error=0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_flight_graph(self, flight_id: str) -> bool:
|
|
||||||
removed = False
|
|
||||||
if flight_id in self._flights_state:
|
|
||||||
del self._flights_state[flight_id]
|
|
||||||
removed = True
|
|
||||||
self._enu_origins.pop(flight_id, None)
|
|
||||||
return removed
|
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ file is available, otherwise falls back to Mock.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from typing import Any, Protocol, runtime_checkable
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -31,26 +30,22 @@ def _is_jetson() -> bool:
|
|||||||
return os.path.exists("/sys/bus/platform/drivers/tegra-se-nvhost")
|
return os.path.exists("/sys/bus/platform/drivers/tegra-se-nvhost")
|
||||||
|
|
||||||
|
|
||||||
class IModelManager(ABC):
|
@runtime_checkable
|
||||||
@abstractmethod
|
class IModelManager(Protocol):
|
||||||
def load_model(self, model_name: str, model_format: str) -> bool:
|
def load_model(self, model_name: str, model_format: str) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_inference_engine(self, model_name: str) -> InferenceEngine:
|
def get_inference_engine(self, model_name: str) -> InferenceEngine:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def optimize_to_tensorrt(self, model_name: str, onnx_path: str) -> str:
|
def optimize_to_tensorrt(self, model_name: str, onnx_path: str) -> str:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def fallback_to_onnx(self, model_name: str) -> bool:
|
def fallback_to_onnx(self, model_name: str) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def warmup_model(self, model_name: str) -> bool:
|
def warmup_model(self, model_name: str) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class MockInferenceEngine(InferenceEngine):
|
class MockInferenceEngine(InferenceEngine):
|
||||||
|
|||||||
@@ -1,227 +1,6 @@
|
|||||||
"""Image Input Pipeline (Component F05)."""
|
"""Legacy import path. Phase 1 shim — code lives in pipeline/image_input.py."""
|
||||||
|
from gps_denied.pipeline.image_input import ( # noqa: F401
|
||||||
import asyncio
|
ImageInputPipeline,
|
||||||
import os
|
QueueFullError,
|
||||||
import re
|
ValidationError,
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from gps_denied.schemas.image import (
|
|
||||||
ImageBatch,
|
|
||||||
ImageData,
|
|
||||||
ImageMetadata,
|
|
||||||
ProcessedBatch,
|
|
||||||
ProcessingStatus,
|
|
||||||
ValidationResult,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QueueFullError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ValidationError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageInputPipeline:
|
|
||||||
"""Manages ingestion, disk storage, and queuing of UAV image batches."""
|
|
||||||
|
|
||||||
def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50):
|
|
||||||
self.storage_dir = storage_dir
|
|
||||||
# flight_id -> asyncio.Queue of ImageBatch
|
|
||||||
self._queues: dict[str, asyncio.Queue] = {}
|
|
||||||
self.max_queue_size = max_queue_size
|
|
||||||
|
|
||||||
# In-memory tracking (in a real system, sync this with DB)
|
|
||||||
self._status: dict[str, dict] = {}
|
|
||||||
# Exact sequence → filename mapping (VO-05: no substring collision)
|
|
||||||
self._sequence_map: dict[str, dict[int, str]] = {}
|
|
||||||
|
|
||||||
def _get_queue(self, flight_id: str) -> asyncio.Queue:
|
|
||||||
if flight_id not in self._queues:
|
|
||||||
self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size)
|
|
||||||
return self._queues[flight_id]
|
|
||||||
|
|
||||||
def _init_status(self, flight_id: str):
|
|
||||||
if flight_id not in self._status:
|
|
||||||
self._status[flight_id] = {
|
|
||||||
"total_images": 0,
|
|
||||||
"processed_images": 0,
|
|
||||||
"current_sequence": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
def validate_batch(self, batch: ImageBatch) -> ValidationResult:
|
|
||||||
"""Validates batch integrity and sequence continuity."""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
num_images = len(batch.images)
|
|
||||||
if num_images < 1:
|
|
||||||
errors.append("Batch is empty")
|
|
||||||
elif num_images > 100:
|
|
||||||
errors.append("Batch too large")
|
|
||||||
|
|
||||||
if len(batch.filenames) != num_images:
|
|
||||||
errors.append("Mismatch between filenames and images count")
|
|
||||||
|
|
||||||
# Naming convention ADxxxxxx.jpg or similar
|
|
||||||
pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE)
|
|
||||||
for fn in batch.filenames:
|
|
||||||
if not pattern.match(fn):
|
|
||||||
errors.append(f"Invalid filename: {fn}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if batch.start_sequence > batch.end_sequence:
|
|
||||||
errors.append("Start sequence greater than end sequence")
|
|
||||||
|
|
||||||
return ValidationResult(valid=len(errors) == 0, errors=errors)
|
|
||||||
|
|
||||||
def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool:
|
|
||||||
"""Queues a batch of images for processing."""
|
|
||||||
val = self.validate_batch(batch)
|
|
||||||
if not val.valid:
|
|
||||||
raise ValidationError(f"Batch validation failed: {val.errors}")
|
|
||||||
|
|
||||||
q = self._get_queue(flight_id)
|
|
||||||
if q.full():
|
|
||||||
raise QueueFullError(f"Queue for flight {flight_id} is full")
|
|
||||||
|
|
||||||
q.put_nowait(batch)
|
|
||||||
|
|
||||||
self._init_status(flight_id)
|
|
||||||
self._status[flight_id]["total_images"] += len(batch.images)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None:
|
|
||||||
"""Dequeues and processing the next batch."""
|
|
||||||
q = self._get_queue(flight_id)
|
|
||||||
if q.empty():
|
|
||||||
return None
|
|
||||||
|
|
||||||
batch: ImageBatch = await q.get()
|
|
||||||
|
|
||||||
processed_images = []
|
|
||||||
for i, raw_bytes in enumerate(batch.images):
|
|
||||||
# Decode
|
|
||||||
nparr = np.frombuffer(raw_bytes, np.uint8)
|
|
||||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
||||||
|
|
||||||
if img is None:
|
|
||||||
continue # skip corrupted
|
|
||||||
|
|
||||||
seq = batch.start_sequence + i
|
|
||||||
fn = batch.filenames[i]
|
|
||||||
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
meta = ImageMetadata(
|
|
||||||
sequence=seq,
|
|
||||||
filename=fn,
|
|
||||||
dimensions=(w, h),
|
|
||||||
file_size=len(raw_bytes),
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
img_data = ImageData(
|
|
||||||
flight_id=flight_id,
|
|
||||||
sequence=seq,
|
|
||||||
filename=fn,
|
|
||||||
image=img,
|
|
||||||
metadata=meta
|
|
||||||
)
|
|
||||||
processed_images.append(img_data)
|
|
||||||
# VO-05: record exact sequence→filename mapping
|
|
||||||
self._sequence_map.setdefault(flight_id, {})[seq] = fn
|
|
||||||
|
|
||||||
# Store to disk
|
|
||||||
self.store_images(flight_id, processed_images)
|
|
||||||
|
|
||||||
self._status[flight_id]["processed_images"] += len(processed_images)
|
|
||||||
q.task_done()
|
|
||||||
|
|
||||||
return ProcessedBatch(
|
|
||||||
images=processed_images,
|
|
||||||
batch_id=f"batch_{batch.batch_number}",
|
|
||||||
start_sequence=batch.start_sequence,
|
|
||||||
end_sequence=batch.end_sequence
|
|
||||||
)
|
|
||||||
|
|
||||||
def store_images(self, flight_id: str, images: list[ImageData]) -> bool:
|
|
||||||
"""Persists images to disk."""
|
|
||||||
flight_dir = os.path.join(self.storage_dir, flight_id)
|
|
||||||
os.makedirs(flight_dir, exist_ok=True)
|
|
||||||
|
|
||||||
for img in images:
|
|
||||||
path = os.path.join(flight_dir, img.filename)
|
|
||||||
cv2.imwrite(path, img.image)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_next_image(self, flight_id: str) -> ImageData | None:
|
|
||||||
"""Gets the next image in sequence for processing."""
|
|
||||||
self._init_status(flight_id)
|
|
||||||
seq = self._status[flight_id]["current_sequence"]
|
|
||||||
|
|
||||||
img = self.get_image_by_sequence(flight_id, seq)
|
|
||||||
if img:
|
|
||||||
self._status[flight_id]["current_sequence"] += 1
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None:
|
|
||||||
"""Retrieves a specific image by sequence number (exact match — VO-05)."""
|
|
||||||
flight_dir = os.path.join(self.storage_dir, flight_id)
|
|
||||||
if not os.path.exists(flight_dir):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Prefer the exact mapping built during process_next_batch
|
|
||||||
fn = self._sequence_map.get(flight_id, {}).get(sequence)
|
|
||||||
if fn:
|
|
||||||
path = os.path.join(flight_dir, fn)
|
|
||||||
img = cv2.imread(path)
|
|
||||||
if img is not None:
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
meta = ImageMetadata(
|
|
||||||
sequence=sequence,
|
|
||||||
filename=fn,
|
|
||||||
dimensions=(w, h),
|
|
||||||
file_size=os.path.getsize(path),
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
return ImageData(flight_id, sequence, fn, img, meta)
|
|
||||||
|
|
||||||
# Fallback: scan directory for exact filename patterns
|
|
||||||
# (handles images stored before this process started)
|
|
||||||
for fn in os.listdir(flight_dir):
|
|
||||||
base, _ = os.path.splitext(fn)
|
|
||||||
# Accept only if the base name ends with exactly the padded sequence number
|
|
||||||
if base.endswith(f"{sequence:06d}") or base == str(sequence):
|
|
||||||
path = os.path.join(flight_dir, fn)
|
|
||||||
img = cv2.imread(path)
|
|
||||||
if img is not None:
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
meta = ImageMetadata(
|
|
||||||
sequence=sequence,
|
|
||||||
filename=fn,
|
|
||||||
dimensions=(w, h),
|
|
||||||
file_size=os.path.getsize(path),
|
|
||||||
timestamp=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
return ImageData(flight_id, sequence, fn, img, meta)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_processing_status(self, flight_id: str) -> ProcessingStatus:
|
|
||||||
self._init_status(flight_id)
|
|
||||||
s = self._status[flight_id]
|
|
||||||
q = self._get_queue(flight_id)
|
|
||||||
|
|
||||||
return ProcessingStatus(
|
|
||||||
flight_id=flight_id,
|
|
||||||
total_images=s["total_images"],
|
|
||||||
processed_images=s["processed_images"],
|
|
||||||
current_sequence=s["current_sequence"],
|
|
||||||
queued_batches=q.qsize(),
|
|
||||||
processing_rate=0.0 # mock
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,599 +1,6 @@
|
|||||||
"""Core Flight Processor — Full Processing Pipeline (Stage 10).
|
"""Legacy import path. Phase 1 shim — code lives in pipeline/orchestrator.py."""
|
||||||
|
from gps_denied.pipeline.orchestrator import ( # noqa: F401
|
||||||
Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE.
|
FlightProcessor,
|
||||||
State Machine: NORMAL → LOST → RECOVERY → NORMAL.
|
TrackingState,
|
||||||
"""
|
FrameResult,
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from gps_denied.core.eskf import ESKF
|
|
||||||
from gps_denied.core.pipeline import ImageInputPipeline
|
|
||||||
from gps_denied.core.results import ResultManager
|
|
||||||
from gps_denied.core.sse import SSEEventStreamer
|
|
||||||
from gps_denied.db.repository import FlightRepository
|
|
||||||
from gps_denied.schemas import CameraParameters, GPSPoint
|
|
||||||
from gps_denied.schemas.flight import (
|
|
||||||
BatchMetadata,
|
|
||||||
BatchResponse,
|
|
||||||
BatchUpdateResponse,
|
|
||||||
DeleteResponse,
|
|
||||||
FlightCreateRequest,
|
|
||||||
FlightDetailResponse,
|
|
||||||
FlightResponse,
|
|
||||||
FlightStatusResponse,
|
|
||||||
ObjectGPSResponse,
|
|
||||||
UpdateResponse,
|
|
||||||
UserFixRequest,
|
|
||||||
UserFixResponse,
|
|
||||||
Waypoint,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# State Machine
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
class TrackingState(str, Enum):
|
|
||||||
"""Processing state for a flight."""
|
|
||||||
NORMAL = "normal"
|
|
||||||
LOST = "lost"
|
|
||||||
RECOVERY = "recovery"
|
|
||||||
|
|
||||||
|
|
||||||
class FrameResult:
|
|
||||||
"""Intermediate result of processing a single frame."""
|
|
||||||
|
|
||||||
def __init__(self, frame_id: int):
|
|
||||||
self.frame_id = frame_id
|
|
||||||
self.gps: Optional[GPSPoint] = None
|
|
||||||
self.confidence: float = 0.0
|
|
||||||
self.tracking_state: TrackingState = TrackingState.NORMAL
|
|
||||||
self.vo_success: bool = False
|
|
||||||
self.alignment_success: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# FlightProcessor
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
class FlightProcessor:
|
|
||||||
"""Manages business logic, background processing, and frame orchestration."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
repository: FlightRepository,
|
|
||||||
streamer: SSEEventStreamer,
|
|
||||||
eskf_config=None,
|
|
||||||
) -> None:
|
|
||||||
self.repository = repository
|
|
||||||
self.streamer = streamer
|
|
||||||
self.result_manager = ResultManager(repository, streamer)
|
|
||||||
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
|
|
||||||
self._eskf_config = eskf_config # ESKFConfig or None → default
|
|
||||||
|
|
||||||
# Per-flight processing state
|
|
||||||
self._flight_states: dict[str, TrackingState] = {}
|
|
||||||
self._prev_images: dict[str, np.ndarray] = {} # previous frame cache
|
|
||||||
self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera
|
|
||||||
self._altitudes: dict[str, float] = {} # per-flight altitude (m)
|
|
||||||
self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter
|
|
||||||
|
|
||||||
# Per-flight ESKF instances (PIPE-01/07)
|
|
||||||
self._eskf: dict[str, ESKF] = {}
|
|
||||||
|
|
||||||
# Lazy-initialised component references (set via `attach_components`)
|
|
||||||
self._vo = None # ISequentialVisualOdometry
|
|
||||||
self._gpr = None # IGlobalPlaceRecognition
|
|
||||||
self._metric = None # IMetricRefinement
|
|
||||||
self._graph = None # IFactorGraphOptimizer
|
|
||||||
self._recovery = None # IFailureRecoveryCoordinator
|
|
||||||
self._chunk_mgr = None # IRouteChunkManager
|
|
||||||
self._rotation = None # ImageRotationManager
|
|
||||||
self._satellite = None # SatelliteDataManager (PIPE-02)
|
|
||||||
self._coord = None # CoordinateTransformer (PIPE-02/06)
|
|
||||||
self._mavlink = None # MAVLinkBridge (PIPE-07)
|
|
||||||
|
|
||||||
# ------ Dependency injection for core components ---------
|
|
||||||
def attach_components(
|
|
||||||
self,
|
|
||||||
vo=None,
|
|
||||||
gpr=None,
|
|
||||||
metric=None,
|
|
||||||
graph=None,
|
|
||||||
recovery=None,
|
|
||||||
chunk_mgr=None,
|
|
||||||
rotation=None,
|
|
||||||
satellite=None,
|
|
||||||
coord=None,
|
|
||||||
mavlink=None,
|
|
||||||
):
|
|
||||||
"""Attach pipeline components after construction (avoids circular deps)."""
|
|
||||||
self._vo = vo
|
|
||||||
self._gpr = gpr
|
|
||||||
self._metric = metric
|
|
||||||
self._graph = graph
|
|
||||||
self._recovery = recovery
|
|
||||||
self._chunk_mgr = chunk_mgr
|
|
||||||
self._rotation = rotation
|
|
||||||
self._satellite = satellite # PIPE-02: SatelliteDataManager
|
|
||||||
self._coord = coord # PIPE-02/06: CoordinateTransformer
|
|
||||||
self._mavlink = mavlink # PIPE-07: MAVLinkBridge
|
|
||||||
|
|
||||||
# ------ ESKF lifecycle helpers ----------------------------
|
|
||||||
def _init_eskf_for_flight(
|
|
||||||
self, flight_id: str, start_gps: GPSPoint, altitude: float
|
|
||||||
) -> None:
|
|
||||||
"""Create and initialize a per-flight ESKF instance."""
|
|
||||||
if flight_id in self._eskf:
|
|
||||||
return
|
|
||||||
eskf = ESKF(config=self._eskf_config)
|
|
||||||
if self._coord:
|
|
||||||
try:
|
|
||||||
e, n, _ = self._coord.gps_to_enu(flight_id, start_gps)
|
|
||||||
eskf.initialize(np.array([e, n, altitude]), time.time())
|
|
||||||
except Exception:
|
|
||||||
eskf.initialize(np.zeros(3), time.time())
|
|
||||||
else:
|
|
||||||
eskf.initialize(np.zeros(3), time.time())
|
|
||||||
self._eskf[flight_id] = eskf
|
|
||||||
|
|
||||||
def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]:
|
|
||||||
"""Convert current ESKF ENU position to WGS84 GPS."""
|
|
||||||
if not eskf.initialized or self._coord is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
pos = eskf.position
|
|
||||||
return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2])))
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# =========================================================
|
|
||||||
# process_frame — central orchestration
|
|
||||||
# =========================================================
|
|
||||||
async def process_frame(
|
|
||||||
self,
|
|
||||||
flight_id: str,
|
|
||||||
frame_id: int,
|
|
||||||
image: np.ndarray,
|
|
||||||
) -> FrameResult:
|
|
||||||
"""
|
|
||||||
Process a single UAV frame through the full pipeline.
|
|
||||||
|
|
||||||
State transitions:
|
|
||||||
NORMAL — VO succeeds → ESKF VO update, attempt satellite fix
|
|
||||||
LOST — VO failed → create new chunk, enter RECOVERY
|
|
||||||
RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL
|
|
||||||
|
|
||||||
PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT
|
|
||||||
PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection
|
|
||||||
PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator
|
|
||||||
PIPE-05: ImageRotationManager initialised on first frame
|
|
||||||
PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state
|
|
||||||
"""
|
|
||||||
result = FrameResult(frame_id)
|
|
||||||
state = self._flight_states.get(flight_id, TrackingState.NORMAL)
|
|
||||||
eskf = self._eskf.get(flight_id)
|
|
||||||
|
|
||||||
_default_cam = CameraParameters(
|
|
||||||
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
|
||||||
resolution_width=640, resolution_height=480,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---- PIPE-05: Initialise heading tracking on first frame ----
|
|
||||||
if self._rotation and frame_id == 0:
|
|
||||||
self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory
|
|
||||||
|
|
||||||
# ---- 1. Visual Odometry (frame-to-frame) ----
|
|
||||||
vo_ok = False
|
|
||||||
if self._vo and flight_id in self._prev_images:
|
|
||||||
try:
|
|
||||||
cam = self._flight_cameras.get(flight_id, _default_cam)
|
|
||||||
rel_pose = self._vo.compute_relative_pose(
|
|
||||||
self._prev_images[flight_id], image, cam
|
|
||||||
)
|
|
||||||
if rel_pose and rel_pose.tracking_good:
|
|
||||||
vo_ok = True
|
|
||||||
result.vo_success = True
|
|
||||||
|
|
||||||
if self._graph:
|
|
||||||
self._graph.add_relative_factor(
|
|
||||||
flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6)
|
|
||||||
)
|
|
||||||
|
|
||||||
# PIPE-01: Feed VO relative displacement into ESKF
|
|
||||||
if eskf and eskf.initialized:
|
|
||||||
now = time.time()
|
|
||||||
dt_vo = max(0.01, now - (eskf.last_timestamp or now))
|
|
||||||
eskf.update_vo(rel_pose.translation, dt_vo)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("VO failed for frame %d: %s", frame_id, exc)
|
|
||||||
|
|
||||||
# Store current image for next frame
|
|
||||||
self._prev_images[flight_id] = image
|
|
||||||
|
|
||||||
# ---- PIPE-04: Consecutive failure counter ----
|
|
||||||
if not vo_ok and frame_id > 0:
|
|
||||||
self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1
|
|
||||||
else:
|
|
||||||
self._failure_counts[flight_id] = 0
|
|
||||||
|
|
||||||
# ---- 2. State Machine transitions ----
|
|
||||||
if state == TrackingState.NORMAL:
|
|
||||||
if not vo_ok and frame_id > 0:
|
|
||||||
state = TrackingState.LOST
|
|
||||||
logger.info("Flight %s → LOST at frame %d", flight_id, frame_id)
|
|
||||||
if self._recovery:
|
|
||||||
self._recovery.handle_tracking_lost(flight_id, frame_id)
|
|
||||||
|
|
||||||
if state == TrackingState.LOST:
|
|
||||||
state = TrackingState.RECOVERY
|
|
||||||
|
|
||||||
if state == TrackingState.RECOVERY:
|
|
||||||
recovered = False
|
|
||||||
if self._recovery and self._chunk_mgr:
|
|
||||||
active_chunk = self._chunk_mgr.get_active_chunk(flight_id)
|
|
||||||
if active_chunk:
|
|
||||||
recovered = self._recovery.process_chunk_recovery(
|
|
||||||
flight_id, active_chunk.chunk_id, [image]
|
|
||||||
)
|
|
||||||
if recovered:
|
|
||||||
state = TrackingState.NORMAL
|
|
||||||
result.alignment_success = True
|
|
||||||
# PIPE-04: Reset failure count on successful recovery
|
|
||||||
self._failure_counts[flight_id] = 0
|
|
||||||
logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id)
|
|
||||||
|
|
||||||
# ---- 3. Satellite position fix (PIPE-01/02) ----
|
|
||||||
if state == TrackingState.NORMAL and self._metric:
|
|
||||||
sat_tile: Optional[np.ndarray] = None
|
|
||||||
tile_bounds = None
|
|
||||||
|
|
||||||
# PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection)
|
|
||||||
if self._satellite and eskf and eskf.initialized:
|
|
||||||
gps_est = self._eskf_to_gps(flight_id, eskf)
|
|
||||||
if gps_est:
|
|
||||||
cov = eskf.covariance
|
|
||||||
sigma_h = float(
|
|
||||||
np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0)
|
|
||||||
) if cov is not None else 30.0
|
|
||||||
sigma_h = max(sigma_h, 5.0)
|
|
||||||
try:
|
|
||||||
tile_result = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
self._satellite.fetch_tiles_for_position,
|
|
||||||
gps_est, sigma_h, 18,
|
|
||||||
)
|
|
||||||
if tile_result:
|
|
||||||
sat_tile, tile_bounds = tile_result
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("Satellite tile fetch failed: %s", exc)
|
|
||||||
|
|
||||||
# Fallback: GPR candidate tile (mock image, real bounds)
|
|
||||||
if sat_tile is None and self._gpr:
|
|
||||||
try:
|
|
||||||
candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1)
|
|
||||||
if candidates:
|
|
||||||
sat_tile = np.zeros((256, 256, 3), dtype=np.uint8)
|
|
||||||
tile_bounds = candidates[0].bounds
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("GPR tile fallback failed: %s", exc)
|
|
||||||
|
|
||||||
if sat_tile is not None and tile_bounds is not None:
|
|
||||||
try:
|
|
||||||
align = self._metric.align_to_satellite(image, sat_tile, tile_bounds)
|
|
||||||
if align and align.matched:
|
|
||||||
result.gps = align.gps_center
|
|
||||||
result.confidence = align.confidence
|
|
||||||
result.alignment_success = True
|
|
||||||
|
|
||||||
if self._graph:
|
|
||||||
self._graph.add_absolute_factor(
|
|
||||||
flight_id, frame_id,
|
|
||||||
align.gps_center, np.eye(6),
|
|
||||||
is_user_anchor=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# PIPE-01: ESKF satellite update — noise from RANSAC confidence
|
|
||||||
if eskf and eskf.initialized and self._coord:
|
|
||||||
try:
|
|
||||||
e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center)
|
|
||||||
alt = self._altitudes.get(flight_id, 100.0)
|
|
||||||
pos_enu = np.array([e, n, alt])
|
|
||||||
noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence))
|
|
||||||
eskf.update_satellite(pos_enu, noise_m)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("ESKF satellite update failed: %s", exc)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc)
|
|
||||||
|
|
||||||
# ---- 4. Graph optimization (incremental) ----
|
|
||||||
if self._graph:
|
|
||||||
opt_result = self._graph.optimize(flight_id, iterations=5)
|
|
||||||
logger.debug(
|
|
||||||
"Optimization: converged=%s, error=%.4f",
|
|
||||||
opt_result.converged, opt_result.final_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ----
|
|
||||||
if self._mavlink and eskf and eskf.initialized:
|
|
||||||
try:
|
|
||||||
eskf_state = eskf.get_state()
|
|
||||||
alt = self._altitudes.get(flight_id, 100.0)
|
|
||||||
self._mavlink.update_state(eskf_state, altitude_m=alt)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("MAVLink state push failed: %s", exc)
|
|
||||||
|
|
||||||
# ---- 5. Publish via SSE ----
|
|
||||||
result.tracking_state = state
|
|
||||||
self._flight_states[flight_id] = state
|
|
||||||
await self._publish_frame_result(flight_id, result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def _publish_frame_result(self, flight_id: str, result: FrameResult):
|
|
||||||
"""Emit SSE event for processed frame."""
|
|
||||||
event_data = {
|
|
||||||
"frame_id": result.frame_id,
|
|
||||||
"tracking_state": result.tracking_state.value,
|
|
||||||
"vo_success": result.vo_success,
|
|
||||||
"alignment_success": result.alignment_success,
|
|
||||||
"confidence": result.confidence,
|
|
||||||
}
|
|
||||||
if result.gps:
|
|
||||||
event_data["lat"] = result.gps.lat
|
|
||||||
event_data["lon"] = result.gps.lon
|
|
||||||
|
|
||||||
await self.streamer.push_event(
|
|
||||||
flight_id, event_type="frame_result", data=event_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# =========================================================
|
|
||||||
# Existing CRUD / REST helpers (unchanged from Stage 3-4)
|
|
||||||
# =========================================================
|
|
||||||
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
|
|
||||||
flight = await self.repository.insert_flight(
|
|
||||||
name=req.name,
|
|
||||||
description=req.description,
|
|
||||||
start_lat=req.start_gps.lat,
|
|
||||||
start_lon=req.start_gps.lon,
|
|
||||||
altitude=req.altitude,
|
|
||||||
camera_params=req.camera_params.model_dump(),
|
|
||||||
)
|
|
||||||
# P0#2: Store camera params for process_frame VO calls
|
|
||||||
self._flight_cameras[flight.id] = req.camera_params
|
|
||||||
|
|
||||||
for poly in req.geofences.polygons:
|
|
||||||
await self.repository.insert_geofence(
|
|
||||||
flight.id,
|
|
||||||
nw_lat=poly.north_west.lat,
|
|
||||||
nw_lon=poly.north_west.lon,
|
|
||||||
se_lat=poly.south_east.lat,
|
|
||||||
se_lon=poly.south_east.lon,
|
|
||||||
)
|
|
||||||
for w in req.rough_waypoints:
|
|
||||||
await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon)
|
|
||||||
|
|
||||||
# Store per-flight altitude for ESKF/pixel projection
|
|
||||||
self._altitudes[flight.id] = req.altitude or 100.0
|
|
||||||
|
|
||||||
# PIPE-02: Set ENU origin and initialise ESKF for this flight
|
|
||||||
if self._coord:
|
|
||||||
self._coord.set_enu_origin(flight.id, req.start_gps)
|
|
||||||
self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0)
|
|
||||||
|
|
||||||
# Start MAVLink bridge for this flight (origin required for GPS_INPUT)
|
|
||||||
if self._mavlink and not self._mavlink._running:
|
|
||||||
try:
|
|
||||||
asyncio.create_task(self._mavlink.start(req.start_gps))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("MAVLink bridge start failed: %s", exc)
|
|
||||||
|
|
||||||
return FlightResponse(
|
|
||||||
flight_id=flight.id,
|
|
||||||
status="prefetching",
|
|
||||||
message="Flight created and prefetching started.",
|
|
||||||
created_at=flight.created_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_flight(self, flight_id: str) -> FlightDetailResponse | None:
|
|
||||||
flight = await self.repository.get_flight(flight_id)
|
|
||||||
if not flight:
|
|
||||||
return None
|
|
||||||
wps = await self.repository.get_waypoints(flight_id)
|
|
||||||
state = await self.repository.load_flight_state(flight_id)
|
|
||||||
|
|
||||||
waypoints = [
|
|
||||||
Waypoint(
|
|
||||||
id=w.id,
|
|
||||||
lat=w.lat,
|
|
||||||
lon=w.lon,
|
|
||||||
altitude=w.altitude,
|
|
||||||
confidence=w.confidence,
|
|
||||||
timestamp=w.timestamp,
|
|
||||||
refined=w.refined,
|
|
||||||
)
|
|
||||||
for w in wps
|
|
||||||
]
|
|
||||||
|
|
||||||
status = state.status if state else "unknown"
|
|
||||||
frames_processed = state.frames_processed if state else 0
|
|
||||||
frames_total = state.frames_total if state else 0
|
|
||||||
|
|
||||||
from gps_denied.schemas import Geofences
|
|
||||||
|
|
||||||
return FlightDetailResponse(
|
|
||||||
flight_id=flight.id,
|
|
||||||
name=flight.name,
|
|
||||||
description=flight.description,
|
|
||||||
start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon),
|
|
||||||
waypoints=waypoints,
|
|
||||||
geofences=Geofences(polygons=[]),
|
|
||||||
camera_params=flight.camera_params,
|
|
||||||
altitude=flight.altitude,
|
|
||||||
status=status,
|
|
||||||
frames_processed=frames_processed,
|
|
||||||
frames_total=frames_total,
|
|
||||||
created_at=flight.created_at,
|
|
||||||
updated_at=flight.updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete_flight(self, flight_id: str) -> DeleteResponse:
|
|
||||||
deleted = await self.repository.delete_flight(flight_id)
|
|
||||||
# P0#1: Cleanup in-memory state to prevent memory leaks
|
|
||||||
self._cleanup_flight(flight_id)
|
|
||||||
return DeleteResponse(deleted=deleted, flight_id=flight_id)
|
|
||||||
|
|
||||||
def _cleanup_flight(self, flight_id: str) -> None:
|
|
||||||
"""Remove all in-memory state for a flight (prevents memory leaks)."""
|
|
||||||
self._prev_images.pop(flight_id, None)
|
|
||||||
self._flight_states.pop(flight_id, None)
|
|
||||||
self._flight_cameras.pop(flight_id, None)
|
|
||||||
self._altitudes.pop(flight_id, None)
|
|
||||||
self._failure_counts.pop(flight_id, None)
|
|
||||||
self._eskf.pop(flight_id, None)
|
|
||||||
if self._graph:
|
|
||||||
self._graph.delete_flight_graph(flight_id)
|
|
||||||
|
|
||||||
async def update_waypoint(
|
|
||||||
self, flight_id: str, waypoint_id: str, waypoint: Waypoint
|
|
||||||
) -> UpdateResponse:
|
|
||||||
ok = await self.repository.update_waypoint(
|
|
||||||
flight_id,
|
|
||||||
waypoint_id,
|
|
||||||
lat=waypoint.lat,
|
|
||||||
lon=waypoint.lon,
|
|
||||||
altitude=waypoint.altitude,
|
|
||||||
confidence=waypoint.confidence,
|
|
||||||
refined=waypoint.refined,
|
|
||||||
)
|
|
||||||
return UpdateResponse(updated=ok, waypoint_id=waypoint_id)
|
|
||||||
|
|
||||||
async def batch_update_waypoints(
|
|
||||||
self, flight_id: str, waypoints: list[Waypoint]
|
|
||||||
) -> BatchUpdateResponse:
|
|
||||||
failed = []
|
|
||||||
updated = 0
|
|
||||||
for wp in waypoints:
|
|
||||||
ok = await self.repository.update_waypoint(
|
|
||||||
flight_id,
|
|
||||||
wp.id,
|
|
||||||
lat=wp.lat,
|
|
||||||
lon=wp.lon,
|
|
||||||
altitude=wp.altitude,
|
|
||||||
confidence=wp.confidence,
|
|
||||||
refined=wp.refined,
|
|
||||||
)
|
|
||||||
if ok:
|
|
||||||
updated += 1
|
|
||||||
else:
|
|
||||||
failed.append(wp.id)
|
|
||||||
return BatchUpdateResponse(
|
|
||||||
success=(len(failed) == 0), updated_count=updated, failed_ids=failed
|
|
||||||
)
|
|
||||||
|
|
||||||
async def queue_images(
|
|
||||||
self, flight_id: str, metadata: BatchMetadata, file_count: int
|
|
||||||
) -> BatchResponse:
|
|
||||||
state = await self.repository.load_flight_state(flight_id)
|
|
||||||
if state:
|
|
||||||
total = state.frames_total + file_count
|
|
||||||
await self.repository.save_flight_state(
|
|
||||||
flight_id, frames_total=total, status="processing"
|
|
||||||
)
|
|
||||||
|
|
||||||
next_seq = metadata.end_sequence + 1
|
|
||||||
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
|
|
||||||
return BatchResponse(
|
|
||||||
accepted=True,
|
|
||||||
sequences=seqs,
|
|
||||||
next_expected=next_seq,
|
|
||||||
message=f"Queued {file_count} images.",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_user_fix(
|
|
||||||
self, flight_id: str, req: UserFixRequest
|
|
||||||
) -> UserFixResponse:
|
|
||||||
await self.repository.save_flight_state(
|
|
||||||
flight_id, blocked=False, status="processing"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inject operator position into ESKF with high uncertainty (500m)
|
|
||||||
eskf = self._eskf.get(flight_id)
|
|
||||||
if eskf and eskf.initialized and self._coord:
|
|
||||||
try:
|
|
||||||
e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps)
|
|
||||||
alt = self._altitudes.get(flight_id, 100.0)
|
|
||||||
eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0)
|
|
||||||
self._failure_counts[flight_id] = 0
|
|
||||||
logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("User fix ESKF injection failed: %s", exc)
|
|
||||||
|
|
||||||
return UserFixResponse(
|
|
||||||
accepted=True, processing_resumed=True, message="Fix applied."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None:
|
|
||||||
state = await self.repository.load_flight_state(flight_id)
|
|
||||||
if not state:
|
|
||||||
return None
|
|
||||||
return FlightStatusResponse(
|
|
||||||
status=state.status,
|
|
||||||
frames_processed=state.frames_processed,
|
|
||||||
frames_total=state.frames_total,
|
|
||||||
current_frame=state.current_frame,
|
|
||||||
current_heading=None,
|
|
||||||
blocked=state.blocked,
|
|
||||||
search_grid_size=state.search_grid_size,
|
|
||||||
created_at=state.created_at,
|
|
||||||
updated_at=state.updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def convert_object_to_gps(
|
|
||||||
self, flight_id: str, frame_id: int, pixel: tuple[float, float]
|
|
||||||
) -> ObjectGPSResponse:
|
|
||||||
# PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection
|
|
||||||
gps: Optional[GPSPoint] = None
|
|
||||||
eskf = self._eskf.get(flight_id)
|
|
||||||
if self._coord and eskf and eskf.initialized:
|
|
||||||
pos = eskf.position
|
|
||||||
quat = eskf.quaternion
|
|
||||||
cam = self._flight_cameras.get(flight_id, CameraParameters(
|
|
||||||
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
|
||||||
resolution_width=640, resolution_height=480,
|
|
||||||
))
|
|
||||||
alt = self._altitudes.get(flight_id, 100.0)
|
|
||||||
try:
|
|
||||||
gps = self._coord.pixel_to_gps(
|
|
||||||
flight_id,
|
|
||||||
pixel,
|
|
||||||
frame_pose={"position": pos},
|
|
||||||
camera_params=cam,
|
|
||||||
altitude=float(alt),
|
|
||||||
quaternion=quat,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug("pixel_to_gps failed: %s", exc)
|
|
||||||
|
|
||||||
# Fallback: return ESKF position projected to ground (no pixel shift)
|
|
||||||
if gps is None and eskf:
|
|
||||||
gps = self._eskf_to_gps(flight_id, eskf)
|
|
||||||
|
|
||||||
return ObjectGPSResponse(
|
|
||||||
gps=gps or GPSPoint(lat=0.0, lon=0.0),
|
|
||||||
accuracy_meters=5.0,
|
|
||||||
frame_id=frame_id,
|
|
||||||
pixel=pixel,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def stream_events(self, flight_id: str, client_id: str):
|
|
||||||
"""Async generator for SSE stream."""
|
|
||||||
async for event in self.streamer.stream_generator(flight_id, client_id):
|
|
||||||
yield event
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Failure Recovery Coordinator (Component F11)."""
|
"""Failure Recovery Coordinator (Component F11)."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from typing import List, Protocol, runtime_checkable
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -14,14 +13,13 @@ from gps_denied.schemas.chunk import ChunkStatus
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class IFailureRecoveryCoordinator(ABC):
|
@runtime_checkable
|
||||||
@abstractmethod
|
class IFailureRecoveryCoordinator(Protocol):
|
||||||
def handle_tracking_lost(self, flight_id: str, current_frame_id: int) -> bool:
|
def handle_tracking_lost(self, flight_id: str, current_frame_id: int) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process_chunk_recovery(self, flight_id: str, chunk_id: str, images: List[np.ndarray]) -> bool:
|
def process_chunk_recovery(self, flight_id: str, chunk_id: str, images: List[np.ndarray]) -> bool:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class FailureRecoveryCoordinator(IFailureRecoveryCoordinator):
|
class FailureRecoveryCoordinator(IFailureRecoveryCoordinator):
|
||||||
|
|||||||
@@ -1,73 +1,4 @@
|
|||||||
"""Result Manager (Component F14)."""
|
"""Legacy import path. Phase 1 shim — code lives in pipeline/result_manager.py."""
|
||||||
|
from gps_denied.pipeline.result_manager import ( # noqa: F401
|
||||||
from __future__ import annotations
|
ResultManager,
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from gps_denied.core.sse import SSEEventStreamer
|
|
||||||
from gps_denied.db.repository import FlightRepository
|
|
||||||
from gps_denied.schemas import GPSPoint
|
|
||||||
from gps_denied.schemas.events import FrameProcessedEvent
|
|
||||||
|
|
||||||
|
|
||||||
class ResultManager:
|
|
||||||
"""Result consistency and publishing."""
|
|
||||||
|
|
||||||
def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None:
|
|
||||||
self.repo = repo
|
|
||||||
self.sse = sse
|
|
||||||
|
|
||||||
async def update_frame_result(
|
|
||||||
self,
|
|
||||||
flight_id: str,
|
|
||||||
frame_id: int,
|
|
||||||
gps_lat: float,
|
|
||||||
gps_lon: float,
|
|
||||||
altitude: float,
|
|
||||||
heading: float,
|
|
||||||
confidence: float,
|
|
||||||
timestamp: datetime,
|
|
||||||
refined: bool = False,
|
|
||||||
) -> bool:
|
|
||||||
"""Atomic DB update + SSE event publish."""
|
|
||||||
|
|
||||||
# 1. Update DB (in the repository these are auto-committing via flush,
|
|
||||||
# but normally F03 would wrap in a single transaction).
|
|
||||||
await self.repo.save_frame_result(
|
|
||||||
flight_id,
|
|
||||||
frame_id=frame_id,
|
|
||||||
gps_lat=gps_lat,
|
|
||||||
gps_lon=gps_lon,
|
|
||||||
altitude=altitude,
|
|
||||||
heading=heading,
|
|
||||||
confidence=confidence,
|
|
||||||
refined=refined,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait, the spec also wants Waypoints to be updated.
|
|
||||||
# But image frames != waypoints. Waypoints are the planned route.
|
|
||||||
# Actually in the spec it says: "Updates waypoint in waypoints table."
|
|
||||||
# This implies updating the closest waypoint or a generated waypoint path.
|
|
||||||
# We will follow the simplest form for now: update the waypoint if there is one corresponding.
|
|
||||||
# Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map,
|
|
||||||
# or we just skip unless specified.
|
|
||||||
|
|
||||||
# 2. Trigger SSE event
|
|
||||||
evt = FrameProcessedEvent(
|
|
||||||
frame_id=frame_id,
|
|
||||||
gps=GPSPoint(lat=gps_lat, lon=gps_lon),
|
|
||||||
altitude=altitude,
|
|
||||||
confidence=confidence,
|
|
||||||
heading=heading,
|
|
||||||
timestamp=timestamp,
|
|
||||||
)
|
|
||||||
if refined:
|
|
||||||
self.sse.send_refinement(flight_id, evt)
|
|
||||||
else:
|
|
||||||
self.sse.send_frame_result(flight_id, evt)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool:
|
|
||||||
# Just delegates to SSE for waypoint updates, which is basically the frame result for UI
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import math
|
import math
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -12,14 +12,14 @@ from gps_denied.schemas.rotation import HeadingHistory, RotationResult
|
|||||||
from gps_denied.schemas.satellite import TileBounds
|
from gps_denied.schemas.satellite import TileBounds
|
||||||
|
|
||||||
|
|
||||||
class IImageMatcher(ABC):
|
@runtime_checkable
|
||||||
|
class IImageMatcher(Protocol):
|
||||||
"""Dependency injection interface for Metric Refinement."""
|
"""Dependency injection interface for Metric Refinement."""
|
||||||
@abstractmethod
|
|
||||||
def align_to_satellite(
|
def align_to_satellite(
|
||||||
self, uav_image: np.ndarray, satellite_tile: np.ndarray,
|
self, uav_image: np.ndarray, satellite_tile: np.ndarray,
|
||||||
tile_bounds: TileBounds,
|
tile_bounds: TileBounds,
|
||||||
) -> RotationResult:
|
) -> RotationResult:
|
||||||
pass
|
...
|
||||||
|
|
||||||
|
|
||||||
class ImageRotationManager:
|
class ImageRotationManager:
|
||||||
|
|||||||
+3
-163
@@ -1,164 +1,4 @@
|
|||||||
"""SSE Event Streamer (Component F15)."""
|
"""Legacy import path. Phase 1 shim — code lives in pipeline/sse_streamer.py."""
|
||||||
|
from gps_denied.pipeline.sse_streamer import ( # noqa: F401
|
||||||
from __future__ import annotations
|
SSEEventStreamer,
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
from gps_denied.schemas.events import (
|
|
||||||
FlightCompletedEvent,
|
|
||||||
FrameProcessedEvent,
|
|
||||||
SearchExpandedEvent,
|
|
||||||
SSEEventType,
|
|
||||||
SSEMessage,
|
|
||||||
UserInputNeededEvent,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SSEEventStreamer:
|
|
||||||
"""Manages real-time SSE connections and event broadcasting."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# Map: flight_id -> Dict[client_id, asyncio.Queue]
|
|
||||||
self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict)
|
|
||||||
|
|
||||||
def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]:
|
|
||||||
"""Create a new event queue for a client."""
|
|
||||||
q: asyncio.Queue[SSEMessage | None] = asyncio.Queue()
|
|
||||||
self._streams[flight_id][client_id] = q
|
|
||||||
return q
|
|
||||||
|
|
||||||
def close_stream(self, flight_id: str, client_id: str) -> None:
|
|
||||||
"""Close a client stream by putting a sentinel and removing the queue."""
|
|
||||||
if flight_id in self._streams and client_id in self._streams[flight_id]:
|
|
||||||
q = self._streams[flight_id].pop(client_id)
|
|
||||||
if not self._streams[flight_id]:
|
|
||||||
del self._streams[flight_id]
|
|
||||||
# Put None to signal generator exit
|
|
||||||
try:
|
|
||||||
q.put_nowait(None)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_active_connections(self, flight_id: str) -> int:
|
|
||||||
return len(self._streams.get(flight_id, {}))
|
|
||||||
|
|
||||||
def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool:
|
|
||||||
"""Broadcast a message to all clients subscribed to flight_id."""
|
|
||||||
if flight_id not in self._streams or not self._streams[flight_id]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
for q in self._streams[flight_id].values():
|
|
||||||
try:
|
|
||||||
q.put_nowait(msg)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
pass # Drop if queue is full rather than blocking
|
|
||||||
return True
|
|
||||||
|
|
||||||
# ── Business Event Senders ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
|
||||||
msg = SSEMessage(
|
|
||||||
event=SSEEventType.FRAME_PROCESSED,
|
|
||||||
data=event_data.model_dump(mode="json"),
|
|
||||||
id=f"frame_{event_data.frame_id}",
|
|
||||||
)
|
|
||||||
return self._broadcast(flight_id, msg)
|
|
||||||
|
|
||||||
def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
|
||||||
msg = SSEMessage(
|
|
||||||
event=SSEEventType.FRAME_REFINED,
|
|
||||||
data=event_data.model_dump(mode="json"),
|
|
||||||
id=f"refine_{event_data.frame_id}",
|
|
||||||
)
|
|
||||||
return self._broadcast(flight_id, msg)
|
|
||||||
|
|
||||||
def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool:
|
|
||||||
msg = SSEMessage(
|
|
||||||
event=SSEEventType.SEARCH_EXPANDED,
|
|
||||||
data=event_data.model_dump(mode="json"),
|
|
||||||
)
|
|
||||||
return self._broadcast(flight_id, msg)
|
|
||||||
|
|
||||||
def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool:
|
|
||||||
msg = SSEMessage(
|
|
||||||
event=SSEEventType.USER_INPUT_NEEDED,
|
|
||||||
data=event_data.model_dump(mode="json"),
|
|
||||||
)
|
|
||||||
return self._broadcast(flight_id, msg)
|
|
||||||
|
|
||||||
def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool:
|
|
||||||
msg = SSEMessage(
|
|
||||||
event=SSEEventType.FLIGHT_COMPLETED,
|
|
||||||
data=event_data.model_dump(mode="json"),
|
|
||||||
)
|
|
||||||
return self._broadcast(flight_id, msg)
|
|
||||||
|
|
||||||
def send_heartbeat(self, flight_id: str) -> bool:
|
|
||||||
# sse_starlette uses empty string or comment for heartbeat,
|
|
||||||
# but we can just send an SSEMessage object that parses as empty event
|
|
||||||
if flight_id not in self._streams:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Manually sending a comment via the generator is tricky with strict SSEMessage schema
|
|
||||||
# but we'll handle this in the stream generator directly
|
|
||||||
return True
|
|
||||||
|
|
||||||
# ── Generic event dispatcher (used by processor.process_frame) ──────────
|
|
||||||
|
|
||||||
async def push_event(self, flight_id: str, event_type: str, data: dict) -> None:
|
|
||||||
"""Dispatch a generic event to all clients for a flight.
|
|
||||||
|
|
||||||
Maps event_type strings to typed SSE events:
|
|
||||||
"frame_result" → FrameProcessedEvent
|
|
||||||
"refinement" → FrameProcessedEvent (refined)
|
|
||||||
Other → raw broadcast via SSEMessage
|
|
||||||
"""
|
|
||||||
if event_type == "frame_result":
|
|
||||||
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
|
||||||
self.send_frame_result(flight_id, evt)
|
|
||||||
elif event_type == "refinement":
|
|
||||||
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
|
||||||
self.send_refinement(flight_id, evt)
|
|
||||||
else:
|
|
||||||
msg = SSEMessage(
|
|
||||||
event=SSEEventType.FRAME_PROCESSED,
|
|
||||||
data=data,
|
|
||||||
id=str(data.get("frame_id", "")),
|
|
||||||
)
|
|
||||||
self._broadcast(flight_id, msg)
|
|
||||||
|
|
||||||
# ── Stream Generator ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def stream_generator(self, flight_id: str, client_id: str):
|
|
||||||
"""Yields dicts for sse_starlette EventSourceResponse."""
|
|
||||||
q = self.create_stream(flight_id, client_id)
|
|
||||||
|
|
||||||
# Send an immediate connection accepted ping
|
|
||||||
yield {"event": "connected", "data": "connected"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# Wait for next event or send heartbeat every 15s
|
|
||||||
try:
|
|
||||||
msg = await asyncio.wait_for(q.get(), timeout=15.0)
|
|
||||||
if msg is None:
|
|
||||||
# Sentinel for clean shutdown
|
|
||||||
break
|
|
||||||
|
|
||||||
# Yield dict format for sse_starlette
|
|
||||||
yield {
|
|
||||||
"event": msg.event.value,
|
|
||||||
"id": msg.id if msg.id else "",
|
|
||||||
"data": json.dumps(msg.data)
|
|
||||||
}
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Heartbeat format for sse_starlette (empty string generates a comment)
|
|
||||||
yield {"event": "heartbeat", "data": "ping"}
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass # Client disconnected
|
|
||||||
finally:
|
|
||||||
self.close_stream(flight_id, client_id)
|
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
"""Pipeline package: orchestrator + IO + composition root."""
|
||||||
|
from .orchestrator import FlightProcessor
|
||||||
|
from .image_input import ImageInputPipeline
|
||||||
|
from .result_manager import ResultManager
|
||||||
|
from .sse_streamer import SSEEventStreamer
|
||||||
|
|
||||||
|
Pipeline = FlightProcessor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FlightProcessor",
|
||||||
|
"Pipeline",
|
||||||
|
"ImageInputPipeline",
|
||||||
|
"ResultManager",
|
||||||
|
"SSEEventStreamer",
|
||||||
|
]
|
||||||
@@ -0,0 +1,227 @@
|
|||||||
|
"""Image Input Pipeline (Component F05)."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gps_denied.schemas.image import (
|
||||||
|
ImageBatch,
|
||||||
|
ImageData,
|
||||||
|
ImageMetadata,
|
||||||
|
ProcessedBatch,
|
||||||
|
ProcessingStatus,
|
||||||
|
ValidationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QueueFullError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageInputPipeline:
|
||||||
|
"""Manages ingestion, disk storage, and queuing of UAV image batches."""
|
||||||
|
|
||||||
|
def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50):
|
||||||
|
self.storage_dir = storage_dir
|
||||||
|
# flight_id -> asyncio.Queue of ImageBatch
|
||||||
|
self._queues: dict[str, asyncio.Queue] = {}
|
||||||
|
self.max_queue_size = max_queue_size
|
||||||
|
|
||||||
|
# In-memory tracking (in a real system, sync this with DB)
|
||||||
|
self._status: dict[str, dict] = {}
|
||||||
|
# Exact sequence → filename mapping (VO-05: no substring collision)
|
||||||
|
self._sequence_map: dict[str, dict[int, str]] = {}
|
||||||
|
|
||||||
|
def _get_queue(self, flight_id: str) -> asyncio.Queue:
|
||||||
|
if flight_id not in self._queues:
|
||||||
|
self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size)
|
||||||
|
return self._queues[flight_id]
|
||||||
|
|
||||||
|
def _init_status(self, flight_id: str):
|
||||||
|
if flight_id not in self._status:
|
||||||
|
self._status[flight_id] = {
|
||||||
|
"total_images": 0,
|
||||||
|
"processed_images": 0,
|
||||||
|
"current_sequence": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_batch(self, batch: ImageBatch) -> ValidationResult:
|
||||||
|
"""Validates batch integrity and sequence continuity."""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
num_images = len(batch.images)
|
||||||
|
if num_images < 1:
|
||||||
|
errors.append("Batch is empty")
|
||||||
|
elif num_images > 100:
|
||||||
|
errors.append("Batch too large")
|
||||||
|
|
||||||
|
if len(batch.filenames) != num_images:
|
||||||
|
errors.append("Mismatch between filenames and images count")
|
||||||
|
|
||||||
|
# Naming convention ADxxxxxx.jpg or similar
|
||||||
|
pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE)
|
||||||
|
for fn in batch.filenames:
|
||||||
|
if not pattern.match(fn):
|
||||||
|
errors.append(f"Invalid filename: {fn}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if batch.start_sequence > batch.end_sequence:
|
||||||
|
errors.append("Start sequence greater than end sequence")
|
||||||
|
|
||||||
|
return ValidationResult(valid=len(errors) == 0, errors=errors)
|
||||||
|
|
||||||
|
def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool:
|
||||||
|
"""Queues a batch of images for processing."""
|
||||||
|
val = self.validate_batch(batch)
|
||||||
|
if not val.valid:
|
||||||
|
raise ValidationError(f"Batch validation failed: {val.errors}")
|
||||||
|
|
||||||
|
q = self._get_queue(flight_id)
|
||||||
|
if q.full():
|
||||||
|
raise QueueFullError(f"Queue for flight {flight_id} is full")
|
||||||
|
|
||||||
|
q.put_nowait(batch)
|
||||||
|
|
||||||
|
self._init_status(flight_id)
|
||||||
|
self._status[flight_id]["total_images"] += len(batch.images)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None:
|
||||||
|
"""Dequeues and processing the next batch."""
|
||||||
|
q = self._get_queue(flight_id)
|
||||||
|
if q.empty():
|
||||||
|
return None
|
||||||
|
|
||||||
|
batch: ImageBatch = await q.get()
|
||||||
|
|
||||||
|
processed_images = []
|
||||||
|
for i, raw_bytes in enumerate(batch.images):
|
||||||
|
# Decode
|
||||||
|
nparr = np.frombuffer(raw_bytes, np.uint8)
|
||||||
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
continue # skip corrupted
|
||||||
|
|
||||||
|
seq = batch.start_sequence + i
|
||||||
|
fn = batch.filenames[i]
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
meta = ImageMetadata(
|
||||||
|
sequence=seq,
|
||||||
|
filename=fn,
|
||||||
|
dimensions=(w, h),
|
||||||
|
file_size=len(raw_bytes),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
img_data = ImageData(
|
||||||
|
flight_id=flight_id,
|
||||||
|
sequence=seq,
|
||||||
|
filename=fn,
|
||||||
|
image=img,
|
||||||
|
metadata=meta
|
||||||
|
)
|
||||||
|
processed_images.append(img_data)
|
||||||
|
# VO-05: record exact sequence→filename mapping
|
||||||
|
self._sequence_map.setdefault(flight_id, {})[seq] = fn
|
||||||
|
|
||||||
|
# Store to disk
|
||||||
|
self.store_images(flight_id, processed_images)
|
||||||
|
|
||||||
|
self._status[flight_id]["processed_images"] += len(processed_images)
|
||||||
|
q.task_done()
|
||||||
|
|
||||||
|
return ProcessedBatch(
|
||||||
|
images=processed_images,
|
||||||
|
batch_id=f"batch_{batch.batch_number}",
|
||||||
|
start_sequence=batch.start_sequence,
|
||||||
|
end_sequence=batch.end_sequence
|
||||||
|
)
|
||||||
|
|
||||||
|
def store_images(self, flight_id: str, images: list[ImageData]) -> bool:
|
||||||
|
"""Persists images to disk."""
|
||||||
|
flight_dir = os.path.join(self.storage_dir, flight_id)
|
||||||
|
os.makedirs(flight_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for img in images:
|
||||||
|
path = os.path.join(flight_dir, img.filename)
|
||||||
|
cv2.imwrite(path, img.image)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_next_image(self, flight_id: str) -> ImageData | None:
|
||||||
|
"""Gets the next image in sequence for processing."""
|
||||||
|
self._init_status(flight_id)
|
||||||
|
seq = self._status[flight_id]["current_sequence"]
|
||||||
|
|
||||||
|
img = self.get_image_by_sequence(flight_id, seq)
|
||||||
|
if img:
|
||||||
|
self._status[flight_id]["current_sequence"] += 1
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None:
|
||||||
|
"""Retrieves a specific image by sequence number (exact match — VO-05)."""
|
||||||
|
flight_dir = os.path.join(self.storage_dir, flight_id)
|
||||||
|
if not os.path.exists(flight_dir):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Prefer the exact mapping built during process_next_batch
|
||||||
|
fn = self._sequence_map.get(flight_id, {}).get(sequence)
|
||||||
|
if fn:
|
||||||
|
path = os.path.join(flight_dir, fn)
|
||||||
|
img = cv2.imread(path)
|
||||||
|
if img is not None:
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
meta = ImageMetadata(
|
||||||
|
sequence=sequence,
|
||||||
|
filename=fn,
|
||||||
|
dimensions=(w, h),
|
||||||
|
file_size=os.path.getsize(path),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
return ImageData(flight_id, sequence, fn, img, meta)
|
||||||
|
|
||||||
|
# Fallback: scan directory for exact filename patterns
|
||||||
|
# (handles images stored before this process started)
|
||||||
|
for fn in os.listdir(flight_dir):
|
||||||
|
base, _ = os.path.splitext(fn)
|
||||||
|
# Accept only if the base name ends with exactly the padded sequence number
|
||||||
|
if base.endswith(f"{sequence:06d}") or base == str(sequence):
|
||||||
|
path = os.path.join(flight_dir, fn)
|
||||||
|
img = cv2.imread(path)
|
||||||
|
if img is not None:
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
meta = ImageMetadata(
|
||||||
|
sequence=sequence,
|
||||||
|
filename=fn,
|
||||||
|
dimensions=(w, h),
|
||||||
|
file_size=os.path.getsize(path),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
return ImageData(flight_id, sequence, fn, img, meta)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_processing_status(self, flight_id: str) -> ProcessingStatus:
|
||||||
|
self._init_status(flight_id)
|
||||||
|
s = self._status[flight_id]
|
||||||
|
q = self._get_queue(flight_id)
|
||||||
|
|
||||||
|
return ProcessingStatus(
|
||||||
|
flight_id=flight_id,
|
||||||
|
total_images=s["total_images"],
|
||||||
|
processed_images=s["processed_images"],
|
||||||
|
current_sequence=s["current_sequence"],
|
||||||
|
queued_batches=q.qsize(),
|
||||||
|
processing_rate=0.0 # mock
|
||||||
|
)
|
||||||
@@ -0,0 +1,599 @@
|
|||||||
|
"""Core Flight Processor — Full Processing Pipeline (Stage 10).
|
||||||
|
|
||||||
|
Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE.
|
||||||
|
State Machine: NORMAL → LOST → RECOVERY → NORMAL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gps_denied.core.eskf import ESKF
|
||||||
|
from gps_denied.pipeline.image_input import ImageInputPipeline
|
||||||
|
from gps_denied.pipeline.result_manager import ResultManager
|
||||||
|
from gps_denied.pipeline.sse_streamer import SSEEventStreamer
|
||||||
|
from gps_denied.db.repository import FlightRepository
|
||||||
|
from gps_denied.schemas import CameraParameters, GPSPoint
|
||||||
|
from gps_denied.schemas.flight import (
|
||||||
|
BatchMetadata,
|
||||||
|
BatchResponse,
|
||||||
|
BatchUpdateResponse,
|
||||||
|
DeleteResponse,
|
||||||
|
FlightCreateRequest,
|
||||||
|
FlightDetailResponse,
|
||||||
|
FlightResponse,
|
||||||
|
FlightStatusResponse,
|
||||||
|
ObjectGPSResponse,
|
||||||
|
UpdateResponse,
|
||||||
|
UserFixRequest,
|
||||||
|
UserFixResponse,
|
||||||
|
Waypoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# State Machine
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TrackingState(str, Enum):
|
||||||
|
"""Processing state for a flight."""
|
||||||
|
NORMAL = "normal"
|
||||||
|
LOST = "lost"
|
||||||
|
RECOVERY = "recovery"
|
||||||
|
|
||||||
|
|
||||||
|
class FrameResult:
|
||||||
|
"""Intermediate result of processing a single frame."""
|
||||||
|
|
||||||
|
def __init__(self, frame_id: int):
|
||||||
|
self.frame_id = frame_id
|
||||||
|
self.gps: Optional[GPSPoint] = None
|
||||||
|
self.confidence: float = 0.0
|
||||||
|
self.tracking_state: TrackingState = TrackingState.NORMAL
|
||||||
|
self.vo_success: bool = False
|
||||||
|
self.alignment_success: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FlightProcessor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class FlightProcessor:
|
||||||
|
"""Manages business logic, background processing, and frame orchestration."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repository: FlightRepository,
|
||||||
|
streamer: SSEEventStreamer,
|
||||||
|
eskf_config=None,
|
||||||
|
) -> None:
|
||||||
|
self.repository = repository
|
||||||
|
self.streamer = streamer
|
||||||
|
self.result_manager = ResultManager(repository, streamer)
|
||||||
|
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
|
||||||
|
self._eskf_config = eskf_config # ESKFConfig or None → default
|
||||||
|
|
||||||
|
# Per-flight processing state
|
||||||
|
self._flight_states: dict[str, TrackingState] = {}
|
||||||
|
self._prev_images: dict[str, np.ndarray] = {} # previous frame cache
|
||||||
|
self._flight_cameras: dict[str, CameraParameters] = {} # per-flight camera
|
||||||
|
self._altitudes: dict[str, float] = {} # per-flight altitude (m)
|
||||||
|
self._failure_counts: dict[str, int] = {} # per-flight consecutive failure counter
|
||||||
|
|
||||||
|
# Per-flight ESKF instances (PIPE-01/07)
|
||||||
|
self._eskf: dict[str, ESKF] = {}
|
||||||
|
|
||||||
|
# Lazy-initialised component references (set via `attach_components`)
|
||||||
|
self._vo = None # ISequentialVisualOdometry
|
||||||
|
self._gpr = None # IGlobalPlaceRecognition
|
||||||
|
self._metric = None # IMetricRefinement
|
||||||
|
self._graph = None # IFactorGraphOptimizer
|
||||||
|
self._recovery = None # IFailureRecoveryCoordinator
|
||||||
|
self._chunk_mgr = None # IRouteChunkManager
|
||||||
|
self._rotation = None # ImageRotationManager
|
||||||
|
self._satellite = None # SatelliteDataManager (PIPE-02)
|
||||||
|
self._coord = None # CoordinateTransformer (PIPE-02/06)
|
||||||
|
self._mavlink = None # MAVLinkBridge (PIPE-07)
|
||||||
|
|
||||||
|
# ------ Dependency injection for core components ---------
|
||||||
|
def attach_components(
|
||||||
|
self,
|
||||||
|
vo=None,
|
||||||
|
gpr=None,
|
||||||
|
metric=None,
|
||||||
|
graph=None,
|
||||||
|
recovery=None,
|
||||||
|
chunk_mgr=None,
|
||||||
|
rotation=None,
|
||||||
|
satellite=None,
|
||||||
|
coord=None,
|
||||||
|
mavlink=None,
|
||||||
|
):
|
||||||
|
"""Attach pipeline components after construction (avoids circular deps)."""
|
||||||
|
self._vo = vo
|
||||||
|
self._gpr = gpr
|
||||||
|
self._metric = metric
|
||||||
|
self._graph = graph
|
||||||
|
self._recovery = recovery
|
||||||
|
self._chunk_mgr = chunk_mgr
|
||||||
|
self._rotation = rotation
|
||||||
|
self._satellite = satellite # PIPE-02: SatelliteDataManager
|
||||||
|
self._coord = coord # PIPE-02/06: CoordinateTransformer
|
||||||
|
self._mavlink = mavlink # PIPE-07: MAVLinkBridge
|
||||||
|
|
||||||
|
# ------ ESKF lifecycle helpers ----------------------------
|
||||||
|
def _init_eskf_for_flight(
|
||||||
|
self, flight_id: str, start_gps: GPSPoint, altitude: float
|
||||||
|
) -> None:
|
||||||
|
"""Create and initialize a per-flight ESKF instance."""
|
||||||
|
if flight_id in self._eskf:
|
||||||
|
return
|
||||||
|
eskf = ESKF(config=self._eskf_config)
|
||||||
|
if self._coord:
|
||||||
|
try:
|
||||||
|
e, n, _ = self._coord.gps_to_enu(flight_id, start_gps)
|
||||||
|
eskf.initialize(np.array([e, n, altitude]), time.time())
|
||||||
|
except Exception:
|
||||||
|
eskf.initialize(np.zeros(3), time.time())
|
||||||
|
else:
|
||||||
|
eskf.initialize(np.zeros(3), time.time())
|
||||||
|
self._eskf[flight_id] = eskf
|
||||||
|
|
||||||
|
def _eskf_to_gps(self, flight_id: str, eskf: ESKF) -> Optional[GPSPoint]:
|
||||||
|
"""Convert current ESKF ENU position to WGS84 GPS."""
|
||||||
|
if not eskf.initialized or self._coord is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
pos = eskf.position
|
||||||
|
return self._coord.enu_to_gps(flight_id, (float(pos[0]), float(pos[1]), float(pos[2])))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# process_frame — central orchestration
|
||||||
|
# =========================================================
|
||||||
|
async def process_frame(
|
||||||
|
self,
|
||||||
|
flight_id: str,
|
||||||
|
frame_id: int,
|
||||||
|
image: np.ndarray,
|
||||||
|
) -> FrameResult:
|
||||||
|
"""
|
||||||
|
Process a single UAV frame through the full pipeline.
|
||||||
|
|
||||||
|
State transitions:
|
||||||
|
NORMAL — VO succeeds → ESKF VO update, attempt satellite fix
|
||||||
|
LOST — VO failed → create new chunk, enter RECOVERY
|
||||||
|
RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL
|
||||||
|
|
||||||
|
PIPE-01: VO result → eskf.update_vo → satellite match → eskf.update_satellite → MAVLink GPS_INPUT
|
||||||
|
PIPE-02: SatelliteDataManager + CoordinateTransformer wired for tile selection
|
||||||
|
PIPE-04: Consecutive failure counter wired to FailureRecoveryCoordinator
|
||||||
|
PIPE-05: ImageRotationManager initialised on first frame
|
||||||
|
PIPE-07: ESKF confidence → MAVLink fix_type via bridge.update_state
|
||||||
|
"""
|
||||||
|
result = FrameResult(frame_id)
|
||||||
|
state = self._flight_states.get(flight_id, TrackingState.NORMAL)
|
||||||
|
eskf = self._eskf.get(flight_id)
|
||||||
|
|
||||||
|
_default_cam = CameraParameters(
|
||||||
|
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
||||||
|
resolution_width=640, resolution_height=480,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- PIPE-05: Initialise heading tracking on first frame ----
|
||||||
|
if self._rotation and frame_id == 0:
|
||||||
|
self._rotation.requires_rotation_sweep(flight_id) # seeds HeadingHistory
|
||||||
|
|
||||||
|
# ---- 1. Visual Odometry (frame-to-frame) ----
|
||||||
|
vo_ok = False
|
||||||
|
if self._vo and flight_id in self._prev_images:
|
||||||
|
try:
|
||||||
|
cam = self._flight_cameras.get(flight_id, _default_cam)
|
||||||
|
rel_pose = self._vo.compute_relative_pose(
|
||||||
|
self._prev_images[flight_id], image, cam
|
||||||
|
)
|
||||||
|
if rel_pose and rel_pose.tracking_good:
|
||||||
|
vo_ok = True
|
||||||
|
result.vo_success = True
|
||||||
|
|
||||||
|
if self._graph:
|
||||||
|
self._graph.add_relative_factor(
|
||||||
|
flight_id, frame_id - 1, frame_id, rel_pose, np.eye(6)
|
||||||
|
)
|
||||||
|
|
||||||
|
# PIPE-01: Feed VO relative displacement into ESKF
|
||||||
|
if eskf and eskf.initialized:
|
||||||
|
now = time.time()
|
||||||
|
dt_vo = max(0.01, now - (eskf.last_timestamp or now))
|
||||||
|
eskf.update_vo(rel_pose.translation, dt_vo)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("VO failed for frame %d: %s", frame_id, exc)
|
||||||
|
|
||||||
|
# Store current image for next frame
|
||||||
|
self._prev_images[flight_id] = image
|
||||||
|
|
||||||
|
# ---- PIPE-04: Consecutive failure counter ----
|
||||||
|
if not vo_ok and frame_id > 0:
|
||||||
|
self._failure_counts[flight_id] = self._failure_counts.get(flight_id, 0) + 1
|
||||||
|
else:
|
||||||
|
self._failure_counts[flight_id] = 0
|
||||||
|
|
||||||
|
# ---- 2. State Machine transitions ----
|
||||||
|
if state == TrackingState.NORMAL:
|
||||||
|
if not vo_ok and frame_id > 0:
|
||||||
|
state = TrackingState.LOST
|
||||||
|
logger.info("Flight %s → LOST at frame %d", flight_id, frame_id)
|
||||||
|
if self._recovery:
|
||||||
|
self._recovery.handle_tracking_lost(flight_id, frame_id)
|
||||||
|
|
||||||
|
if state == TrackingState.LOST:
|
||||||
|
state = TrackingState.RECOVERY
|
||||||
|
|
||||||
|
if state == TrackingState.RECOVERY:
|
||||||
|
recovered = False
|
||||||
|
if self._recovery and self._chunk_mgr:
|
||||||
|
active_chunk = self._chunk_mgr.get_active_chunk(flight_id)
|
||||||
|
if active_chunk:
|
||||||
|
recovered = self._recovery.process_chunk_recovery(
|
||||||
|
flight_id, active_chunk.chunk_id, [image]
|
||||||
|
)
|
||||||
|
if recovered:
|
||||||
|
state = TrackingState.NORMAL
|
||||||
|
result.alignment_success = True
|
||||||
|
# PIPE-04: Reset failure count on successful recovery
|
||||||
|
self._failure_counts[flight_id] = 0
|
||||||
|
logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id)
|
||||||
|
|
||||||
|
# ---- 3. Satellite position fix (PIPE-01/02) ----
|
||||||
|
if state == TrackingState.NORMAL and self._metric:
|
||||||
|
sat_tile: Optional[np.ndarray] = None
|
||||||
|
tile_bounds = None
|
||||||
|
|
||||||
|
# PIPE-02: Prefer real SatelliteDataManager tiles (ESKF ±3σ selection)
|
||||||
|
if self._satellite and eskf and eskf.initialized:
|
||||||
|
gps_est = self._eskf_to_gps(flight_id, eskf)
|
||||||
|
if gps_est:
|
||||||
|
cov = eskf.covariance
|
||||||
|
sigma_h = float(
|
||||||
|
np.sqrt(np.trace(cov[0:3, 0:3]) / 3.0)
|
||||||
|
) if cov is not None else 30.0
|
||||||
|
sigma_h = max(sigma_h, 5.0)
|
||||||
|
try:
|
||||||
|
tile_result = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
self._satellite.fetch_tiles_for_position,
|
||||||
|
gps_est, sigma_h, 18,
|
||||||
|
)
|
||||||
|
if tile_result:
|
||||||
|
sat_tile, tile_bounds = tile_result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Satellite tile fetch failed: %s", exc)
|
||||||
|
|
||||||
|
# Fallback: GPR candidate tile (mock image, real bounds)
|
||||||
|
if sat_tile is None and self._gpr:
|
||||||
|
try:
|
||||||
|
candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1)
|
||||||
|
if candidates:
|
||||||
|
sat_tile = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||||
|
tile_bounds = candidates[0].bounds
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("GPR tile fallback failed: %s", exc)
|
||||||
|
|
||||||
|
if sat_tile is not None and tile_bounds is not None:
|
||||||
|
try:
|
||||||
|
align = self._metric.align_to_satellite(image, sat_tile, tile_bounds)
|
||||||
|
if align and align.matched:
|
||||||
|
result.gps = align.gps_center
|
||||||
|
result.confidence = align.confidence
|
||||||
|
result.alignment_success = True
|
||||||
|
|
||||||
|
if self._graph:
|
||||||
|
self._graph.add_absolute_factor(
|
||||||
|
flight_id, frame_id,
|
||||||
|
align.gps_center, np.eye(6),
|
||||||
|
is_user_anchor=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# PIPE-01: ESKF satellite update — noise from RANSAC confidence
|
||||||
|
if eskf and eskf.initialized and self._coord:
|
||||||
|
try:
|
||||||
|
e, n, _ = self._coord.gps_to_enu(flight_id, align.gps_center)
|
||||||
|
alt = self._altitudes.get(flight_id, 100.0)
|
||||||
|
pos_enu = np.array([e, n, alt])
|
||||||
|
noise_m = 5.0 + 15.0 * (1.0 - float(align.confidence))
|
||||||
|
eskf.update_satellite(pos_enu, noise_m)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("ESKF satellite update failed: %s", exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Metric alignment failed at frame %d: %s", frame_id, exc)
|
||||||
|
|
||||||
|
# ---- 4. Graph optimization (incremental) ----
|
||||||
|
if self._graph:
|
||||||
|
opt_result = self._graph.optimize(flight_id, iterations=5)
|
||||||
|
logger.debug(
|
||||||
|
"Optimization: converged=%s, error=%.4f",
|
||||||
|
opt_result.converged, opt_result.final_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- PIPE-07: Push ESKF state → MAVLink GPS_INPUT ----
|
||||||
|
if self._mavlink and eskf and eskf.initialized:
|
||||||
|
try:
|
||||||
|
eskf_state = eskf.get_state()
|
||||||
|
alt = self._altitudes.get(flight_id, 100.0)
|
||||||
|
self._mavlink.update_state(eskf_state, altitude_m=alt)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("MAVLink state push failed: %s", exc)
|
||||||
|
|
||||||
|
# ---- 5. Publish via SSE ----
|
||||||
|
result.tracking_state = state
|
||||||
|
self._flight_states[flight_id] = state
|
||||||
|
await self._publish_frame_result(flight_id, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _publish_frame_result(self, flight_id: str, result: FrameResult):
|
||||||
|
"""Emit SSE event for processed frame."""
|
||||||
|
event_data = {
|
||||||
|
"frame_id": result.frame_id,
|
||||||
|
"tracking_state": result.tracking_state.value,
|
||||||
|
"vo_success": result.vo_success,
|
||||||
|
"alignment_success": result.alignment_success,
|
||||||
|
"confidence": result.confidence,
|
||||||
|
}
|
||||||
|
if result.gps:
|
||||||
|
event_data["lat"] = result.gps.lat
|
||||||
|
event_data["lon"] = result.gps.lon
|
||||||
|
|
||||||
|
await self.streamer.push_event(
|
||||||
|
flight_id, event_type="frame_result", data=event_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# Existing CRUD / REST helpers (unchanged from Stage 3-4)
|
||||||
|
# =========================================================
|
||||||
|
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
|
||||||
|
flight = await self.repository.insert_flight(
|
||||||
|
name=req.name,
|
||||||
|
description=req.description,
|
||||||
|
start_lat=req.start_gps.lat,
|
||||||
|
start_lon=req.start_gps.lon,
|
||||||
|
altitude=req.altitude,
|
||||||
|
camera_params=req.camera_params.model_dump(),
|
||||||
|
)
|
||||||
|
# P0#2: Store camera params for process_frame VO calls
|
||||||
|
self._flight_cameras[flight.id] = req.camera_params
|
||||||
|
|
||||||
|
for poly in req.geofences.polygons:
|
||||||
|
await self.repository.insert_geofence(
|
||||||
|
flight.id,
|
||||||
|
nw_lat=poly.north_west.lat,
|
||||||
|
nw_lon=poly.north_west.lon,
|
||||||
|
se_lat=poly.south_east.lat,
|
||||||
|
se_lon=poly.south_east.lon,
|
||||||
|
)
|
||||||
|
for w in req.rough_waypoints:
|
||||||
|
await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon)
|
||||||
|
|
||||||
|
# Store per-flight altitude for ESKF/pixel projection
|
||||||
|
self._altitudes[flight.id] = req.altitude or 100.0
|
||||||
|
|
||||||
|
# PIPE-02: Set ENU origin and initialise ESKF for this flight
|
||||||
|
if self._coord:
|
||||||
|
self._coord.set_enu_origin(flight.id, req.start_gps)
|
||||||
|
self._init_eskf_for_flight(flight.id, req.start_gps, req.altitude or 100.0)
|
||||||
|
|
||||||
|
# Start MAVLink bridge for this flight (origin required for GPS_INPUT)
|
||||||
|
if self._mavlink and not self._mavlink._running:
|
||||||
|
try:
|
||||||
|
asyncio.create_task(self._mavlink.start(req.start_gps))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("MAVLink bridge start failed: %s", exc)
|
||||||
|
|
||||||
|
return FlightResponse(
|
||||||
|
flight_id=flight.id,
|
||||||
|
status="prefetching",
|
||||||
|
message="Flight created and prefetching started.",
|
||||||
|
created_at=flight.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_flight(self, flight_id: str) -> FlightDetailResponse | None:
|
||||||
|
flight = await self.repository.get_flight(flight_id)
|
||||||
|
if not flight:
|
||||||
|
return None
|
||||||
|
wps = await self.repository.get_waypoints(flight_id)
|
||||||
|
state = await self.repository.load_flight_state(flight_id)
|
||||||
|
|
||||||
|
waypoints = [
|
||||||
|
Waypoint(
|
||||||
|
id=w.id,
|
||||||
|
lat=w.lat,
|
||||||
|
lon=w.lon,
|
||||||
|
altitude=w.altitude,
|
||||||
|
confidence=w.confidence,
|
||||||
|
timestamp=w.timestamp,
|
||||||
|
refined=w.refined,
|
||||||
|
)
|
||||||
|
for w in wps
|
||||||
|
]
|
||||||
|
|
||||||
|
status = state.status if state else "unknown"
|
||||||
|
frames_processed = state.frames_processed if state else 0
|
||||||
|
frames_total = state.frames_total if state else 0
|
||||||
|
|
||||||
|
from gps_denied.schemas import Geofences
|
||||||
|
|
||||||
|
return FlightDetailResponse(
|
||||||
|
flight_id=flight.id,
|
||||||
|
name=flight.name,
|
||||||
|
description=flight.description,
|
||||||
|
start_gps=GPSPoint(lat=flight.start_lat, lon=flight.start_lon),
|
||||||
|
waypoints=waypoints,
|
||||||
|
geofences=Geofences(polygons=[]),
|
||||||
|
camera_params=flight.camera_params,
|
||||||
|
altitude=flight.altitude,
|
||||||
|
status=status,
|
||||||
|
frames_processed=frames_processed,
|
||||||
|
frames_total=frames_total,
|
||||||
|
created_at=flight.created_at,
|
||||||
|
updated_at=flight.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_flight(self, flight_id: str) -> DeleteResponse:
|
||||||
|
deleted = await self.repository.delete_flight(flight_id)
|
||||||
|
# P0#1: Cleanup in-memory state to prevent memory leaks
|
||||||
|
self._cleanup_flight(flight_id)
|
||||||
|
return DeleteResponse(deleted=deleted, flight_id=flight_id)
|
||||||
|
|
||||||
|
def _cleanup_flight(self, flight_id: str) -> None:
|
||||||
|
"""Remove all in-memory state for a flight (prevents memory leaks)."""
|
||||||
|
self._prev_images.pop(flight_id, None)
|
||||||
|
self._flight_states.pop(flight_id, None)
|
||||||
|
self._flight_cameras.pop(flight_id, None)
|
||||||
|
self._altitudes.pop(flight_id, None)
|
||||||
|
self._failure_counts.pop(flight_id, None)
|
||||||
|
self._eskf.pop(flight_id, None)
|
||||||
|
if self._graph:
|
||||||
|
self._graph.delete_flight_graph(flight_id)
|
||||||
|
|
||||||
|
async def update_waypoint(
|
||||||
|
self, flight_id: str, waypoint_id: str, waypoint: Waypoint
|
||||||
|
) -> UpdateResponse:
|
||||||
|
ok = await self.repository.update_waypoint(
|
||||||
|
flight_id,
|
||||||
|
waypoint_id,
|
||||||
|
lat=waypoint.lat,
|
||||||
|
lon=waypoint.lon,
|
||||||
|
altitude=waypoint.altitude,
|
||||||
|
confidence=waypoint.confidence,
|
||||||
|
refined=waypoint.refined,
|
||||||
|
)
|
||||||
|
return UpdateResponse(updated=ok, waypoint_id=waypoint_id)
|
||||||
|
|
||||||
|
async def batch_update_waypoints(
|
||||||
|
self, flight_id: str, waypoints: list[Waypoint]
|
||||||
|
) -> BatchUpdateResponse:
|
||||||
|
failed = []
|
||||||
|
updated = 0
|
||||||
|
for wp in waypoints:
|
||||||
|
ok = await self.repository.update_waypoint(
|
||||||
|
flight_id,
|
||||||
|
wp.id,
|
||||||
|
lat=wp.lat,
|
||||||
|
lon=wp.lon,
|
||||||
|
altitude=wp.altitude,
|
||||||
|
confidence=wp.confidence,
|
||||||
|
refined=wp.refined,
|
||||||
|
)
|
||||||
|
if ok:
|
||||||
|
updated += 1
|
||||||
|
else:
|
||||||
|
failed.append(wp.id)
|
||||||
|
return BatchUpdateResponse(
|
||||||
|
success=(len(failed) == 0), updated_count=updated, failed_ids=failed
|
||||||
|
)
|
||||||
|
|
||||||
|
async def queue_images(
|
||||||
|
self, flight_id: str, metadata: BatchMetadata, file_count: int
|
||||||
|
) -> BatchResponse:
|
||||||
|
state = await self.repository.load_flight_state(flight_id)
|
||||||
|
if state:
|
||||||
|
total = state.frames_total + file_count
|
||||||
|
await self.repository.save_flight_state(
|
||||||
|
flight_id, frames_total=total, status="processing"
|
||||||
|
)
|
||||||
|
|
||||||
|
next_seq = metadata.end_sequence + 1
|
||||||
|
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
|
||||||
|
return BatchResponse(
|
||||||
|
accepted=True,
|
||||||
|
sequences=seqs,
|
||||||
|
next_expected=next_seq,
|
||||||
|
message=f"Queued {file_count} images.",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_user_fix(
|
||||||
|
self, flight_id: str, req: UserFixRequest
|
||||||
|
) -> UserFixResponse:
|
||||||
|
await self.repository.save_flight_state(
|
||||||
|
flight_id, blocked=False, status="processing"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inject operator position into ESKF with high uncertainty (500m)
|
||||||
|
eskf = self._eskf.get(flight_id)
|
||||||
|
if eskf and eskf.initialized and self._coord:
|
||||||
|
try:
|
||||||
|
e, n, _ = self._coord.gps_to_enu(flight_id, req.satellite_gps)
|
||||||
|
alt = self._altitudes.get(flight_id, 100.0)
|
||||||
|
eskf.update_satellite(np.array([e, n, alt]), noise_meters=500.0)
|
||||||
|
self._failure_counts[flight_id] = 0
|
||||||
|
logger.info("User fix applied for %s: %s", flight_id, req.satellite_gps)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("User fix ESKF injection failed: %s", exc)
|
||||||
|
|
||||||
|
return UserFixResponse(
|
||||||
|
accepted=True, processing_resumed=True, message="Fix applied."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None:
|
||||||
|
state = await self.repository.load_flight_state(flight_id)
|
||||||
|
if not state:
|
||||||
|
return None
|
||||||
|
return FlightStatusResponse(
|
||||||
|
status=state.status,
|
||||||
|
frames_processed=state.frames_processed,
|
||||||
|
frames_total=state.frames_total,
|
||||||
|
current_frame=state.current_frame,
|
||||||
|
current_heading=None,
|
||||||
|
blocked=state.blocked,
|
||||||
|
search_grid_size=state.search_grid_size,
|
||||||
|
created_at=state.created_at,
|
||||||
|
updated_at=state.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def convert_object_to_gps(
|
||||||
|
self, flight_id: str, frame_id: int, pixel: tuple[float, float]
|
||||||
|
) -> ObjectGPSResponse:
|
||||||
|
# PIPE-06: Use real CoordinateTransformer + ESKF pose for ray-ground projection
|
||||||
|
gps: Optional[GPSPoint] = None
|
||||||
|
eskf = self._eskf.get(flight_id)
|
||||||
|
if self._coord and eskf and eskf.initialized:
|
||||||
|
pos = eskf.position
|
||||||
|
quat = eskf.quaternion
|
||||||
|
cam = self._flight_cameras.get(flight_id, CameraParameters(
|
||||||
|
focal_length=4.5, sensor_width=6.17, sensor_height=4.55,
|
||||||
|
resolution_width=640, resolution_height=480,
|
||||||
|
))
|
||||||
|
alt = self._altitudes.get(flight_id, 100.0)
|
||||||
|
try:
|
||||||
|
gps = self._coord.pixel_to_gps(
|
||||||
|
flight_id,
|
||||||
|
pixel,
|
||||||
|
frame_pose={"position": pos},
|
||||||
|
camera_params=cam,
|
||||||
|
altitude=float(alt),
|
||||||
|
quaternion=quat,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("pixel_to_gps failed: %s", exc)
|
||||||
|
|
||||||
|
# Fallback: return ESKF position projected to ground (no pixel shift)
|
||||||
|
if gps is None and eskf:
|
||||||
|
gps = self._eskf_to_gps(flight_id, eskf)
|
||||||
|
|
||||||
|
return ObjectGPSResponse(
|
||||||
|
gps=gps or GPSPoint(lat=0.0, lon=0.0),
|
||||||
|
accuracy_meters=5.0,
|
||||||
|
frame_id=frame_id,
|
||||||
|
pixel=pixel,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stream_events(self, flight_id: str, client_id: str):
|
||||||
|
"""Async generator for SSE stream."""
|
||||||
|
async for event in self.streamer.stream_generator(flight_id, client_id):
|
||||||
|
yield event
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
"""Result Manager (Component F14)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from gps_denied.pipeline.sse_streamer import SSEEventStreamer
|
||||||
|
from gps_denied.db.repository import FlightRepository
|
||||||
|
from gps_denied.schemas import GPSPoint
|
||||||
|
from gps_denied.schemas.events import FrameProcessedEvent
|
||||||
|
|
||||||
|
|
||||||
|
class ResultManager:
|
||||||
|
"""Result consistency and publishing."""
|
||||||
|
|
||||||
|
def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None:
|
||||||
|
self.repo = repo
|
||||||
|
self.sse = sse
|
||||||
|
|
||||||
|
async def update_frame_result(
|
||||||
|
self,
|
||||||
|
flight_id: str,
|
||||||
|
frame_id: int,
|
||||||
|
gps_lat: float,
|
||||||
|
gps_lon: float,
|
||||||
|
altitude: float,
|
||||||
|
heading: float,
|
||||||
|
confidence: float,
|
||||||
|
timestamp: datetime,
|
||||||
|
refined: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Atomic DB update + SSE event publish."""
|
||||||
|
|
||||||
|
# 1. Update DB (in the repository these are auto-committing via flush,
|
||||||
|
# but normally F03 would wrap in a single transaction).
|
||||||
|
await self.repo.save_frame_result(
|
||||||
|
flight_id,
|
||||||
|
frame_id=frame_id,
|
||||||
|
gps_lat=gps_lat,
|
||||||
|
gps_lon=gps_lon,
|
||||||
|
altitude=altitude,
|
||||||
|
heading=heading,
|
||||||
|
confidence=confidence,
|
||||||
|
refined=refined,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait, the spec also wants Waypoints to be updated.
|
||||||
|
# But image frames != waypoints. Waypoints are the planned route.
|
||||||
|
# Actually in the spec it says: "Updates waypoint in waypoints table."
|
||||||
|
# This implies updating the closest waypoint or a generated waypoint path.
|
||||||
|
# We will follow the simplest form for now: update the waypoint if there is one corresponding.
|
||||||
|
# Let's say we update a waypoint with id "wp_{frame_id}" for now if we know how they map,
|
||||||
|
# or we just skip unless specified.
|
||||||
|
|
||||||
|
# 2. Trigger SSE event
|
||||||
|
evt = FrameProcessedEvent(
|
||||||
|
frame_id=frame_id,
|
||||||
|
gps=GPSPoint(lat=gps_lat, lon=gps_lon),
|
||||||
|
altitude=altitude,
|
||||||
|
confidence=confidence,
|
||||||
|
heading=heading,
|
||||||
|
timestamp=timestamp,
|
||||||
|
)
|
||||||
|
if refined:
|
||||||
|
self.sse.send_refinement(flight_id, evt)
|
||||||
|
else:
|
||||||
|
self.sse.send_frame_result(flight_id, evt)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool:
|
||||||
|
# Just delegates to SSE for waypoint updates, which is basically the frame result for UI
|
||||||
|
pass
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
"""SSE Event Streamer (Component F15)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from gps_denied.schemas.events import (
|
||||||
|
FlightCompletedEvent,
|
||||||
|
FrameProcessedEvent,
|
||||||
|
SearchExpandedEvent,
|
||||||
|
SSEEventType,
|
||||||
|
SSEMessage,
|
||||||
|
UserInputNeededEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SSEEventStreamer:
|
||||||
|
"""Manages real-time SSE connections and event broadcasting."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Map: flight_id -> Dict[client_id, asyncio.Queue]
|
||||||
|
self._streams: dict[str, dict[str, asyncio.Queue[SSEMessage | None]]] = defaultdict(dict)
|
||||||
|
|
||||||
|
def create_stream(self, flight_id: str, client_id: str) -> asyncio.Queue[SSEMessage | None]:
|
||||||
|
"""Create a new event queue for a client."""
|
||||||
|
q: asyncio.Queue[SSEMessage | None] = asyncio.Queue()
|
||||||
|
self._streams[flight_id][client_id] = q
|
||||||
|
return q
|
||||||
|
|
||||||
|
def close_stream(self, flight_id: str, client_id: str) -> None:
|
||||||
|
"""Close a client stream by putting a sentinel and removing the queue."""
|
||||||
|
if flight_id in self._streams and client_id in self._streams[flight_id]:
|
||||||
|
q = self._streams[flight_id].pop(client_id)
|
||||||
|
if not self._streams[flight_id]:
|
||||||
|
del self._streams[flight_id]
|
||||||
|
# Put None to signal generator exit
|
||||||
|
try:
|
||||||
|
q.put_nowait(None)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_active_connections(self, flight_id: str) -> int:
|
||||||
|
return len(self._streams.get(flight_id, {}))
|
||||||
|
|
||||||
|
def _broadcast(self, flight_id: str, msg: SSEMessage) -> bool:
|
||||||
|
"""Broadcast a message to all clients subscribed to flight_id."""
|
||||||
|
if flight_id not in self._streams or not self._streams[flight_id]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for q in self._streams[flight_id].values():
|
||||||
|
try:
|
||||||
|
q.put_nowait(msg)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass # Drop if queue is full rather than blocking
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Business Event Senders ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def send_frame_result(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
||||||
|
msg = SSEMessage(
|
||||||
|
event=SSEEventType.FRAME_PROCESSED,
|
||||||
|
data=event_data.model_dump(mode="json"),
|
||||||
|
id=f"frame_{event_data.frame_id}",
|
||||||
|
)
|
||||||
|
return self._broadcast(flight_id, msg)
|
||||||
|
|
||||||
|
def send_refinement(self, flight_id: str, event_data: FrameProcessedEvent) -> bool:
|
||||||
|
msg = SSEMessage(
|
||||||
|
event=SSEEventType.FRAME_REFINED,
|
||||||
|
data=event_data.model_dump(mode="json"),
|
||||||
|
id=f"refine_{event_data.frame_id}",
|
||||||
|
)
|
||||||
|
return self._broadcast(flight_id, msg)
|
||||||
|
|
||||||
|
def send_search_progress(self, flight_id: str, event_data: SearchExpandedEvent) -> bool:
|
||||||
|
msg = SSEMessage(
|
||||||
|
event=SSEEventType.SEARCH_EXPANDED,
|
||||||
|
data=event_data.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
return self._broadcast(flight_id, msg)
|
||||||
|
|
||||||
|
def send_user_input_request(self, flight_id: str, event_data: UserInputNeededEvent) -> bool:
|
||||||
|
msg = SSEMessage(
|
||||||
|
event=SSEEventType.USER_INPUT_NEEDED,
|
||||||
|
data=event_data.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
return self._broadcast(flight_id, msg)
|
||||||
|
|
||||||
|
def send_flight_completed(self, flight_id: str, event_data: FlightCompletedEvent) -> bool:
|
||||||
|
msg = SSEMessage(
|
||||||
|
event=SSEEventType.FLIGHT_COMPLETED,
|
||||||
|
data=event_data.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
return self._broadcast(flight_id, msg)
|
||||||
|
|
||||||
|
def send_heartbeat(self, flight_id: str) -> bool:
|
||||||
|
# sse_starlette uses empty string or comment for heartbeat,
|
||||||
|
# but we can just send an SSEMessage object that parses as empty event
|
||||||
|
if flight_id not in self._streams:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Manually sending a comment via the generator is tricky with strict SSEMessage schema
|
||||||
|
# but we'll handle this in the stream generator directly
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Generic event dispatcher (used by processor.process_frame) ──────────
|
||||||
|
|
||||||
|
async def push_event(self, flight_id: str, event_type: str, data: dict) -> None:
|
||||||
|
"""Dispatch a generic event to all clients for a flight.
|
||||||
|
|
||||||
|
Maps event_type strings to typed SSE events:
|
||||||
|
"frame_result" → FrameProcessedEvent
|
||||||
|
"refinement" → FrameProcessedEvent (refined)
|
||||||
|
Other → raw broadcast via SSEMessage
|
||||||
|
"""
|
||||||
|
if event_type == "frame_result":
|
||||||
|
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
||||||
|
self.send_frame_result(flight_id, evt)
|
||||||
|
elif event_type == "refinement":
|
||||||
|
evt = FrameProcessedEvent(**data) if not isinstance(data, FrameProcessedEvent) else data
|
||||||
|
self.send_refinement(flight_id, evt)
|
||||||
|
else:
|
||||||
|
msg = SSEMessage(
|
||||||
|
event=SSEEventType.FRAME_PROCESSED,
|
||||||
|
data=data,
|
||||||
|
id=str(data.get("frame_id", "")),
|
||||||
|
)
|
||||||
|
self._broadcast(flight_id, msg)
|
||||||
|
|
||||||
|
# ── Stream Generator ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def stream_generator(self, flight_id: str, client_id: str):
|
||||||
|
"""Yields dicts for sse_starlette EventSourceResponse."""
|
||||||
|
q = self.create_stream(flight_id, client_id)
|
||||||
|
|
||||||
|
# Send an immediate connection accepted ping
|
||||||
|
yield {"event": "connected", "data": "connected"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Wait for next event or send heartbeat every 15s
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(q.get(), timeout=15.0)
|
||||||
|
if msg is None:
|
||||||
|
# Sentinel for clean shutdown
|
||||||
|
break
|
||||||
|
|
||||||
|
# Yield dict format for sse_starlette
|
||||||
|
yield {
|
||||||
|
"event": msg.event.value,
|
||||||
|
"id": msg.id if msg.id else "",
|
||||||
|
"data": json.dumps(msg.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Heartbeat format for sse_starlette (empty string generates a comment)
|
||||||
|
yield {"event": "heartbeat", "data": "ping"}
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass # Client disconnected
|
||||||
|
finally:
|
||||||
|
self.close_stream(flight_id, client_id)
|
||||||
@@ -0,0 +1,371 @@
|
|||||||
|
"""Accuracy Benchmark (Phase 7).
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- SyntheticTrajectory — generates a realistic fixed-wing UAV flight path
|
||||||
|
with ground-truth GPS + noisy sensor data.
|
||||||
|
- AccuracyBenchmark — replays a trajectory through the ESKF pipeline
|
||||||
|
and computes position-error statistics.
|
||||||
|
|
||||||
|
Acceptance criteria (from solution.md):
|
||||||
|
AC-PERF-1: 80 % of frames within 50 m of ground truth.
|
||||||
|
AC-PERF-2: 60 % of frames within 20 m of ground truth.
|
||||||
|
AC-PERF-3: End-to-end per-frame latency < 400 ms.
|
||||||
|
AC-PERF-4: VO drift over 1 km straight segment (no sat correction) < 100 m.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gps_denied.core.coordinates import CoordinateTransformer
|
||||||
|
from gps_denied.core.eskf import ESKF
|
||||||
|
from gps_denied.schemas import GPSPoint
|
||||||
|
from gps_denied.schemas.eskf import ESKFConfig, IMUMeasurement
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Synthetic trajectory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrajectoryFrame:
|
||||||
|
"""One simulated camera frame with ground-truth and noisy sensor data."""
|
||||||
|
frame_id: int
|
||||||
|
timestamp: float
|
||||||
|
true_position_enu: np.ndarray # (3,) East, North, Up in metres
|
||||||
|
true_gps: GPSPoint # WGS84 from true ENU
|
||||||
|
imu_measurements: list[IMUMeasurement] # High-rate IMU between frames
|
||||||
|
vo_translation: Optional[np.ndarray] # Noisy relative displacement (3,)
|
||||||
|
vo_tracking_good: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SyntheticTrajectoryConfig:
|
||||||
|
"""Parameters for trajectory generation."""
|
||||||
|
# Origin (mission start)
|
||||||
|
origin: GPSPoint = field(default_factory=lambda: GPSPoint(lat=49.0, lon=32.0))
|
||||||
|
altitude_m: float = 600.0 # Constant AGL altitude (m)
|
||||||
|
# UAV speed and heading
|
||||||
|
speed_mps: float = 20.0 # ~70 km/h (typical fixed-wing)
|
||||||
|
heading_deg: float = 45.0 # Initial heading (degrees CW from North)
|
||||||
|
camera_fps: float = 0.7 # ADTI 20L V1 camera rate (Hz)
|
||||||
|
imu_hz: float = 200.0 # IMU sample rate
|
||||||
|
num_frames: int = 50 # Number of camera frames to simulate
|
||||||
|
# Noise parameters
|
||||||
|
vo_noise_m: float = 0.5 # VO translation noise (sigma, metres)
|
||||||
|
imu_accel_noise: float = 0.01 # Accelerometer noise sigma (m/s²)
|
||||||
|
imu_gyro_noise: float = 0.001 # Gyroscope noise sigma (rad/s)
|
||||||
|
# Failure injection
|
||||||
|
vo_failure_frames: list[int] = field(default_factory=list)
|
||||||
|
# Waypoints for heading changes (ENU East, North metres from origin)
|
||||||
|
waypoints_enu: list[tuple[float, float]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticTrajectory:
|
||||||
|
"""Generate a synthetic fixed-wing UAV flight with ground truth + noisy sensors."""
|
||||||
|
|
||||||
|
def __init__(self, config: SyntheticTrajectoryConfig | None = None):
|
||||||
|
self.config = config or SyntheticTrajectoryConfig()
|
||||||
|
self._coord = CoordinateTransformer()
|
||||||
|
self._flight_id = "__synthetic__"
|
||||||
|
self._coord.set_enu_origin(self._flight_id, self.config.origin)
|
||||||
|
|
||||||
|
def generate(self) -> list[TrajectoryFrame]:
|
||||||
|
"""Generate all trajectory frames."""
|
||||||
|
cfg = self.config
|
||||||
|
dt_camera = 1.0 / cfg.camera_fps
|
||||||
|
dt_imu = 1.0 / cfg.imu_hz
|
||||||
|
imu_steps = int(dt_camera * cfg.imu_hz)
|
||||||
|
|
||||||
|
frames: list[TrajectoryFrame] = []
|
||||||
|
pos = np.array([0.0, 0.0, cfg.altitude_m])
|
||||||
|
vel = self._heading_to_enu_vel(cfg.heading_deg, cfg.speed_mps)
|
||||||
|
prev_pos = pos.copy()
|
||||||
|
t = time.time()
|
||||||
|
|
||||||
|
waypoints = list(cfg.waypoints_enu) # copy
|
||||||
|
|
||||||
|
for fid in range(cfg.num_frames):
|
||||||
|
# --- Waypoint steering ---
|
||||||
|
if waypoints:
|
||||||
|
wp_e, wp_n = waypoints[0]
|
||||||
|
to_wp = np.array([wp_e - pos[0], wp_n - pos[1], 0.0])
|
||||||
|
dist_wp = np.linalg.norm(to_wp[:2])
|
||||||
|
if dist_wp < cfg.speed_mps * dt_camera:
|
||||||
|
waypoints.pop(0)
|
||||||
|
else:
|
||||||
|
heading_rad = math.atan2(to_wp[0], to_wp[1]) # ENU: E=X, N=Y
|
||||||
|
vel = np.array([
|
||||||
|
cfg.speed_mps * math.sin(heading_rad),
|
||||||
|
cfg.speed_mps * math.cos(heading_rad),
|
||||||
|
0.0,
|
||||||
|
])
|
||||||
|
|
||||||
|
# --- Simulate IMU between frames ---
|
||||||
|
imu_list: list[IMUMeasurement] = []
|
||||||
|
for step in range(imu_steps):
|
||||||
|
ts = t + step * dt_imu
|
||||||
|
# Body-frame acceleration (mostly gravity correction, small forward accel)
|
||||||
|
accel_true = np.array([0.0, 0.0, 9.81]) # gravity compensation
|
||||||
|
gyro_true = np.zeros(3)
|
||||||
|
imu = IMUMeasurement(
|
||||||
|
accel=accel_true + np.random.randn(3) * cfg.imu_accel_noise,
|
||||||
|
gyro=gyro_true + np.random.randn(3) * cfg.imu_gyro_noise,
|
||||||
|
timestamp=ts,
|
||||||
|
)
|
||||||
|
imu_list.append(imu)
|
||||||
|
|
||||||
|
# --- Propagate position ---
|
||||||
|
prev_pos = pos.copy()
|
||||||
|
pos = pos + vel * dt_camera
|
||||||
|
t += dt_camera
|
||||||
|
|
||||||
|
# --- True GPS from ENU position ---
|
||||||
|
true_gps = self._coord.enu_to_gps(
|
||||||
|
self._flight_id, (float(pos[0]), float(pos[1]), float(pos[2]))
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- VO measurement (relative displacement + noise) ---
|
||||||
|
true_displacement = pos - prev_pos
|
||||||
|
vo_tracking_good = fid not in cfg.vo_failure_frames
|
||||||
|
if vo_tracking_good:
|
||||||
|
noisy_displacement = true_displacement + np.random.randn(3) * cfg.vo_noise_m
|
||||||
|
noisy_displacement[2] = 0.0 # monocular VO is scale-ambiguous in Z
|
||||||
|
else:
|
||||||
|
noisy_displacement = None
|
||||||
|
|
||||||
|
frames.append(TrajectoryFrame(
|
||||||
|
frame_id=fid,
|
||||||
|
timestamp=t,
|
||||||
|
true_position_enu=pos.copy(),
|
||||||
|
true_gps=true_gps,
|
||||||
|
imu_measurements=imu_list,
|
||||||
|
vo_translation=noisy_displacement,
|
||||||
|
vo_tracking_good=vo_tracking_good,
|
||||||
|
))
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _heading_to_enu_vel(heading_deg: float, speed_mps: float) -> np.ndarray:
|
||||||
|
"""Convert heading (degrees CW from North) to ENU velocity vector."""
|
||||||
|
rad = math.radians(heading_deg)
|
||||||
|
return np.array([
|
||||||
|
speed_mps * math.sin(rad), # East
|
||||||
|
speed_mps * math.cos(rad), # North
|
||||||
|
0.0, # Up
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Accuracy Benchmark
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkResult:
|
||||||
|
"""Position error statistics over a trajectory replay."""
|
||||||
|
errors_m: list[float] # Per-frame horizontal error in metres
|
||||||
|
latencies_ms: list[float] # Per-frame process time in ms
|
||||||
|
frames_total: int
|
||||||
|
frames_with_good_estimate: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p80_error_m(self) -> float:
|
||||||
|
"""80th percentile position error (metres)."""
|
||||||
|
return float(np.percentile(self.errors_m, 80)) if self.errors_m else float("inf")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p60_error_m(self) -> float:
|
||||||
|
"""60th percentile position error (metres)."""
|
||||||
|
return float(np.percentile(self.errors_m, 60)) if self.errors_m else float("inf")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median_error_m(self) -> float:
|
||||||
|
"""Median position error (metres)."""
|
||||||
|
return float(np.median(self.errors_m)) if self.errors_m else float("inf")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_error_m(self) -> float:
|
||||||
|
return float(max(self.errors_m)) if self.errors_m else float("inf")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def p95_latency_ms(self) -> float:
|
||||||
|
"""95th percentile frame latency (ms)."""
|
||||||
|
return float(np.percentile(self.latencies_ms, 95)) if self.latencies_ms else float("inf")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pct_within_50m(self) -> float:
|
||||||
|
"""Fraction of frames within 50 m error."""
|
||||||
|
if not self.errors_m:
|
||||||
|
return 0.0
|
||||||
|
return sum(e <= 50.0 for e in self.errors_m) / len(self.errors_m)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pct_within_20m(self) -> float:
|
||||||
|
"""Fraction of frames within 20 m error."""
|
||||||
|
if not self.errors_m:
|
||||||
|
return 0.0
|
||||||
|
return sum(e <= 20.0 for e in self.errors_m) / len(self.errors_m)
|
||||||
|
|
||||||
|
def passes_acceptance_criteria(self) -> tuple[bool, dict[str, bool]]:
|
||||||
|
"""Check all solution.md acceptance criteria.
|
||||||
|
|
||||||
|
Returns (overall_pass, per_criterion_dict).
|
||||||
|
"""
|
||||||
|
checks = {
|
||||||
|
"AC-PERF-1: 80% within 50m": self.pct_within_50m >= 0.80,
|
||||||
|
"AC-PERF-2: 60% within 20m": self.pct_within_20m >= 0.60,
|
||||||
|
"AC-PERF-3: p95 latency < 400ms": self.p95_latency_ms < 400.0,
|
||||||
|
}
|
||||||
|
overall = all(checks.values())
|
||||||
|
return overall, checks
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
overall, checks = self.passes_acceptance_criteria()
|
||||||
|
lines = [
|
||||||
|
f"Frames: {self.frames_total} | with estimate: {self.frames_with_good_estimate}",
|
||||||
|
f"Error — median: {self.median_error_m:.1f}m p80: {self.p80_error_m:.1f}m "
|
||||||
|
f"p60: {self.p60_error_m:.1f}m max: {self.max_error_m:.1f}m",
|
||||||
|
f"Within 50m: {self.pct_within_50m*100:.1f}% | within 20m: {self.pct_within_20m*100:.1f}%",
|
||||||
|
f"Latency p95: {self.p95_latency_ms:.1f}ms",
|
||||||
|
"",
|
||||||
|
"Acceptance criteria:",
|
||||||
|
]
|
||||||
|
for criterion, passed in checks.items():
|
||||||
|
lines.append(f" {'PASS' if passed else 'FAIL'} {criterion}")
|
||||||
|
lines.append(f"\nOverall: {'PASS' if overall else 'FAIL'}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class AccuracyBenchmark:
|
||||||
|
"""Replays a SyntheticTrajectory through the ESKF and measures accuracy.
|
||||||
|
|
||||||
|
The benchmark uses only the ESKF (no full FlightProcessor) for speed.
|
||||||
|
Satellite corrections are injected optionally via sat_correction_fn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
eskf_config: ESKFConfig | None = None,
|
||||||
|
sat_correction_fn: Optional[Callable[[TrajectoryFrame], Optional[np.ndarray]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
eskf_config: ESKF tuning parameters.
|
||||||
|
sat_correction_fn: Optional callback(frame) → ENU position or None.
|
||||||
|
Called on keyframes to inject satellite corrections.
|
||||||
|
If None, no satellite corrections are applied.
|
||||||
|
"""
|
||||||
|
self.eskf_config = eskf_config or ESKFConfig()
|
||||||
|
self.sat_correction_fn = sat_correction_fn
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
trajectory: list[TrajectoryFrame],
|
||||||
|
origin: GPSPoint,
|
||||||
|
satellite_keyframe_interval: int = 7,
|
||||||
|
) -> BenchmarkResult:
|
||||||
|
"""Replay trajectory frames through ESKF, collect errors and latencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trajectory: List of TrajectoryFrame (from SyntheticTrajectory).
|
||||||
|
origin: WGS84 reference origin for ENU.
|
||||||
|
satellite_keyframe_interval: Apply satellite correction every N frames.
|
||||||
|
"""
|
||||||
|
coord = CoordinateTransformer()
|
||||||
|
flight_id = "__benchmark__"
|
||||||
|
coord.set_enu_origin(flight_id, origin)
|
||||||
|
|
||||||
|
eskf = ESKF(self.eskf_config)
|
||||||
|
# Init at origin with HIGH uncertainty
|
||||||
|
eskf.initialize(np.array([0.0, 0.0, trajectory[0].true_position_enu[2]]),
|
||||||
|
trajectory[0].timestamp)
|
||||||
|
|
||||||
|
errors_m: list[float] = []
|
||||||
|
latencies_ms: list[float] = []
|
||||||
|
frames_with_estimate = 0
|
||||||
|
|
||||||
|
for frame in trajectory:
|
||||||
|
t_frame_start = time.perf_counter()
|
||||||
|
|
||||||
|
# --- IMU prediction ---
|
||||||
|
for imu in frame.imu_measurements:
|
||||||
|
eskf.predict(imu)
|
||||||
|
|
||||||
|
# --- VO update ---
|
||||||
|
if frame.vo_tracking_good and frame.vo_translation is not None:
|
||||||
|
dt_vo = 1.0 / 0.7 # camera interval
|
||||||
|
eskf.update_vo(frame.vo_translation, dt_vo)
|
||||||
|
|
||||||
|
# --- Satellite update (keyframes) ---
|
||||||
|
if frame.frame_id % satellite_keyframe_interval == 0:
|
||||||
|
sat_pos_enu: Optional[np.ndarray] = None
|
||||||
|
if self.sat_correction_fn is not None:
|
||||||
|
sat_pos_enu = self.sat_correction_fn(frame)
|
||||||
|
else:
|
||||||
|
# Default: inject ground-truth position + realistic noise
|
||||||
|
noise_m = 10.0
|
||||||
|
sat_pos_enu = (
|
||||||
|
frame.true_position_enu[:3]
|
||||||
|
+ np.random.randn(3) * noise_m
|
||||||
|
)
|
||||||
|
sat_pos_enu[2] = frame.true_position_enu[2] # keep altitude
|
||||||
|
|
||||||
|
if sat_pos_enu is not None:
|
||||||
|
# Tell ESKF the measurement noise matches what we inject
|
||||||
|
eskf.update_satellite(sat_pos_enu, noise_meters=noise_m)
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - t_frame_start) * 1000.0
|
||||||
|
latencies_ms.append(latency_ms)
|
||||||
|
|
||||||
|
# --- Compute horizontal error vs ground truth ---
|
||||||
|
if eskf.initialized and eskf._nominal_state is not None:
|
||||||
|
est_pos = eskf._nominal_state["position"]
|
||||||
|
true_pos = frame.true_position_enu
|
||||||
|
horiz_error = float(np.linalg.norm(est_pos[:2] - true_pos[:2]))
|
||||||
|
errors_m.append(horiz_error)
|
||||||
|
frames_with_estimate += 1
|
||||||
|
else:
|
||||||
|
errors_m.append(float("inf"))
|
||||||
|
|
||||||
|
return BenchmarkResult(
|
||||||
|
errors_m=errors_m,
|
||||||
|
latencies_ms=latencies_ms,
|
||||||
|
frames_total=len(trajectory),
|
||||||
|
frames_with_good_estimate=frames_with_estimate,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_vo_drift_test(
|
||||||
|
self,
|
||||||
|
trajectory_length_m: float = 1000.0,
|
||||||
|
speed_mps: float = 20.0,
|
||||||
|
) -> float:
|
||||||
|
"""Measure VO drift over a straight segment with NO satellite correction.
|
||||||
|
|
||||||
|
Returns final horizontal position error in metres.
|
||||||
|
Per solution.md, this should be < 100m over 1km.
|
||||||
|
"""
|
||||||
|
fps = 0.7
|
||||||
|
num_frames = max(10, int(trajectory_length_m / speed_mps * fps))
|
||||||
|
cfg = SyntheticTrajectoryConfig(
|
||||||
|
speed_mps=speed_mps,
|
||||||
|
heading_deg=0.0, # straight North
|
||||||
|
camera_fps=fps,
|
||||||
|
num_frames=num_frames,
|
||||||
|
vo_noise_m=0.3, # cuVSLAM-grade VO noise
|
||||||
|
)
|
||||||
|
traj_gen = SyntheticTrajectory(cfg)
|
||||||
|
frames = traj_gen.generate()
|
||||||
|
|
||||||
|
# No satellite corrections
|
||||||
|
benchmark_no_sat = AccuracyBenchmark(
|
||||||
|
eskf_config=self.eskf_config,
|
||||||
|
sat_correction_fn=lambda _: None, # suppress all satellite updates
|
||||||
|
)
|
||||||
|
result = benchmark_no_sat.run(frames, cfg.origin, satellite_keyframe_interval=9999)
|
||||||
|
# Return final-frame error
|
||||||
|
return result.errors_m[-1] if result.errors_m else float("inf")
|
||||||
Reference in New Issue
Block a user