mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 22:06:37 +00:00
193 lines
7.7 KiB
Python
193 lines
7.7 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any, AsyncGenerator
|
|
from pydantic import BaseModel
|
|
from abc import ABC, abstractmethod
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Data Models ---
|
|
|
|
class StreamConnection(BaseModel):
|
|
stream_id: str
|
|
flight_id: str
|
|
client_id: str
|
|
created_at: datetime
|
|
last_event_id: Optional[str] = None
|
|
|
|
class SSEEvent(BaseModel):
|
|
event: str
|
|
id: Optional[str]
|
|
data: str
|
|
|
|
# --- Interface ---
|
|
|
|
class ISSEEventStreamer(ABC):
|
|
@abstractmethod
|
|
def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]: pass
|
|
|
|
@abstractmethod
|
|
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def send_search_progress(self, flight_id: str, search_status: Any) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def send_user_input_request(self, flight_id: str, request: Any) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def send_heartbeat(self, flight_id: str) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def close_stream(self, flight_id: str, client_id: str) -> bool: pass
|
|
|
|
@abstractmethod
|
|
def get_active_connections(self, flight_id: str) -> int: pass
|
|
|
|
|
|
# --- Implementation ---
|
|
|
|
class SSEEventStreamer(ISSEEventStreamer):
|
|
"""
|
|
F15: SSE Event Streamer
|
|
Manages real-time asynchronous data broadcasting to connected clients.
|
|
Supports event buffering, replaying on reconnection, and filtering.
|
|
"""
|
|
def __init__(self, max_buffer_size: int = 1000, queue_maxsize: int = 100):
|
|
self.max_buffer_size = max_buffer_size
|
|
self.queue_maxsize = queue_maxsize
|
|
|
|
# flight_id -> client_id -> connection/queue
|
|
self._connections: Dict[str, Dict[str, StreamConnection]] = {}
|
|
self._client_queues: Dict[str, Dict[str, asyncio.Queue]] = {}
|
|
|
|
# flight_id -> historical events buffer
|
|
self._event_buffers: Dict[str, List[SSEEvent]] = {}
|
|
self._event_counters: Dict[str, int] = {}
|
|
|
|
def _extract_data(self, model: Any) -> dict:
|
|
"""Helper to serialize incoming Pydantic models or dicts to JSON-ready dicts."""
|
|
if hasattr(model, "model_dump"):
|
|
return model.model_dump(mode="json")
|
|
elif hasattr(model, "dict"):
|
|
return model.dict()
|
|
elif isinstance(model, dict):
|
|
return model
|
|
return {"data": str(model)}
|
|
|
|
def _broadcast(self, flight_id: str, event_type: str, data: dict) -> bool:
|
|
"""Core broadcasting logic: generates ID, buffers, and distributes to queues."""
|
|
if flight_id not in self._event_counters:
|
|
self._event_counters[flight_id] = 0
|
|
self._event_buffers[flight_id] = []
|
|
|
|
self._event_counters[flight_id] += 1
|
|
event_id = str(self._event_counters[flight_id])
|
|
|
|
# Heartbeats have special treatment (empty payload, SSE comment)
|
|
if event_type == "comment":
|
|
sse_event = SSEEvent(event="comment", id=None, data=json.dumps(data) if data else "")
|
|
else:
|
|
sse_event = SSEEvent(event=event_type, id=event_id, data=json.dumps(data))
|
|
|
|
# Buffer standard events
|
|
self._event_buffers[flight_id].append(sse_event)
|
|
if len(self._event_buffers[flight_id]) > self.max_buffer_size:
|
|
self._event_buffers[flight_id].pop(0)
|
|
|
|
# Distribute to active client queues
|
|
if flight_id in self._client_queues:
|
|
for client_id, q in list(self._client_queues[flight_id].items()):
|
|
try:
|
|
q.put_nowait(sse_event)
|
|
except asyncio.QueueFull:
|
|
logger.warning(f"Slow client {client_id} on flight {flight_id}. Closing connection.")
|
|
self.close_stream(flight_id, client_id)
|
|
return True
|
|
|
|
async def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]:
|
|
"""Creates an async generator yielding SSE dictionaries formatted for sse_starlette."""
|
|
stream_id = str(uuid.uuid4())
|
|
conn = StreamConnection(stream_id=stream_id, flight_id=flight_id, client_id=client_id, created_at=datetime.utcnow(), last_event_id=last_event_id)
|
|
|
|
if flight_id not in self._connections:
|
|
self._connections[flight_id] = {}
|
|
self._client_queues[flight_id] = {}
|
|
|
|
self._connections[flight_id][client_id] = conn
|
|
q: asyncio.Queue = asyncio.Queue(maxsize=self.queue_maxsize)
|
|
self._client_queues[flight_id][client_id] = q
|
|
|
|
# Replay buffered events if the client is reconnecting
|
|
if last_event_id and flight_id in self._event_buffers:
|
|
try:
|
|
last_id_int = int(last_event_id)
|
|
for ev in self._event_buffers[flight_id]:
|
|
if ev.id and int(ev.id) > last_id_int:
|
|
if not event_types or ev.event in event_types:
|
|
q.put_nowait(ev)
|
|
except (ValueError, asyncio.QueueFull):
|
|
pass
|
|
|
|
try:
|
|
while True:
|
|
event = await q.get()
|
|
if event is None: # Sentinel value to cleanly close
|
|
break
|
|
|
|
if event_types and event.event not in event_types and event.event != "comment":
|
|
continue
|
|
|
|
if event.event == "comment":
|
|
yield {"comment": event.data}
|
|
else:
|
|
yield {
|
|
"event": event.event,
|
|
"id": event.id,
|
|
"data": event.data
|
|
}
|
|
finally:
|
|
self.close_stream(flight_id, client_id)
|
|
|
|
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool:
|
|
data = self._extract_data(frame_result)
|
|
return self._broadcast(flight_id, "frame_processed", data)
|
|
|
|
def send_search_progress(self, flight_id: str, search_status: Any) -> bool:
|
|
data = self._extract_data(search_status)
|
|
return self._broadcast(flight_id, "search_expanded", data)
|
|
|
|
def send_user_input_request(self, flight_id: str, request: Any) -> bool:
|
|
data = self._extract_data(request)
|
|
return self._broadcast(flight_id, "user_input_needed", data)
|
|
|
|
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool:
|
|
data = self._extract_data(updated_result)
|
|
# Match specific structure typically requested
|
|
data["refined"] = True
|
|
return self._broadcast(flight_id, "frame_refined", data)
|
|
|
|
def send_heartbeat(self, flight_id: str) -> bool:
|
|
return self._broadcast(flight_id, "comment", {"msg": "heartbeat"})
|
|
|
|
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool:
|
|
return self._broadcast(flight_id, event_type, self._extract_data(data))
|
|
|
|
def close_stream(self, flight_id: str, client_id: str) -> bool:
|
|
if flight_id in self._connections and client_id in self._connections[flight_id]:
|
|
del self._connections[flight_id][client_id]
|
|
del self._client_queues[flight_id][client_id]
|
|
return True
|
|
return False
|
|
|
|
def get_active_connections(self, flight_id: str) -> int:
|
|
return len(self._connections.get(flight_id, {})) |