Files
gps-denied-onboard/f13_result_manager.py
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

325 lines
14 KiB
Python

import sqlite3
import json
import csv
import math
import os
import uuid
import logging
from datetime import datetime
from typing import List, Optional, Dict, Any, Union
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# --- Helper Functions ---
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
"""Calculates the great-circle distance between two points in meters."""
R = 6371000.0 # Earth radius in meters
phi1, phi2 = math.radians(lat1), math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = math.sin(delta_phi / 2.0) ** 2 + \
math.cos(phi1) * math.cos(phi2) * \
math.sin(delta_lambda / 2.0) ** 2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
return R * c
# --- Data Models ---
class GPSPoint(BaseModel):
lat: float
lon: float
altitude_m: Optional[float] = 400.0
class ResultData(BaseModel):
result_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
flight_id: str
image_id: str
sequence_number: int
version: int = 1
estimated_gps: GPSPoint
ground_truth_gps: Optional[GPSPoint] = None
error_m: Optional[float] = None
confidence: float
source: str # e.g., "L3", "factor_graph", "user"
processing_time_ms: float = 0.0
metadata: Dict[str, Any] = {}
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
refinement_reason: Optional[str] = None
class ResultStatistics(BaseModel):
mean_error_m: float
median_error_m: float
rmse_m: float
percent_under_50m: float
percent_under_20m: float
max_error_m: float
registration_rate: float
total_images: int
processed_images: int
# --- Implementation ---
class ResultManager:
"""
F13: Result Manager.
Handles persistence, versioning (AC-8 refinement), statistics calculations,
and format exports (CSV, JSON, KML) for the localization results.
"""
def __init__(self, db_path: str = "./results_cache.db"):
self.db_path = db_path
self._init_db()
logger.info(f"ResultManager initialized with DB at {self.db_path}")
def _get_conn(self):
conn = sqlite3.connect(self.db_path, isolation_level=None) # Autocommit handling manually
conn.row_factory = sqlite3.Row
return conn
def _init_db(self):
with self._get_conn() as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS results (
result_id TEXT PRIMARY KEY,
flight_id TEXT,
image_id TEXT,
sequence_number INTEGER,
version INTEGER,
est_lat REAL,
est_lon REAL,
est_alt REAL,
gt_lat REAL,
gt_lon REAL,
error_m REAL,
confidence REAL,
source TEXT,
processing_time_ms REAL,
metadata TEXT,
created_at TEXT,
refinement_reason TEXT
)
''')
conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_image ON results(flight_id, image_id)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_seq ON results(flight_id, sequence_number)')
def _row_to_result(self, row: sqlite3.Row) -> ResultData:
gt_gps = None
if row['gt_lat'] is not None and row['gt_lon'] is not None:
gt_gps = GPSPoint(lat=row['gt_lat'], lon=row['gt_lon'])
return ResultData(
result_id=row['result_id'],
flight_id=row['flight_id'],
image_id=row['image_id'],
sequence_number=row['sequence_number'],
version=row['version'],
estimated_gps=GPSPoint(lat=row['est_lat'], lon=row['est_lon'], altitude_m=row['est_alt']),
ground_truth_gps=gt_gps,
error_m=row['error_m'],
confidence=row['confidence'],
source=row['source'],
processing_time_ms=row['processing_time_ms'],
metadata=json.loads(row['metadata']) if row['metadata'] else {},
created_at=row['created_at'],
refinement_reason=row['refinement_reason']
)
def _compute_error(self, result: ResultData) -> None:
"""Calculates distance error if ground truth is available."""
if result.ground_truth_gps and result.estimated_gps:
result.error_m = haversine_distance(
result.estimated_gps.lat, result.estimated_gps.lon,
result.ground_truth_gps.lat, result.ground_truth_gps.lon
)
def store_result(self, result: ResultData) -> ResultData:
"""Stores a new result. Automatically handles version increments."""
self._compute_error(result)
with self._get_conn() as conn:
# Determine the next version
cursor = conn.execute(
'SELECT MAX(version) as max_v FROM results WHERE flight_id=? AND image_id=?',
(result.flight_id, result.image_id)
)
row = cursor.fetchone()
max_v = row['max_v'] if row['max_v'] is not None else 0
result.version = max_v + 1
conn.execute('''
INSERT INTO results (
result_id, flight_id, image_id, sequence_number, version,
est_lat, est_lon, est_alt, gt_lat, gt_lon, error_m,
confidence, source, processing_time_ms, metadata, created_at, refinement_reason
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
result.result_id, result.flight_id, result.image_id, result.sequence_number, result.version,
result.estimated_gps.lat, result.estimated_gps.lon, result.estimated_gps.altitude_m,
result.ground_truth_gps.lat if result.ground_truth_gps else None,
result.ground_truth_gps.lon if result.ground_truth_gps else None,
result.error_m, result.confidence, result.source, result.processing_time_ms,
json.dumps(result.metadata), result.created_at, result.refinement_reason
))
return result
def store_results_batch(self, results: List[ResultData]) -> List[ResultData]:
"""Atomically stores a batch of results."""
for r in results:
self.store_result(r)
return results
def get_result(self, flight_id: str, image_id: str, include_all_versions: bool = False) -> Union[ResultData, List[ResultData], None]:
"""Retrieves results for a specific image."""
with self._get_conn() as conn:
if include_all_versions:
cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version ASC', (flight_id, image_id))
rows = cursor.fetchall()
return [self._row_to_result(row) for row in rows] if rows else []
else:
cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version DESC LIMIT 1', (flight_id, image_id))
row = cursor.fetchone()
return self._row_to_result(row) if row else None
def get_flight_results(self, flight_id: str, latest_version_only: bool = True, min_confidence: float = 0.0, max_error: float = float('inf')) -> List[ResultData]:
"""Retrieves flight results matching filters."""
with self._get_conn() as conn:
if latest_version_only:
# Subquery to get the latest version per image
query = '''
SELECT r.* FROM results r
INNER JOIN (
SELECT image_id, MAX(version) as max_v
FROM results WHERE flight_id=? GROUP BY image_id
) grouped_r ON r.image_id = grouped_r.image_id AND r.version = grouped_r.max_v
WHERE r.flight_id=? AND r.confidence >= ?
'''
params = [flight_id, flight_id, min_confidence]
else:
query = 'SELECT * FROM results WHERE flight_id=? AND confidence >= ?'
params = [flight_id, min_confidence]
if max_error < float('inf'):
query += ' AND (r.error_m IS NULL OR r.error_m <= ?)'
params.append(max_error)
query += ' ORDER BY r.sequence_number ASC'
cursor = conn.execute(query, tuple(params))
return [self._row_to_result(row) for row in cursor.fetchall()]
def get_result_history(self, flight_id: str, image_id: str) -> List[ResultData]:
"""Retrieves the timeline of versions for a specific image."""
return self.get_result(flight_id, image_id, include_all_versions=True)
def store_user_fix(self, flight_id: str, image_id: str, sequence_number: int, gps: GPSPoint) -> ResultData:
"""Stores a manual user-provided coordinate anchor (AC-6)."""
result = ResultData(
flight_id=flight_id,
image_id=image_id,
sequence_number=sequence_number,
estimated_gps=gps,
confidence=1.0,
source="user",
refinement_reason="Manual User Fix"
)
return self.store_result(result)
def calculate_statistics(self, flight_id: str, total_flight_images: int = 0) -> Optional[ResultStatistics]:
"""Calculates performance validation metrics (AC-1, AC-2, AC-9)."""
results = self.get_flight_results(flight_id, latest_version_only=True)
if not results:
return None
errors = [r.error_m for r in results if r.error_m is not None]
processed_count = len(results)
total_count = max(total_flight_images, processed_count)
if not errors:
# No ground truth to compute spatial stats
return ResultStatistics(
mean_error_m=0.0, median_error_m=0.0, rmse_m=0.0,
percent_under_50m=0.0, percent_under_20m=0.0, max_error_m=0.0,
registration_rate=(processed_count / total_count) * 100.0,
total_images=total_count, processed_images=processed_count
)
errors.sort()
mean_err = sum(errors) / len(errors)
median_err = errors[len(errors) // 2]
rmse = math.sqrt(sum(e**2 for e in errors) / len(errors))
pct_50 = sum(1 for e in errors if e <= 50.0) / len(errors) * 100.0
pct_20 = sum(1 for e in errors if e <= 20.0) / len(errors) * 100.0
return ResultStatistics(
mean_error_m=mean_err,
median_error_m=median_err,
rmse_m=rmse,
percent_under_50m=pct_50,
percent_under_20m=pct_20,
max_error_m=max(errors),
registration_rate=(processed_count / total_count) * 100.0 if total_count else 100.0,
total_images=total_count,
processed_images=processed_count
)
def export_results(self, flight_id: str, format: str = "json", filepath: Optional[str] = None) -> str:
"""Exports flight results to the specified format (json, csv, kml)."""
results = self.get_flight_results(flight_id, latest_version_only=True)
if not filepath:
filepath = f"./export_{flight_id}_{int(datetime.utcnow().timestamp())}.{format}"
if format == "json":
data = {
"flight_id": flight_id,
"total_images": len(results),
"results": [
{
"image": r.image_id,
"sequence": r.sequence_number,
"gps": {"lat": r.estimated_gps.lat, "lon": r.estimated_gps.lon},
"error_m": r.error_m,
"confidence": r.confidence
} for r in results
]
}
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
elif format == "csv":
with open(filepath, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(["image", "sequence", "lat", "lon", "altitude_m", "error_m", "confidence", "source"])
for r in results:
writer.writerow([
r.image_id, r.sequence_number, r.estimated_gps.lat, r.estimated_gps.lon,
r.estimated_gps.altitude_m, r.error_m if r.error_m else "", r.confidence, r.source
])
elif format == "kml":
kml_content = [
'<?xml version="1.0" encoding="UTF-8"?>',
'<kml xmlns="http://www.opengis.net/kml/2.2">',
' <Document>'
]
for r in results:
alt = r.estimated_gps.altitude_m if r.estimated_gps.altitude_m else 400.0
kml_content.append(' <Placemark>')
kml_content.append(f' <name>{r.image_id}</name>')
kml_content.append(' <Point>')
kml_content.append(f' <coordinates>{r.estimated_gps.lon},{r.estimated_gps.lat},{alt}</coordinates>')
kml_content.append(' </Point>')
kml_content.append(' </Placemark>')
kml_content.extend([' </Document>', '</kml>'])
with open(filepath, 'w') as f:
f.write("\n".join(kml_content))
else:
raise ValueError(f"Unsupported export format: {format}")
logger.info(f"Exported {len(results)} results to {filepath}")
return filepath