import sqlite3 import json import csv import math import os import uuid import logging from datetime import datetime from typing import List, Optional, Dict, Any, Union from pydantic import BaseModel, Field logger = logging.getLogger(__name__) # --- Helper Functions --- def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: """Calculates the great-circle distance between two points in meters.""" R = 6371000.0 # Earth radius in meters phi1, phi2 = math.radians(lat1), math.radians(lat2) delta_phi = math.radians(lat2 - lat1) delta_lambda = math.radians(lon2 - lon1) a = math.sin(delta_phi / 2.0) ** 2 + \ math.cos(phi1) * math.cos(phi2) * \ math.sin(delta_lambda / 2.0) ** 2 c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) return R * c # --- Data Models --- class GPSPoint(BaseModel): lat: float lon: float altitude_m: Optional[float] = 400.0 class ResultData(BaseModel): result_id: str = Field(default_factory=lambda: str(uuid.uuid4())) flight_id: str image_id: str sequence_number: int version: int = 1 estimated_gps: GPSPoint ground_truth_gps: Optional[GPSPoint] = None error_m: Optional[float] = None confidence: float source: str # e.g., "L3", "factor_graph", "user" processing_time_ms: float = 0.0 metadata: Dict[str, Any] = {} created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) refinement_reason: Optional[str] = None class ResultStatistics(BaseModel): mean_error_m: float median_error_m: float rmse_m: float percent_under_50m: float percent_under_20m: float max_error_m: float registration_rate: float total_images: int processed_images: int # --- Implementation --- class ResultManager: """ F13: Result Manager. Handles persistence, versioning (AC-8 refinement), statistics calculations, and format exports (CSV, JSON, KML) for the localization results. """ def __init__(self, db_path: str = "./results_cache.db"): self.db_path = db_path self._init_db() logger.info(f"ResultManager initialized with DB at {self.db_path}") def _get_conn(self): conn = sqlite3.connect(self.db_path, isolation_level=None) # Autocommit handling manually conn.row_factory = sqlite3.Row return conn def _init_db(self): with self._get_conn() as conn: conn.execute(''' CREATE TABLE IF NOT EXISTS results ( result_id TEXT PRIMARY KEY, flight_id TEXT, image_id TEXT, sequence_number INTEGER, version INTEGER, est_lat REAL, est_lon REAL, est_alt REAL, gt_lat REAL, gt_lon REAL, error_m REAL, confidence REAL, source TEXT, processing_time_ms REAL, metadata TEXT, created_at TEXT, refinement_reason TEXT ) ''') conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_image ON results(flight_id, image_id)') conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_seq ON results(flight_id, sequence_number)') def _row_to_result(self, row: sqlite3.Row) -> ResultData: gt_gps = None if row['gt_lat'] is not None and row['gt_lon'] is not None: gt_gps = GPSPoint(lat=row['gt_lat'], lon=row['gt_lon']) return ResultData( result_id=row['result_id'], flight_id=row['flight_id'], image_id=row['image_id'], sequence_number=row['sequence_number'], version=row['version'], estimated_gps=GPSPoint(lat=row['est_lat'], lon=row['est_lon'], altitude_m=row['est_alt']), ground_truth_gps=gt_gps, error_m=row['error_m'], confidence=row['confidence'], source=row['source'], processing_time_ms=row['processing_time_ms'], metadata=json.loads(row['metadata']) if row['metadata'] else {}, created_at=row['created_at'], refinement_reason=row['refinement_reason'] ) def _compute_error(self, result: ResultData) -> None: """Calculates distance error if ground truth is available.""" if result.ground_truth_gps and result.estimated_gps: result.error_m = haversine_distance( result.estimated_gps.lat, result.estimated_gps.lon, result.ground_truth_gps.lat, result.ground_truth_gps.lon ) def store_result(self, result: ResultData) -> ResultData: """Stores a new result. Automatically handles version increments.""" self._compute_error(result) with self._get_conn() as conn: # Determine the next version cursor = conn.execute( 'SELECT MAX(version) as max_v FROM results WHERE flight_id=? AND image_id=?', (result.flight_id, result.image_id) ) row = cursor.fetchone() max_v = row['max_v'] if row['max_v'] is not None else 0 result.version = max_v + 1 conn.execute(''' INSERT INTO results ( result_id, flight_id, image_id, sequence_number, version, est_lat, est_lon, est_alt, gt_lat, gt_lon, error_m, confidence, source, processing_time_ms, metadata, created_at, refinement_reason ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( result.result_id, result.flight_id, result.image_id, result.sequence_number, result.version, result.estimated_gps.lat, result.estimated_gps.lon, result.estimated_gps.altitude_m, result.ground_truth_gps.lat if result.ground_truth_gps else None, result.ground_truth_gps.lon if result.ground_truth_gps else None, result.error_m, result.confidence, result.source, result.processing_time_ms, json.dumps(result.metadata), result.created_at, result.refinement_reason )) return result def store_results_batch(self, results: List[ResultData]) -> List[ResultData]: """Atomically stores a batch of results.""" for r in results: self.store_result(r) return results def get_result(self, flight_id: str, image_id: str, include_all_versions: bool = False) -> Union[ResultData, List[ResultData], None]: """Retrieves results for a specific image.""" with self._get_conn() as conn: if include_all_versions: cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version ASC', (flight_id, image_id)) rows = cursor.fetchall() return [self._row_to_result(row) for row in rows] if rows else [] else: cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version DESC LIMIT 1', (flight_id, image_id)) row = cursor.fetchone() return self._row_to_result(row) if row else None def get_flight_results(self, flight_id: str, latest_version_only: bool = True, min_confidence: float = 0.0, max_error: float = float('inf')) -> List[ResultData]: """Retrieves flight results matching filters.""" with self._get_conn() as conn: if latest_version_only: # Subquery to get the latest version per image query = ''' SELECT r.* FROM results r INNER JOIN ( SELECT image_id, MAX(version) as max_v FROM results WHERE flight_id=? GROUP BY image_id ) grouped_r ON r.image_id = grouped_r.image_id AND r.version = grouped_r.max_v WHERE r.flight_id=? AND r.confidence >= ? ''' params = [flight_id, flight_id, min_confidence] else: query = 'SELECT * FROM results WHERE flight_id=? AND confidence >= ?' params = [flight_id, min_confidence] if max_error < float('inf'): query += ' AND (r.error_m IS NULL OR r.error_m <= ?)' params.append(max_error) query += ' ORDER BY r.sequence_number ASC' cursor = conn.execute(query, tuple(params)) return [self._row_to_result(row) for row in cursor.fetchall()] def get_result_history(self, flight_id: str, image_id: str) -> List[ResultData]: """Retrieves the timeline of versions for a specific image.""" return self.get_result(flight_id, image_id, include_all_versions=True) def store_user_fix(self, flight_id: str, image_id: str, sequence_number: int, gps: GPSPoint) -> ResultData: """Stores a manual user-provided coordinate anchor (AC-6).""" result = ResultData( flight_id=flight_id, image_id=image_id, sequence_number=sequence_number, estimated_gps=gps, confidence=1.0, source="user", refinement_reason="Manual User Fix" ) return self.store_result(result) def calculate_statistics(self, flight_id: str, total_flight_images: int = 0) -> Optional[ResultStatistics]: """Calculates performance validation metrics (AC-1, AC-2, AC-9).""" results = self.get_flight_results(flight_id, latest_version_only=True) if not results: return None errors = [r.error_m for r in results if r.error_m is not None] processed_count = len(results) total_count = max(total_flight_images, processed_count) if not errors: # No ground truth to compute spatial stats return ResultStatistics( mean_error_m=0.0, median_error_m=0.0, rmse_m=0.0, percent_under_50m=0.0, percent_under_20m=0.0, max_error_m=0.0, registration_rate=(processed_count / total_count) * 100.0, total_images=total_count, processed_images=processed_count ) errors.sort() mean_err = sum(errors) / len(errors) median_err = errors[len(errors) // 2] rmse = math.sqrt(sum(e**2 for e in errors) / len(errors)) pct_50 = sum(1 for e in errors if e <= 50.0) / len(errors) * 100.0 pct_20 = sum(1 for e in errors if e <= 20.0) / len(errors) * 100.0 return ResultStatistics( mean_error_m=mean_err, median_error_m=median_err, rmse_m=rmse, percent_under_50m=pct_50, percent_under_20m=pct_20, max_error_m=max(errors), registration_rate=(processed_count / total_count) * 100.0 if total_count else 100.0, total_images=total_count, processed_images=processed_count ) def export_results(self, flight_id: str, format: str = "json", filepath: Optional[str] = None) -> str: """Exports flight results to the specified format (json, csv, kml).""" results = self.get_flight_results(flight_id, latest_version_only=True) if not filepath: filepath = f"./export_{flight_id}_{int(datetime.utcnow().timestamp())}.{format}" if format == "json": data = { "flight_id": flight_id, "total_images": len(results), "results": [ { "image": r.image_id, "sequence": r.sequence_number, "gps": {"lat": r.estimated_gps.lat, "lon": r.estimated_gps.lon}, "error_m": r.error_m, "confidence": r.confidence } for r in results ] } with open(filepath, 'w') as f: json.dump(data, f, indent=2) elif format == "csv": with open(filepath, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(["image", "sequence", "lat", "lon", "altitude_m", "error_m", "confidence", "source"]) for r in results: writer.writerow([ r.image_id, r.sequence_number, r.estimated_gps.lat, r.estimated_gps.lon, r.estimated_gps.altitude_m, r.error_m if r.error_m else "", r.confidence, r.source ]) elif format == "kml": kml_content = [ '', '', ' ' ] for r in results: alt = r.estimated_gps.altitude_m if r.estimated_gps.altitude_m else 400.0 kml_content.append(' ') kml_content.append(f' {r.image_id}') kml_content.append(' ') kml_content.append(f' {r.estimated_gps.lon},{r.estimated_gps.lat},{alt}') kml_content.append(' ') kml_content.append(' ') kml_content.extend([' ', '']) with open(filepath, 'w') as f: f.write("\n".join(kml_content)) else: raise ValueError(f"Unsupported export format: {format}") logger.info(f"Exported {len(results)} results to {filepath}") return filepath