import asyncio import json import logging import uuid from datetime import datetime from typing import Dict, List, Optional, Any, AsyncGenerator from pydantic import BaseModel from abc import ABC, abstractmethod logger = logging.getLogger(__name__) # --- Data Models --- class StreamConnection(BaseModel): stream_id: str flight_id: str client_id: str created_at: datetime last_event_id: Optional[str] = None class SSEEvent(BaseModel): event: str id: Optional[str] data: str # --- Interface --- class ISSEEventStreamer(ABC): @abstractmethod def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]: pass @abstractmethod def send_frame_result(self, flight_id: str, frame_result: Any) -> bool: pass @abstractmethod def send_search_progress(self, flight_id: str, search_status: Any) -> bool: pass @abstractmethod def send_user_input_request(self, flight_id: str, request: Any) -> bool: pass @abstractmethod def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool: pass @abstractmethod def send_heartbeat(self, flight_id: str) -> bool: pass @abstractmethod def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool: pass @abstractmethod def close_stream(self, flight_id: str, client_id: str) -> bool: pass @abstractmethod def get_active_connections(self, flight_id: str) -> int: pass # --- Implementation --- class SSEEventStreamer(ISSEEventStreamer): """ F15: SSE Event Streamer Manages real-time asynchronous data broadcasting to connected clients. Supports event buffering, replaying on reconnection, and filtering. """ def __init__(self, max_buffer_size: int = 1000, queue_maxsize: int = 100): self.max_buffer_size = max_buffer_size self.queue_maxsize = queue_maxsize # flight_id -> client_id -> connection/queue self._connections: Dict[str, Dict[str, StreamConnection]] = {} self._client_queues: Dict[str, Dict[str, asyncio.Queue]] = {} # flight_id -> historical events buffer self._event_buffers: Dict[str, List[SSEEvent]] = {} self._event_counters: Dict[str, int] = {} def _extract_data(self, model: Any) -> dict: """Helper to serialize incoming Pydantic models or dicts to JSON-ready dicts.""" if hasattr(model, "model_dump"): return model.model_dump(mode="json") elif hasattr(model, "dict"): return model.dict() elif isinstance(model, dict): return model return {"data": str(model)} def _broadcast(self, flight_id: str, event_type: str, data: dict) -> bool: """Core broadcasting logic: generates ID, buffers, and distributes to queues.""" if flight_id not in self._event_counters: self._event_counters[flight_id] = 0 self._event_buffers[flight_id] = [] self._event_counters[flight_id] += 1 event_id = str(self._event_counters[flight_id]) # Heartbeats have special treatment (empty payload, SSE comment) if event_type == "comment": sse_event = SSEEvent(event="comment", id=None, data=json.dumps(data) if data else "") else: sse_event = SSEEvent(event=event_type, id=event_id, data=json.dumps(data)) # Buffer standard events self._event_buffers[flight_id].append(sse_event) if len(self._event_buffers[flight_id]) > self.max_buffer_size: self._event_buffers[flight_id].pop(0) # Distribute to active client queues if flight_id in self._client_queues: for client_id, q in list(self._client_queues[flight_id].items()): try: q.put_nowait(sse_event) except asyncio.QueueFull: logger.warning(f"Slow client {client_id} on flight {flight_id}. Closing connection.") self.close_stream(flight_id, client_id) return True async def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]: """Creates an async generator yielding SSE dictionaries formatted for sse_starlette.""" stream_id = str(uuid.uuid4()) conn = StreamConnection(stream_id=stream_id, flight_id=flight_id, client_id=client_id, created_at=datetime.utcnow(), last_event_id=last_event_id) if flight_id not in self._connections: self._connections[flight_id] = {} self._client_queues[flight_id] = {} self._connections[flight_id][client_id] = conn q: asyncio.Queue = asyncio.Queue(maxsize=self.queue_maxsize) self._client_queues[flight_id][client_id] = q # Replay buffered events if the client is reconnecting if last_event_id and flight_id in self._event_buffers: try: last_id_int = int(last_event_id) for ev in self._event_buffers[flight_id]: if ev.id and int(ev.id) > last_id_int: if not event_types or ev.event in event_types: q.put_nowait(ev) except (ValueError, asyncio.QueueFull): pass try: while True: event = await q.get() if event is None: # Sentinel value to cleanly close break if event_types and event.event not in event_types and event.event != "comment": continue if event.event == "comment": yield {"comment": event.data} else: yield { "event": event.event, "id": event.id, "data": event.data } finally: self.close_stream(flight_id, client_id) def send_frame_result(self, flight_id: str, frame_result: Any) -> bool: data = self._extract_data(frame_result) return self._broadcast(flight_id, "frame_processed", data) def send_search_progress(self, flight_id: str, search_status: Any) -> bool: data = self._extract_data(search_status) return self._broadcast(flight_id, "search_expanded", data) def send_user_input_request(self, flight_id: str, request: Any) -> bool: data = self._extract_data(request) return self._broadcast(flight_id, "user_input_needed", data) def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool: data = self._extract_data(updated_result) # Match specific structure typically requested data["refined"] = True return self._broadcast(flight_id, "frame_refined", data) def send_heartbeat(self, flight_id: str) -> bool: return self._broadcast(flight_id, "comment", {"msg": "heartbeat"}) def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool: return self._broadcast(flight_id, event_type, self._extract_data(data)) def close_stream(self, flight_id: str, client_id: str) -> bool: if flight_id in self._connections and client_id in self._connections[flight_id]: del self._connections[flight_id][client_id] del self._client_queues[flight_id][client_id] return True return False def get_active_connections(self, flight_id: str) -> int: return len(self._connections.get(flight_id, {}))