mirror of
https://github.com/azaion/loader.git
synced 2026-04-22 12:16:32 +00:00
[AZ-182][AZ-184][AZ-187] Batch 1
Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,280 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import requests
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def backoff_seconds(failure_index: int) -> int:
|
||||
sequence = (60, 300, 900, 3600, 14400)
|
||||
idx = min(max(0, failure_index), len(sequence) - 1)
|
||||
return sequence[idx]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadState:
|
||||
url: str
|
||||
expected_sha256: str
|
||||
expected_size: int
|
||||
bytes_downloaded: int
|
||||
temp_file_path: str
|
||||
phase: str
|
||||
|
||||
def to_json_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_json_dict(cls, data: dict) -> "DownloadState":
|
||||
return cls(
|
||||
url=data["url"],
|
||||
expected_sha256=data["expected_sha256"],
|
||||
expected_size=int(data["expected_size"]),
|
||||
bytes_downloaded=int(data["bytes_downloaded"]),
|
||||
temp_file_path=data["temp_file_path"],
|
||||
phase=data["phase"],
|
||||
)
|
||||
|
||||
|
||||
def load_download_state(path: str) -> DownloadState:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return DownloadState.from_json_dict(json.load(f))
|
||||
|
||||
|
||||
def save_download_state(path: str, state: DownloadState) -> None:
|
||||
directory = os.path.dirname(path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
payload = json.dumps(state.to_json_dict(), indent=2, sort_keys=True)
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
dir=directory or None,
|
||||
prefix=".download_state_",
|
||||
suffix=".tmp",
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(payload)
|
||||
os.replace(tmp, path)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def _sha256_file(path: str, chunk_size: int = 1024 * 1024) -> str:
|
||||
h = hashlib.sha256()
|
||||
with open(path, "rb") as f:
|
||||
while True:
|
||||
block = f.read(chunk_size)
|
||||
if not block:
|
||||
break
|
||||
h.update(block)
|
||||
return h.hexdigest().lower()
|
||||
|
||||
|
||||
def _safe_job_id(job_id: str) -> str:
|
||||
return "".join(c if c.isalnum() or c in "-_" else "_" for c in job_id)
|
||||
|
||||
|
||||
def decrypt_cbc_file(encrypted_path: str, aes_key: bytes, output_path: str) -> None:
|
||||
with open(encrypted_path, "rb") as f_in:
|
||||
iv = f_in.read(16)
|
||||
if len(iv) != 16:
|
||||
raise ValueError("invalid ciphertext: missing iv")
|
||||
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
unpadder = padding.PKCS7(128).unpadder()
|
||||
with open(output_path, "wb") as f_out:
|
||||
while True:
|
||||
chunk = f_in.read(64 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
decrypted = decryptor.update(chunk)
|
||||
if decrypted:
|
||||
f_out.write(unpadder.update(decrypted))
|
||||
final_decrypted = decryptor.finalize()
|
||||
f_out.write(unpadder.update(final_decrypted) + unpadder.finalize())
|
||||
|
||||
|
||||
class ResumableDownloadManager:
|
||||
def __init__(
|
||||
self,
|
||||
state_directory: Optional[str] = None,
|
||||
*,
|
||||
session_factory: Optional[Callable[[], requests.Session]] = None,
|
||||
sleep_fn: Optional[Callable[[float], None]] = None,
|
||||
chunk_size: int = 1024 * 1024,
|
||||
) -> None:
|
||||
resolved = state_directory or os.environ.get("LOADER_DOWNLOAD_STATE_DIR")
|
||||
if not resolved:
|
||||
raise ValueError("state_directory or LOADER_DOWNLOAD_STATE_DIR is required")
|
||||
self._state_directory = resolved
|
||||
self._session_factory = session_factory or requests.Session
|
||||
self._sleep = sleep_fn or time.sleep
|
||||
self._chunk_size = chunk_size
|
||||
os.makedirs(self._state_directory, exist_ok=True)
|
||||
|
||||
def _state_path(self, job_id: str) -> str:
|
||||
safe = _safe_job_id(job_id)
|
||||
return os.path.join(self._state_directory, f"{safe}.json")
|
||||
|
||||
def _persist(self, path: str, state: DownloadState) -> None:
|
||||
save_download_state(path, state)
|
||||
|
||||
def fetch_decrypt_verify(
|
||||
self,
|
||||
job_id: str,
|
||||
url: str,
|
||||
expected_sha256: str,
|
||||
expected_size: int,
|
||||
decryption_key: bytes,
|
||||
output_plaintext_path: str,
|
||||
) -> None:
|
||||
state_path = self._state_path(job_id)
|
||||
safe = _safe_job_id(job_id)
|
||||
temp_file_path = os.path.join(self._state_directory, f"{safe}.cipher.tmp")
|
||||
if os.path.isfile(state_path):
|
||||
state = load_download_state(state_path)
|
||||
if state.url != url:
|
||||
raise ValueError("state url mismatch")
|
||||
else:
|
||||
state = DownloadState(
|
||||
url=url,
|
||||
expected_sha256=expected_sha256,
|
||||
expected_size=expected_size,
|
||||
bytes_downloaded=0,
|
||||
temp_file_path=temp_file_path,
|
||||
phase="pending",
|
||||
)
|
||||
self._persist(state_path, state)
|
||||
|
||||
state.expected_sha256 = expected_sha256
|
||||
state.expected_size = expected_size
|
||||
state.temp_file_path = temp_file_path
|
||||
if os.path.isfile(state.temp_file_path):
|
||||
on_disk = os.path.getsize(state.temp_file_path)
|
||||
state.bytes_downloaded = min(on_disk, state.expected_size)
|
||||
else:
|
||||
state.bytes_downloaded = 0
|
||||
|
||||
network_failures = 0
|
||||
session = self._session_factory()
|
||||
|
||||
try:
|
||||
while True:
|
||||
while state.bytes_downloaded < state.expected_size:
|
||||
state.phase = "downloading"
|
||||
self._persist(state_path, state)
|
||||
try:
|
||||
self._stream_download(session, state, state_path)
|
||||
network_failures = 0
|
||||
except requests.RequestException as exc:
|
||||
logger.exception("download request failed: {}", exc)
|
||||
state.phase = "paused"
|
||||
self._persist(state_path, state)
|
||||
wait_s = backoff_seconds(network_failures)
|
||||
self._sleep(wait_s)
|
||||
network_failures += 1
|
||||
|
||||
state.phase = "verifying"
|
||||
self._persist(state_path, state)
|
||||
if _sha256_file(state.temp_file_path) != state.expected_sha256.lower().strip():
|
||||
try:
|
||||
os.remove(state.temp_file_path)
|
||||
except OSError as exc:
|
||||
logger.exception("failed to remove corrupt download: {}", exc)
|
||||
state.bytes_downloaded = 0
|
||||
state.phase = "downloading"
|
||||
self._persist(state_path, state)
|
||||
continue
|
||||
|
||||
state.phase = "decrypting"
|
||||
self._persist(state_path, state)
|
||||
decrypt_cbc_file(state.temp_file_path, decryption_key, output_plaintext_path)
|
||||
state.phase = "complete"
|
||||
self._persist(state_path, state)
|
||||
return
|
||||
except Exception:
|
||||
state.phase = "failed"
|
||||
try:
|
||||
self._persist(state_path, state)
|
||||
except Exception as persist_exc:
|
||||
logger.exception("failed to persist failed state: {}", persist_exc)
|
||||
raise
|
||||
|
||||
def _stream_download(
|
||||
self,
|
||||
session: requests.Session,
|
||||
state: DownloadState,
|
||||
state_path: str,
|
||||
) -> None:
|
||||
headers = {}
|
||||
if state.bytes_downloaded > 0:
|
||||
headers["Range"] = f"bytes={state.bytes_downloaded}-"
|
||||
with session.get(
|
||||
state.url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
timeout=(30, 120),
|
||||
) as resp:
|
||||
if state.bytes_downloaded > 0 and resp.status_code == 200:
|
||||
try:
|
||||
os.remove(state.temp_file_path)
|
||||
except OSError:
|
||||
pass
|
||||
state.bytes_downloaded = 0
|
||||
self._persist(state_path, state)
|
||||
with session.get(
|
||||
state.url,
|
||||
headers={},
|
||||
stream=True,
|
||||
timeout=(30, 120),
|
||||
) as resp_full:
|
||||
self._write_response_stream(resp_full, state, state_path, append=False)
|
||||
return
|
||||
if state.bytes_downloaded > 0 and resp.status_code != 206:
|
||||
resp.raise_for_status()
|
||||
raise requests.HTTPError("expected 206 Partial Content when resuming")
|
||||
if state.bytes_downloaded == 0 and resp.status_code not in (200, 206):
|
||||
resp.raise_for_status()
|
||||
append = state.bytes_downloaded > 0
|
||||
self._write_response_stream(resp, state, state_path, append=append)
|
||||
|
||||
def _write_response_stream(
|
||||
self,
|
||||
resp: requests.Response,
|
||||
state: DownloadState,
|
||||
state_path: str,
|
||||
*,
|
||||
append: bool,
|
||||
) -> None:
|
||||
mode = "ab" if append else "wb"
|
||||
written_since_persist = 0
|
||||
with open(state.temp_file_path, mode) as out:
|
||||
for chunk in resp.iter_content(chunk_size=self._chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
room = state.expected_size - state.bytes_downloaded
|
||||
if room <= 0:
|
||||
break
|
||||
if len(chunk) > room:
|
||||
chunk = chunk[:room]
|
||||
out.write(chunk)
|
||||
state.bytes_downloaded += len(chunk)
|
||||
written_since_persist += len(chunk)
|
||||
if written_since_persist >= self._chunk_size:
|
||||
self._persist(state_path, state)
|
||||
written_since_persist = 0
|
||||
if state.bytes_downloaded >= state.expected_size:
|
||||
break
|
||||
if written_since_persist:
|
||||
self._persist(state_path, state)
|
||||
@@ -0,0 +1,37 @@
|
||||
from credentials import Credentials
|
||||
from security import (
|
||||
security_calc_hash,
|
||||
security_decrypt_to,
|
||||
security_encrypt_to,
|
||||
security_get_api_encryption_key,
|
||||
security_get_hw_hash,
|
||||
security_get_resource_encryption_key,
|
||||
)
|
||||
from security_provider import SecurityProvider
|
||||
|
||||
|
||||
class LegacySecurityProvider(SecurityProvider):
|
||||
@property
|
||||
def kind(self) -> str:
|
||||
return "legacy"
|
||||
|
||||
def encrypt_to(self, input_bytes: bytes, key: str) -> bytes:
|
||||
return security_encrypt_to(input_bytes, key)
|
||||
|
||||
def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes:
|
||||
return security_decrypt_to(ciphertext_with_iv_bytes, key)
|
||||
|
||||
def get_hw_hash(self, hardware: str) -> str:
|
||||
return security_get_hw_hash(hardware)
|
||||
|
||||
def get_api_encryption_key(
|
||||
self, creds_email: str, creds_password: str, hardware_hash: str
|
||||
) -> str:
|
||||
creds = Credentials(creds_email, creds_password)
|
||||
return security_get_api_encryption_key(creds, hardware_hash)
|
||||
|
||||
def get_resource_encryption_key(self) -> str:
|
||||
return security_get_resource_encryption_key()
|
||||
|
||||
def calc_hash(self, key: str) -> str:
|
||||
return security_calc_hash(key)
|
||||
@@ -7,9 +7,12 @@ from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from unlock_state import UnlockState
|
||||
from security_provider import create_security_provider
|
||||
|
||||
app = FastAPI(title="Azaion.Loader")
|
||||
|
||||
security_provider = create_security_provider()
|
||||
|
||||
RESOURCE_API_URL = os.environ.get("RESOURCE_API_URL", "https://api.azaion.com")
|
||||
IMAGES_PATH = os.environ.get("IMAGES_PATH", "/opt/azaion/images.enc")
|
||||
API_VERSION = os.environ.get("API_VERSION", "latest")
|
||||
|
||||
@@ -61,3 +61,27 @@ cdef class Security:
|
||||
hash_bytes = sha384(str_bytes).digest()
|
||||
cdef str h = base64.b64encode(hash_bytes).decode('utf-8')
|
||||
return h
|
||||
|
||||
|
||||
cpdef bytes security_encrypt_to(bytes input_bytes, str key):
|
||||
return Security.encrypt_to(input_bytes, key)
|
||||
|
||||
|
||||
cpdef bytes security_decrypt_to(bytes ciphertext_with_iv_bytes, str key):
|
||||
return Security.decrypt_to(ciphertext_with_iv_bytes, key)
|
||||
|
||||
|
||||
cpdef str security_get_hw_hash(str hardware):
|
||||
return Security.get_hw_hash(hardware)
|
||||
|
||||
|
||||
cpdef str security_get_api_encryption_key(Credentials credentials, str hardware_hash):
|
||||
return Security.get_api_encryption_key(credentials, hardware_hash)
|
||||
|
||||
|
||||
cpdef str security_get_resource_encryption_key():
|
||||
return Security.get_resource_encryption_key()
|
||||
|
||||
|
||||
cpdef str security_calc_hash(str key):
|
||||
return Security.calc_hash(key)
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def _security_provider_override(environ: Mapping[str, str]) -> Optional[str]:
|
||||
raw = environ.get("SECURITY_PROVIDER")
|
||||
if raw is None:
|
||||
return None
|
||||
s = raw.strip().lower()
|
||||
return s if s else None
|
||||
|
||||
|
||||
def _tpm_device_visible(path_exists: Callable[[str], bool]) -> bool:
|
||||
return path_exists("/dev/tpm0") or path_exists("/dev/tpmrm0")
|
||||
|
||||
|
||||
def _tpm_transport_configured(environ: Mapping[str, str]) -> bool:
|
||||
t = (environ.get("TSS2_TCTI") or environ.get("TPM2TOOLS_TCTI") or "").strip()
|
||||
if t:
|
||||
return True
|
||||
return bool((environ.get("TSS2_FAPICONF") or "").strip())
|
||||
|
||||
|
||||
def should_attempt_tpm(
|
||||
environ: Mapping[str, str],
|
||||
path_exists: Callable[[str], bool],
|
||||
) -> bool:
|
||||
return _tpm_device_visible(path_exists) or _tpm_transport_configured(environ)
|
||||
|
||||
|
||||
class SecurityProvider(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def kind(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def encrypt_to(self, input_bytes: bytes, key: str) -> bytes: ...
|
||||
|
||||
@abstractmethod
|
||||
def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_hw_hash(self, hardware: str) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_api_encryption_key(
|
||||
self, creds_email: str, creds_password: str, hardware_hash: str
|
||||
) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_resource_encryption_key(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def calc_hash(self, key: str) -> str: ...
|
||||
|
||||
def seal(self, object_path: str, data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def unseal(self, object_path: str) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_security_provider(
|
||||
*,
|
||||
environ: Optional[Mapping[str, str]] = None,
|
||||
path_exists: Optional[Callable[[str], bool]] = None,
|
||||
) -> SecurityProvider:
|
||||
from legacy_security_provider import LegacySecurityProvider
|
||||
|
||||
if path_exists is None:
|
||||
path_exists = os.path.exists
|
||||
env = environ if environ is not None else os.environ
|
||||
override = _security_provider_override(env)
|
||||
if override == "legacy":
|
||||
logger.info("security provider: legacy (SECURITY_PROVIDER override)")
|
||||
return LegacySecurityProvider()
|
||||
if not should_attempt_tpm(env, path_exists):
|
||||
logger.info("security provider: legacy (no TPM device or TCTI)")
|
||||
return LegacySecurityProvider()
|
||||
try:
|
||||
from tpm_security_provider import TpmSecurityProvider
|
||||
|
||||
provider = TpmSecurityProvider()
|
||||
logger.info("security provider: tpm")
|
||||
return provider
|
||||
except Exception as e:
|
||||
logger.warning("TPM security provider failed ({}), using legacy", e)
|
||||
return LegacySecurityProvider()
|
||||
@@ -0,0 +1,57 @@
|
||||
from security import (
|
||||
security_calc_hash,
|
||||
security_decrypt_to,
|
||||
security_encrypt_to,
|
||||
security_get_api_encryption_key,
|
||||
security_get_hw_hash,
|
||||
security_get_resource_encryption_key,
|
||||
)
|
||||
from credentials import Credentials
|
||||
from security_provider import SecurityProvider
|
||||
|
||||
|
||||
class TpmSecurityProvider(SecurityProvider):
|
||||
def __init__(self):
|
||||
try:
|
||||
from tpm2_pytss import FAPI
|
||||
from tpm2_pytss import TSS2_Exception
|
||||
except (ImportError, NotImplementedError) as e:
|
||||
raise RuntimeError("tpm2-pytss FAPI is not available") from e
|
||||
self._TSS2_Exception = TSS2_Exception
|
||||
self._fapi = FAPI()
|
||||
try:
|
||||
self._fapi.provision(is_provisioned_ok=True)
|
||||
except TSS2_Exception:
|
||||
pass
|
||||
self._fapi.get_random(1)
|
||||
|
||||
@property
|
||||
def kind(self) -> str:
|
||||
return "tpm"
|
||||
|
||||
def encrypt_to(self, input_bytes: bytes, key: str) -> bytes:
|
||||
return security_encrypt_to(input_bytes, key)
|
||||
|
||||
def decrypt_to(self, ciphertext_with_iv_bytes: bytes, key: str) -> bytes:
|
||||
return security_decrypt_to(ciphertext_with_iv_bytes, key)
|
||||
|
||||
def get_hw_hash(self, hardware: str) -> str:
|
||||
return security_get_hw_hash(hardware)
|
||||
|
||||
def get_api_encryption_key(
|
||||
self, creds_email: str, creds_password: str, hardware_hash: str
|
||||
) -> str:
|
||||
creds = Credentials(creds_email, creds_password)
|
||||
return security_get_api_encryption_key(creds, hardware_hash)
|
||||
|
||||
def get_resource_encryption_key(self) -> str:
|
||||
return security_get_resource_encryption_key()
|
||||
|
||||
def calc_hash(self, key: str) -> str:
|
||||
return security_calc_hash(key)
|
||||
|
||||
def seal(self, object_path: str, data: bytes) -> None:
|
||||
self._fapi.create_seal(object_path, data=data, exists_ok=True)
|
||||
|
||||
def unseal(self, object_path: str) -> bytes:
|
||||
return self._fapi.unseal(object_path)
|
||||
Reference in New Issue
Block a user