mirror of
https://github.com/azaion/loader.git
synced 2026-04-22 08:16:33 +00:00
941b8199aa
Made-with: Cursor
188 lines
4.8 KiB
Python
188 lines
4.8 KiB
Python
import os
|
|
import threading
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
|
|
from fastapi.responses import Response
|
|
from pydantic import BaseModel
|
|
|
|
from unlock_state import UnlockState
|
|
|
|
app = FastAPI(title="Azaion.Loader")
|
|
|
|
RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com")
|
|
IMAGES_PATH = os.environ.get("IMAGES_PATH", "/opt/azaion/images.enc")
|
|
API_VERSION = os.environ.get("API_VERSION", "latest")
|
|
|
|
api_client = None
|
|
|
|
|
|
def get_api_client():
|
|
global api_client
|
|
if api_client is None:
|
|
from api_client import ApiClient
|
|
api_client = ApiClient(RESOURCE_API_URL)
|
|
return api_client
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
|
|
|
|
class LoadRequest(BaseModel):
|
|
filename: str
|
|
folder: str
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: str
|
|
|
|
|
|
class StatusResponse(BaseModel):
|
|
status: str
|
|
authenticated: bool
|
|
modelCacheDir: str
|
|
|
|
|
|
unlock_state = UnlockState.idle
|
|
unlock_error: Optional[str] = None
|
|
unlock_lock = threading.Lock()
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> HealthResponse:
|
|
return HealthResponse(status="healthy")
|
|
|
|
|
|
@app.get("/status")
|
|
def status() -> StatusResponse:
|
|
client = get_api_client()
|
|
return StatusResponse(
|
|
status="healthy",
|
|
authenticated=client.token is not None,
|
|
modelCacheDir="models",
|
|
)
|
|
|
|
|
|
@app.post("/login")
|
|
def login(req: LoginRequest):
|
|
try:
|
|
client = get_api_client()
|
|
client.set_credentials_from_dict(req.email, req.password)
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=401, detail=str(e))
|
|
|
|
|
|
@app.post("/load/{filename}")
|
|
def load_resource(filename: str, req: LoadRequest):
|
|
try:
|
|
client = get_api_client()
|
|
data = client.load_big_small_resource(req.filename, req.folder)
|
|
return Response(content=data, media_type="application/octet-stream")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/upload/{filename}")
|
|
def upload_resource(
|
|
filename: str,
|
|
data: UploadFile = File(...),
|
|
folder: str = Form("models"),
|
|
):
|
|
try:
|
|
client = get_api_client()
|
|
content = data.file.read()
|
|
client.upload_big_small_resource(content, filename, folder)
|
|
return {"status": "ok"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
def _run_unlock(email: str, password: str):
|
|
global unlock_state, unlock_error
|
|
|
|
from binary_split import (
|
|
download_key_fragment,
|
|
decrypt_archive,
|
|
docker_load,
|
|
check_images_loaded,
|
|
)
|
|
|
|
try:
|
|
if check_images_loaded(API_VERSION):
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.ready
|
|
return
|
|
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.authenticating
|
|
|
|
client = get_api_client()
|
|
client.set_credentials_from_dict(email, password)
|
|
client.login()
|
|
token = client.token
|
|
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.downloading_key
|
|
|
|
key_fragment = download_key_fragment(RESOURCE_API_URL, token)
|
|
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.decrypting
|
|
|
|
tar_path = IMAGES_PATH.replace(".enc", ".tar")
|
|
decrypt_archive(IMAGES_PATH, key_fragment, tar_path)
|
|
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.loading_images
|
|
|
|
docker_load(tar_path)
|
|
|
|
try:
|
|
os.remove(tar_path)
|
|
except OSError:
|
|
pass
|
|
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.ready
|
|
unlock_error = None
|
|
|
|
except Exception as e:
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.error
|
|
unlock_error = str(e)
|
|
|
|
|
|
@app.post("/unlock")
|
|
def unlock(req: LoginRequest, background_tasks: BackgroundTasks):
|
|
global unlock_state, unlock_error
|
|
|
|
with unlock_lock:
|
|
if unlock_state == UnlockState.ready:
|
|
return {"state": unlock_state.value}
|
|
if unlock_state not in (UnlockState.idle, UnlockState.error):
|
|
return {"state": unlock_state.value}
|
|
|
|
if not os.path.exists(IMAGES_PATH):
|
|
from binary_split import check_images_loaded
|
|
if check_images_loaded(API_VERSION):
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.ready
|
|
return {"state": unlock_state.value}
|
|
raise HTTPException(status_code=404, detail="Encrypted archive not found")
|
|
|
|
with unlock_lock:
|
|
unlock_state = UnlockState.authenticating
|
|
unlock_error = None
|
|
|
|
background_tasks.add_task(_run_unlock, req.email, req.password)
|
|
return {"state": unlock_state.value}
|
|
|
|
|
|
@app.get("/unlock/status")
|
|
def get_unlock_status():
|
|
with unlock_lock:
|
|
return {"state": unlock_state.value, "error": unlock_error}
|