Files
gps-denied-onboard/f15_sse_event_streamer.py
Denys Zaitsev d7e1066c60 Initial commit
2026-04-03 23:25:54 +03:00

193 lines
7.7 KiB
Python

import asyncio
import json
import logging
import uuid
from datetime import datetime
from typing import Dict, List, Optional, Any, AsyncGenerator
from pydantic import BaseModel
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
# --- Data Models ---
class StreamConnection(BaseModel):
stream_id: str
flight_id: str
client_id: str
created_at: datetime
last_event_id: Optional[str] = None
class SSEEvent(BaseModel):
event: str
id: Optional[str]
data: str
# --- Interface ---
class ISSEEventStreamer(ABC):
@abstractmethod
def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]: pass
@abstractmethod
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool: pass
@abstractmethod
def send_search_progress(self, flight_id: str, search_status: Any) -> bool: pass
@abstractmethod
def send_user_input_request(self, flight_id: str, request: Any) -> bool: pass
@abstractmethod
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool: pass
@abstractmethod
def send_heartbeat(self, flight_id: str) -> bool: pass
@abstractmethod
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool: pass
@abstractmethod
def close_stream(self, flight_id: str, client_id: str) -> bool: pass
@abstractmethod
def get_active_connections(self, flight_id: str) -> int: pass
# --- Implementation ---
class SSEEventStreamer(ISSEEventStreamer):
"""
F15: SSE Event Streamer
Manages real-time asynchronous data broadcasting to connected clients.
Supports event buffering, replaying on reconnection, and filtering.
"""
def __init__(self, max_buffer_size: int = 1000, queue_maxsize: int = 100):
self.max_buffer_size = max_buffer_size
self.queue_maxsize = queue_maxsize
# flight_id -> client_id -> connection/queue
self._connections: Dict[str, Dict[str, StreamConnection]] = {}
self._client_queues: Dict[str, Dict[str, asyncio.Queue]] = {}
# flight_id -> historical events buffer
self._event_buffers: Dict[str, List[SSEEvent]] = {}
self._event_counters: Dict[str, int] = {}
def _extract_data(self, model: Any) -> dict:
"""Helper to serialize incoming Pydantic models or dicts to JSON-ready dicts."""
if hasattr(model, "model_dump"):
return model.model_dump(mode="json")
elif hasattr(model, "dict"):
return model.dict()
elif isinstance(model, dict):
return model
return {"data": str(model)}
def _broadcast(self, flight_id: str, event_type: str, data: dict) -> bool:
"""Core broadcasting logic: generates ID, buffers, and distributes to queues."""
if flight_id not in self._event_counters:
self._event_counters[flight_id] = 0
self._event_buffers[flight_id] = []
self._event_counters[flight_id] += 1
event_id = str(self._event_counters[flight_id])
# Heartbeats have special treatment (empty payload, SSE comment)
if event_type == "comment":
sse_event = SSEEvent(event="comment", id=None, data=json.dumps(data) if data else "")
else:
sse_event = SSEEvent(event=event_type, id=event_id, data=json.dumps(data))
# Buffer standard events
self._event_buffers[flight_id].append(sse_event)
if len(self._event_buffers[flight_id]) > self.max_buffer_size:
self._event_buffers[flight_id].pop(0)
# Distribute to active client queues
if flight_id in self._client_queues:
for client_id, q in list(self._client_queues[flight_id].items()):
try:
q.put_nowait(sse_event)
except asyncio.QueueFull:
logger.warning(f"Slow client {client_id} on flight {flight_id}. Closing connection.")
self.close_stream(flight_id, client_id)
return True
async def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]:
"""Creates an async generator yielding SSE dictionaries formatted for sse_starlette."""
stream_id = str(uuid.uuid4())
conn = StreamConnection(stream_id=stream_id, flight_id=flight_id, client_id=client_id, created_at=datetime.utcnow(), last_event_id=last_event_id)
if flight_id not in self._connections:
self._connections[flight_id] = {}
self._client_queues[flight_id] = {}
self._connections[flight_id][client_id] = conn
q: asyncio.Queue = asyncio.Queue(maxsize=self.queue_maxsize)
self._client_queues[flight_id][client_id] = q
# Replay buffered events if the client is reconnecting
if last_event_id and flight_id in self._event_buffers:
try:
last_id_int = int(last_event_id)
for ev in self._event_buffers[flight_id]:
if ev.id and int(ev.id) > last_id_int:
if not event_types or ev.event in event_types:
q.put_nowait(ev)
except (ValueError, asyncio.QueueFull):
pass
try:
while True:
event = await q.get()
if event is None: # Sentinel value to cleanly close
break
if event_types and event.event not in event_types and event.event != "comment":
continue
if event.event == "comment":
yield {"comment": event.data}
else:
yield {
"event": event.event,
"id": event.id,
"data": event.data
}
finally:
self.close_stream(flight_id, client_id)
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool:
data = self._extract_data(frame_result)
return self._broadcast(flight_id, "frame_processed", data)
def send_search_progress(self, flight_id: str, search_status: Any) -> bool:
data = self._extract_data(search_status)
return self._broadcast(flight_id, "search_expanded", data)
def send_user_input_request(self, flight_id: str, request: Any) -> bool:
data = self._extract_data(request)
return self._broadcast(flight_id, "user_input_needed", data)
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool:
data = self._extract_data(updated_result)
# Match specific structure typically requested
data["refined"] = True
return self._broadcast(flight_id, "frame_refined", data)
def send_heartbeat(self, flight_id: str) -> bool:
return self._broadcast(flight_id, "comment", {"msg": "heartbeat"})
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool:
return self._broadcast(flight_id, event_type, self._extract_data(data))
def close_stream(self, flight_id: str, client_id: str) -> bool:
if flight_id in self._connections and client_id in self._connections[flight_id]:
del self._connections[flight_id][client_id]
del self._client_queues[flight_id][client_id]
return True
return False
def get_active_connections(self, flight_id: str) -> int:
return len(self._connections.get(flight_id, {}))