import os import threading from typing import Optional from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks from fastapi.responses import Response from pydantic import BaseModel from unlock_state import UnlockState app = FastAPI(title="Azaion.Loader") RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com") IMAGES_PATH = os.environ.get("IMAGES_PATH", "/opt/azaion/images.enc") API_VERSION = os.environ.get("API_VERSION", "latest") _api_client = None _api_client_lock = threading.Lock() def get_api_client(): global _api_client if _api_client is None: with _api_client_lock: if _api_client is None: from api_client import ApiClient _api_client = ApiClient(RESOURCE_API_URL) return _api_client class LoginRequest(BaseModel): email: str password: str class LoadRequest(BaseModel): filename: str folder: str class HealthResponse(BaseModel): status: str class StatusResponse(BaseModel): status: str authenticated: bool modelCacheDir: str class _UnlockStateHolder: def __init__(self): self._state = UnlockState.idle self._error: Optional[str] = None self._lock = threading.Lock() def get(self): with self._lock: return self._state, self._error def set(self, state: UnlockState, error: Optional[str] = None): with self._lock: self._state = state self._error = error @property def state(self): with self._lock: return self._state _unlock = _UnlockStateHolder() @app.get("/health") def health() -> HealthResponse: return HealthResponse(status="healthy") @app.get("/status") def status() -> StatusResponse: client = get_api_client() return StatusResponse( status="healthy", authenticated=client.token is not None, modelCacheDir="models", ) @app.post("/login") def login(req: LoginRequest): try: client = get_api_client() client.set_credentials_from_dict(req.email, req.password) return {"status": "ok"} except Exception as e: raise HTTPException(status_code=401, detail=str(e)) @app.post("/load/{filename}") def load_resource(filename: str, req: LoadRequest): try: client = get_api_client() data = client.load_big_small_resource(req.filename, req.folder) return Response(content=data, media_type="application/octet-stream") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload/{filename}") def upload_resource( filename: str, data: UploadFile = File(...), folder: str = Form("models"), ): try: client = get_api_client() content = data.file.read() client.upload_big_small_resource(content, filename, folder) return {"status": "ok"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def _run_unlock(email: str, password: str): from binary_split import ( download_key_fragment, decrypt_archive, docker_load, check_images_loaded, ) try: if check_images_loaded(API_VERSION): _, prev_err = _unlock.get() _unlock.set(UnlockState.ready, prev_err) return _unlock.set(UnlockState.authenticating) client = get_api_client() client.set_credentials_from_dict(email, password) client.login() token = client.token _unlock.set(UnlockState.downloading_key) key_fragment = download_key_fragment(RESOURCE_API_URL, token) _unlock.set(UnlockState.decrypting) tar_path = IMAGES_PATH.replace(".enc", ".tar") decrypt_archive(IMAGES_PATH, key_fragment, tar_path) _unlock.set(UnlockState.loading_images) docker_load(tar_path) try: os.remove(tar_path) except OSError as e: from loguru import logger logger.warning(f"Failed to remove {tar_path}: {e}") _unlock.set(UnlockState.ready, None) except Exception as e: _unlock.set(UnlockState.error, str(e)) @app.post("/unlock") def unlock(req: LoginRequest, background_tasks: BackgroundTasks): state, _ = _unlock.get() if state == UnlockState.ready: return {"state": state.value} if state not in (UnlockState.idle, UnlockState.error): return {"state": state.value} if not os.path.exists(IMAGES_PATH): from binary_split import check_images_loaded if check_images_loaded(API_VERSION): _, prev_err = _unlock.get() _unlock.set(UnlockState.ready, prev_err) return {"state": _unlock.state.value} raise HTTPException(status_code=404, detail="Encrypted archive not found") _unlock.set(UnlockState.authenticating, None) background_tasks.add_task(_run_unlock, req.email, req.password) return {"state": _unlock.state.value} @app.get("/unlock/status") def get_unlock_status(): state, error = _unlock.get() return {"state": state.value, "error": error}