mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 11:16:37 +00:00
325 lines
14 KiB
Python
325 lines
14 KiB
Python
import sqlite3
|
|
import json
|
|
import csv
|
|
import math
|
|
import os
|
|
import uuid
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import List, Optional, Dict, Any, Union
|
|
from pydantic import BaseModel, Field
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Helper Functions ---
|
|
|
|
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
|
"""Calculates the great-circle distance between two points in meters."""
|
|
R = 6371000.0 # Earth radius in meters
|
|
phi1, phi2 = math.radians(lat1), math.radians(lat2)
|
|
delta_phi = math.radians(lat2 - lat1)
|
|
delta_lambda = math.radians(lon2 - lon1)
|
|
|
|
a = math.sin(delta_phi / 2.0) ** 2 + \
|
|
math.cos(phi1) * math.cos(phi2) * \
|
|
math.sin(delta_lambda / 2.0) ** 2
|
|
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
|
|
|
return R * c
|
|
|
|
# --- Data Models ---
|
|
|
|
class GPSPoint(BaseModel):
|
|
lat: float
|
|
lon: float
|
|
altitude_m: Optional[float] = 400.0
|
|
|
|
class ResultData(BaseModel):
|
|
result_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
flight_id: str
|
|
image_id: str
|
|
sequence_number: int
|
|
version: int = 1
|
|
estimated_gps: GPSPoint
|
|
ground_truth_gps: Optional[GPSPoint] = None
|
|
error_m: Optional[float] = None
|
|
confidence: float
|
|
source: str # e.g., "L3", "factor_graph", "user"
|
|
processing_time_ms: float = 0.0
|
|
metadata: Dict[str, Any] = {}
|
|
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
|
|
refinement_reason: Optional[str] = None
|
|
|
|
class ResultStatistics(BaseModel):
|
|
mean_error_m: float
|
|
median_error_m: float
|
|
rmse_m: float
|
|
percent_under_50m: float
|
|
percent_under_20m: float
|
|
max_error_m: float
|
|
registration_rate: float
|
|
total_images: int
|
|
processed_images: int
|
|
|
|
# --- Implementation ---
|
|
|
|
class ResultManager:
|
|
"""
|
|
F13: Result Manager.
|
|
Handles persistence, versioning (AC-8 refinement), statistics calculations,
|
|
and format exports (CSV, JSON, KML) for the localization results.
|
|
"""
|
|
def __init__(self, db_path: str = "./results_cache.db"):
|
|
self.db_path = db_path
|
|
self._init_db()
|
|
logger.info(f"ResultManager initialized with DB at {self.db_path}")
|
|
|
|
def _get_conn(self):
|
|
conn = sqlite3.connect(self.db_path, isolation_level=None) # Autocommit handling manually
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def _init_db(self):
|
|
with self._get_conn() as conn:
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS results (
|
|
result_id TEXT PRIMARY KEY,
|
|
flight_id TEXT,
|
|
image_id TEXT,
|
|
sequence_number INTEGER,
|
|
version INTEGER,
|
|
est_lat REAL,
|
|
est_lon REAL,
|
|
est_alt REAL,
|
|
gt_lat REAL,
|
|
gt_lon REAL,
|
|
error_m REAL,
|
|
confidence REAL,
|
|
source TEXT,
|
|
processing_time_ms REAL,
|
|
metadata TEXT,
|
|
created_at TEXT,
|
|
refinement_reason TEXT
|
|
)
|
|
''')
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_image ON results(flight_id, image_id)')
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_seq ON results(flight_id, sequence_number)')
|
|
|
|
def _row_to_result(self, row: sqlite3.Row) -> ResultData:
|
|
gt_gps = None
|
|
if row['gt_lat'] is not None and row['gt_lon'] is not None:
|
|
gt_gps = GPSPoint(lat=row['gt_lat'], lon=row['gt_lon'])
|
|
|
|
return ResultData(
|
|
result_id=row['result_id'],
|
|
flight_id=row['flight_id'],
|
|
image_id=row['image_id'],
|
|
sequence_number=row['sequence_number'],
|
|
version=row['version'],
|
|
estimated_gps=GPSPoint(lat=row['est_lat'], lon=row['est_lon'], altitude_m=row['est_alt']),
|
|
ground_truth_gps=gt_gps,
|
|
error_m=row['error_m'],
|
|
confidence=row['confidence'],
|
|
source=row['source'],
|
|
processing_time_ms=row['processing_time_ms'],
|
|
metadata=json.loads(row['metadata']) if row['metadata'] else {},
|
|
created_at=row['created_at'],
|
|
refinement_reason=row['refinement_reason']
|
|
)
|
|
|
|
def _compute_error(self, result: ResultData) -> None:
|
|
"""Calculates distance error if ground truth is available."""
|
|
if result.ground_truth_gps and result.estimated_gps:
|
|
result.error_m = haversine_distance(
|
|
result.estimated_gps.lat, result.estimated_gps.lon,
|
|
result.ground_truth_gps.lat, result.ground_truth_gps.lon
|
|
)
|
|
|
|
def store_result(self, result: ResultData) -> ResultData:
|
|
"""Stores a new result. Automatically handles version increments."""
|
|
self._compute_error(result)
|
|
|
|
with self._get_conn() as conn:
|
|
# Determine the next version
|
|
cursor = conn.execute(
|
|
'SELECT MAX(version) as max_v FROM results WHERE flight_id=? AND image_id=?',
|
|
(result.flight_id, result.image_id)
|
|
)
|
|
row = cursor.fetchone()
|
|
max_v = row['max_v'] if row['max_v'] is not None else 0
|
|
result.version = max_v + 1
|
|
|
|
conn.execute('''
|
|
INSERT INTO results (
|
|
result_id, flight_id, image_id, sequence_number, version,
|
|
est_lat, est_lon, est_alt, gt_lat, gt_lon, error_m,
|
|
confidence, source, processing_time_ms, metadata, created_at, refinement_reason
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
result.result_id, result.flight_id, result.image_id, result.sequence_number, result.version,
|
|
result.estimated_gps.lat, result.estimated_gps.lon, result.estimated_gps.altitude_m,
|
|
result.ground_truth_gps.lat if result.ground_truth_gps else None,
|
|
result.ground_truth_gps.lon if result.ground_truth_gps else None,
|
|
result.error_m, result.confidence, result.source, result.processing_time_ms,
|
|
json.dumps(result.metadata), result.created_at, result.refinement_reason
|
|
))
|
|
return result
|
|
|
|
def store_results_batch(self, results: List[ResultData]) -> List[ResultData]:
|
|
"""Atomically stores a batch of results."""
|
|
for r in results:
|
|
self.store_result(r)
|
|
return results
|
|
|
|
def get_result(self, flight_id: str, image_id: str, include_all_versions: bool = False) -> Union[ResultData, List[ResultData], None]:
|
|
"""Retrieves results for a specific image."""
|
|
with self._get_conn() as conn:
|
|
if include_all_versions:
|
|
cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version ASC', (flight_id, image_id))
|
|
rows = cursor.fetchall()
|
|
return [self._row_to_result(row) for row in rows] if rows else []
|
|
else:
|
|
cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version DESC LIMIT 1', (flight_id, image_id))
|
|
row = cursor.fetchone()
|
|
return self._row_to_result(row) if row else None
|
|
|
|
def get_flight_results(self, flight_id: str, latest_version_only: bool = True, min_confidence: float = 0.0, max_error: float = float('inf')) -> List[ResultData]:
|
|
"""Retrieves flight results matching filters."""
|
|
with self._get_conn() as conn:
|
|
if latest_version_only:
|
|
# Subquery to get the latest version per image
|
|
query = '''
|
|
SELECT r.* FROM results r
|
|
INNER JOIN (
|
|
SELECT image_id, MAX(version) as max_v
|
|
FROM results WHERE flight_id=? GROUP BY image_id
|
|
) grouped_r ON r.image_id = grouped_r.image_id AND r.version = grouped_r.max_v
|
|
WHERE r.flight_id=? AND r.confidence >= ?
|
|
'''
|
|
params = [flight_id, flight_id, min_confidence]
|
|
else:
|
|
query = 'SELECT * FROM results WHERE flight_id=? AND confidence >= ?'
|
|
params = [flight_id, min_confidence]
|
|
|
|
if max_error < float('inf'):
|
|
query += ' AND (r.error_m IS NULL OR r.error_m <= ?)'
|
|
params.append(max_error)
|
|
|
|
query += ' ORDER BY r.sequence_number ASC'
|
|
|
|
cursor = conn.execute(query, tuple(params))
|
|
return [self._row_to_result(row) for row in cursor.fetchall()]
|
|
|
|
def get_result_history(self, flight_id: str, image_id: str) -> List[ResultData]:
|
|
"""Retrieves the timeline of versions for a specific image."""
|
|
return self.get_result(flight_id, image_id, include_all_versions=True)
|
|
|
|
def store_user_fix(self, flight_id: str, image_id: str, sequence_number: int, gps: GPSPoint) -> ResultData:
|
|
"""Stores a manual user-provided coordinate anchor (AC-6)."""
|
|
result = ResultData(
|
|
flight_id=flight_id,
|
|
image_id=image_id,
|
|
sequence_number=sequence_number,
|
|
estimated_gps=gps,
|
|
confidence=1.0,
|
|
source="user",
|
|
refinement_reason="Manual User Fix"
|
|
)
|
|
return self.store_result(result)
|
|
|
|
def calculate_statistics(self, flight_id: str, total_flight_images: int = 0) -> Optional[ResultStatistics]:
|
|
"""Calculates performance validation metrics (AC-1, AC-2, AC-9)."""
|
|
results = self.get_flight_results(flight_id, latest_version_only=True)
|
|
if not results:
|
|
return None
|
|
|
|
errors = [r.error_m for r in results if r.error_m is not None]
|
|
processed_count = len(results)
|
|
total_count = max(total_flight_images, processed_count)
|
|
|
|
if not errors:
|
|
# No ground truth to compute spatial stats
|
|
return ResultStatistics(
|
|
mean_error_m=0.0, median_error_m=0.0, rmse_m=0.0,
|
|
percent_under_50m=0.0, percent_under_20m=0.0, max_error_m=0.0,
|
|
registration_rate=(processed_count / total_count) * 100.0,
|
|
total_images=total_count, processed_images=processed_count
|
|
)
|
|
|
|
errors.sort()
|
|
mean_err = sum(errors) / len(errors)
|
|
median_err = errors[len(errors) // 2]
|
|
rmse = math.sqrt(sum(e**2 for e in errors) / len(errors))
|
|
pct_50 = sum(1 for e in errors if e <= 50.0) / len(errors) * 100.0
|
|
pct_20 = sum(1 for e in errors if e <= 20.0) / len(errors) * 100.0
|
|
|
|
return ResultStatistics(
|
|
mean_error_m=mean_err,
|
|
median_error_m=median_err,
|
|
rmse_m=rmse,
|
|
percent_under_50m=pct_50,
|
|
percent_under_20m=pct_20,
|
|
max_error_m=max(errors),
|
|
registration_rate=(processed_count / total_count) * 100.0 if total_count else 100.0,
|
|
total_images=total_count,
|
|
processed_images=processed_count
|
|
)
|
|
|
|
def export_results(self, flight_id: str, format: str = "json", filepath: Optional[str] = None) -> str:
|
|
"""Exports flight results to the specified format (json, csv, kml)."""
|
|
results = self.get_flight_results(flight_id, latest_version_only=True)
|
|
|
|
if not filepath:
|
|
filepath = f"./export_{flight_id}_{int(datetime.utcnow().timestamp())}.{format}"
|
|
|
|
if format == "json":
|
|
data = {
|
|
"flight_id": flight_id,
|
|
"total_images": len(results),
|
|
"results": [
|
|
{
|
|
"image": r.image_id,
|
|
"sequence": r.sequence_number,
|
|
"gps": {"lat": r.estimated_gps.lat, "lon": r.estimated_gps.lon},
|
|
"error_m": r.error_m,
|
|
"confidence": r.confidence
|
|
} for r in results
|
|
]
|
|
}
|
|
with open(filepath, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
elif format == "csv":
|
|
with open(filepath, 'w', newline='') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(["image", "sequence", "lat", "lon", "altitude_m", "error_m", "confidence", "source"])
|
|
for r in results:
|
|
writer.writerow([
|
|
r.image_id, r.sequence_number, r.estimated_gps.lat, r.estimated_gps.lon,
|
|
r.estimated_gps.altitude_m, r.error_m if r.error_m else "", r.confidence, r.source
|
|
])
|
|
|
|
elif format == "kml":
|
|
kml_content = [
|
|
'<?xml version="1.0" encoding="UTF-8"?>',
|
|
'<kml xmlns="http://www.opengis.net/kml/2.2">',
|
|
' <Document>'
|
|
]
|
|
for r in results:
|
|
alt = r.estimated_gps.altitude_m if r.estimated_gps.altitude_m else 400.0
|
|
kml_content.append(' <Placemark>')
|
|
kml_content.append(f' <name>{r.image_id}</name>')
|
|
kml_content.append(' <Point>')
|
|
kml_content.append(f' <coordinates>{r.estimated_gps.lon},{r.estimated_gps.lat},{alt}</coordinates>')
|
|
kml_content.append(' </Point>')
|
|
kml_content.append(' </Placemark>')
|
|
|
|
kml_content.extend([' </Document>', '</kml>'])
|
|
with open(filepath, 'w') as f:
|
|
f.write("\n".join(kml_content))
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported export format: {format}")
|
|
|
|
logger.info(f"Exported {len(results)} results to {filepath}")
|
|
return filepath |