mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-23 06:26:37 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,193 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
from pydantic import BaseModel
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Data Models ---
|
||||
|
||||
class StreamConnection(BaseModel):
|
||||
stream_id: str
|
||||
flight_id: str
|
||||
client_id: str
|
||||
created_at: datetime
|
||||
last_event_id: Optional[str] = None
|
||||
|
||||
class SSEEvent(BaseModel):
|
||||
event: str
|
||||
id: Optional[str]
|
||||
data: str
|
||||
|
||||
# --- Interface ---
|
||||
|
||||
class ISSEEventStreamer(ABC):
|
||||
@abstractmethod
|
||||
def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]: pass
|
||||
|
||||
@abstractmethod
|
||||
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def send_search_progress(self, flight_id: str, search_status: Any) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def send_user_input_request(self, flight_id: str, request: Any) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def send_heartbeat(self, flight_id: str) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def close_stream(self, flight_id: str, client_id: str) -> bool: pass
|
||||
|
||||
@abstractmethod
|
||||
def get_active_connections(self, flight_id: str) -> int: pass
|
||||
|
||||
|
||||
# --- Implementation ---
|
||||
|
||||
class SSEEventStreamer(ISSEEventStreamer):
|
||||
"""
|
||||
F15: SSE Event Streamer
|
||||
Manages real-time asynchronous data broadcasting to connected clients.
|
||||
Supports event buffering, replaying on reconnection, and filtering.
|
||||
"""
|
||||
def __init__(self, max_buffer_size: int = 1000, queue_maxsize: int = 100):
|
||||
self.max_buffer_size = max_buffer_size
|
||||
self.queue_maxsize = queue_maxsize
|
||||
|
||||
# flight_id -> client_id -> connection/queue
|
||||
self._connections: Dict[str, Dict[str, StreamConnection]] = {}
|
||||
self._client_queues: Dict[str, Dict[str, asyncio.Queue]] = {}
|
||||
|
||||
# flight_id -> historical events buffer
|
||||
self._event_buffers: Dict[str, List[SSEEvent]] = {}
|
||||
self._event_counters: Dict[str, int] = {}
|
||||
|
||||
def _extract_data(self, model: Any) -> dict:
|
||||
"""Helper to serialize incoming Pydantic models or dicts to JSON-ready dicts."""
|
||||
if hasattr(model, "model_dump"):
|
||||
return model.model_dump(mode="json")
|
||||
elif hasattr(model, "dict"):
|
||||
return model.dict()
|
||||
elif isinstance(model, dict):
|
||||
return model
|
||||
return {"data": str(model)}
|
||||
|
||||
def _broadcast(self, flight_id: str, event_type: str, data: dict) -> bool:
|
||||
"""Core broadcasting logic: generates ID, buffers, and distributes to queues."""
|
||||
if flight_id not in self._event_counters:
|
||||
self._event_counters[flight_id] = 0
|
||||
self._event_buffers[flight_id] = []
|
||||
|
||||
self._event_counters[flight_id] += 1
|
||||
event_id = str(self._event_counters[flight_id])
|
||||
|
||||
# Heartbeats have special treatment (empty payload, SSE comment)
|
||||
if event_type == "comment":
|
||||
sse_event = SSEEvent(event="comment", id=None, data=json.dumps(data) if data else "")
|
||||
else:
|
||||
sse_event = SSEEvent(event=event_type, id=event_id, data=json.dumps(data))
|
||||
|
||||
# Buffer standard events
|
||||
self._event_buffers[flight_id].append(sse_event)
|
||||
if len(self._event_buffers[flight_id]) > self.max_buffer_size:
|
||||
self._event_buffers[flight_id].pop(0)
|
||||
|
||||
# Distribute to active client queues
|
||||
if flight_id in self._client_queues:
|
||||
for client_id, q in list(self._client_queues[flight_id].items()):
|
||||
try:
|
||||
q.put_nowait(sse_event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Slow client {client_id} on flight {flight_id}. Closing connection.")
|
||||
self.close_stream(flight_id, client_id)
|
||||
return True
|
||||
|
||||
async def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]:
|
||||
"""Creates an async generator yielding SSE dictionaries formatted for sse_starlette."""
|
||||
stream_id = str(uuid.uuid4())
|
||||
conn = StreamConnection(stream_id=stream_id, flight_id=flight_id, client_id=client_id, created_at=datetime.utcnow(), last_event_id=last_event_id)
|
||||
|
||||
if flight_id not in self._connections:
|
||||
self._connections[flight_id] = {}
|
||||
self._client_queues[flight_id] = {}
|
||||
|
||||
self._connections[flight_id][client_id] = conn
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=self.queue_maxsize)
|
||||
self._client_queues[flight_id][client_id] = q
|
||||
|
||||
# Replay buffered events if the client is reconnecting
|
||||
if last_event_id and flight_id in self._event_buffers:
|
||||
try:
|
||||
last_id_int = int(last_event_id)
|
||||
for ev in self._event_buffers[flight_id]:
|
||||
if ev.id and int(ev.id) > last_id_int:
|
||||
if not event_types or ev.event in event_types:
|
||||
q.put_nowait(ev)
|
||||
except (ValueError, asyncio.QueueFull):
|
||||
pass
|
||||
|
||||
try:
|
||||
while True:
|
||||
event = await q.get()
|
||||
if event is None: # Sentinel value to cleanly close
|
||||
break
|
||||
|
||||
if event_types and event.event not in event_types and event.event != "comment":
|
||||
continue
|
||||
|
||||
if event.event == "comment":
|
||||
yield {"comment": event.data}
|
||||
else:
|
||||
yield {
|
||||
"event": event.event,
|
||||
"id": event.id,
|
||||
"data": event.data
|
||||
}
|
||||
finally:
|
||||
self.close_stream(flight_id, client_id)
|
||||
|
||||
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool:
|
||||
data = self._extract_data(frame_result)
|
||||
return self._broadcast(flight_id, "frame_processed", data)
|
||||
|
||||
def send_search_progress(self, flight_id: str, search_status: Any) -> bool:
|
||||
data = self._extract_data(search_status)
|
||||
return self._broadcast(flight_id, "search_expanded", data)
|
||||
|
||||
def send_user_input_request(self, flight_id: str, request: Any) -> bool:
|
||||
data = self._extract_data(request)
|
||||
return self._broadcast(flight_id, "user_input_needed", data)
|
||||
|
||||
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool:
|
||||
data = self._extract_data(updated_result)
|
||||
# Match specific structure typically requested
|
||||
data["refined"] = True
|
||||
return self._broadcast(flight_id, "frame_refined", data)
|
||||
|
||||
def send_heartbeat(self, flight_id: str) -> bool:
|
||||
return self._broadcast(flight_id, "comment", {"msg": "heartbeat"})
|
||||
|
||||
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool:
|
||||
return self._broadcast(flight_id, event_type, self._extract_data(data))
|
||||
|
||||
def close_stream(self, flight_id: str, client_id: str) -> bool:
|
||||
if flight_id in self._connections and client_id in self._connections[flight_id]:
|
||||
del self._connections[flight_id][client_id]
|
||||
del self._client_queues[flight_id][client_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_active_connections(self, flight_id: str) -> int:
|
||||
return len(self._connections.get(flight_id, {}))
|
||||
Reference in New Issue
Block a user