mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 05:36:36 +00:00
feat: stage2 — SQLite DB layer (ORM, async engine, repository, cascade delete, 9 DB tests)
This commit is contained in:
@@ -64,9 +64,9 @@
|
|||||||
- **Дані:** Завантажити тестові зображення у папку `data/test_flights`.
|
- **Дані:** Завантажити тестові зображення у папку `data/test_flights`.
|
||||||
- **Ваги:** Завантажити ваги SuperPoint, LightGlue, LiteSAM локально в `weights/`.
|
- **Ваги:** Завантажити ваги SuperPoint, LightGlue, LiteSAM локально в `weights/`.
|
||||||
|
|
||||||
### Етап 1 — Конфігурація та доменні моделі
|
### Етап 1 — Конфігурація та доменні моделі ✅
|
||||||
- Реалізувати завантаження конфігів з env + YAML.
|
- Реалізовано завантаження конфігів з `.env` через `pydantic-settings` (`config.py`).
|
||||||
- Pydantic-схеми: Flight, Waypoint, ImageBatch, події SSE.
|
- Pydantic-схеми: GPSPoint, CameraParameters, Flight*, Waypoint, Batch*, SSE events.
|
||||||
|
|
||||||
### Етап 2 — База даних полёту
|
### Етап 2 — База даних полёту
|
||||||
- SQLite БД: міграції (flights, waypoints, frame results, chunk state).
|
- SQLite БД: міграції (flights, waypoints, frame results, chunk state).
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Database package."""
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
"""Async database engine and session factory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from gps_denied.config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
def _build_engine():
|
||||||
|
settings = get_settings()
|
||||||
|
connect_args = {}
|
||||||
|
if "sqlite" in settings.db.url:
|
||||||
|
connect_args["check_same_thread"] = False
|
||||||
|
|
||||||
|
eng = create_async_engine(
|
||||||
|
settings.db.url,
|
||||||
|
echo=settings.db.echo,
|
||||||
|
connect_args=connect_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable FK enforcement for SQLite (required for ON DELETE CASCADE)
|
||||||
|
@event.listens_for(eng.sync_engine, "connect")
|
||||||
|
def _set_sqlite_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return eng
|
||||||
|
|
||||||
|
|
||||||
|
engine = _build_engine()
|
||||||
|
|
||||||
|
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncSession: # noqa: ANN201 — used as FastAPI dependency
|
||||||
|
"""Yield an async session (for use as a FastAPI dependency)."""
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
yield session
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
"""SQLAlchemy ORM models for the flight database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
JSON,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
return datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _new_id() -> str:
|
||||||
|
return uuid.uuid4().hex
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Declarative base for all ORM models."""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Flights ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class FlightRow(Base):
|
||||||
|
__tablename__ = "flights"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, default="")
|
||||||
|
start_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
start_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
altitude: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
camera_params: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships (cascade delete)
|
||||||
|
waypoints: Mapped[list[WaypointRow]] = relationship(
|
||||||
|
back_populates="flight", cascade="all, delete-orphan", passive_deletes=True
|
||||||
|
)
|
||||||
|
geofences: Mapped[list[GeofenceRow]] = relationship(
|
||||||
|
back_populates="flight", cascade="all, delete-orphan", passive_deletes=True
|
||||||
|
)
|
||||||
|
state: Mapped[FlightStateRow | None] = relationship(
|
||||||
|
back_populates="flight", cascade="all, delete-orphan", passive_deletes=True, uselist=False
|
||||||
|
)
|
||||||
|
frame_results: Mapped[list[FrameResultRow]] = relationship(
|
||||||
|
back_populates="flight", cascade="all, delete-orphan", passive_deletes=True
|
||||||
|
)
|
||||||
|
headings: Mapped[list[HeadingRow]] = relationship(
|
||||||
|
back_populates="flight", cascade="all, delete-orphan", passive_deletes=True
|
||||||
|
)
|
||||||
|
images: Mapped[list[ImageRow]] = relationship(
|
||||||
|
back_populates="flight", cascade="all, delete-orphan", passive_deletes=True
|
||||||
|
)
|
||||||
|
chunks: Mapped[list[ChunkRow]] = relationship(
|
||||||
|
back_populates="flight", cascade="all, delete-orphan", passive_deletes=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Waypoints ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class WaypointRow(Base):
|
||||||
|
__tablename__ = "waypoints"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id)
|
||||||
|
flight_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
altitude: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
|
timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
refined: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
|
||||||
|
flight: Mapped[FlightRow] = relationship(back_populates="waypoints")
|
||||||
|
|
||||||
|
__table_args__ = (Index("ix_waypoints_flight_ts", "flight_id", "timestamp"),)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Geofences ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class GeofenceRow(Base):
|
||||||
|
__tablename__ = "geofences"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id)
|
||||||
|
flight_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
nw_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
nw_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
se_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
se_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
|
||||||
|
flight: Mapped[FlightRow] = relationship(back_populates="geofences")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Flight State ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class FlightStateRow(Base):
|
||||||
|
__tablename__ = "flight_state"
|
||||||
|
|
||||||
|
flight_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("flights.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
status: Mapped[str] = mapped_column(String(32), default="created")
|
||||||
|
frames_processed: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
frames_total: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
current_frame: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
blocked: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
search_grid_size: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
flight: Mapped[FlightRow] = relationship(back_populates="state")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Frame Results ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class FrameResultRow(Base):
|
||||||
|
__tablename__ = "frame_results"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id)
|
||||||
|
flight_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
frame_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
gps_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
gps_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
altitude: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
|
heading: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
|
refined: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
|
||||||
|
flight: Mapped[FlightRow] = relationship(back_populates="frame_results")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_frame_results_flight_frame", "flight_id", "frame_id", unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Heading History ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class HeadingRow(Base):
|
||||||
|
__tablename__ = "heading_history"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id)
|
||||||
|
flight_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
frame_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
heading: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
|
||||||
|
flight: Mapped[FlightRow] = relationship(back_populates="headings")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_heading_flight_frame", "flight_id", "frame_id", unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Images ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRow(Base):
|
||||||
|
__tablename__ = "images"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id)
|
||||||
|
flight_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
frame_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
metadata_json: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
|
|
||||||
|
flight: Mapped[FlightRow] = relationship(back_populates="images")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_images_flight_frame", "flight_id", "frame_id", unique=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chunks ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkRow(Base):
|
||||||
|
__tablename__ = "chunks"
|
||||||
|
|
||||||
|
chunk_id: Mapped[str] = mapped_column(String(64), primary_key=True, default=_new_id)
|
||||||
|
flight_id: Mapped[str] = mapped_column(
|
||||||
|
ForeignKey("flights.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
start_frame_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
end_frame_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
frames: Mapped[list] = mapped_column(JSON, default=list)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
has_anchor: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
anchor_frame_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
anchor_lat: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
anchor_lon: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
matching_status: Mapped[str] = mapped_column(String(32), default="pending")
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
flight: Mapped[FlightRow] = relationship(back_populates="chunks")
|
||||||
@@ -0,0 +1,295 @@
|
|||||||
|
"""Repository (DAO) implementing IFlightDatabase operations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import delete, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from gps_denied.db.models import (
|
||||||
|
ChunkRow,
|
||||||
|
FlightRow,
|
||||||
|
FlightStateRow,
|
||||||
|
FrameResultRow,
|
||||||
|
GeofenceRow,
|
||||||
|
HeadingRow,
|
||||||
|
ImageRow,
|
||||||
|
WaypointRow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlightRepository:
|
||||||
|
"""Async repository wrapping all DB operations for flights."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._s = session
|
||||||
|
|
||||||
|
# ── Flight CRUD ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def insert_flight(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
start_lat: float,
|
||||||
|
start_lon: float,
|
||||||
|
altitude: float,
|
||||||
|
camera_params: dict,
|
||||||
|
flight_id: str | None = None,
|
||||||
|
) -> FlightRow:
|
||||||
|
row = FlightRow(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
start_lat=start_lat,
|
||||||
|
start_lon=start_lon,
|
||||||
|
altitude=altitude,
|
||||||
|
camera_params=camera_params,
|
||||||
|
)
|
||||||
|
if flight_id:
|
||||||
|
row.id = flight_id
|
||||||
|
self._s.add(row)
|
||||||
|
await self._s.flush()
|
||||||
|
# Create initial flight state
|
||||||
|
state = FlightStateRow(flight_id=row.id)
|
||||||
|
self._s.add(state)
|
||||||
|
await self._s.flush()
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def get_flight(self, flight_id: str) -> FlightRow | None:
|
||||||
|
return await self._s.get(FlightRow, flight_id)
|
||||||
|
|
||||||
|
async def list_flights(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[FlightRow]:
|
||||||
|
stmt = select(FlightRow).offset(offset).limit(limit).order_by(FlightRow.created_at.desc())
|
||||||
|
if status:
|
||||||
|
stmt = stmt.join(FlightStateRow).where(FlightStateRow.status == status)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def update_flight(self, flight_id: str, **kwargs: Any) -> bool:
|
||||||
|
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
|
||||||
|
stmt = update(FlightRow).where(FlightRow.id == flight_id).values(**kwargs)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return result.rowcount > 0 # type: ignore[union-attr]
|
||||||
|
|
||||||
|
async def delete_flight(self, flight_id: str) -> bool:
|
||||||
|
stmt = delete(FlightRow).where(FlightRow.id == flight_id)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return result.rowcount > 0 # type: ignore[union-attr]
|
||||||
|
|
||||||
|
# ── Waypoints ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def insert_waypoint(
|
||||||
|
self,
|
||||||
|
flight_id: str,
|
||||||
|
*,
|
||||||
|
lat: float,
|
||||||
|
lon: float,
|
||||||
|
altitude: float | None = None,
|
||||||
|
confidence: float = 0.0,
|
||||||
|
refined: bool = False,
|
||||||
|
waypoint_id: str | None = None,
|
||||||
|
) -> WaypointRow:
|
||||||
|
row = WaypointRow(
|
||||||
|
flight_id=flight_id,
|
||||||
|
lat=lat,
|
||||||
|
lon=lon,
|
||||||
|
altitude=altitude,
|
||||||
|
confidence=confidence,
|
||||||
|
refined=refined,
|
||||||
|
)
|
||||||
|
if waypoint_id:
|
||||||
|
row.id = waypoint_id
|
||||||
|
self._s.add(row)
|
||||||
|
await self._s.flush()
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def get_waypoints(self, flight_id: str, *, limit: int | None = None) -> list[WaypointRow]:
|
||||||
|
stmt = (
|
||||||
|
select(WaypointRow)
|
||||||
|
.where(WaypointRow.flight_id == flight_id)
|
||||||
|
.order_by(WaypointRow.timestamp)
|
||||||
|
)
|
||||||
|
if limit:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def update_waypoint(self, flight_id: str, waypoint_id: str, **kwargs: Any) -> bool:
|
||||||
|
stmt = (
|
||||||
|
update(WaypointRow)
|
||||||
|
.where(WaypointRow.id == waypoint_id, WaypointRow.flight_id == flight_id)
|
||||||
|
.values(**kwargs)
|
||||||
|
)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return result.rowcount > 0 # type: ignore[union-attr]
|
||||||
|
|
||||||
|
# ── Flight State ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def save_flight_state(self, flight_id: str, **kwargs: Any) -> bool:
|
||||||
|
kwargs["updated_at"] = datetime.now(tz=timezone.utc)
|
||||||
|
stmt = update(FlightStateRow).where(FlightStateRow.flight_id == flight_id).values(**kwargs)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return result.rowcount > 0 # type: ignore[union-attr]
|
||||||
|
|
||||||
|
async def load_flight_state(self, flight_id: str) -> FlightStateRow | None:
|
||||||
|
return await self._s.get(FlightStateRow, flight_id)
|
||||||
|
|
||||||
|
# ── Frame Results ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def save_frame_result(
|
||||||
|
self,
|
||||||
|
flight_id: str,
|
||||||
|
*,
|
||||||
|
frame_id: int,
|
||||||
|
gps_lat: float,
|
||||||
|
gps_lon: float,
|
||||||
|
altitude: float = 0.0,
|
||||||
|
heading: float = 0.0,
|
||||||
|
confidence: float = 0.0,
|
||||||
|
refined: bool = False,
|
||||||
|
) -> FrameResultRow:
|
||||||
|
row = FrameResultRow(
|
||||||
|
flight_id=flight_id,
|
||||||
|
frame_id=frame_id,
|
||||||
|
gps_lat=gps_lat,
|
||||||
|
gps_lon=gps_lon,
|
||||||
|
altitude=altitude,
|
||||||
|
heading=heading,
|
||||||
|
confidence=confidence,
|
||||||
|
refined=refined,
|
||||||
|
)
|
||||||
|
self._s.add(row)
|
||||||
|
await self._s.flush()
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def get_frame_results(self, flight_id: str) -> list[FrameResultRow]:
|
||||||
|
stmt = (
|
||||||
|
select(FrameResultRow)
|
||||||
|
.where(FrameResultRow.flight_id == flight_id)
|
||||||
|
.order_by(FrameResultRow.frame_id)
|
||||||
|
)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
# ── Heading History ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def save_heading(
|
||||||
|
self, flight_id: str, *, frame_id: int, heading: float
|
||||||
|
) -> HeadingRow:
|
||||||
|
row = HeadingRow(flight_id=flight_id, frame_id=frame_id, heading=heading)
|
||||||
|
self._s.add(row)
|
||||||
|
await self._s.flush()
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def get_heading_history(
|
||||||
|
self, flight_id: str, *, last_n: int | None = None
|
||||||
|
) -> list[HeadingRow]:
|
||||||
|
stmt = (
|
||||||
|
select(HeadingRow)
|
||||||
|
.where(HeadingRow.flight_id == flight_id)
|
||||||
|
.order_by(HeadingRow.timestamp.desc())
|
||||||
|
)
|
||||||
|
if last_n:
|
||||||
|
stmt = stmt.limit(last_n)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_latest_heading(self, flight_id: str) -> float | None:
|
||||||
|
rows = await self.get_heading_history(flight_id, last_n=1)
|
||||||
|
return rows[0].heading if rows else None
|
||||||
|
|
||||||
|
# ── Images ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def save_image_metadata(
|
||||||
|
self, flight_id: str, *, frame_id: int, file_path: str, metadata: dict | None = None
|
||||||
|
) -> ImageRow:
|
||||||
|
row = ImageRow(
|
||||||
|
flight_id=flight_id, frame_id=frame_id, file_path=file_path, metadata_json=metadata
|
||||||
|
)
|
||||||
|
self._s.add(row)
|
||||||
|
await self._s.flush()
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def get_image_path(self, flight_id: str, frame_id: int) -> str | None:
|
||||||
|
stmt = select(ImageRow.file_path).where(
|
||||||
|
ImageRow.flight_id == flight_id, ImageRow.frame_id == frame_id
|
||||||
|
)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# ── Chunks ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def save_chunk(
|
||||||
|
self,
|
||||||
|
flight_id: str,
|
||||||
|
*,
|
||||||
|
chunk_id: str | None = None,
|
||||||
|
start_frame_id: int,
|
||||||
|
end_frame_id: int | None = None,
|
||||||
|
frames: list[int] | None = None,
|
||||||
|
is_active: bool = True,
|
||||||
|
has_anchor: bool = False,
|
||||||
|
anchor_frame_id: int | None = None,
|
||||||
|
anchor_lat: float | None = None,
|
||||||
|
anchor_lon: float | None = None,
|
||||||
|
matching_status: str = "pending",
|
||||||
|
) -> ChunkRow:
|
||||||
|
row = ChunkRow(
|
||||||
|
flight_id=flight_id,
|
||||||
|
start_frame_id=start_frame_id,
|
||||||
|
end_frame_id=end_frame_id,
|
||||||
|
frames=frames or [],
|
||||||
|
is_active=is_active,
|
||||||
|
has_anchor=has_anchor,
|
||||||
|
anchor_frame_id=anchor_frame_id,
|
||||||
|
anchor_lat=anchor_lat,
|
||||||
|
anchor_lon=anchor_lon,
|
||||||
|
matching_status=matching_status,
|
||||||
|
)
|
||||||
|
if chunk_id:
|
||||||
|
row.chunk_id = chunk_id
|
||||||
|
self._s.add(row)
|
||||||
|
await self._s.flush()
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def load_chunks(self, flight_id: str) -> list[ChunkRow]:
|
||||||
|
stmt = select(ChunkRow).where(ChunkRow.flight_id == flight_id)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def delete_chunk(self, flight_id: str, chunk_id: str) -> bool:
|
||||||
|
stmt = delete(ChunkRow).where(
|
||||||
|
ChunkRow.chunk_id == chunk_id, ChunkRow.flight_id == flight_id
|
||||||
|
)
|
||||||
|
result = await self._s.execute(stmt)
|
||||||
|
return result.rowcount > 0 # type: ignore[union-attr]
|
||||||
|
|
||||||
|
# ── Geofences ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def insert_geofence(
|
||||||
|
self,
|
||||||
|
flight_id: str,
|
||||||
|
*,
|
||||||
|
nw_lat: float,
|
||||||
|
nw_lon: float,
|
||||||
|
se_lat: float,
|
||||||
|
se_lon: float,
|
||||||
|
) -> GeofenceRow:
|
||||||
|
row = GeofenceRow(
|
||||||
|
flight_id=flight_id,
|
||||||
|
nw_lat=nw_lat,
|
||||||
|
nw_lon=nw_lon,
|
||||||
|
se_lat=se_lat,
|
||||||
|
se_lon=se_lon,
|
||||||
|
)
|
||||||
|
self._s.add(row)
|
||||||
|
await self._s.flush()
|
||||||
|
return row
|
||||||
@@ -0,0 +1,228 @@
|
|||||||
|
"""Tests for the database layer — CRUD, cascade, transactions."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from gps_denied.db.models import Base
|
||||||
|
from gps_denied.db.repository import FlightRepository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def session():
|
||||||
|
"""Create an in-memory SQLite database for each test."""
|
||||||
|
engine = create_async_engine("sqlite+aiosqlite://", echo=False)
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def _set_sqlite_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
async with async_session() as s:
|
||||||
|
yield s
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repo(session: AsyncSession) -> FlightRepository:
|
||||||
|
return FlightRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
CAM = {
|
||||||
|
"focal_length": 25.0,
|
||||||
|
"sensor_width": 23.5,
|
||||||
|
"sensor_height": 15.6,
|
||||||
|
"resolution_width": 6252,
|
||||||
|
"resolution_height": 4168,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Flight CRUD ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_and_get_flight(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="Test_Flight_001",
|
||||||
|
description="Test",
|
||||||
|
start_lat=48.275,
|
||||||
|
start_lon=37.385,
|
||||||
|
altitude=400,
|
||||||
|
camera_params=CAM,
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
loaded = await repo.get_flight(flight.id)
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.name == "Test_Flight_001"
|
||||||
|
assert loaded.altitude == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_flights(repo: FlightRepository, session: AsyncSession):
|
||||||
|
await repo.insert_flight(
|
||||||
|
name="F1", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
await repo.insert_flight(
|
||||||
|
name="F2", description="", start_lat=0, start_lon=0, altitude=200, camera_params=CAM
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
flights = await repo.list_flights()
|
||||||
|
assert len(flights) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_flight(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="Old", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
ok = await repo.update_flight(flight.id, name="New")
|
||||||
|
await session.commit()
|
||||||
|
assert ok is True
|
||||||
|
|
||||||
|
reloaded = await repo.get_flight(flight.id)
|
||||||
|
assert reloaded.name == "New"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_flight_cascade(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="Del", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
fid = flight.id
|
||||||
|
# Add related entities
|
||||||
|
await repo.insert_waypoint(fid, lat=48.0, lon=37.0, confidence=0.9)
|
||||||
|
await repo.save_frame_result(fid, frame_id=1, gps_lat=48.0, gps_lon=37.0)
|
||||||
|
await repo.save_heading(fid, frame_id=1, heading=90.0)
|
||||||
|
await repo.save_image_metadata(fid, frame_id=1, file_path="/img/1.jpg")
|
||||||
|
await repo.save_chunk(fid, start_frame_id=1, frames=[1, 2, 3])
|
||||||
|
await repo.insert_geofence(fid, nw_lat=49.0, nw_lon=36.0, se_lat=47.0, se_lon=38.0)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Delete flight — should cascade
|
||||||
|
ok = await repo.delete_flight(fid)
|
||||||
|
await session.commit()
|
||||||
|
assert ok is True
|
||||||
|
assert await repo.get_flight(fid) is None
|
||||||
|
assert await repo.get_waypoints(fid) == []
|
||||||
|
assert await repo.get_frame_results(fid) == []
|
||||||
|
assert await repo.get_heading_history(fid) == []
|
||||||
|
assert await repo.load_chunks(fid) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── Waypoints ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_waypoint_crud(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="WP", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
wp = await repo.insert_waypoint(flight.id, lat=48.1, lon=37.2, confidence=0.8)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
wps = await repo.get_waypoints(flight.id)
|
||||||
|
assert len(wps) == 1
|
||||||
|
assert wps[0].lat == 48.1
|
||||||
|
|
||||||
|
ok = await repo.update_waypoint(flight.id, wp.id, lat=48.2, refined=True)
|
||||||
|
await session.commit()
|
||||||
|
assert ok is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Flight State ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flight_state(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="State", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
state = await repo.load_flight_state(flight.id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.status == "created"
|
||||||
|
|
||||||
|
await repo.save_flight_state(flight.id, status="processing", frames_total=500)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
state = await repo.load_flight_state(flight.id)
|
||||||
|
assert state.status == "processing"
|
||||||
|
assert state.frames_total == 500
|
||||||
|
|
||||||
|
|
||||||
|
# ── Frame Results ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_frame_results(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="FR", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
await repo.save_frame_result(
|
||||||
|
flight.id, frame_id=1, gps_lat=48.0, gps_lon=37.0, confidence=0.95
|
||||||
|
)
|
||||||
|
await repo.save_frame_result(
|
||||||
|
flight.id, frame_id=2, gps_lat=48.001, gps_lon=37.001, confidence=0.90
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
results = await repo.get_frame_results(flight.id)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].frame_id == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Heading History ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_heading_history(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="HD", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
await repo.save_heading(flight.id, frame_id=i, heading=float(i * 30))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
latest = await repo.get_latest_heading(flight.id)
|
||||||
|
assert latest == 120.0 # last frame heading
|
||||||
|
|
||||||
|
last3 = await repo.get_heading_history(flight.id, last_n=3)
|
||||||
|
assert len(last3) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chunks ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chunks(repo: FlightRepository, session: AsyncSession):
|
||||||
|
flight = await repo.insert_flight(
|
||||||
|
name="CK", description="", start_lat=0, start_lon=0, altitude=100, camera_params=CAM
|
||||||
|
)
|
||||||
|
chunk = await repo.save_chunk(
|
||||||
|
flight.id,
|
||||||
|
chunk_id="chunk_001",
|
||||||
|
start_frame_id=1,
|
||||||
|
end_frame_id=10,
|
||||||
|
frames=[1, 2, 3, 4, 5],
|
||||||
|
has_anchor=True,
|
||||||
|
anchor_lat=48.0,
|
||||||
|
anchor_lon=37.0,
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
chunks = await repo.load_chunks(flight.id)
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0].chunk_id == "chunk_001"
|
||||||
|
|
||||||
|
ok = await repo.delete_chunk(flight.id, "chunk_001")
|
||||||
|
await session.commit()
|
||||||
|
assert ok is True
|
||||||
|
assert await repo.load_chunks(flight.id) == []
|
||||||
Reference in New Issue
Block a user