feat: stage10 — Full processing cycle with State Machine

This commit is contained in:
Yuzviak
2026-03-22 23:14:33 +02:00
parent 74aa6454b8
commit c86cdc2e82
4 changed files with 342 additions and 15 deletions
+213 -13
View File
@@ -1,9 +1,18 @@
"""Core Flight Processor (Dummy / Stub for Stage 3)."""
"""Core Flight Processor — Full Processing Pipeline (Stage 10).
Orchestrates: ImageInputPipeline → VO → MetricRefinement → FactorGraph → SSE.
State Machine: NORMAL → LOST → RECOVERY → NORMAL.
"""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
import numpy as np
from gps_denied.core.pipeline import ImageInputPipeline
from gps_denied.core.results import ResultManager
@@ -27,9 +36,36 @@ from gps_denied.schemas.flight import (
)
from gps_denied.schemas.image import ImageBatch
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# State Machine
# ---------------------------------------------------------------------------
class TrackingState(str, Enum):
"""Processing state for a flight."""
NORMAL = "normal"
LOST = "lost"
RECOVERY = "recovery"
class FrameResult:
"""Intermediate result of processing a single frame."""
def __init__(self, frame_id: int):
self.frame_id = frame_id
self.gps: Optional[GPSPoint] = None
self.confidence: float = 0.0
self.tracking_state: TrackingState = TrackingState.NORMAL
self.vo_success: bool = False
self.alignment_success: bool = False
# ---------------------------------------------------------------------------
# FlightProcessor
# ---------------------------------------------------------------------------
class FlightProcessor:
"""Manages business logic and background processing for flights."""
"""Manages business logic, background processing, and frame orchestration."""
def __init__(self, repository: FlightRepository, streamer: SSEEventStreamer) -> None:
self.repository = repository
@@ -37,6 +73,165 @@ class FlightProcessor:
self.result_manager = ResultManager(repository, streamer)
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
# Per-flight processing state
self._flight_states: dict[str, TrackingState] = {}
self._prev_images: dict[str, np.ndarray] = {} # previous frame cache
# Lazy-initialised component references (set via `attach_components`)
self._vo = None # SequentialVisualOdometry
self._gpr = None # GlobalPlaceRecognition
self._metric = None # MetricRefinement
self._graph = None # FactorGraphOptimizer
self._recovery = None # FailureRecoveryCoordinator
self._chunk_mgr = None # RouteChunkManager
self._rotation = None # ImageRotationManager
# ------ Dependency injection for core components ---------
def attach_components(
self,
vo=None,
gpr=None,
metric=None,
graph=None,
recovery=None,
chunk_mgr=None,
rotation=None,
):
"""Attach pipeline components after construction (avoids circular deps)."""
self._vo = vo
self._gpr = gpr
self._metric = metric
self._graph = graph
self._recovery = recovery
self._chunk_mgr = chunk_mgr
self._rotation = rotation
# =========================================================
# process_frame — central orchestration
# =========================================================
async def process_frame(
self,
flight_id: str,
frame_id: int,
image: np.ndarray,
) -> FrameResult:
"""
Process a single UAV frame through the full pipeline.
State transitions:
NORMAL — VO succeeds → add relative factor, attempt drift correction
LOST — VO failed → create new chunk, enter RECOVERY
RECOVERY— try GPR + MetricRefinement → if anchored, merge & return to NORMAL
"""
result = FrameResult(frame_id)
state = self._flight_states.get(flight_id, TrackingState.NORMAL)
# ---- 1. Visual Odometry (frame-to-frame) ----
vo_ok = False
if self._vo and flight_id in self._prev_images:
try:
rel_pose = self._vo.compute_relative_pose(
self._prev_images[flight_id], image
)
if rel_pose and rel_pose.tracking_good:
vo_ok = True
result.vo_success = True
# Add factor to graph
if self._graph:
self._graph.add_relative_factor(
flight_id, frame_id - 1, frame_id,
rel_pose, np.eye(6)
)
except Exception as exc:
logger.warning("VO failed for frame %d: %s", frame_id, exc)
# Store current image for next frame
self._prev_images[flight_id] = image
# ---- 2. State Machine transitions ----
if state == TrackingState.NORMAL:
if not vo_ok and frame_id > 0:
# Transition → LOST
state = TrackingState.LOST
logger.info("Flight %s → LOST at frame %d", flight_id, frame_id)
if self._recovery:
self._recovery.handle_tracking_lost(flight_id, frame_id)
if state == TrackingState.LOST:
# Transition → RECOVERY
state = TrackingState.RECOVERY
if state == TrackingState.RECOVERY:
recovered = False
if self._recovery and self._chunk_mgr:
active_chunk = self._chunk_mgr.get_active_chunk(flight_id)
if active_chunk:
recovered = self._recovery.process_chunk_recovery(
flight_id, active_chunk.chunk_id, [image]
)
if recovered:
state = TrackingState.NORMAL
result.alignment_success = True
logger.info("Flight %s recovered → NORMAL at frame %d", flight_id, frame_id)
# ---- 3. Drift correction via Metric Refinement ----
if state == TrackingState.NORMAL and self._metric and self._gpr:
try:
candidates = self._gpr.retrieve_candidate_tiles(image, top_k=1)
if candidates:
best = candidates[0]
sat_img = np.zeros((256, 256, 3), dtype=np.uint8) # mock tile
align = self._metric.align_to_satellite(image, sat_img, best.bounds)
if align and align.matched:
result.gps = align.gps_center
result.confidence = align.confidence
result.alignment_success = True
if self._graph:
self._graph.add_absolute_factor(
flight_id, frame_id,
align.gps_center, np.eye(2),
is_user_anchor=False
)
except Exception as exc:
logger.warning("Drift correction failed at frame %d: %s", frame_id, exc)
# ---- 4. Graph optimization (incremental) ----
if self._graph:
opt_result = self._graph.optimize(flight_id, iterations=5)
logger.debug("Optimization: converged=%s, error=%.4f", opt_result.converged, opt_result.final_error)
# ---- 5. Publish via SSE ----
result.tracking_state = state
self._flight_states[flight_id] = state
await self._publish_frame_result(flight_id, result)
return result
async def _publish_frame_result(self, flight_id: str, result: FrameResult):
"""Emit SSE event for processed frame."""
event_data = {
"frame_id": result.frame_id,
"tracking_state": result.tracking_state.value,
"vo_success": result.vo_success,
"alignment_success": result.alignment_success,
"confidence": result.confidence,
}
if result.gps:
event_data["lat"] = result.gps.lat
event_data["lon"] = result.gps.lon
await self.streamer.push_event(
flight_id, event_type="frame_result", data=event_data
)
# =========================================================
# Existing CRUD / REST helpers (unchanged from Stage 3-4)
# =========================================================
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
flight = await self.repository.insert_flight(
name=req.name,
@@ -83,14 +278,13 @@ class FlightProcessor:
)
for w in wps
]
status = state.status if state else "unknown"
frames_processed = state.frames_processed if state else 0
frames_total = state.frames_total if state else 0
# Assuming empty geofences for now unless loaded (omitted for brevity)
from gps_denied.schemas import Geofences
return FlightDetailResponse(
flight_id=flight.id,
name=flight.name,
@@ -144,7 +338,9 @@ class FlightProcessor:
updated += 1
else:
failed.append(wp.id)
return BatchUpdateResponse(success=(len(failed) == 0), updated_count=updated, failed_ids=failed)
return BatchUpdateResponse(
success=(len(failed) == 0), updated_count=updated, failed_ids=failed
)
async def queue_images(
self, flight_id: str, metadata: BatchMetadata, file_count: int
@@ -152,8 +348,10 @@ class FlightProcessor:
state = await self.repository.load_flight_state(flight_id)
if state:
total = state.frames_total + file_count
await self.repository.save_flight_state(flight_id, frames_total=total, status="processing")
await self.repository.save_flight_state(
flight_id, frames_total=total, status="processing"
)
next_seq = metadata.end_sequence + 1
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
return BatchResponse(
@@ -163,8 +361,12 @@ class FlightProcessor:
message=f"Queued {file_count} images.",
)
async def handle_user_fix(self, flight_id: str, req: UserFixRequest) -> UserFixResponse:
await self.repository.save_flight_state(flight_id, blocked=False, status="processing")
async def handle_user_fix(
self, flight_id: str, req: UserFixRequest
) -> UserFixResponse:
await self.repository.save_flight_state(
flight_id, blocked=False, status="processing"
)
return UserFixResponse(
accepted=True, processing_resumed=True, message="Fix applied."
)
@@ -178,7 +380,7 @@ class FlightProcessor:
frames_processed=state.frames_processed,
frames_total=state.frames_total,
current_frame=state.current_frame,
current_heading=None, # would load from latest
current_heading=None,
blocked=state.blocked,
search_grid_size=state.search_grid_size,
created_at=state.created_at,
@@ -188,7 +390,6 @@ class FlightProcessor:
async def convert_object_to_gps(
self, flight_id: str, frame_id: int, pixel: tuple[float, float]
) -> ObjectGPSResponse:
# Dummy math
return ObjectGPSResponse(
gps=GPSPoint(lat=48.0, lon=37.0),
accuracy_meters=5.0,
@@ -198,6 +399,5 @@ class FlightProcessor:
async def stream_events(self, flight_id: str, client_id: str):
"""Async generator for SSE stream."""
# Yield from the real SSE streamer generator
async for event in self.streamer.stream_generator(flight_id, client_id):
yield event