feat: stage2 — SQLite DB layer (ORM, async engine, repository, cascade delete, 9 DB tests)

This commit is contained in:
Yuzviak
2026-03-22 22:25:44 +02:00
parent 445f3bd099
commit e47274bcbd
6 changed files with 798 additions and 3 deletions
+295
View File
@@ -0,0 +1,295 @@
"""Repository (DAO) implementing IFlightDatabase operations."""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from gps_denied.db.models import (
ChunkRow,
FlightRow,
FlightStateRow,
FrameResultRow,
GeofenceRow,
HeadingRow,
ImageRow,
WaypointRow,
)
class FlightRepository:
"""Async repository wrapping all DB operations for flights."""
def __init__(self, session: AsyncSession) -> None:
self._s = session
# ── Flight CRUD ───────────────────────────────────────────────────
async def insert_flight(
self,
*,
name: str,
description: str,
start_lat: float,
start_lon: float,
altitude: float,
camera_params: dict,
flight_id: str | None = None,
) -> FlightRow:
row = FlightRow(
name=name,
description=description,
start_lat=start_lat,
start_lon=start_lon,
altitude=altitude,
camera_params=camera_params,
)
if flight_id:
row.id = flight_id
self._s.add(row)
await self._s.flush()
# Create initial flight state
state = FlightStateRow(flight_id=row.id)
self._s.add(state)
await self._s.flush()
return row
async def get_flight(self, flight_id: str) -> FlightRow | None:
return await self._s.get(FlightRow, flight_id)
async def list_flights(
self,
*,
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[FlightRow]:
stmt = select(FlightRow).offset(offset).limit(limit).order_by(FlightRow.created_at.desc())
if status:
stmt = stmt.join(FlightStateRow).where(FlightStateRow.status == status)
result = await self._s.execute(stmt)
return list(result.scalars().all())
async def update_flight(self, flight_id: str, **kwargs: Any) -> bool:
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
stmt = update(FlightRow).where(FlightRow.id == flight_id).values(**kwargs)
result = await self._s.execute(stmt)
return result.rowcount > 0 # type: ignore[union-attr]
async def delete_flight(self, flight_id: str) -> bool:
stmt = delete(FlightRow).where(FlightRow.id == flight_id)
result = await self._s.execute(stmt)
return result.rowcount > 0 # type: ignore[union-attr]
# ── Waypoints ─────────────────────────────────────────────────────
async def insert_waypoint(
self,
flight_id: str,
*,
lat: float,
lon: float,
altitude: float | None = None,
confidence: float = 0.0,
refined: bool = False,
waypoint_id: str | None = None,
) -> WaypointRow:
row = WaypointRow(
flight_id=flight_id,
lat=lat,
lon=lon,
altitude=altitude,
confidence=confidence,
refined=refined,
)
if waypoint_id:
row.id = waypoint_id
self._s.add(row)
await self._s.flush()
return row
async def get_waypoints(self, flight_id: str, *, limit: int | None = None) -> list[WaypointRow]:
stmt = (
select(WaypointRow)
.where(WaypointRow.flight_id == flight_id)
.order_by(WaypointRow.timestamp)
)
if limit:
stmt = stmt.limit(limit)
result = await self._s.execute(stmt)
return list(result.scalars().all())
async def update_waypoint(self, flight_id: str, waypoint_id: str, **kwargs: Any) -> bool:
stmt = (
update(WaypointRow)
.where(WaypointRow.id == waypoint_id, WaypointRow.flight_id == flight_id)
.values(**kwargs)
)
result = await self._s.execute(stmt)
return result.rowcount > 0 # type: ignore[union-attr]
# ── Flight State ──────────────────────────────────────────────────
async def save_flight_state(self, flight_id: str, **kwargs: Any) -> bool:
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
stmt = update(FlightStateRow).where(FlightStateRow.flight_id == flight_id).values(**kwargs)
result = await self._s.execute(stmt)
return result.rowcount > 0 # type: ignore[union-attr]
async def load_flight_state(self, flight_id: str) -> FlightStateRow | None:
return await self._s.get(FlightStateRow, flight_id)
# ── Frame Results ─────────────────────────────────────────────────
async def save_frame_result(
self,
flight_id: str,
*,
frame_id: int,
gps_lat: float,
gps_lon: float,
altitude: float = 0.0,
heading: float = 0.0,
confidence: float = 0.0,
refined: bool = False,
) -> FrameResultRow:
row = FrameResultRow(
flight_id=flight_id,
frame_id=frame_id,
gps_lat=gps_lat,
gps_lon=gps_lon,
altitude=altitude,
heading=heading,
confidence=confidence,
refined=refined,
)
self._s.add(row)
await self._s.flush()
return row
async def get_frame_results(self, flight_id: str) -> list[FrameResultRow]:
stmt = (
select(FrameResultRow)
.where(FrameResultRow.flight_id == flight_id)
.order_by(FrameResultRow.frame_id)
)
result = await self._s.execute(stmt)
return list(result.scalars().all())
# ── Heading History ───────────────────────────────────────────────
async def save_heading(
self, flight_id: str, *, frame_id: int, heading: float
) -> HeadingRow:
row = HeadingRow(flight_id=flight_id, frame_id=frame_id, heading=heading)
self._s.add(row)
await self._s.flush()
return row
async def get_heading_history(
self, flight_id: str, *, last_n: int | None = None
) -> list[HeadingRow]:
stmt = (
select(HeadingRow)
.where(HeadingRow.flight_id == flight_id)
.order_by(HeadingRow.timestamp.desc())
)
if last_n:
stmt = stmt.limit(last_n)
result = await self._s.execute(stmt)
return list(result.scalars().all())
async def get_latest_heading(self, flight_id: str) -> float | None:
rows = await self.get_heading_history(flight_id, last_n=1)
return rows[0].heading if rows else None
# ── Images ────────────────────────────────────────────────────────
async def save_image_metadata(
self, flight_id: str, *, frame_id: int, file_path: str, metadata: dict | None = None
) -> ImageRow:
row = ImageRow(
flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=metadata
)
self._s.add(row)
await self._s.flush()
return row
async def get_image_path(self, flight_id: str, frame_id: int) -> str | None:
stmt = select(ImageRow.file_path).where(
ImageRow.flight_id == flight_id, ImageRow.frame_id == frame_id
)
result = await self._s.execute(stmt)
return result.scalar_one_or_none()
# ── Chunks ────────────────────────────────────────────────────────
async def save_chunk(
self,
flight_id: str,
*,
chunk_id: str | None = None,
start_frame_id: int,
end_frame_id: int | None = None,
frames: list[int] | None = None,
is_active: bool = True,
has_anchor: bool = False,
anchor_frame_id: int | None = None,
anchor_lat: float | None = None,
anchor_lon: float | None = None,
matching_status: str = "pending",
) -> ChunkRow:
row = ChunkRow(
flight_id=flight_id,
start_frame_id=start_frame_id,
end_frame_id=end_frame_id,
frames=frames or [],
is_active=is_active,
has_anchor=has_anchor,
anchor_frame_id=anchor_frame_id,
anchor_lat=anchor_lat,
anchor_lon=anchor_lon,
matching_status=matching_status,
)
if chunk_id:
row.chunk_id = chunk_id
self._s.add(row)
await self._s.flush()
return row
async def load_chunks(self, flight_id: str) -> list[ChunkRow]:
stmt = select(ChunkRow).where(ChunkRow.flight_id == flight_id)
result = await self._s.execute(stmt)
return list(result.scalars().all())
async def delete_chunk(self, flight_id: str, chunk_id: str) -> bool:
stmt = delete(ChunkRow).where(
ChunkRow.chunk_id == chunk_id, ChunkRow.flight_id == flight_id
)
result = await self._s.execute(stmt)
return result.rowcount > 0 # type: ignore[union-attr]
# ── Geofences ─────────────────────────────────────────────────────
async def insert_geofence(
self,
flight_id: str,
*,
nw_lat: float,
nw_lon: float,
se_lat: float,
se_lon: float,
) -> GeofenceRow:
row = GeofenceRow(
flight_id=flight_id,
nw_lat=nw_lat,
nw_lon=nw_lon,
se_lat=se_lat,
se_lon=se_lon,
)
self._s.add(row)
await self._s.flush()
return row