mirror of
https://github.com/azaion/loader.git
synced 2026-04-22 12:26:32 +00:00
[AZ-185][AZ-186] Batch 2
Made-with: Cursor
This commit is contained in:
@@ -11,6 +11,15 @@ from security_provider import create_security_provider
|
||||
|
||||
app = FastAPI(title="Azaion.Loader")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def _startup_update_manager():
|
||||
try:
|
||||
from update_manager import maybe_start_update_background
|
||||
except Exception:
|
||||
return
|
||||
maybe_start_update_background(get_api_client, RESOURCE_API_URL)
|
||||
|
||||
security_provider = create_security_provider()
|
||||
|
||||
RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com")
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from download_manager import ResumableDownloadManager
|
||||
from version_collector import VersionCollector
|
||||
|
||||
|
||||
def _aes_key_from_encryption_field(encryption_key: Any) -> bytes:
|
||||
if isinstance(encryption_key, bytes):
|
||||
if len(encryption_key) == 32:
|
||||
return encryption_key
|
||||
raise ValueError("invalid encryption key")
|
||||
s = str(encryption_key).strip()
|
||||
if len(s) == 64 and all(c in "0123456789abcdefABCDEF" for c in s):
|
||||
return bytes.fromhex(s)
|
||||
return hashlib.sha256(s.encode("utf-8")).digest()
|
||||
|
||||
|
||||
def _sort_services_loader_last(services: List[str]) -> List[str]:
|
||||
head = sorted(s for s in services if s != "loader")
|
||||
tail = [s for s in services if s == "loader"]
|
||||
return head + tail
|
||||
|
||||
|
||||
def _sort_updates_loader_last(updates: List[dict]) -> List[dict]:
|
||||
rest = [u for u in updates if u.get("resourceName") != "loader"]
|
||||
rest.sort(key=lambda u: str(u.get("resourceName", "")))
|
||||
loader = [u for u in updates if u.get("resourceName") == "loader"]
|
||||
return rest + loader
|
||||
|
||||
|
||||
class UpdateManager:
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str,
|
||||
get_token: Callable[[], Optional[str]],
|
||||
download_manager: ResumableDownloadManager,
|
||||
version_collector: VersionCollector,
|
||||
compose_file: str,
|
||||
model_dir: str,
|
||||
state_path: str,
|
||||
interval_seconds: float = 300.0,
|
||||
*,
|
||||
subprocess_run: Optional[Callable] = None,
|
||||
post_get_update: Optional[Callable[..., Any]] = None,
|
||||
head_content_length: Optional[Callable[..., int]] = None,
|
||||
stop_event: Optional[threading.Event] = None,
|
||||
wait_fn: Optional[Callable[[float], bool]] = None,
|
||||
) -> None:
|
||||
self._api_url = api_url.rstrip("/")
|
||||
self._get_token = get_token
|
||||
self._download_manager = download_manager
|
||||
self._version_collector = version_collector
|
||||
self._compose_file = compose_file
|
||||
self._model_dir = model_dir
|
||||
self._state_path = state_path
|
||||
self._interval = interval_seconds
|
||||
self._subprocess_run = subprocess_run or subprocess.run
|
||||
self._post_get_update = post_get_update or self._default_post_get_update
|
||||
self._head_content_length = head_content_length or self._default_head_content_length
|
||||
self._stop_event = stop_event or threading.Event()
|
||||
self._wait_fn = wait_fn
|
||||
|
||||
def _default_post_get_update(self, token: str, body: dict) -> Any:
|
||||
url = f"{self._api_url}/get-update"
|
||||
resp = requests.post(
|
||||
url,
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=120,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def _default_head_content_length(self, url: str, token: str) -> int:
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
resp = requests.head(url, headers=headers, allow_redirects=True, timeout=120)
|
||||
resp.raise_for_status()
|
||||
cl = resp.headers.get("Content-Length")
|
||||
if not cl:
|
||||
raise ValueError("missing Content-Length")
|
||||
return int(cl)
|
||||
|
||||
def _load_state(self) -> dict:
|
||||
if not os.path.isfile(self._state_path):
|
||||
return {"pending_compose": []}
|
||||
with open(self._state_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if "pending_compose" not in data:
|
||||
data["pending_compose"] = []
|
||||
return data
|
||||
|
||||
def _save_state(self, data: dict) -> None:
|
||||
directory = os.path.dirname(self._state_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
tmp = self._state_path + ".tmp"
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, sort_keys=True)
|
||||
os.replace(tmp, self._state_path)
|
||||
|
||||
def _drain_pending_compose(self) -> None:
|
||||
state = self._load_state()
|
||||
pending = list(dict.fromkeys(state.get("pending_compose") or []))
|
||||
if not pending:
|
||||
return
|
||||
for svc in _sort_services_loader_last(pending):
|
||||
self._subprocess_run(
|
||||
["docker", "compose", "-f", self._compose_file, "up", "-d", svc],
|
||||
check=True,
|
||||
)
|
||||
state["pending_compose"] = []
|
||||
self._save_state(state)
|
||||
|
||||
def _current_versions_payload(self) -> Dict[str, str]:
|
||||
rows = self._version_collector.collect()
|
||||
return {r.resource_name: r.version for r in rows}
|
||||
|
||||
def _build_get_update_body(self) -> dict:
|
||||
return {
|
||||
"dev_stage": os.environ.get("LOADER_DEV_STAGE", ""),
|
||||
"architecture": os.environ.get("LOADER_ARCH", ""),
|
||||
"current_versions": self._current_versions_payload(),
|
||||
}
|
||||
|
||||
def _artifact_size(self, url: str, token: str) -> int:
|
||||
return self._head_content_length(url, token)
|
||||
|
||||
def _apply_model(self, item: dict, token: str) -> None:
|
||||
name = str(item["resourceName"])
|
||||
version = str(item["version"])
|
||||
url = str(item["cdnUrl"])
|
||||
sha256 = str(item["sha256"])
|
||||
key = _aes_key_from_encryption_field(item["encryptionKey"])
|
||||
size = self._artifact_size(url, token)
|
||||
job_id = f"update-{name}-{version}"
|
||||
os.makedirs(self._model_dir, exist_ok=True)
|
||||
out_path = os.path.join(self._model_dir, f"azaion-{version}.trt")
|
||||
self._download_manager.fetch_decrypt_verify(
|
||||
job_id,
|
||||
url,
|
||||
sha256,
|
||||
size,
|
||||
key,
|
||||
out_path,
|
||||
)
|
||||
self._version_collector.invalidate()
|
||||
|
||||
def _mark_pending_compose(self, service: str) -> None:
|
||||
state = self._load_state()
|
||||
pending = list(state.get("pending_compose") or [])
|
||||
if service not in pending:
|
||||
pending.append(service)
|
||||
state["pending_compose"] = pending
|
||||
self._save_state(state)
|
||||
|
||||
def _clear_pending_compose(self, service: str) -> None:
|
||||
state = self._load_state()
|
||||
pending = [s for s in (state.get("pending_compose") or []) if s != service]
|
||||
state["pending_compose"] = pending
|
||||
self._save_state(state)
|
||||
|
||||
def _apply_docker_image(self, item: dict, token: str) -> None:
|
||||
name = str(item["resourceName"])
|
||||
version = str(item["version"])
|
||||
url = str(item["cdnUrl"])
|
||||
sha256 = str(item["sha256"])
|
||||
key = _aes_key_from_encryption_field(item["encryptionKey"])
|
||||
size = self._artifact_size(url, token)
|
||||
job_id = f"update-{name}-{version}"
|
||||
artifact_dir = os.path.dirname(self._state_path)
|
||||
os.makedirs(artifact_dir, exist_ok=True)
|
||||
out_tar = os.path.join(artifact_dir, f"{job_id}.plaintext.tar")
|
||||
self._download_manager.fetch_decrypt_verify(
|
||||
job_id,
|
||||
url,
|
||||
sha256,
|
||||
size,
|
||||
key,
|
||||
out_tar,
|
||||
)
|
||||
self._subprocess_run(["docker", "load", "-i", out_tar], check=True)
|
||||
self._version_collector.invalidate()
|
||||
self._mark_pending_compose(name)
|
||||
self._subprocess_run(
|
||||
["docker", "compose", "-f", self._compose_file, "up", "-d", name],
|
||||
check=True,
|
||||
)
|
||||
self._clear_pending_compose(name)
|
||||
|
||||
def _tick_once(self) -> None:
|
||||
token = self._get_token()
|
||||
if not token:
|
||||
return
|
||||
self._drain_pending_compose()
|
||||
body = self._build_get_update_body()
|
||||
updates = self._post_get_update(token, body)
|
||||
if not isinstance(updates, list):
|
||||
return
|
||||
for item in _sort_updates_loader_last(updates):
|
||||
rname = str(item.get("resourceName", ""))
|
||||
if rname == "detection_model":
|
||||
self._apply_model(item, token)
|
||||
else:
|
||||
self._apply_docker_image(item, token)
|
||||
|
||||
def run_forever(self) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
self._drain_pending_compose()
|
||||
self._tick_once()
|
||||
except Exception as exc:
|
||||
logger.exception("update manager tick failed: {}", exc)
|
||||
if self._wait_fn is not None:
|
||||
if self._wait_fn(self._interval):
|
||||
break
|
||||
elif self._stop_event.wait(self._interval):
|
||||
break
|
||||
|
||||
|
||||
def maybe_start_update_background(
|
||||
get_api_client: Callable[[], Any],
|
||||
api_url: str,
|
||||
) -> None:
|
||||
state_dir = os.environ.get("LOADER_DOWNLOAD_STATE_DIR")
|
||||
if not state_dir:
|
||||
return
|
||||
model_dir = os.environ.get("LOADER_MODEL_DIR", "models")
|
||||
compose_file = os.environ.get("LOADER_COMPOSE_FILE", "docker-compose.yml")
|
||||
interval = float(os.environ.get("LOADER_UPDATE_INTERVAL_SEC", "300"))
|
||||
orchestrator_path = os.environ.get(
|
||||
"LOADER_UPDATE_STATE_PATH",
|
||||
os.path.join(state_dir, "update_orchestrator.json"),
|
||||
)
|
||||
|
||||
def token_getter() -> Optional[str]:
|
||||
client = get_api_client()
|
||||
return getattr(client, "token", None)
|
||||
|
||||
try:
|
||||
dm = ResumableDownloadManager(state_dir)
|
||||
vc = VersionCollector(model_dir)
|
||||
um = UpdateManager(
|
||||
api_url,
|
||||
token_getter,
|
||||
dm,
|
||||
vc,
|
||||
compose_file,
|
||||
model_dir,
|
||||
orchestrator_path,
|
||||
interval_seconds=interval,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("update manager failed to start: {}", exc)
|
||||
return
|
||||
|
||||
threading.Thread(target=um.run_forever, name="loader-updates", daemon=True).start()
|
||||
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
TRT_DATE_PATTERN = re.compile(r"^azaion-(\d{4}-\d{2}-\d{2})\.trt$", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResourceVersion:
|
||||
resource_name: str
|
||||
version: str
|
||||
|
||||
|
||||
class VersionCollector:
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
*,
|
||||
subprocess_run: Optional[Callable] = None,
|
||||
) -> None:
|
||||
self._model_dir = model_dir
|
||||
self._subprocess_run = subprocess_run or subprocess.run
|
||||
self._cache: Optional[List[ResourceVersion]] = None
|
||||
|
||||
def invalidate(self) -> None:
|
||||
self._cache = None
|
||||
|
||||
def collect(self) -> List[ResourceVersion]:
|
||||
if self._cache is not None:
|
||||
return list(self._cache)
|
||||
rows = self._collect_uncached()
|
||||
self._cache = rows
|
||||
return list(rows)
|
||||
|
||||
def collect_as_dicts(self) -> List[dict]:
|
||||
return [asdict(r) for r in self.collect()]
|
||||
|
||||
def _collect_uncached(self) -> List[ResourceVersion]:
|
||||
out: List[ResourceVersion] = []
|
||||
mv = self._best_trt_version()
|
||||
if mv is not None:
|
||||
out.append(ResourceVersion("detection_model", mv))
|
||||
out.extend(self._docker_versions())
|
||||
rest = [r for r in out if r.resource_name != "detection_model"]
|
||||
rest.sort(key=lambda r: r.resource_name)
|
||||
if mv is not None:
|
||||
return [ResourceVersion("detection_model", mv)] + rest
|
||||
return rest
|
||||
|
||||
def _best_trt_version(self) -> Optional[str]:
|
||||
if not os.path.isdir(self._model_dir):
|
||||
return None
|
||||
best: Optional[str] = None
|
||||
for name in os.listdir(self._model_dir):
|
||||
m = TRT_DATE_PATTERN.match(name)
|
||||
if not m:
|
||||
continue
|
||||
v = m.group(1)
|
||||
if best is None or v > best:
|
||||
best = v
|
||||
return best
|
||||
|
||||
def _docker_versions(self) -> List[ResourceVersion]:
|
||||
try:
|
||||
result = self._subprocess_run(
|
||||
["docker", "images", "--format", "{{.Repository}}:{{.Tag}}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
except (OSError, subprocess.CalledProcessError):
|
||||
return []
|
||||
found: List[ResourceVersion] = []
|
||||
for line in result.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if not line or ":<none>" in line:
|
||||
continue
|
||||
if not line.startswith("azaion/"):
|
||||
continue
|
||||
if ":" not in line:
|
||||
continue
|
||||
repo, tag = line.rsplit(":", 1)
|
||||
if tag in ("<none>", ""):
|
||||
continue
|
||||
parts = repo.split("/", 1)
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
found.append(ResourceVersion(parts[1], tag))
|
||||
return found
|
||||
Reference in New Issue
Block a user