mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 22:16:31 +00:00
27f4aceb52
- Updated the `Inference` class to replace the `get_onnx_engine_bytes` method with `download_model`, allowing for dynamic model loading based on a specified filename. - Modified the `convert_and_upload_model` method to accept `source_bytes` instead of `onnx_engine_bytes`, enhancing flexibility in model conversion. - Introduced a new property `engine_name` to the `Inference` class for better access to engine details. - Adjusted the `AIRecognitionConfig` structure to include a new method pointer `from_dict`, improving configuration handling. - Updated various test cases to reflect changes in model paths and timeout settings, ensuring consistency and reliability in testing.
216 lines
6.6 KiB
Python
216 lines
6.6 KiB
Python
import io
|
|
import json
|
|
import os
|
|
import struct
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
_DETECT_SLOW_TIMEOUT = 120
|
|
_MEDIA = os.environ.get("MEDIA_DIR", "/media")
|
|
_EPS = 1e-6
|
|
_WEATHER_CLASS_STRIDE = 20
|
|
|
|
|
|
def _image_width_height(data):
|
|
if len(data) >= 24 and data[:8] == b"\x89PNG\r\n\x1a\n":
|
|
w, h = struct.unpack(">II", data[16:24])
|
|
return w, h
|
|
if len(data) >= 2 and data[:2] == b"\xff\xd8":
|
|
i = 2
|
|
while i + 1 < len(data):
|
|
if data[i] != 0xFF:
|
|
i += 1
|
|
continue
|
|
i += 1
|
|
while i < len(data) and data[i] == 0xFF:
|
|
i += 1
|
|
if i >= len(data):
|
|
break
|
|
m = data[i]
|
|
i += 1
|
|
if m in (0xD8, 0xD9):
|
|
continue
|
|
if i + 3 > len(data):
|
|
break
|
|
seg_len = (data[i] << 8) | data[i + 1]
|
|
i += 2
|
|
if m in (0xC0, 0xC1, 0xC2, 0xC3, 0xC5, 0xC6, 0xC7):
|
|
if i + 5 > len(data):
|
|
return None
|
|
h = (data[i + 1] << 8) | data[i + 2]
|
|
w = (data[i + 3] << 8) | data[i + 4]
|
|
return w, h
|
|
i += max(0, seg_len - 2)
|
|
return None
|
|
|
|
|
|
def _overlap_to_min_area_ratio(a, b):
|
|
ox = 0.5 * (a["width"] + b["width"]) - abs(a["centerX"] - b["centerX"])
|
|
oy = 0.5 * (a["height"] + b["height"]) - abs(a["centerY"] - b["centerY"])
|
|
overlap_area = max(0.0, ox) * max(0.0, oy)
|
|
aa = a["width"] * a["height"]
|
|
ab = b["width"] * b["height"]
|
|
m = min(aa, ab)
|
|
if m <= 0:
|
|
return 0.0
|
|
return overlap_area / m
|
|
|
|
|
|
def _load_classes_media():
|
|
p = Path(_MEDIA) / "classes.json"
|
|
if not p.is_file():
|
|
pytest.skip(f"missing {p}")
|
|
raw = json.loads(p.read_text())
|
|
by_id = {}
|
|
names = []
|
|
for row in raw:
|
|
cid = row["Id"]
|
|
by_id[cid] = float(row["MaxSizeM"])
|
|
names.append(row["Name"])
|
|
return by_id, names
|
|
|
|
|
|
def _weather_label_ok(label, base_names):
|
|
for n in base_names:
|
|
if label == n:
|
|
return True
|
|
if label == n + "(Wint)" or label == n + "(Night)":
|
|
return True
|
|
return False
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_ft_p_03_detection_response_structure_ac1(http_client, image_small, warm_engine):
|
|
r = http_client.post(
|
|
"/detect",
|
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
|
)
|
|
assert r.status_code == 200
|
|
body = r.json()
|
|
assert isinstance(body, list)
|
|
for d in body:
|
|
assert isinstance(d["centerX"], (int, float))
|
|
assert isinstance(d["centerY"], (int, float))
|
|
assert isinstance(d["width"], (int, float))
|
|
assert isinstance(d["height"], (int, float))
|
|
assert 0.0 <= float(d["centerX"]) <= 1.0
|
|
assert 0.0 <= float(d["centerY"]) <= 1.0
|
|
assert 0.0 <= float(d["width"]) <= 1.0
|
|
assert 0.0 <= float(d["height"]) <= 1.0
|
|
assert isinstance(d["classNum"], int)
|
|
assert isinstance(d["label"], str)
|
|
assert isinstance(d["confidence"], (int, float))
|
|
assert 0.0 <= float(d["confidence"]) <= 1.0
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_ft_p_05_confidence_filtering_ac2(http_client, image_small, warm_engine):
|
|
cfg_hi = json.dumps({"probability_threshold": 0.8})
|
|
r_hi = http_client.post(
|
|
"/detect",
|
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
|
data={"config": cfg_hi},
|
|
)
|
|
assert r_hi.status_code == 200
|
|
hi = r_hi.json()
|
|
assert isinstance(hi, list)
|
|
for d in hi:
|
|
assert float(d["confidence"]) + _EPS >= 0.8
|
|
cfg_lo = json.dumps({"probability_threshold": 0.1})
|
|
r_lo = http_client.post(
|
|
"/detect",
|
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
|
data={"config": cfg_lo},
|
|
)
|
|
assert r_lo.status_code == 200
|
|
lo = r_lo.json()
|
|
assert isinstance(lo, list)
|
|
assert len(lo) >= len(hi)
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_ft_p_06_overlap_deduplication_ac3(http_client, image_dense, warm_engine):
|
|
cfg_loose = json.dumps({"tracking_intersection_threshold": 0.6})
|
|
r1 = http_client.post(
|
|
"/detect",
|
|
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
|
data={"config": cfg_loose},
|
|
timeout=_DETECT_SLOW_TIMEOUT,
|
|
)
|
|
assert r1.status_code == 200
|
|
dets = r1.json()
|
|
assert isinstance(dets, list)
|
|
by_label = {}
|
|
for d in dets:
|
|
by_label.setdefault(d["label"], []).append(d)
|
|
for label, group in by_label.items():
|
|
for i in range(len(group)):
|
|
for j in range(i + 1, len(group)):
|
|
ratio = _overlap_to_min_area_ratio(group[i], group[j])
|
|
assert ratio <= 0.6 + _EPS, (label, ratio)
|
|
cfg_strict = json.dumps({"tracking_intersection_threshold": 0.01})
|
|
r2 = http_client.post(
|
|
"/detect",
|
|
files={"file": ("img.jpg", image_dense, "image/jpeg")},
|
|
data={"config": cfg_strict},
|
|
timeout=_DETECT_SLOW_TIMEOUT,
|
|
)
|
|
assert r2.status_code == 200
|
|
strict = r2.json()
|
|
assert isinstance(strict, list)
|
|
assert len(strict) <= len(dets)
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_ft_p_07_physical_size_filtering_ac4(http_client, image_small, warm_engine):
|
|
by_id, _ = _load_classes_media()
|
|
wh = _image_width_height(image_small)
|
|
assert wh is not None
|
|
image_width_px, _ = wh
|
|
altitude = 400.0
|
|
focal_length = 24.0
|
|
sensor_width = 23.5
|
|
gsd = (sensor_width * altitude) / (focal_length * image_width_px)
|
|
cfg = json.dumps(
|
|
{
|
|
"altitude": altitude,
|
|
"focal_length": focal_length,
|
|
"sensor_width": sensor_width,
|
|
}
|
|
)
|
|
r = http_client.post(
|
|
"/detect",
|
|
files={"file": ("img.jpg", image_small, "image/jpeg")},
|
|
data={"config": cfg},
|
|
timeout=_DETECT_SLOW_TIMEOUT,
|
|
)
|
|
assert r.status_code == 200
|
|
body = r.json()
|
|
assert isinstance(body, list)
|
|
for d in body:
|
|
base_id = d["classNum"] % _WEATHER_CLASS_STRIDE
|
|
assert base_id in by_id
|
|
physical_width = float(d["width"]) * image_width_px * gsd
|
|
assert physical_width <= by_id[base_id] + _EPS
|
|
|
|
|
|
@pytest.mark.slow
|
|
def test_ft_p_13_weather_mode_class_variants_ac5(
|
|
http_client, image_different_types, warm_engine
|
|
):
|
|
_, base_names = _load_classes_media()
|
|
r = http_client.post(
|
|
"/detect",
|
|
files={"file": ("img.jpg", image_different_types, "image/jpeg")},
|
|
timeout=_DETECT_SLOW_TIMEOUT,
|
|
)
|
|
assert r.status_code == 200
|
|
body = r.json()
|
|
assert isinstance(body, list)
|
|
for d in body:
|
|
label = d["label"]
|
|
assert isinstance(label, str)
|
|
assert len(label) > 0
|
|
assert _weather_label_ok(label, base_names)
|