Files
loader/main.py
T
Oleksandr Bezdieniezhnykh 4eaf218f09 Quality cleanup refactoring
Made-with: Cursor
2026-04-13 06:21:26 +03:00

199 lines
5.1 KiB
Python

import os
import threading
from typing import Optional
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
from fastapi.responses import Response
from pydantic import BaseModel
from unlock_state import UnlockState
app = FastAPI(title="Azaion.Loader")
RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com")
IMAGES_PATH = os.environ.get("IMAGES_PATH", "/opt/azaion/images.enc")
API_VERSION = os.environ.get("API_VERSION", "latest")
_api_client = None
_api_client_lock = threading.Lock()
def get_api_client():
global _api_client
if _api_client is None:
with _api_client_lock:
if _api_client is None:
from api_client import ApiClient
_api_client = ApiClient(RESOURCE_API_URL)
return _api_client
class LoginRequest(BaseModel):
email: str
password: str
class LoadRequest(BaseModel):
filename: str
folder: str
class HealthResponse(BaseModel):
status: str
class StatusResponse(BaseModel):
status: str
authenticated: bool
modelCacheDir: str
class _UnlockStateHolder:
def __init__(self):
self._state = UnlockState.idle
self._error: Optional[str] = None
self._lock = threading.Lock()
def get(self):
with self._lock:
return self._state, self._error
def set(self, state: UnlockState, error: Optional[str] = None):
with self._lock:
self._state = state
self._error = error
@property
def state(self):
with self._lock:
return self._state
_unlock = _UnlockStateHolder()
@app.get("/health")
def health() -> HealthResponse:
return HealthResponse(status="healthy")
@app.get("/status")
def status() -> StatusResponse:
client = get_api_client()
return StatusResponse(
status="healthy",
authenticated=client.token is not None,
modelCacheDir="models",
)
@app.post("/login")
def login(req: LoginRequest):
try:
client = get_api_client()
client.set_credentials_from_dict(req.email, req.password)
return {"status": "ok"}
except Exception as e:
raise HTTPException(status_code=401, detail=str(e))
@app.post("/load/{filename}")
def load_resource(filename: str, req: LoadRequest):
try:
client = get_api_client()
data = client.load_big_small_resource(req.filename, req.folder)
return Response(content=data, media_type="application/octet-stream")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload/{filename}")
def upload_resource(
filename: str,
data: UploadFile = File(...),
folder: str = Form("models"),
):
try:
client = get_api_client()
content = data.file.read()
client.upload_big_small_resource(content, filename, folder)
return {"status": "ok"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def _run_unlock(email: str, password: str):
from binary_split import (
download_key_fragment,
decrypt_archive,
docker_load,
check_images_loaded,
)
try:
if check_images_loaded(API_VERSION):
_, prev_err = _unlock.get()
_unlock.set(UnlockState.ready, prev_err)
return
_unlock.set(UnlockState.authenticating)
client = get_api_client()
client.set_credentials_from_dict(email, password)
client.login()
token = client.token
_unlock.set(UnlockState.downloading_key)
key_fragment = download_key_fragment(RESOURCE_API_URL, token)
_unlock.set(UnlockState.decrypting)
tar_path = IMAGES_PATH.replace(".enc", ".tar")
decrypt_archive(IMAGES_PATH, key_fragment, tar_path)
_unlock.set(UnlockState.loading_images)
docker_load(tar_path)
try:
os.remove(tar_path)
except OSError as e:
from loguru import logger
logger.warning(f"Failed to remove {tar_path}: {e}")
_unlock.set(UnlockState.ready, None)
except Exception as e:
_unlock.set(UnlockState.error, str(e))
@app.post("/unlock")
def unlock(req: LoginRequest, background_tasks: BackgroundTasks):
state, _ = _unlock.get()
if state == UnlockState.ready:
return {"state": state.value}
if state not in (UnlockState.idle, UnlockState.error):
return {"state": state.value}
if not os.path.exists(IMAGES_PATH):
from binary_split import check_images_loaded
if check_images_loaded(API_VERSION):
_, prev_err = _unlock.get()
_unlock.set(UnlockState.ready, prev_err)
return {"state": _unlock.state.value}
raise HTTPException(status_code=404, detail="Encrypted archive not found")
_unlock.set(UnlockState.authenticating, None)
background_tasks.add_task(_run_unlock, req.email, req.password)
return {"state": _unlock.state.value}
@app.get("/unlock/status")
def get_unlock_status():
state, error = _unlock.get()
return {"state": state.value, "error": error}