mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 08:56:37 +00:00
229 lines
9.4 KiB
Python
229 lines
9.4 KiB
Python
import logging
|
|
from datetime import datetime
|
|
from typing import List, Optional, Tuple, Dict, Any, Callable
|
|
from pydantic import BaseModel, Field
|
|
from abc import ABC, abstractmethod
|
|
|
|
from f02_1_flight_lifecycle_manager import GPSPoint
|
|
from f03_flight_database import FrameResult as F03FrameResult, Waypoint
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Data Models ---
|
|
|
|
class ObjectLocation(BaseModel):
|
|
object_id: str
|
|
pixel: Tuple[float, float]
|
|
gps: GPSPoint
|
|
class_name: str
|
|
confidence: float
|
|
|
|
class FrameResult(BaseModel):
|
|
frame_id: int
|
|
gps_center: GPSPoint
|
|
altitude: float
|
|
heading: float
|
|
confidence: float
|
|
timestamp: datetime
|
|
refined: bool = False
|
|
objects: List[ObjectLocation] = Field(default_factory=list)
|
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
|
|
|
class RefinedFrameResult(BaseModel):
|
|
frame_id: int
|
|
gps_center: GPSPoint
|
|
confidence: float
|
|
heading: Optional[float] = None
|
|
|
|
class FlightStatistics(BaseModel):
|
|
total_frames: int
|
|
processed_frames: int
|
|
refined_frames: int
|
|
mean_confidence: float
|
|
processing_time: float
|
|
|
|
class FlightResults(BaseModel):
|
|
flight_id: str
|
|
frames: List[FrameResult]
|
|
statistics: FlightStatistics
|
|
|
|
# --- Interface ---
|
|
|
|
class IResultManager(ABC):
|
|
@abstractmethod
|
|
def update_frame_result(self, flight_id: str, frame_id: int, result: FrameResult) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def get_flight_results(self, flight_id: str) -> FlightResults: pass
|
|
|
|
@abstractmethod
|
|
def mark_refined(self, flight_id: str, refined_results: List[RefinedFrameResult]) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def get_changed_frames(self, flight_id: str, since: datetime) -> List[int]: pass
|
|
|
|
@abstractmethod
|
|
def update_results_after_chunk_merge(self, flight_id: str, refined_results: List[RefinedFrameResult]) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def export_results(self, flight_id: str, format: str) -> str: pass
|
|
|
|
# --- Implementation ---
|
|
|
|
class ResultManager(IResultManager):
|
|
"""
|
|
F14: Result Manager
|
|
Handles atomic persistence and real-time publishing of individual frame processing results
|
|
and batch refinement updates.
|
|
"""
|
|
def __init__(self, f03_database=None, f15_streamer=None):
|
|
self.f03 = f03_database
|
|
self.f15 = f15_streamer
|
|
|
|
def _map_to_f03_result(self, result: FrameResult) -> F03FrameResult:
|
|
return F03FrameResult(
|
|
frame_id=result.frame_id,
|
|
gps_center=result.gps_center,
|
|
altitude=result.altitude,
|
|
heading=result.heading,
|
|
confidence=result.confidence,
|
|
refined=result.refined,
|
|
timestamp=result.timestamp,
|
|
updated_at=result.updated_at
|
|
)
|
|
|
|
def _map_to_f14_result(self, f03_res: F03FrameResult) -> FrameResult:
|
|
return FrameResult(
|
|
frame_id=f03_res.frame_id,
|
|
gps_center=f03_res.gps_center,
|
|
altitude=f03_res.altitude or 0.0,
|
|
heading=f03_res.heading,
|
|
confidence=f03_res.confidence,
|
|
timestamp=f03_res.timestamp,
|
|
refined=f03_res.refined,
|
|
objects=[],
|
|
updated_at=f03_res.updated_at
|
|
)
|
|
|
|
def _build_frame_transaction(self, flight_id: str, result: FrameResult) -> List[Callable]:
|
|
f03_result = self._map_to_f03_result(result)
|
|
waypoint = Waypoint(
|
|
id=f"wp_{result.frame_id}", lat=result.gps_center.lat, lon=result.gps_center.lon,
|
|
altitude=result.altitude, confidence=result.confidence,
|
|
timestamp=result.timestamp, refined=result.refined
|
|
)
|
|
|
|
return [
|
|
lambda: self.f03.save_frame_result(flight_id, f03_result),
|
|
lambda: self.f03.insert_waypoint(flight_id, waypoint)
|
|
]
|
|
|
|
def update_frame_result(self, flight_id: str, frame_id: int, result: FrameResult) -> bool:
|
|
if not self.f03: return False
|
|
|
|
operations = self._build_frame_transaction(flight_id, result)
|
|
success = self.f03.execute_transaction(operations)
|
|
|
|
if success:
|
|
self.publish_waypoint_update(flight_id, frame_id)
|
|
|
|
return success
|
|
|
|
def publish_waypoint_update(self, flight_id: str, frame_id: int) -> bool:
|
|
if not self.f03 or not self.f15: return False
|
|
|
|
for attempt in range(3):
|
|
try:
|
|
results = self.f03.get_frame_results(flight_id)
|
|
for res in results:
|
|
if res.frame_id == frame_id:
|
|
f14_res = self._map_to_f14_result(res)
|
|
self.f15.send_frame_result(flight_id, f14_res)
|
|
return True
|
|
break # Not found, no point in retrying
|
|
except Exception as e:
|
|
logger.warning(f"Transient error publishing waypoint (attempt {attempt+1}): {e}")
|
|
|
|
logger.error(f"Failed to publish waypoint after DB unavailable or retries exhausted.")
|
|
return False
|
|
|
|
def _compute_flight_statistics(self, frames: List[FrameResult]) -> FlightStatistics:
|
|
total = len(frames)
|
|
refined = sum(1 for f in frames if f.refined)
|
|
mean_conf = sum(f.confidence for f in frames) / total if total > 0 else 0.0
|
|
return FlightStatistics(total_frames=total, processed_frames=total, refined_frames=refined, mean_confidence=mean_conf, processing_time=0.0)
|
|
|
|
def get_flight_results(self, flight_id: str) -> FlightResults:
|
|
if not self.f03:
|
|
return FlightResults(flight_id=flight_id, frames=[], statistics=FlightStatistics(total_frames=0, processed_frames=0, refined_frames=0, mean_confidence=0.0, processing_time=0.0))
|
|
|
|
frames = [self._map_to_f14_result(r) for r in self.f03.get_frame_results(flight_id)]
|
|
stats = self._compute_flight_statistics(frames)
|
|
return FlightResults(flight_id=flight_id, frames=frames, statistics=stats)
|
|
|
|
def _build_batch_refinement_transaction(self, flight_id: str, refined_results: List[RefinedFrameResult]) -> List[Callable]:
|
|
existing_dict = {res.frame_id: res for res in self.f03.get_frame_results(flight_id)}
|
|
operations = []
|
|
|
|
for ref in refined_results:
|
|
if ref.frame_id in existing_dict:
|
|
curr = existing_dict[ref.frame_id]
|
|
curr.gps_center, curr.confidence = ref.gps_center, ref.confidence
|
|
curr.heading = ref.heading if ref.heading is not None else curr.heading
|
|
curr.refined, curr.updated_at = True, datetime.utcnow()
|
|
|
|
operations.extend(self._build_frame_transaction(flight_id, self._map_to_f14_result(curr)))
|
|
|
|
return operations
|
|
|
|
def _publish_refinement_events(self, flight_id: str, frame_ids: List[int]):
|
|
if not self.f03 or not self.f15: return
|
|
|
|
updated_frames = {r.frame_id: self._map_to_f14_result(r) for r in self.f03.get_frame_results(flight_id) if r.frame_id in frame_ids}
|
|
for f_id in frame_ids:
|
|
if f_id in updated_frames:
|
|
self.f15.send_refinement(flight_id, f_id, updated_frames[f_id])
|
|
|
|
def _apply_batch_refinement(self, flight_id: str, refined_results: List[RefinedFrameResult]) -> bool:
|
|
if not self.f03: return False
|
|
|
|
operations = self._build_batch_refinement_transaction(flight_id, refined_results)
|
|
if not operations: return True
|
|
|
|
success = self.f03.execute_transaction(operations)
|
|
if success:
|
|
self._publish_refinement_events(flight_id, [r.frame_id for r in refined_results])
|
|
return success
|
|
|
|
def mark_refined(self, flight_id: str, refined_results: List[RefinedFrameResult]) -> bool:
|
|
return self._apply_batch_refinement(flight_id, refined_results)
|
|
|
|
def update_results_after_chunk_merge(self, flight_id: str, refined_results: List[RefinedFrameResult]) -> bool:
|
|
return self._apply_batch_refinement(flight_id, refined_results)
|
|
|
|
def _safe_dt_compare(self, dt1: datetime, dt2: datetime) -> bool:
|
|
return dt1.replace(tzinfo=None) > dt2.replace(tzinfo=None)
|
|
|
|
def get_changed_frames(self, flight_id: str, since: datetime) -> List[int]:
|
|
if not self.f03: return []
|
|
return [r.frame_id for r in self.f03.get_frame_results(flight_id) if self._safe_dt_compare(r.updated_at, since)]
|
|
|
|
def export_results(self, flight_id: str, format: str) -> str:
|
|
results = self.get_flight_results(flight_id)
|
|
if format.lower() == "json":
|
|
return results.model_dump_json(indent=2)
|
|
elif format.lower() == "csv":
|
|
lines = ["image,sequence,lat,lon,altitude_m,error_m,confidence,source"]
|
|
for f in sorted(results.frames, key=lambda x: x.frame_id):
|
|
lines.append(f"AD{f.frame_id:06d}.jpg,{f.frame_id},{f.gps_center.lat},{f.gps_center.lon},{f.altitude},0.0,{f.confidence},factor_graph")
|
|
return "\n".join(lines)
|
|
elif format.lower() == "kml":
|
|
kml = ['<?xml version="1.0" encoding="UTF-8"?><kml xmlns="http://www.opengis.net/kml/2.2"><Document>']
|
|
for f in results.frames:
|
|
kml.append(f"<Placemark><name>AD{f.frame_id:06d}.jpg</name><Point><coordinates>{f.gps_center.lon},{f.gps_center.lat},{f.altitude}</coordinates></Point></Placemark>")
|
|
kml.append("</Document></kml>")
|
|
return "\n".join(kml)
|
|
return "" |