mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 04:36:38 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,325 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import csv
|
||||
import math
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
||||
"""Calculates the great-circle distance between two points in meters."""
|
||||
R = 6371000.0 # Earth radius in meters
|
||||
phi1, phi2 = math.radians(lat1), math.radians(lat2)
|
||||
delta_phi = math.radians(lat2 - lat1)
|
||||
delta_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = math.sin(delta_phi / 2.0) ** 2 + \
|
||||
math.cos(phi1) * math.cos(phi2) * \
|
||||
math.sin(delta_lambda / 2.0) ** 2
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
|
||||
return R * c
|
||||
|
||||
# --- Data Models ---
|
||||
|
||||
class GPSPoint(BaseModel):
|
||||
lat: float
|
||||
lon: float
|
||||
altitude_m: Optional[float] = 400.0
|
||||
|
||||
class ResultData(BaseModel):
|
||||
result_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
flight_id: str
|
||||
image_id: str
|
||||
sequence_number: int
|
||||
version: int = 1
|
||||
estimated_gps: GPSPoint
|
||||
ground_truth_gps: Optional[GPSPoint] = None
|
||||
error_m: Optional[float] = None
|
||||
confidence: float
|
||||
source: str # e.g., "L3", "factor_graph", "user"
|
||||
processing_time_ms: float = 0.0
|
||||
metadata: Dict[str, Any] = {}
|
||||
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
|
||||
refinement_reason: Optional[str] = None
|
||||
|
||||
class ResultStatistics(BaseModel):
|
||||
mean_error_m: float
|
||||
median_error_m: float
|
||||
rmse_m: float
|
||||
percent_under_50m: float
|
||||
percent_under_20m: float
|
||||
max_error_m: float
|
||||
registration_rate: float
|
||||
total_images: int
|
||||
processed_images: int
|
||||
|
||||
# --- Implementation ---
|
||||
|
||||
class ResultManager:
|
||||
"""
|
||||
F13: Result Manager.
|
||||
Handles persistence, versioning (AC-8 refinement), statistics calculations,
|
||||
and format exports (CSV, JSON, KML) for the localization results.
|
||||
"""
|
||||
def __init__(self, db_path: str = "./results_cache.db"):
|
||||
self.db_path = db_path
|
||||
self._init_db()
|
||||
logger.info(f"ResultManager initialized with DB at {self.db_path}")
|
||||
|
||||
def _get_conn(self):
|
||||
conn = sqlite3.connect(self.db_path, isolation_level=None) # Autocommit handling manually
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._get_conn() as conn:
|
||||
conn.execute('''
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
result_id TEXT PRIMARY KEY,
|
||||
flight_id TEXT,
|
||||
image_id TEXT,
|
||||
sequence_number INTEGER,
|
||||
version INTEGER,
|
||||
est_lat REAL,
|
||||
est_lon REAL,
|
||||
est_alt REAL,
|
||||
gt_lat REAL,
|
||||
gt_lon REAL,
|
||||
error_m REAL,
|
||||
confidence REAL,
|
||||
source TEXT,
|
||||
processing_time_ms REAL,
|
||||
metadata TEXT,
|
||||
created_at TEXT,
|
||||
refinement_reason TEXT
|
||||
)
|
||||
''')
|
||||
conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_image ON results(flight_id, image_id)')
|
||||
conn.execute('CREATE INDEX IF NOT EXISTS idx_flight_seq ON results(flight_id, sequence_number)')
|
||||
|
||||
def _row_to_result(self, row: sqlite3.Row) -> ResultData:
|
||||
gt_gps = None
|
||||
if row['gt_lat'] is not None and row['gt_lon'] is not None:
|
||||
gt_gps = GPSPoint(lat=row['gt_lat'], lon=row['gt_lon'])
|
||||
|
||||
return ResultData(
|
||||
result_id=row['result_id'],
|
||||
flight_id=row['flight_id'],
|
||||
image_id=row['image_id'],
|
||||
sequence_number=row['sequence_number'],
|
||||
version=row['version'],
|
||||
estimated_gps=GPSPoint(lat=row['est_lat'], lon=row['est_lon'], altitude_m=row['est_alt']),
|
||||
ground_truth_gps=gt_gps,
|
||||
error_m=row['error_m'],
|
||||
confidence=row['confidence'],
|
||||
source=row['source'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else {},
|
||||
created_at=row['created_at'],
|
||||
refinement_reason=row['refinement_reason']
|
||||
)
|
||||
|
||||
def _compute_error(self, result: ResultData) -> None:
|
||||
"""Calculates distance error if ground truth is available."""
|
||||
if result.ground_truth_gps and result.estimated_gps:
|
||||
result.error_m = haversine_distance(
|
||||
result.estimated_gps.lat, result.estimated_gps.lon,
|
||||
result.ground_truth_gps.lat, result.ground_truth_gps.lon
|
||||
)
|
||||
|
||||
def store_result(self, result: ResultData) -> ResultData:
|
||||
"""Stores a new result. Automatically handles version increments."""
|
||||
self._compute_error(result)
|
||||
|
||||
with self._get_conn() as conn:
|
||||
# Determine the next version
|
||||
cursor = conn.execute(
|
||||
'SELECT MAX(version) as max_v FROM results WHERE flight_id=? AND image_id=?',
|
||||
(result.flight_id, result.image_id)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
max_v = row['max_v'] if row['max_v'] is not None else 0
|
||||
result.version = max_v + 1
|
||||
|
||||
conn.execute('''
|
||||
INSERT INTO results (
|
||||
result_id, flight_id, image_id, sequence_number, version,
|
||||
est_lat, est_lon, est_alt, gt_lat, gt_lon, error_m,
|
||||
confidence, source, processing_time_ms, metadata, created_at, refinement_reason
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
result.result_id, result.flight_id, result.image_id, result.sequence_number, result.version,
|
||||
result.estimated_gps.lat, result.estimated_gps.lon, result.estimated_gps.altitude_m,
|
||||
result.ground_truth_gps.lat if result.ground_truth_gps else None,
|
||||
result.ground_truth_gps.lon if result.ground_truth_gps else None,
|
||||
result.error_m, result.confidence, result.source, result.processing_time_ms,
|
||||
json.dumps(result.metadata), result.created_at, result.refinement_reason
|
||||
))
|
||||
return result
|
||||
|
||||
def store_results_batch(self, results: List[ResultData]) -> List[ResultData]:
|
||||
"""Atomically stores a batch of results."""
|
||||
for r in results:
|
||||
self.store_result(r)
|
||||
return results
|
||||
|
||||
def get_result(self, flight_id: str, image_id: str, include_all_versions: bool = False) -> Union[ResultData, List[ResultData], None]:
|
||||
"""Retrieves results for a specific image."""
|
||||
with self._get_conn() as conn:
|
||||
if include_all_versions:
|
||||
cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version ASC', (flight_id, image_id))
|
||||
rows = cursor.fetchall()
|
||||
return [self._row_to_result(row) for row in rows] if rows else []
|
||||
else:
|
||||
cursor = conn.execute('SELECT * FROM results WHERE flight_id=? AND image_id=? ORDER BY version DESC LIMIT 1', (flight_id, image_id))
|
||||
row = cursor.fetchone()
|
||||
return self._row_to_result(row) if row else None
|
||||
|
||||
def get_flight_results(self, flight_id: str, latest_version_only: bool = True, min_confidence: float = 0.0, max_error: float = float('inf')) -> List[ResultData]:
|
||||
"""Retrieves flight results matching filters."""
|
||||
with self._get_conn() as conn:
|
||||
if latest_version_only:
|
||||
# Subquery to get the latest version per image
|
||||
query = '''
|
||||
SELECT r.* FROM results r
|
||||
INNER JOIN (
|
||||
SELECT image_id, MAX(version) as max_v
|
||||
FROM results WHERE flight_id=? GROUP BY image_id
|
||||
) grouped_r ON r.image_id = grouped_r.image_id AND r.version = grouped_r.max_v
|
||||
WHERE r.flight_id=? AND r.confidence >= ?
|
||||
'''
|
||||
params = [flight_id, flight_id, min_confidence]
|
||||
else:
|
||||
query = 'SELECT * FROM results WHERE flight_id=? AND confidence >= ?'
|
||||
params = [flight_id, min_confidence]
|
||||
|
||||
if max_error < float('inf'):
|
||||
query += ' AND (r.error_m IS NULL OR r.error_m <= ?)'
|
||||
params.append(max_error)
|
||||
|
||||
query += ' ORDER BY r.sequence_number ASC'
|
||||
|
||||
cursor = conn.execute(query, tuple(params))
|
||||
return [self._row_to_result(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_result_history(self, flight_id: str, image_id: str) -> List[ResultData]:
|
||||
"""Retrieves the timeline of versions for a specific image."""
|
||||
return self.get_result(flight_id, image_id, include_all_versions=True)
|
||||
|
||||
def store_user_fix(self, flight_id: str, image_id: str, sequence_number: int, gps: GPSPoint) -> ResultData:
|
||||
"""Stores a manual user-provided coordinate anchor (AC-6)."""
|
||||
result = ResultData(
|
||||
flight_id=flight_id,
|
||||
image_id=image_id,
|
||||
sequence_number=sequence_number,
|
||||
estimated_gps=gps,
|
||||
confidence=1.0,
|
||||
source="user",
|
||||
refinement_reason="Manual User Fix"
|
||||
)
|
||||
return self.store_result(result)
|
||||
|
||||
def calculate_statistics(self, flight_id: str, total_flight_images: int = 0) -> Optional[ResultStatistics]:
|
||||
"""Calculates performance validation metrics (AC-1, AC-2, AC-9)."""
|
||||
results = self.get_flight_results(flight_id, latest_version_only=True)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
errors = [r.error_m for r in results if r.error_m is not None]
|
||||
processed_count = len(results)
|
||||
total_count = max(total_flight_images, processed_count)
|
||||
|
||||
if not errors:
|
||||
# No ground truth to compute spatial stats
|
||||
return ResultStatistics(
|
||||
mean_error_m=0.0, median_error_m=0.0, rmse_m=0.0,
|
||||
percent_under_50m=0.0, percent_under_20m=0.0, max_error_m=0.0,
|
||||
registration_rate=(processed_count / total_count) * 100.0,
|
||||
total_images=total_count, processed_images=processed_count
|
||||
)
|
||||
|
||||
errors.sort()
|
||||
mean_err = sum(errors) / len(errors)
|
||||
median_err = errors[len(errors) // 2]
|
||||
rmse = math.sqrt(sum(e**2 for e in errors) / len(errors))
|
||||
pct_50 = sum(1 for e in errors if e <= 50.0) / len(errors) * 100.0
|
||||
pct_20 = sum(1 for e in errors if e <= 20.0) / len(errors) * 100.0
|
||||
|
||||
return ResultStatistics(
|
||||
mean_error_m=mean_err,
|
||||
median_error_m=median_err,
|
||||
rmse_m=rmse,
|
||||
percent_under_50m=pct_50,
|
||||
percent_under_20m=pct_20,
|
||||
max_error_m=max(errors),
|
||||
registration_rate=(processed_count / total_count) * 100.0 if total_count else 100.0,
|
||||
total_images=total_count,
|
||||
processed_images=processed_count
|
||||
)
|
||||
|
||||
def export_results(self, flight_id: str, format: str = "json", filepath: Optional[str] = None) -> str:
|
||||
"""Exports flight results to the specified format (json, csv, kml)."""
|
||||
results = self.get_flight_results(flight_id, latest_version_only=True)
|
||||
|
||||
if not filepath:
|
||||
filepath = f"./export_{flight_id}_{int(datetime.utcnow().timestamp())}.{format}"
|
||||
|
||||
if format == "json":
|
||||
data = {
|
||||
"flight_id": flight_id,
|
||||
"total_images": len(results),
|
||||
"results": [
|
||||
{
|
||||
"image": r.image_id,
|
||||
"sequence": r.sequence_number,
|
||||
"gps": {"lat": r.estimated_gps.lat, "lon": r.estimated_gps.lon},
|
||||
"error_m": r.error_m,
|
||||
"confidence": r.confidence
|
||||
} for r in results
|
||||
]
|
||||
}
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
elif format == "csv":
|
||||
with open(filepath, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["image", "sequence", "lat", "lon", "altitude_m", "error_m", "confidence", "source"])
|
||||
for r in results:
|
||||
writer.writerow([
|
||||
r.image_id, r.sequence_number, r.estimated_gps.lat, r.estimated_gps.lon,
|
||||
r.estimated_gps.altitude_m, r.error_m if r.error_m else "", r.confidence, r.source
|
||||
])
|
||||
|
||||
elif format == "kml":
|
||||
kml_content = [
|
||||
'<?xml version="1.0" encoding="UTF-8"?>',
|
||||
'<kml xmlns="http://www.opengis.net/kml/2.2">',
|
||||
' <Document>'
|
||||
]
|
||||
for r in results:
|
||||
alt = r.estimated_gps.altitude_m if r.estimated_gps.altitude_m else 400.0
|
||||
kml_content.append(' <Placemark>')
|
||||
kml_content.append(f' <name>{r.image_id}</name>')
|
||||
kml_content.append(' <Point>')
|
||||
kml_content.append(f' <coordinates>{r.estimated_gps.lon},{r.estimated_gps.lat},{alt}</coordinates>')
|
||||
kml_content.append(' </Point>')
|
||||
kml_content.append(' </Placemark>')
|
||||
|
||||
kml_content.extend([' </Document>', '</kml>'])
|
||||
with open(filepath, 'w') as f:
|
||||
f.write("\n".join(kml_content))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported export format: {format}")
|
||||
|
||||
logger.info(f"Exported {len(results)} results to {filepath}")
|
||||
return filepath
|
||||
Reference in New Issue
Block a user