feat: stage6 — Image Pipeline (F05) and Rotation Manager (F06)

This commit is contained in:
Yuzviak
2026-03-22 22:51:00 +02:00
parent a2fb9ab404
commit 9ef046d623
9 changed files with 653 additions and 26 deletions
+8 -6
View File
@@ -8,12 +8,14 @@
| Підсистема | Технології та реалізація | | Підсистема | Технології та реалізація |
|-----------|------------| |-----------|------------|
| API | FastAPI + Pydantic v2 | | **Core API** | FastAPI + Pydantic v2 (REST endpoints `POST /flights`, `POST /flights/{id}/images/batch`) |
| Стрім подій (SSE) | sse-starlette, asyncio.Queue, pub/sub для real-time трансляції поза | | **Real-time стрім (SSE)** | `sse-starlette`, `asyncio.Queue` (Pub/Sub для live-трансляції уточнених поз на наземну станцію) |
| Репозиторій (БД) | SQLite + SQLAlchemy 2 + AsyncIO + Alembic. Підтримка Cascade Deletes | | **Репозиторій (БД)** | `SQLite` + `SQLAlchemy 2` + `AsyncIO` + `Alembic`. Скеровує каскадні видалення та зберігає waypoint-результати. |
| Супутникові тайли | httpx, diskcache, інтеграція з Google Maps (Web Mercator) | | **Супутникові тайли (F04)** | `httpx`, `diskcache`, інтеграція з Google Maps Static Tiles + утиліти Web Mercator |
| Трансформація координат | ENU Origin, конвертація WGS84 ↔ Local ENU ↔ Pixels | | **Трансформація координат (F13)** | Зберігання локального ENU Origin, конвертація WGS84 ↔ Local ENU ↔ Pixels |
| Граф поз (VO/GPR) | GTSAM (Python) - очікується | | **Вхідний пайплайн (F05)** | `cv2`, `asyncio.Queue`. Керує FIFO чергою батчів кадрів з БПЛА, здійснює базову валідацію послідовностей та збереження фотографій на диск. |
| **Менеджер ротацій (F06)** | Оберти 360° блоками по 30° для підбору орієнтації; трекінг історії курсу з виявленням різких поворотів (>45°). |
| **Граф поз (VO/GPR)** | GTSAM (Python) - очікується в наступних етапах |
## Швидкий старт ## Швидкий старт
+3 -2
View File
@@ -84,8 +84,9 @@
- Клієнт Google Maps тайлів, локальний кеш. - Клієнт Google Maps тайлів, локальний кеш.
- Функції піксель <-> GPS (проекції, ENU координати). - Функції піксель <-> GPS (проекції, ENU координати).
### Етап 6 — Черга зображень і ротації ### Етап 6 — Черга зображень і ротації
- FIFO батчів, менеджер ротацій для кросс-вью. - FIFO батчів (ImageInputPipeline), менеджер ротацій (ImageRotationManager).
- Асинхронне збереження в кеш `image_storage`.
### Етап 7 — Model manager та послідовний VO ### Етап 7 — Model manager та послідовний VO
- Завантаження локальних вагів (SuperPoint+LightGlue), побудова ланцюжка відносних оцінок. - Завантаження локальних вагів (SuperPoint+LightGlue), побудова ланцюжка відносних оцінок.
+204
View File
@@ -0,0 +1,204 @@
"""Image Input Pipeline (Component F05)."""
import asyncio
import os
import re
from datetime import datetime, timezone
import cv2
import numpy as np
from gps_denied.schemas.image import (
ImageBatch, ImageData, ImageMetadata, ProcessedBatch, ProcessingStatus, ValidationResult
)
class QueueFullError(Exception):
pass
class ValidationError(Exception):
pass
class ImageInputPipeline:
"""Manages ingestion, disk storage, and queuing of UAV image batches."""
def __init__(self, storage_dir: str = "image_storage", max_queue_size: int = 50):
self.storage_dir = storage_dir
# flight_id -> asyncio.Queue of ImageBatch
self._queues: dict[str, asyncio.Queue] = {}
self.max_queue_size = max_queue_size
# In-memory tracking (in a real system, sync this with DB)
self._status: dict[str, dict] = {}
def _get_queue(self, flight_id: str) -> asyncio.Queue:
if flight_id not in self._queues:
self._queues[flight_id] = asyncio.Queue(maxsize=self.max_queue_size)
return self._queues[flight_id]
def _init_status(self, flight_id: str):
if flight_id not in self._status:
self._status[flight_id] = {
"total_images": 0,
"processed_images": 0,
"current_sequence": 1,
}
def validate_batch(self, batch: ImageBatch) -> ValidationResult:
"""Validates batch integrity and sequence continuity."""
errors = []
num_images = len(batch.images)
if num_images < 10:
errors.append("Batch is empty")
elif num_images > 100:
errors.append("Batch too large")
if len(batch.filenames) != num_images:
errors.append("Mismatch between filenames and images count")
# Naming convention ADxxxxxx.jpg or similar
pattern = re.compile(r"^[A-Za-z0-9_-]+\.(jpg|jpeg|png)$", re.IGNORECASE)
for fn in batch.filenames:
if not pattern.match(fn):
errors.append(f"Invalid filename: {fn}")
break
if batch.start_sequence > batch.end_sequence:
errors.append("Start sequence greater than end sequence")
return ValidationResult(valid=len(errors) == 0, errors=errors)
def queue_batch(self, flight_id: str, batch: ImageBatch) -> bool:
"""Queues a batch of images for processing."""
val = self.validate_batch(batch)
if not val.valid:
raise ValidationError(f"Batch validation failed: {val.errors}")
q = self._get_queue(flight_id)
if q.full():
raise QueueFullError(f"Queue for flight {flight_id} is full")
q.put_nowait(batch)
self._init_status(flight_id)
self._status[flight_id]["total_images"] += len(batch.images)
return True
async def process_next_batch(self, flight_id: str) -> ProcessedBatch | None:
"""Dequeues and processing the next batch."""
q = self._get_queue(flight_id)
if q.empty():
return None
batch: ImageBatch = await q.get()
processed_images = []
for i, raw_bytes in enumerate(batch.images):
# Decode
nparr = np.frombuffer(raw_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
continue # skip corrupted
seq = batch.start_sequence + i
fn = batch.filenames[i]
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=seq,
filename=fn,
dimensions=(w, h),
file_size=len(raw_bytes),
timestamp=datetime.now(timezone.utc),
)
img_data = ImageData(
flight_id=flight_id,
sequence=seq,
filename=fn,
image=img,
metadata=meta
)
processed_images.append(img_data)
# Store to disk
self.store_images(flight_id, processed_images)
self._status[flight_id]["processed_images"] += len(processed_images)
q.task_done()
return ProcessedBatch(
images=processed_images,
batch_id=f"batch_{batch.batch_number}",
start_sequence=batch.start_sequence,
end_sequence=batch.end_sequence
)
def store_images(self, flight_id: str, images: list[ImageData]) -> bool:
"""Persists images to disk."""
flight_dir = os.path.join(self.storage_dir, flight_id)
os.makedirs(flight_dir, exist_ok=True)
for img in images:
path = os.path.join(flight_dir, img.filename)
cv2.imwrite(path, img.image)
return True
def get_next_image(self, flight_id: str) -> ImageData | None:
"""Gets the next image in sequence for processing."""
self._init_status(flight_id)
seq = self._status[flight_id]["current_sequence"]
img = self.get_image_by_sequence(flight_id, seq)
if img:
self._status[flight_id]["current_sequence"] += 1
return img
def get_image_by_sequence(self, flight_id: str, sequence: int) -> ImageData | None:
"""Retrieves a specific image by sequence number."""
# For simplicity, we assume filenames follow "frame_{sequence:06d}.jpg"
# But if the user uploaded custom files, we'd need a DB lookup.
# Let's use a local map for this prototype if it's strictly required,
# or search the directory.
flight_dir = os.path.join(self.storage_dir, flight_id)
if not os.path.exists(flight_dir):
return None
# search
for fn in os.listdir(flight_dir):
# very rough matching
if str(sequence) in fn or fn.endswith(f"_{sequence:06d}.jpg"):
path = os.path.join(flight_dir, fn)
img = cv2.imread(path)
if img is not None:
h, w = img.shape[:2]
meta = ImageMetadata(
sequence=sequence,
filename=fn,
dimensions=(w, h),
file_size=os.path.getsize(path),
timestamp=datetime.now(timezone.utc)
)
return ImageData(flight_id, sequence, fn, img, meta)
return None
def get_processing_status(self, flight_id: str) -> ProcessingStatus:
self._init_status(flight_id)
s = self._status[flight_id]
q = self._get_queue(flight_id)
return ProcessingStatus(
flight_id=flight_id,
total_images=s["total_images"],
processed_images=s["processed_images"],
current_sequence=s["current_sequence"],
queued_batches=q.qsize(),
processing_rate=0.0 # mock
)
+23 -18
View File
@@ -5,6 +5,8 @@ from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timezone from datetime import datetime, timezone
from gps_denied.core.pipeline import ImageInputPipeline
from gps_denied.core.results import ResultManager
from gps_denied.core.sse import SSEEventStreamer from gps_denied.core.sse import SSEEventStreamer
from gps_denied.db.repository import FlightRepository from gps_denied.db.repository import FlightRepository
from gps_denied.schemas import GPSPoint from gps_denied.schemas import GPSPoint
@@ -23,17 +25,20 @@ from gps_denied.schemas.flight import (
UserFixResponse, UserFixResponse,
Waypoint, Waypoint,
) )
from gps_denied.schemas.image import ImageBatch
class FlightProcessor: class FlightProcessor:
"""Orchestrates flight business logic.""" """Manages business logic and background processing for flights."""
def __init__(self, repo: FlightRepository, sse: SSEEventStreamer) -> None: def __init__(self, repository: FlightRepository, streamer: SSEEventStreamer) -> None:
self.repo = repo self.repository = repository
self.sse = sse self.streamer = streamer
self.result_manager = ResultManager(repository, streamer)
self.pipeline = ImageInputPipeline(storage_dir=".image_storage", max_queue_size=50)
async def create_flight(self, req: FlightCreateRequest) -> FlightResponse: async def create_flight(self, req: FlightCreateRequest) -> FlightResponse:
flight = await self.repo.insert_flight( flight = await self.repository.insert_flight(
name=req.name, name=req.name,
description=req.description, description=req.description,
start_lat=req.start_gps.lat, start_lat=req.start_gps.lat,
@@ -42,7 +47,7 @@ class FlightProcessor:
camera_params=req.camera_params.model_dump(), camera_params=req.camera_params.model_dump(),
) )
for poly in req.geofences.polygons: for poly in req.geofences.polygons:
await self.repo.insert_geofence( await self.repository.insert_geofence(
flight.id, flight.id,
nw_lat=poly.north_west.lat, nw_lat=poly.north_west.lat,
nw_lon=poly.north_west.lon, nw_lon=poly.north_west.lon,
@@ -50,7 +55,7 @@ class FlightProcessor:
se_lon=poly.south_east.lon, se_lon=poly.south_east.lon,
) )
for w in req.rough_waypoints: for w in req.rough_waypoints:
await self.repo.insert_waypoint(flight.id, lat=w.lat, lon=w.lon) await self.repository.insert_waypoint(flight.id, lat=w.lat, lon=w.lon)
return FlightResponse( return FlightResponse(
flight_id=flight.id, flight_id=flight.id,
@@ -60,11 +65,11 @@ class FlightProcessor:
) )
async def get_flight(self, flight_id: str) -> FlightDetailResponse | None: async def get_flight(self, flight_id: str) -> FlightDetailResponse | None:
flight = await self.repo.get_flight(flight_id) flight = await self.repository.get_flight(flight_id)
if not flight: if not flight:
return None return None
wps = await self.repo.get_waypoints(flight_id) wps = await self.repository.get_waypoints(flight_id)
state = await self.repo.load_flight_state(flight_id) state = await self.repository.load_flight_state(flight_id)
waypoints = [ waypoints = [
Waypoint( Waypoint(
@@ -103,13 +108,13 @@ class FlightProcessor:
) )
async def delete_flight(self, flight_id: str) -> DeleteResponse: async def delete_flight(self, flight_id: str) -> DeleteResponse:
deleted = await self.repo.delete_flight(flight_id) deleted = await self.repository.delete_flight(flight_id)
return DeleteResponse(deleted=deleted, flight_id=flight_id) return DeleteResponse(deleted=deleted, flight_id=flight_id)
async def update_waypoint( async def update_waypoint(
self, flight_id: str, waypoint_id: str, waypoint: Waypoint self, flight_id: str, waypoint_id: str, waypoint: Waypoint
) -> UpdateResponse: ) -> UpdateResponse:
ok = await self.repo.update_waypoint( ok = await self.repository.update_waypoint(
flight_id, flight_id,
waypoint_id, waypoint_id,
lat=waypoint.lat, lat=waypoint.lat,
@@ -126,7 +131,7 @@ class FlightProcessor:
failed = [] failed = []
updated = 0 updated = 0
for wp in waypoints: for wp in waypoints:
ok = await self.repo.update_waypoint( ok = await self.repository.update_waypoint(
flight_id, flight_id,
wp.id, wp.id,
lat=wp.lat, lat=wp.lat,
@@ -144,10 +149,10 @@ class FlightProcessor:
async def queue_images( async def queue_images(
self, flight_id: str, metadata: BatchMetadata, file_count: int self, flight_id: str, metadata: BatchMetadata, file_count: int
) -> BatchResponse: ) -> BatchResponse:
state = await self.repo.load_flight_state(flight_id) state = await self.repository.load_flight_state(flight_id)
if state: if state:
total = state.frames_total + file_count total = state.frames_total + file_count
await self.repo.save_flight_state(flight_id, frames_total=total, status="processing") await self.repository.save_flight_state(flight_id, frames_total=total, status="processing")
next_seq = metadata.end_sequence + 1 next_seq = metadata.end_sequence + 1
seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1)) seqs = list(range(metadata.start_sequence, metadata.end_sequence + 1))
@@ -159,13 +164,13 @@ class FlightProcessor:
) )
async def handle_user_fix(self, flight_id: str, req: UserFixRequest) -> UserFixResponse: async def handle_user_fix(self, flight_id: str, req: UserFixRequest) -> UserFixResponse:
await self.repo.save_flight_state(flight_id, blocked=False, status="processing") await self.repository.save_flight_state(flight_id, blocked=False, status="processing")
return UserFixResponse( return UserFixResponse(
accepted=True, processing_resumed=True, message="Fix applied." accepted=True, processing_resumed=True, message="Fix applied."
) )
async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None: async def get_flight_status(self, flight_id: str) -> FlightStatusResponse | None:
state = await self.repo.load_flight_state(flight_id) state = await self.repository.load_flight_state(flight_id)
if not state: if not state:
return None return None
return FlightStatusResponse( return FlightStatusResponse(
@@ -194,5 +199,5 @@ class FlightProcessor:
async def stream_events(self, flight_id: str, client_id: str): async def stream_events(self, flight_id: str, client_id: str):
"""Async generator for SSE stream.""" """Async generator for SSE stream."""
# Yield from the real SSE streamer generator # Yield from the real SSE streamer generator
async for event in self.sse.stream_generator(flight_id, client_id): async for event in self.streamer.stream_generator(flight_id, client_id):
yield event yield event
+139
View File
@@ -0,0 +1,139 @@
"""Image Rotation Manager (Component F06)."""
import math
from datetime import datetime
from abc import ABC, abstractmethod
import cv2
import numpy as np
from gps_denied.schemas.rotation import HeadingHistory, RotationResult
from gps_denied.schemas.satellite import TileBounds
class IImageMatcher(ABC):
"""Dependency injection interface for Metric Refinement."""
@abstractmethod
def align_to_satellite(self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds) -> RotationResult:
pass
class ImageRotationManager:
"""Handles 360-degree rotations, heading tracking, and sweeps."""
def __init__(self):
# flight_id -> HeadingHistory
self._history: dict[str, HeadingHistory] = {}
def _init_flight(self, flight_id: str):
if flight_id not in self._history:
self._history[flight_id] = HeadingHistory(flight_id=flight_id)
def rotate_image_360(self, image: np.ndarray, angle: float) -> np.ndarray:
"""Rotates an image by specified angle around center."""
if angle == 0.0 or angle == 360.0:
return image
h, w = image.shape[:2]
center = (w / 2, h / 2)
# Get rotation matrix. Negative angle for standard counter-clockwise interpretation in some math
# or positive for OpenCV's coordinate system.
matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(
image, matrix, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0)
)
return rotated
def rotate_chunk_360(self, chunk_images: list[np.ndarray], angle: float) -> list[np.ndarray]:
"""Rotates all images in a chunk by the same angle."""
if angle == 0.0 or angle == 360.0:
return chunk_images
return [self.rotate_image_360(img, angle) for img in chunk_images]
def try_rotation_steps(
self,
flight_id: str,
frame_id: int,
image: np.ndarray,
satellite_tile: np.ndarray,
tile_bounds: TileBounds,
timestamp: datetime,
matcher: IImageMatcher
) -> RotationResult | None:
"""Performs 30° rotation sweep to find matching orientation."""
# 12 steps: 0, 30, 60... 330
for angle in range(0, 360, 30):
rotated = self.rotate_image_360(image, float(angle))
result = matcher.align_to_satellite(rotated, satellite_tile, tile_bounds)
if result.matched:
precise_angle = self.calculate_precise_angle(result.homography, float(angle))
result.precise_angle = precise_angle
result.initial_angle = float(angle)
self.update_heading(flight_id, frame_id, precise_angle, timestamp)
return result
return None
def calculate_precise_angle(self, homography: np.ndarray | None, initial_angle: float) -> float:
"""Calculates precise rotation angle from homography matrix."""
if homography is None:
return initial_angle
# Extract rotation angle from 2D affine component of homography
# h00, h01 = homography[0, 0], homography[0, 1]
# angle_delta = math.degrees(math.atan2(h01, h00))
# For simplicity in mock, just return initial
return initial_angle
def get_current_heading(self, flight_id: str) -> float | None:
"""Gets current UAV heading angle."""
self._init_flight(flight_id)
return self._history[flight_id].current_heading
def update_heading(self, flight_id: str, frame_id: int, heading: float, timestamp: datetime) -> bool:
"""Updates UAV heading angle."""
self._init_flight(flight_id)
# Normalize to 0-360
normalized = heading % 360.0
hist = self._history[flight_id]
hist.current_heading = normalized
hist.last_update = timestamp
hist.heading_history.append(normalized)
if len(hist.heading_history) > 10:
hist.heading_history.pop(0)
return True
def detect_sharp_turn(self, flight_id: str, new_heading: float) -> bool:
"""Detects if UAV made a sharp turn (>45°)."""
current = self.get_current_heading(flight_id)
if current is None:
return False
delta = abs(new_heading - current)
if delta > 180:
delta = 360 - delta
return delta > 45.0
def requires_rotation_sweep(self, flight_id: str) -> bool:
"""Determines if rotation sweep is needed for current frame."""
self._init_flight(flight_id)
hist = self._history[flight_id]
# First frame scenario
if hist.current_heading is None:
return True
return False
+68
View File
@@ -0,0 +1,68 @@
"""Image Input Pipeline schemas (Component F05)."""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
import numpy as np
class ImageBatch(BaseModel):
"""Batch of raw images for processing."""
images: list[bytes]
filenames: list[str]
start_sequence: int
end_sequence: int
batch_number: int
class ImageMetadata(BaseModel):
"""Extracted metadata from an image."""
sequence: int
filename: str
dimensions: tuple[int, int]
file_size: int
timestamp: datetime
exif_data: Optional[dict] = None
class ValidationResult(BaseModel):
"""Result of batch validation."""
valid: bool
errors: list[str]
class ProcessingStatus(BaseModel):
"""Status of image pipeline processing."""
flight_id: str
total_images: int
processed_images: int
current_sequence: int
queued_batches: int
processing_rate: float
class ImageData:
"""Loaded image ready for processing."""
# Using normal class instead of BaseModel for np.ndarray support
def __init__(
self,
flight_id: str,
sequence: int,
filename: str,
image: np.ndarray,
metadata: ImageMetadata
):
self.flight_id = flight_id
self.sequence = sequence
self.filename = filename
self.image = image
self.metadata = metadata
class ProcessedBatch:
"""Batch of decoded images."""
def __init__(self, images: list[ImageData], batch_id: str, start_sequence: int, end_sequence: int):
self.images = images
self.batch_id = batch_id
self.start_sequence = start_sequence
self.end_sequence = end_sequence
+38
View File
@@ -0,0 +1,38 @@
"""Rotation schemas (Component F06)."""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
import numpy as np
class RotationResult(BaseModel):
"""Result of a rotation sweep alignment."""
matched: bool
initial_angle: float
precise_angle: float
confidence: float
# We will exclude np.ndarray from BaseModel to avoid validation issues,
# but store it as an attribute if needed or use arbitrary_types_allowed.
model_config = {"arbitrary_types_allowed": True}
homography: Optional[np.ndarray] = None
inlier_count: int = 0
class HeadingHistory(BaseModel):
"""Flight heading tracking history."""
flight_id: str
current_heading: Optional[float] = None
heading_history: list[float] = []
last_update: Optional[datetime] = None
sharp_turns: int = 0
class RotationConfig(BaseModel):
"""Configuration for rotation sweeps."""
step_angle: float = 30.0
sharp_turn_threshold: float = 45.0
confidence_threshold: float = 0.7
history_size: int = 10
+82
View File
@@ -0,0 +1,82 @@
"""Tests for Image Input Pipeline (F05)."""
import asyncio
import cv2
import numpy as np
import pytest
from gps_denied.core.pipeline import ImageInputPipeline, QueueFullError, ValidationError
from gps_denied.schemas.image import ImageBatch
@pytest.fixture
def pipeline(tmp_path):
storage = str(tmp_path / "images")
return ImageInputPipeline(storage_dir=storage, max_queue_size=2)
def test_batch_validation(pipeline):
# Too few images
b1 = ImageBatch(images=[b"1", b"2"], filenames=["1.jpg", "2.jpg"], start_sequence=1, end_sequence=2, batch_number=1)
val = pipeline.validate_batch(b1)
assert not val.valid
assert "Batch is empty" in val.errors
# Let's mock a valid batch of 10 images
fake_imgs = [b"fake"] * 10
fake_names = [f"AD{i:06d}.jpg" for i in range(1, 11)]
b2 = ImageBatch(images=fake_imgs, filenames=fake_names, start_sequence=1, end_sequence=10, batch_number=1)
val2 = pipeline.validate_batch(b2)
assert val2.valid
@pytest.mark.asyncio
async def test_queue_and_process(pipeline):
flight_id = "test_f1"
# Create valid fake images
fake_img_np = np.zeros((10, 10, 3), dtype=np.uint8)
_, encoded = cv2.imencode(".jpg", fake_img_np)
fake_bytes = encoded.tobytes()
fake_imgs = [fake_bytes] * 10
fake_names = [f"AD{i:06d}.jpg" for i in range(1, 11)]
b = ImageBatch(images=fake_imgs, filenames=fake_names, start_sequence=1, end_sequence=10, batch_number=1)
pipeline.queue_batch(flight_id, b)
# Process
processed = await pipeline.process_next_batch(flight_id)
assert processed is not None
assert len(processed.images) == 10
assert processed.images[0].sequence == 1
assert processed.images[-1].sequence == 10
# Status
st = pipeline.get_processing_status(flight_id)
assert st.total_images == 10
assert st.processed_images == 10
# Sequential get
next_img = pipeline.get_next_image(flight_id)
assert next_img is not None
assert next_img.sequence == 1
# Second get
next_img2 = pipeline.get_next_image(flight_id)
assert next_img2 is not None
assert next_img2.sequence == 2
def test_queue_full(pipeline):
flight_id = "test_full"
fake_imgs = [b"fake"] * 10
fake_names = [f"AD{i:06d}.jpg" for i in range(1, 11)]
b = ImageBatch(images=fake_imgs, filenames=fake_names, start_sequence=1, end_sequence=10, batch_number=1)
pipeline.queue_batch(flight_id, b)
pipeline.queue_batch(flight_id, b)
with pytest.raises(QueueFullError):
pipeline.queue_batch(flight_id, b)
+88
View File
@@ -0,0 +1,88 @@
"""Tests for Image Rotation Manager (F06)."""
from datetime import datetime, timezone
import numpy as np
import pytest
from gps_denied.core.rotation import IImageMatcher, ImageRotationManager
from gps_denied.schemas.rotation import RotationResult
from gps_denied.schemas.satellite import TileBounds
from gps_denied.schemas import GPSPoint
@pytest.fixture
def rotation_manager():
return ImageRotationManager()
def test_rotate_image_360(rotation_manager):
img = np.zeros((100, 100, 3), dtype=np.uint8)
# Just draw a white rectangle to test rotation
img[10:40, 10:40] = [255, 255, 255]
r90 = rotation_manager.rotate_image_360(img, 90.0)
assert r90.shape == (100, 100, 3)
# Top left corner should move
assert not np.array_equal(img, r90)
def test_heading_management(rotation_manager):
fid = "flight_1"
now = datetime.now(timezone.utc)
assert rotation_manager.get_current_heading(fid) is None
rotation_manager.update_heading(fid, 1, 370.0, now) # should modulo to 10
assert rotation_manager.get_current_heading(fid) == 10.0
rotation_manager.update_heading(fid, 2, 90.0, now)
assert rotation_manager.get_current_heading(fid) == 90.0
def test_detect_sharp_turn(rotation_manager):
fid = "flight_2"
now = datetime.now(timezone.utc)
assert rotation_manager.detect_sharp_turn(fid, 90.0) is False # no history
rotation_manager.update_heading(fid, 1, 90.0, now)
assert rotation_manager.detect_sharp_turn(fid, 100.0) is False # delta 10
assert rotation_manager.detect_sharp_turn(fid, 180.0) is True # delta 90
assert rotation_manager.detect_sharp_turn(fid, 350.0) is True # delta 100
assert rotation_manager.detect_sharp_turn(fid, 80.0) is False # delta 10 (wraparound)
# Wraparound test explicitly
rotation_manager.update_heading(fid, 2, 350.0, now)
assert rotation_manager.detect_sharp_turn(fid, 10.0) is False # delta 20
class MockMatcher(IImageMatcher):
def align_to_satellite(self, uav_image: np.ndarray, satellite_tile: np.ndarray, tile_bounds: TileBounds) -> RotationResult:
# Mock that only matches when angle was originally 90
# By checking internal state or just returning generic true for test
return RotationResult(matched=True, initial_angle=90.0, precise_angle=90.0, confidence=0.99)
def test_try_rotation_steps(rotation_manager):
fid = "flight_3"
img = np.zeros((10, 10, 3), dtype=np.uint8)
sat = np.zeros((10, 10, 3), dtype=np.uint8)
nw = GPSPoint(lat=10.0, lon=10.0)
se = GPSPoint(lat=9.0, lon=11.0)
tb = TileBounds(nw=nw, ne=nw, sw=se, se=se, center=nw, gsd=0.5)
matcher = MockMatcher()
# Should perform sweep and mock matcher says matched=True immediately in the loop
res = rotation_manager.try_rotation_steps(fid, 1, img, sat, tb, datetime.now(timezone.utc), matcher)
assert res is not None
assert res.matched is True
# The first step is 0 degrees, the mock matcher returns matched=True.
# Therefore the first matched angle is 0.
assert res.initial_angle == 0.0