mirror of
https://github.com/azaion/loader.git
synced 2026-04-22 19:16:37 +00:00
Quality cleanup refactoring
Made-with: Cursor
This commit is contained in:
@@ -14,15 +14,18 @@ RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com")
|
||||
IMAGES_PATH = os.environ.get("IMAGES_PATH", "/opt/azaion/images.enc")
|
||||
API_VERSION = os.environ.get("API_VERSION", "latest")
|
||||
|
||||
api_client = None
|
||||
_api_client = None
|
||||
_api_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_api_client():
|
||||
global api_client
|
||||
if api_client is None:
|
||||
from api_client import ApiClient
|
||||
api_client = ApiClient(RESOURCE_API_URL)
|
||||
return api_client
|
||||
global _api_client
|
||||
if _api_client is None:
|
||||
with _api_client_lock:
|
||||
if _api_client is None:
|
||||
from api_client import ApiClient
|
||||
_api_client = ApiClient(RESOURCE_API_URL)
|
||||
return _api_client
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@@ -45,9 +48,28 @@ class StatusResponse(BaseModel):
|
||||
modelCacheDir: str
|
||||
|
||||
|
||||
unlock_state = UnlockState.idle
|
||||
unlock_error: Optional[str] = None
|
||||
unlock_lock = threading.Lock()
|
||||
class _UnlockStateHolder:
|
||||
def __init__(self):
|
||||
self._state = UnlockState.idle
|
||||
self._error: Optional[str] = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
return self._state, self._error
|
||||
|
||||
def set(self, state: UnlockState, error: Optional[str] = None):
|
||||
with self._lock:
|
||||
self._state = state
|
||||
self._error = error
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
with self._lock:
|
||||
return self._state
|
||||
|
||||
|
||||
_unlock = _UnlockStateHolder()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@@ -101,8 +123,6 @@ def upload_resource(
|
||||
|
||||
|
||||
def _run_unlock(email: str, password: str):
|
||||
global unlock_state, unlock_error
|
||||
|
||||
from binary_split import (
|
||||
download_key_fragment,
|
||||
decrypt_archive,
|
||||
@@ -112,76 +132,67 @@ def _run_unlock(email: str, password: str):
|
||||
|
||||
try:
|
||||
if check_images_loaded(API_VERSION):
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.ready
|
||||
_, prev_err = _unlock.get()
|
||||
_unlock.set(UnlockState.ready, prev_err)
|
||||
return
|
||||
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.authenticating
|
||||
_unlock.set(UnlockState.authenticating)
|
||||
|
||||
client = get_api_client()
|
||||
client.set_credentials_from_dict(email, password)
|
||||
client.login()
|
||||
token = client.token
|
||||
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.downloading_key
|
||||
_unlock.set(UnlockState.downloading_key)
|
||||
|
||||
key_fragment = download_key_fragment(RESOURCE_API_URL, token)
|
||||
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.decrypting
|
||||
_unlock.set(UnlockState.decrypting)
|
||||
|
||||
tar_path = IMAGES_PATH.replace(".enc", ".tar")
|
||||
decrypt_archive(IMAGES_PATH, key_fragment, tar_path)
|
||||
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.loading_images
|
||||
_unlock.set(UnlockState.loading_images)
|
||||
|
||||
docker_load(tar_path)
|
||||
|
||||
try:
|
||||
os.remove(tar_path)
|
||||
except OSError:
|
||||
pass
|
||||
except OSError as e:
|
||||
from loguru import logger
|
||||
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.ready
|
||||
unlock_error = None
|
||||
logger.warning(f"Failed to remove {tar_path}: {e}")
|
||||
|
||||
_unlock.set(UnlockState.ready, None)
|
||||
|
||||
except Exception as e:
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.error
|
||||
unlock_error = str(e)
|
||||
_unlock.set(UnlockState.error, str(e))
|
||||
|
||||
|
||||
@app.post("/unlock")
|
||||
def unlock(req: LoginRequest, background_tasks: BackgroundTasks):
|
||||
global unlock_state, unlock_error
|
||||
|
||||
with unlock_lock:
|
||||
if unlock_state == UnlockState.ready:
|
||||
return {"state": unlock_state.value}
|
||||
if unlock_state not in (UnlockState.idle, UnlockState.error):
|
||||
return {"state": unlock_state.value}
|
||||
state, _ = _unlock.get()
|
||||
if state == UnlockState.ready:
|
||||
return {"state": state.value}
|
||||
if state not in (UnlockState.idle, UnlockState.error):
|
||||
return {"state": state.value}
|
||||
|
||||
if not os.path.exists(IMAGES_PATH):
|
||||
from binary_split import check_images_loaded
|
||||
|
||||
if check_images_loaded(API_VERSION):
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.ready
|
||||
return {"state": unlock_state.value}
|
||||
_, prev_err = _unlock.get()
|
||||
_unlock.set(UnlockState.ready, prev_err)
|
||||
return {"state": _unlock.state.value}
|
||||
raise HTTPException(status_code=404, detail="Encrypted archive not found")
|
||||
|
||||
with unlock_lock:
|
||||
unlock_state = UnlockState.authenticating
|
||||
unlock_error = None
|
||||
_unlock.set(UnlockState.authenticating, None)
|
||||
|
||||
background_tasks.add_task(_run_unlock, req.email, req.password)
|
||||
return {"state": unlock_state.value}
|
||||
return {"state": _unlock.state.value}
|
||||
|
||||
|
||||
@app.get("/unlock/status")
|
||||
def get_unlock_status():
|
||||
with unlock_lock:
|
||||
return {"state": unlock_state.value, "error": unlock_error}
|
||||
state, error = _unlock.get()
|
||||
return {"state": state.value, "error": error}
|
||||
|
||||
Reference in New Issue
Block a user