[AZ-185][AZ-186] Batch 2

Made-with: Cursor
This commit is contained in:
Oleksandr Bezdieniezhnykh
2026-04-15 07:32:37 +03:00
parent d244799f02
commit 9a0248af72
18 changed files with 1857 additions and 26 deletions
+9
View File
@@ -11,6 +11,15 @@ from security_provider import create_security_provider
app = FastAPI(title="Azaion.Loader")
@app.on_event("startup")
def _startup_update_manager():
try:
from update_manager import maybe_start_update_background
except Exception:
return
maybe_start_update_background(get_api_client, RESOURCE_API_URL)
security_provider = create_security_provider()
RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com")
+266
View File
@@ -0,0 +1,266 @@
import hashlib
import json
import os
import subprocess
import threading
from typing import Any, Callable, Dict, List, Optional
import requests
from loguru import logger
from download_manager import ResumableDownloadManager
from version_collector import VersionCollector
def _aes_key_from_encryption_field(encryption_key: Any) -> bytes:
if isinstance(encryption_key, bytes):
if len(encryption_key) == 32:
return encryption_key
raise ValueError("invalid encryption key")
s = str(encryption_key).strip()
if len(s) == 64 and all(c in "0123456789abcdefABCDEF" for c in s):
return bytes.fromhex(s)
return hashlib.sha256(s.encode("utf-8")).digest()
def _sort_services_loader_last(services: List[str]) -> List[str]:
head = sorted(s for s in services if s != "loader")
tail = [s for s in services if s == "loader"]
return head + tail
def _sort_updates_loader_last(updates: List[dict]) -> List[dict]:
rest = [u for u in updates if u.get("resourceName") != "loader"]
rest.sort(key=lambda u: str(u.get("resourceName", "")))
loader = [u for u in updates if u.get("resourceName") == "loader"]
return rest + loader
class UpdateManager:
def __init__(
self,
api_url: str,
get_token: Callable[[], Optional[str]],
download_manager: ResumableDownloadManager,
version_collector: VersionCollector,
compose_file: str,
model_dir: str,
state_path: str,
interval_seconds: float = 300.0,
*,
subprocess_run: Optional[Callable] = None,
post_get_update: Optional[Callable[..., Any]] = None,
head_content_length: Optional[Callable[..., int]] = None,
stop_event: Optional[threading.Event] = None,
wait_fn: Optional[Callable[[float], bool]] = None,
) -> None:
self._api_url = api_url.rstrip("/")
self._get_token = get_token
self._download_manager = download_manager
self._version_collector = version_collector
self._compose_file = compose_file
self._model_dir = model_dir
self._state_path = state_path
self._interval = interval_seconds
self._subprocess_run = subprocess_run or subprocess.run
self._post_get_update = post_get_update or self._default_post_get_update
self._head_content_length = head_content_length or self._default_head_content_length
self._stop_event = stop_event or threading.Event()
self._wait_fn = wait_fn
def _default_post_get_update(self, token: str, body: dict) -> Any:
url = f"{self._api_url}/get-update"
resp = requests.post(
url,
json=body,
headers={"Authorization": f"Bearer {token}"},
timeout=120,
)
resp.raise_for_status()
return resp.json()
def _default_head_content_length(self, url: str, token: str) -> int:
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
resp = requests.head(url, headers=headers, allow_redirects=True, timeout=120)
resp.raise_for_status()
cl = resp.headers.get("Content-Length")
if not cl:
raise ValueError("missing Content-Length")
return int(cl)
def _load_state(self) -> dict:
if not os.path.isfile(self._state_path):
return {"pending_compose": []}
with open(self._state_path, encoding="utf-8") as f:
data = json.load(f)
if "pending_compose" not in data:
data["pending_compose"] = []
return data
def _save_state(self, data: dict) -> None:
directory = os.path.dirname(self._state_path)
if directory:
os.makedirs(directory, exist_ok=True)
tmp = self._state_path + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, sort_keys=True)
os.replace(tmp, self._state_path)
def _drain_pending_compose(self) -> None:
state = self._load_state()
pending = list(dict.fromkeys(state.get("pending_compose") or []))
if not pending:
return
for svc in _sort_services_loader_last(pending):
self._subprocess_run(
["docker", "compose", "-f", self._compose_file, "up", "-d", svc],
check=True,
)
state["pending_compose"] = []
self._save_state(state)
def _current_versions_payload(self) -> Dict[str, str]:
rows = self._version_collector.collect()
return {r.resource_name: r.version for r in rows}
def _build_get_update_body(self) -> dict:
return {
"dev_stage": os.environ.get("LOADER_DEV_STAGE", ""),
"architecture": os.environ.get("LOADER_ARCH", ""),
"current_versions": self._current_versions_payload(),
}
def _artifact_size(self, url: str, token: str) -> int:
return self._head_content_length(url, token)
def _apply_model(self, item: dict, token: str) -> None:
name = str(item["resourceName"])
version = str(item["version"])
url = str(item["cdnUrl"])
sha256 = str(item["sha256"])
key = _aes_key_from_encryption_field(item["encryptionKey"])
size = self._artifact_size(url, token)
job_id = f"update-{name}-{version}"
os.makedirs(self._model_dir, exist_ok=True)
out_path = os.path.join(self._model_dir, f"azaion-{version}.trt")
self._download_manager.fetch_decrypt_verify(
job_id,
url,
sha256,
size,
key,
out_path,
)
self._version_collector.invalidate()
def _mark_pending_compose(self, service: str) -> None:
state = self._load_state()
pending = list(state.get("pending_compose") or [])
if service not in pending:
pending.append(service)
state["pending_compose"] = pending
self._save_state(state)
def _clear_pending_compose(self, service: str) -> None:
state = self._load_state()
pending = [s for s in (state.get("pending_compose") or []) if s != service]
state["pending_compose"] = pending
self._save_state(state)
def _apply_docker_image(self, item: dict, token: str) -> None:
name = str(item["resourceName"])
version = str(item["version"])
url = str(item["cdnUrl"])
sha256 = str(item["sha256"])
key = _aes_key_from_encryption_field(item["encryptionKey"])
size = self._artifact_size(url, token)
job_id = f"update-{name}-{version}"
artifact_dir = os.path.dirname(self._state_path)
os.makedirs(artifact_dir, exist_ok=True)
out_tar = os.path.join(artifact_dir, f"{job_id}.plaintext.tar")
self._download_manager.fetch_decrypt_verify(
job_id,
url,
sha256,
size,
key,
out_tar,
)
self._subprocess_run(["docker", "load", "-i", out_tar], check=True)
self._version_collector.invalidate()
self._mark_pending_compose(name)
self._subprocess_run(
["docker", "compose", "-f", self._compose_file, "up", "-d", name],
check=True,
)
self._clear_pending_compose(name)
def _tick_once(self) -> None:
token = self._get_token()
if not token:
return
self._drain_pending_compose()
body = self._build_get_update_body()
updates = self._post_get_update(token, body)
if not isinstance(updates, list):
return
for item in _sort_updates_loader_last(updates):
rname = str(item.get("resourceName", ""))
if rname == "detection_model":
self._apply_model(item, token)
else:
self._apply_docker_image(item, token)
def run_forever(self) -> None:
while not self._stop_event.is_set():
try:
self._drain_pending_compose()
self._tick_once()
except Exception as exc:
logger.exception("update manager tick failed: {}", exc)
if self._wait_fn is not None:
if self._wait_fn(self._interval):
break
elif self._stop_event.wait(self._interval):
break
def maybe_start_update_background(
get_api_client: Callable[[], Any],
api_url: str,
) -> None:
state_dir = os.environ.get("LOADER_DOWNLOAD_STATE_DIR")
if not state_dir:
return
model_dir = os.environ.get("LOADER_MODEL_DIR", "models")
compose_file = os.environ.get("LOADER_COMPOSE_FILE", "docker-compose.yml")
interval = float(os.environ.get("LOADER_UPDATE_INTERVAL_SEC", "300"))
orchestrator_path = os.environ.get(
"LOADER_UPDATE_STATE_PATH",
os.path.join(state_dir, "update_orchestrator.json"),
)
def token_getter() -> Optional[str]:
client = get_api_client()
return getattr(client, "token", None)
try:
dm = ResumableDownloadManager(state_dir)
vc = VersionCollector(model_dir)
um = UpdateManager(
api_url,
token_getter,
dm,
vc,
compose_file,
model_dir,
orchestrator_path,
interval_seconds=interval,
)
except Exception as exc:
logger.exception("update manager failed to start: {}", exc)
return
threading.Thread(target=um.run_forever, name="loader-updates", daemon=True).start()
+91
View File
@@ -0,0 +1,91 @@
import os
import re
import subprocess
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
TRT_DATE_PATTERN = re.compile(r"^azaion-(\d{4}-\d{2}-\d{2})\.trt$", re.IGNORECASE)
@dataclass(frozen=True)
class ResourceVersion:
resource_name: str
version: str
class VersionCollector:
def __init__(
self,
model_dir: str,
*,
subprocess_run: Optional[Callable] = None,
) -> None:
self._model_dir = model_dir
self._subprocess_run = subprocess_run or subprocess.run
self._cache: Optional[List[ResourceVersion]] = None
def invalidate(self) -> None:
self._cache = None
def collect(self) -> List[ResourceVersion]:
if self._cache is not None:
return list(self._cache)
rows = self._collect_uncached()
self._cache = rows
return list(rows)
def collect_as_dicts(self) -> List[dict]:
return [asdict(r) for r in self.collect()]
def _collect_uncached(self) -> List[ResourceVersion]:
out: List[ResourceVersion] = []
mv = self._best_trt_version()
if mv is not None:
out.append(ResourceVersion("detection_model", mv))
out.extend(self._docker_versions())
rest = [r for r in out if r.resource_name != "detection_model"]
rest.sort(key=lambda r: r.resource_name)
if mv is not None:
return [ResourceVersion("detection_model", mv)] + rest
return rest
def _best_trt_version(self) -> Optional[str]:
if not os.path.isdir(self._model_dir):
return None
best: Optional[str] = None
for name in os.listdir(self._model_dir):
m = TRT_DATE_PATTERN.match(name)
if not m:
continue
v = m.group(1)
if best is None or v > best:
best = v
return best
def _docker_versions(self) -> List[ResourceVersion]:
try:
result = self._subprocess_run(
["docker", "images", "--format", "{{.Repository}}:{{.Tag}}"],
capture_output=True,
text=True,
check=True,
)
except (OSError, subprocess.CalledProcessError):
return []
found: List[ResourceVersion] = []
for line in result.stdout.splitlines():
line = line.strip()
if not line or ":<none>" in line:
continue
if not line.startswith("azaion/"):
continue
if ":" not in line:
continue
repo, tag = line.rsplit(":", 1)
if tag in ("<none>", ""):
continue
parts = repo.split("/", 1)
if len(parts) < 2:
continue
found.append(ResourceVersion(parts[1], tag))
return found