import hashlib import os import shutil import tempfile import unittest import requests from download_manager import ( DownloadState, ResumableDownloadManager, backoff_seconds, decrypt_cbc_file, load_download_state, save_download_state, ) from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes def _encrypt_cbc(plaintext: bytes, aes_key: bytes) -> bytes: iv = os.urandom(16) padder = padding.PKCS7(128).padder() padded = padder.update(plaintext) + padder.finalize() cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend()) encryptor = cipher.encryptor() ciphertext = encryptor.update(padded) + encryptor.finalize() return iv + ciphertext class _StreamResponse: def __init__(self, status_code: int, chunk_source): self.status_code = status_code self.headers = {} self._chunk_source = chunk_source def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def raise_for_status(self): if self.status_code >= 400: raise requests.HTTPError(response=self) def iter_content(self, chunk_size=1024 * 1024): yield from self._chunk_source() class _MockSession: def __init__(self, handler): self._handler = handler def get(self, url, headers=None, stream=True, timeout=None): return self._handler(url, headers=headers or {}) class TestBackoff(unittest.TestCase): def test_ac5_exponential_backoff_sequence(self): # Arrange expected = (60, 300, 900, 3600, 14400) # Act values = [backoff_seconds(i) for i in range(6)] # Assert self.assertEqual(values[0], expected[0]) self.assertEqual(values[1], expected[1]) self.assertEqual(values[2], expected[2]) self.assertEqual(values[3], expected[3]) self.assertEqual(values[4], expected[4]) self.assertEqual(values[5], expected[4]) def test_ac5_sleep_invoked_with_backoff_on_repeated_failures(self): # Arrange sleeps = [] def fake_sleep(seconds): sleeps.append(seconds) key = hashlib.sha256(b"k").digest() ciphertext = _encrypt_cbc(b"x" * 100, key) sha = hashlib.sha256(ciphertext).hexdigest() failures_left = [3] def range_start(headers): r = headers.get("Range") if not r: return 0 return int(r.split("=", 1)[1].split("-", 1)[0]) def handler(url, headers): start = range_start(headers) if failures_left[0] > 0: failures_left[0] -= 1 def chunks(): yield ciphertext[start : start + 8] raise requests.ConnectionError("drop") return _StreamResponse(206 if start else 200, chunks) def chunks_final(): yield ciphertext[start:] return _StreamResponse(206 if start else 200, chunks_final) tmp = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) out = os.path.join(tmp, "out.bin") mgr = ResumableDownloadManager( state_directory=tmp, session_factory=lambda: _MockSession(handler), sleep_fn=fake_sleep, chunk_size=16, ) # Act mgr.fetch_decrypt_verify("job-backoff", "http://x", sha, len(ciphertext), key, out) # Assert self.assertEqual(sleeps, [60, 300, 900]) class TestStatePersistence(unittest.TestCase): def test_ac4_state_file_reload_restores_offset(self): # Arrange tmp = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) tf = os.path.join(tmp, "partial.cipher.tmp") with open(tf, "wb") as f: f.write(b"a" * 400) state = DownloadState( url="http://example/a", expected_sha256="ab" * 32, expected_size=1000, bytes_downloaded=400, temp_file_path=tf, phase="paused", ) path = os.path.join(tmp, "state.json") save_download_state(path, state) # Act loaded = load_download_state(path) # Assert self.assertEqual(loaded.bytes_downloaded, 400) self.assertEqual(loaded.expected_size, 1000) self.assertEqual(loaded.temp_file_path, tf) def test_ac4_manager_resumes_from_persisted_progress(self): # Arrange tmp = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) key = hashlib.sha256(b"k2").digest() plaintext = b"full-plaintext-payload-xyz" ciphertext = _encrypt_cbc(plaintext, key) sha = hashlib.sha256(ciphertext).hexdigest() partial = int(len(ciphertext) * 0.4) safe_job = "job_resume" tf = os.path.join(tmp, f"{safe_job}.cipher.tmp") with open(tf, "wb") as f: f.write(ciphertext[:partial]) state = DownloadState( url="http://cdn/blob", expected_sha256=sha, expected_size=len(ciphertext), bytes_downloaded=partial, temp_file_path=tf, phase="paused", ) save_download_state(os.path.join(tmp, f"{safe_job}.json"), state) seen_ranges = [] def handler(url, headers): rng = headers.get("Range") seen_ranges.append(rng) rest = ciphertext[partial:] def chunks(): yield rest return _StreamResponse(206, chunks) out = os.path.join(tmp, "plain.out") mgr = ResumableDownloadManager( state_directory=tmp, session_factory=lambda: _MockSession(handler), sleep_fn=lambda s: None, ) # Act mgr.fetch_decrypt_verify(safe_job, "http://cdn/blob", sha, len(ciphertext), key, out) # Assert self.assertEqual(seen_ranges[0], f"bytes={partial}-") with open(out, "rb") as f: self.assertEqual(f.read(), plaintext) class TestResumeAfterDrop(unittest.TestCase): def test_ac1_resume_uses_range_after_partial_transfer(self): # Arrange tmp = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) key = hashlib.sha256(b"k3").digest() body = b"q" * 100 ciphertext = _encrypt_cbc(body, key) sha = hashlib.sha256(ciphertext).hexdigest() cut = 60 headers_log = [] def handler(url, headers): headers_log.append(dict(headers)) if len(headers_log) == 1: def chunks(): yield ciphertext[:cut] raise requests.ConnectionError("starlink drop") return _StreamResponse(200, chunks) def chunks2(): yield ciphertext[cut:] return _StreamResponse(206, chunks2) out = os.path.join(tmp, "p.out") mgr = ResumableDownloadManager( state_directory=tmp, session_factory=lambda: _MockSession(handler), sleep_fn=lambda s: None, chunk_size=32, ) # Act mgr.fetch_decrypt_verify("ac1", "http://s3/o", sha, len(ciphertext), key, out) # Assert self.assertNotIn("Range", headers_log[0]) self.assertEqual(headers_log[1].get("Range"), f"bytes={cut}-") with open(out, "rb") as f: self.assertEqual(f.read(), body) class TestShaMismatchRedownload(unittest.TestCase): def test_ac2_corrupt_hash_deletes_file_and_redownloads(self): # Arrange tmp = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) key = hashlib.sha256(b"k4").digest() good_plain = b"same-len-pt-a!" bad_plain = b"same-len-pt-b!" good_ct = _encrypt_cbc(good_plain, key) bad_ct = _encrypt_cbc(bad_plain, key) sha_good = hashlib.sha256(good_ct).hexdigest() calls = {"n": 0} def handler(url, headers): calls["n"] += 1 data = bad_ct if calls["n"] == 1 else good_ct def chunks(): yield data return _StreamResponse(200, chunks) out = os.path.join(tmp, "good.out") mgr = ResumableDownloadManager( state_directory=tmp, session_factory=lambda: _MockSession(handler), sleep_fn=lambda s: None, ) # Act mgr.fetch_decrypt_verify("ac2", "http://x", sha_good, len(good_ct), key, out) # Assert self.assertEqual(calls["n"], 2) with open(out, "rb") as f: self.assertEqual(f.read(), good_plain) class TestDecryptRoundTrip(unittest.TestCase): def test_ac3_decrypt_matches_original_plaintext(self): # Arrange tmp = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) key = hashlib.sha256(b"artifact-key").digest() original = b"payload-for-roundtrip-check" ciphertext = _encrypt_cbc(original, key) sha = hashlib.sha256(ciphertext).hexdigest() def handler(url, headers): return _StreamResponse(200, lambda: [ciphertext]) out = os.path.join(tmp, "decrypted.bin") mgr = ResumableDownloadManager( state_directory=tmp, session_factory=lambda: _MockSession(handler), sleep_fn=lambda s: None, ) # Act mgr.fetch_decrypt_verify("ac3", "http://blob", sha, len(ciphertext), key, out) # Assert with open(out, "rb") as f: self.assertEqual(f.read(), original) def test_decrypt_cbc_file_matches_encrypt_helper(self): # Arrange tmp = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmp, ignore_errors=True)) key = hashlib.sha256(b"x").digest() plain = b"abc" * 500 ct = _encrypt_cbc(plain, key) enc_path = os.path.join(tmp, "e.bin") with open(enc_path, "wb") as f: f.write(ct) out_path = os.path.join(tmp, "d.bin") # Act decrypt_cbc_file(enc_path, key, out_path) # Assert with open(out_path, "rb") as f: self.assertEqual(f.read(), plain) if __name__ == "__main__": unittest.main()