Initial commit

This commit is contained in:
Denys Zaitsev
2026-04-03 23:25:54 +03:00
parent 531a1301d5
commit d7e1066c60
3843 changed files with 1554468 additions and 0 deletions
+193
View File
@@ -0,0 +1,193 @@
import asyncio
import json
import logging
import uuid
from datetime import datetime
from typing import Dict, List, Optional, Any, AsyncGenerator
from pydantic import BaseModel
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
# --- Data Models ---
class StreamConnection(BaseModel):
stream_id: str
flight_id: str
client_id: str
created_at: datetime
last_event_id: Optional[str] = None
class SSEEvent(BaseModel):
event: str
id: Optional[str]
data: str
# --- Interface ---
class ISSEEventStreamer(ABC):
@abstractmethod
def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]: pass
@abstractmethod
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool: pass
@abstractmethod
def send_search_progress(self, flight_id: str, search_status: Any) -> bool: pass
@abstractmethod
def send_user_input_request(self, flight_id: str, request: Any) -> bool: pass
@abstractmethod
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool: pass
@abstractmethod
def send_heartbeat(self, flight_id: str) -> bool: pass
@abstractmethod
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool: pass
@abstractmethod
def close_stream(self, flight_id: str, client_id: str) -> bool: pass
@abstractmethod
def get_active_connections(self, flight_id: str) -> int: pass
# --- Implementation ---
class SSEEventStreamer(ISSEEventStreamer):
"""
F15: SSE Event Streamer
Manages real-time asynchronous data broadcasting to connected clients.
Supports event buffering, replaying on reconnection, and filtering.
"""
def __init__(self, max_buffer_size: int = 1000, queue_maxsize: int = 100):
self.max_buffer_size = max_buffer_size
self.queue_maxsize = queue_maxsize
# flight_id -> client_id -> connection/queue
self._connections: Dict[str, Dict[str, StreamConnection]] = {}
self._client_queues: Dict[str, Dict[str, asyncio.Queue]] = {}
# flight_id -> historical events buffer
self._event_buffers: Dict[str, List[SSEEvent]] = {}
self._event_counters: Dict[str, int] = {}
def _extract_data(self, model: Any) -> dict:
"""Helper to serialize incoming Pydantic models or dicts to JSON-ready dicts."""
if hasattr(model, "model_dump"):
return model.model_dump(mode="json")
elif hasattr(model, "dict"):
return model.dict()
elif isinstance(model, dict):
return model
return {"data": str(model)}
def _broadcast(self, flight_id: str, event_type: str, data: dict) -> bool:
"""Core broadcasting logic: generates ID, buffers, and distributes to queues."""
if flight_id not in self._event_counters:
self._event_counters[flight_id] = 0
self._event_buffers[flight_id] = []
self._event_counters[flight_id] += 1
event_id = str(self._event_counters[flight_id])
# Heartbeats have special treatment (empty payload, SSE comment)
if event_type == "comment":
sse_event = SSEEvent(event="comment", id=None, data=json.dumps(data) if data else "")
else:
sse_event = SSEEvent(event=event_type, id=event_id, data=json.dumps(data))
# Buffer standard events
self._event_buffers[flight_id].append(sse_event)
if len(self._event_buffers[flight_id]) > self.max_buffer_size:
self._event_buffers[flight_id].pop(0)
# Distribute to active client queues
if flight_id in self._client_queues:
for client_id, q in list(self._client_queues[flight_id].items()):
try:
q.put_nowait(sse_event)
except asyncio.QueueFull:
logger.warning(f"Slow client {client_id} on flight {flight_id}. Closing connection.")
self.close_stream(flight_id, client_id)
return True
async def create_stream(self, flight_id: str, client_id: str, last_event_id: Optional[str] = None, event_types: Optional[List[str]] = None) -> AsyncGenerator[dict, None]:
"""Creates an async generator yielding SSE dictionaries formatted for sse_starlette."""
stream_id = str(uuid.uuid4())
conn = StreamConnection(stream_id=stream_id, flight_id=flight_id, client_id=client_id, created_at=datetime.utcnow(), last_event_id=last_event_id)
if flight_id not in self._connections:
self._connections[flight_id] = {}
self._client_queues[flight_id] = {}
self._connections[flight_id][client_id] = conn
q: asyncio.Queue = asyncio.Queue(maxsize=self.queue_maxsize)
self._client_queues[flight_id][client_id] = q
# Replay buffered events if the client is reconnecting
if last_event_id and flight_id in self._event_buffers:
try:
last_id_int = int(last_event_id)
for ev in self._event_buffers[flight_id]:
if ev.id and int(ev.id) > last_id_int:
if not event_types or ev.event in event_types:
q.put_nowait(ev)
except (ValueError, asyncio.QueueFull):
pass
try:
while True:
event = await q.get()
if event is None: # Sentinel value to cleanly close
break
if event_types and event.event not in event_types and event.event != "comment":
continue
if event.event == "comment":
yield {"comment": event.data}
else:
yield {
"event": event.event,
"id": event.id,
"data": event.data
}
finally:
self.close_stream(flight_id, client_id)
def send_frame_result(self, flight_id: str, frame_result: Any) -> bool:
data = self._extract_data(frame_result)
return self._broadcast(flight_id, "frame_processed", data)
def send_search_progress(self, flight_id: str, search_status: Any) -> bool:
data = self._extract_data(search_status)
return self._broadcast(flight_id, "search_expanded", data)
def send_user_input_request(self, flight_id: str, request: Any) -> bool:
data = self._extract_data(request)
return self._broadcast(flight_id, "user_input_needed", data)
def send_refinement(self, flight_id: str, frame_id: int, updated_result: Any) -> bool:
data = self._extract_data(updated_result)
# Match specific structure typically requested
data["refined"] = True
return self._broadcast(flight_id, "frame_refined", data)
def send_heartbeat(self, flight_id: str) -> bool:
return self._broadcast(flight_id, "comment", {"msg": "heartbeat"})
def send_generic_event(self, flight_id: str, event_type: str, data: Any) -> bool:
return self._broadcast(flight_id, event_type, self._extract_data(data))
def close_stream(self, flight_id: str, client_id: str) -> bool:
if flight_id in self._connections and client_id in self._connections[flight_id]:
del self._connections[flight_id][client_id]
del self._client_queues[flight_id][client_id]
return True
return False
def get_active_connections(self, flight_id: str) -> int:
return len(self._connections.get(flight_id, {}))