mirror of
https://github.com/azaion/detections.git
synced 2026-04-22 09:16:33 +00:00
[AZ-180] Add Jetson Orin Nano support with INT8 TensorRT engine
- Dockerfile.jetson: JetPack 6.x L4T base image (aarch64), TensorRT and PyCUDA from apt - requirements-jetson.txt: derived from requirements.txt, no pip tensorrt/pycuda - docker-compose.jetson.yml: runtime: nvidia for NVIDIA Container Runtime - tensorrt_engine.pyx: convert_from_source accepts optional calib_cache_path; INT8 used when cache present, FP16 fallback; get_engine_filename encodes precision suffix to avoid engine cache confusion - inference.pyx: init_ai tries INT8 engine then FP16 on lookup; downloads calibration cache before conversion thread; passes cache path through to convert_from_source - constants_inf: add INT8_CALIB_CACHE_FILE constant - Unit tests for AC-3 (INT8 flag set when cache provided) and AC-4 (FP16 when no cache) Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
FROM nvcr.io/nvidia/l4t-base:r36.3.0
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3 python3-pip python3-dev gcc \
|
||||
libgl1 libglib2.0-0 \
|
||||
python3-libnvinfer python3-libnvinfer-dev \
|
||||
python3-pycuda \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python3 -c "import tensorrt" || \
|
||||
(echo "TensorRT Python bindings not found; check PYTHONPATH for JetPack installation" && exit 1)
|
||||
|
||||
WORKDIR /app
|
||||
COPY requirements-jetson.txt ./
|
||||
RUN pip3 install --no-cache-dir -r requirements-jetson.txt
|
||||
COPY . .
|
||||
RUN python3 setup.py build_ext --inplace
|
||||
ENV PYTHONPATH=/app/src
|
||||
RUN adduser --disabled-password --no-create-home --gecos "" appuser \
|
||||
&& chown -R appuser /app
|
||||
USER appuser
|
||||
EXPOSE 8080
|
||||
CMD ["python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
@@ -34,3 +34,4 @@
|
||||
| Task | Name | Complexity | Dependencies | Epic | Status |
|
||||
|------|------|-----------|-------------|------|--------|
|
||||
| AZ-177 | remove_redundant_video_prewrite | 2 | AZ-173 | AZ-172 | todo |
|
||||
| AZ-180 | jetson_orin_nano_support | 5 | None | pending | todo |
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
# Jetson Orin Nano Support (TensorRT + INT8)
|
||||
|
||||
**Task**: AZ-180_jetson_orin_nano_support
|
||||
**Name**: Jetson Orin Nano Support
|
||||
**Description**: Run the detection service on NVIDIA Jetson Orin Nano with a JetPack 6.x container image, INT8 engine conversion using a pre-generated calibration cache, and docker-compose configuration.
|
||||
**Complexity**: 5 points
|
||||
**Dependencies**: None
|
||||
**Component**: Deployment + Inference Engine
|
||||
**Tracker**: AZ-180
|
||||
**Epic**: pending
|
||||
|
||||
## Problem
|
||||
|
||||
The detection service cannot run on NVIDIA Jetson Orin Nano for two reasons:
|
||||
1. The existing `Dockerfile.gpu` and `requirements-gpu.txt` are x86-specific — TensorRT and PyCUDA are pip-installed but on Jetson they come bundled with JetPack and cannot be installed via pip on aarch64.
|
||||
2. The ONNX→TensorRT conversion in `convert_from_source()` only supports FP16/FP32. On Jetson Orin Nano (8 GB shared RAM, 40 TOPS INT8 vs 20 TFLOPS FP16), INT8 gives ~1.5–2× throughput improvement over FP16, but requires a calibration cache to quantize per-layer activations.
|
||||
|
||||
## Outcome
|
||||
|
||||
- A `Dockerfile.jetson` that builds and runs on Jetson Orin Nano (aarch64, JetPack 6.x)
|
||||
- A `requirements-jetson.txt` that installs Python dependencies without pip-installing tensorrt or pycuda
|
||||
- A `docker-compose.jetson.yml` with NVIDIA Container Runtime configuration
|
||||
- `convert_from_source()` in `tensorrt_engine.pyx` extended to accept an optional INT8 calibration cache path — if the cache is present, INT8 is used; otherwise FP16 fallback
|
||||
- `init_ai()` in `inference.pyx` extended to try downloading the calibration cache from the Loader service before starting the conversion thread
|
||||
- The service reports `engineType: tensorrt` on `GET /health` on Jetson
|
||||
|
||||
## Scope
|
||||
|
||||
### Included
|
||||
- `Dockerfile.jetson` using a JetPack 6.x L4T base image with pre-installed TensorRT and PyCUDA
|
||||
- `requirements-jetson.txt` derived from `requirements.txt`, excluding tensorrt and pycuda
|
||||
- `docker-compose.jetson.yml` with `runtime: nvidia`
|
||||
- `tensorrt_engine.pyx`: extend `convert_from_source(bytes onnx_model, str calib_cache_path=None)` — set `INT8` flag and load cache when path is provided; fall back to FP16 when not
|
||||
- `inference.pyx`: extend `init_ai()` to attempt download of `azaion.int8_calib.cache` from Loader before spawning the conversion thread; pass the local path to `convert_from_source()`
|
||||
- Update `_docs/04_deploy/containerization.md` with the Jetson image variant
|
||||
|
||||
### Excluded
|
||||
- DLA (Deep Learning Accelerator) targeting
|
||||
- Calibration cache generation tooling (cache is produced offline and uploaded to Loader manually)
|
||||
- CI/CD pipeline for Jetson (no build agents available)
|
||||
- ONNX fallback in Jetson image
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
**AC-1: Jetson image builds**
|
||||
Given a machine with Docker and aarch64 buildx (or a Jetson device)
|
||||
When `docker build -f Dockerfile.jetson .` is executed
|
||||
Then the image builds without error and Cython extensions compile for aarch64
|
||||
|
||||
**AC-2: Service starts with TensorRT engine on Jetson**
|
||||
Given the container is running on a Jetson Orin Nano with JetPack 6.x
|
||||
When `GET /health` is called after engine initialization
|
||||
Then `aiAvailability` is `Enabled` and `engineType` is `tensorrt`
|
||||
|
||||
**AC-3: INT8 conversion used when calibration cache is available**
|
||||
Given a valid `azaion.int8_calib.cache` is accessible on the Loader service
|
||||
When the service initializes and no cached engine exists
|
||||
Then ONNX→TensorRT conversion uses INT8 precision
|
||||
|
||||
**AC-4: FP16 fallback when calibration cache is absent**
|
||||
Given no `azaion.int8_calib.cache` is available on the Loader service
|
||||
When the service initializes and no cached engine exists
|
||||
Then ONNX→TensorRT conversion falls back to FP16 (existing behavior)
|
||||
|
||||
**AC-5: Detection endpoint functions on Jetson**
|
||||
Given the container is running and a valid model is accessible via the Loader service
|
||||
When `POST /detect/image` is called with a test image
|
||||
Then detections are returned with the same structure as on x86
|
||||
|
||||
**AC-6: Compose brings up the service on Jetson**
|
||||
Given a Jetson device with docker-compose and NVIDIA Container Runtime installed
|
||||
When `docker compose -f docker-compose.jetson.yml up` is executed
|
||||
Then the detections service is reachable on port 8080
|
||||
|
||||
## Non-Functional Requirements
|
||||
|
||||
**Compatibility**
|
||||
- JetPack 6.x (CUDA 12.2, TensorRT 10.x)
|
||||
- Jetson Orin Nano (aarch64, SM 8.7)
|
||||
|
||||
**Reliability**
|
||||
- Engine filename auto-encodes CC+SM — Jetson engine file is distinct from any x86-cached engine
|
||||
- INT8 conversion is best-effort: calibration cache download failure is non-fatal; service falls back to FP16
|
||||
|
||||
## Unit Tests
|
||||
|
||||
| AC Ref | What to Test | Required Outcome |
|
||||
|--------|-------------|-----------------|
|
||||
| AC-3 | `convert_from_source()` with a valid calib cache path | INT8 flag set in builder config |
|
||||
| AC-4 | `convert_from_source()` with `calib_cache_path=None` | FP16 flag set, no INT8 flag |
|
||||
|
||||
## Blackbox Tests
|
||||
|
||||
| AC Ref | Initial Data/Conditions | What to Test | Expected Behavior | NFR References |
|
||||
|--------|------------------------|-------------|-------------------|----------------|
|
||||
| AC-2 | Jetson device, container running, model accessible | GET /health after warm-up | engineType=tensorrt, aiAvailability=Enabled | — |
|
||||
| AC-5 | Jetson device, test image | POST /detect/image | Valid DetectionDto list | — |
|
||||
| AC-6 | Jetson device, docker-compose.jetson.yml | docker compose up | Service healthy on :8080 | — |
|
||||
|
||||
Note: AC-2, AC-5, AC-6 require physical Jetson hardware and cannot run in standard CI.
|
||||
|
||||
## Constraints
|
||||
|
||||
- TensorRT and PyCUDA must NOT be pip-installed — provided by JetPack in the base image
|
||||
- Base image must be a JetPack 6.x L4T image — not a generic CUDA image
|
||||
- Calibration cache download failure must be non-fatal — log a warning and fall back to FP16
|
||||
- INT8 conversion and FP16 conversion produce different engine files (different filenames) so cached engines are not confused
|
||||
|
||||
## Risks & Mitigation
|
||||
|
||||
**Risk 1: JetPack TensorRT Python binding path**
|
||||
- *Risk*: L4T base image may put TensorRT Python bindings in a non-standard location
|
||||
- *Mitigation*: Verify `python3 -c "import tensorrt"` in the base image layer; adjust `PYTHONPATH` if needed
|
||||
|
||||
**Risk 2: PyCUDA availability in base image**
|
||||
- *Risk*: Some L4T images do not include pycuda
|
||||
- *Mitigation*: Fall back to `apt-get install python3-pycuda` or source build with `CUDA_ROOT` set
|
||||
|
||||
**Risk 3: INT8 accuracy degradation**
|
||||
- *Risk*: Without a well-representative calibration dataset, mAP may drop >1 point
|
||||
- *Mitigation*: Calibration cache is generated offline with controlled data; FP16 fallback always available
|
||||
@@ -115,6 +115,30 @@ networks:
|
||||
|
||||
Already exists: `e2e/docker-compose.test.yml`. No changes needed — supports both `cpu` and `gpu` profiles with mock services and test runner.
|
||||
|
||||
### detections-jetson (Dockerfile.jetson)
|
||||
|
||||
| Aspect | Specification |
|
||||
|--------|--------------|
|
||||
| Base image | `nvcr.io/nvidia/l4t-base:r36.3.0` (JetPack 6.x, aarch64) |
|
||||
| TensorRT | Pre-installed via JetPack — `python3-libnvinfer` apt package (NOT pip) |
|
||||
| PyCUDA | Pre-installed via JetPack — `python3-pycuda` apt package (NOT pip) |
|
||||
| Build stages | Single stage (Cython compile requires gcc) |
|
||||
| Non-root user | `adduser --disabled-password --gecos '' appuser` + `USER appuser` |
|
||||
| Exposed ports | 8080 |
|
||||
| Entrypoint | `uvicorn main:app --host 0.0.0.0 --port 8080` |
|
||||
| Runtime | Requires NVIDIA Container Runtime (`runtime: nvidia` in docker-compose) |
|
||||
|
||||
**Jetson-specific behaviour**:
|
||||
- `requirements-jetson.txt` derives from `requirements.txt` — `tensorrt` and `pycuda` are excluded from pip and provided by JetPack
|
||||
- Engine filename auto-encodes CC+SM (e.g. `azaion.cc_8.7_sm_16.engine` for Orin Nano), ensuring the Jetson engine is distinct from any x86-cached engine
|
||||
- INT8 is used when `azaion.int8_calib.cache` is available on the Loader service; precision suffix appended to engine filename (`*.int8.engine`); FP16 fallback when cache is absent
|
||||
- `docker-compose.jetson.yml` uses `runtime: nvidia` for the NVIDIA Container Runtime
|
||||
|
||||
**Compose usage on Jetson**:
|
||||
```bash
|
||||
docker compose -f docker-compose.jetson.yml up
|
||||
```
|
||||
|
||||
## Image Tagging Strategy
|
||||
|
||||
| Context | Tag Format | Example |
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
# Autopilot State
|
||||
## Current Step
|
||||
flow: existing-code
|
||||
step: 14
|
||||
name: Deploy
|
||||
step: 9
|
||||
name: Implement
|
||||
status: in_progress
|
||||
sub_step: 0
|
||||
sub_step: batch_01
|
||||
retry_count: 0
|
||||
|
||||
## Cycle Notes
|
||||
Previous full cycle (steps 1–14) completed. New cycle started for AZ-178.
|
||||
AZ-178 cycle (steps 8–14) completed 2026-04-02.
|
||||
step: 8 (New Task) — DONE (AZ-178 defined)
|
||||
step: 9 (Implement) — DONE (implementation_report_streaming_video.md, 67/67 tests pass)
|
||||
step: 10 (Run Tests) — DONE (67 passed, 0 failed)
|
||||
step: 11 (Update Docs) — DONE (docs updated during step 9 implementation)
|
||||
step: 12 (Security Audit) — DONE (Critical/High findings remediated 2026-04-01; 64/64 tests pass)
|
||||
step: 13 (Performance Test) — SKIPPED (500ms latency validated by real-video integration test)
|
||||
step: 14 (Deploy) — DONE (all artifacts + 5 scripts created)
|
||||
|
||||
AZ-180 cycle started 2026-04-02.
|
||||
step: 8 (New Task) — DONE (AZ-180: Jetson Orin Nano support + INT8)
|
||||
step: 9 (Implement) — NOT STARTED
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
name: detections-jetson
|
||||
|
||||
services:
|
||||
detections:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.jetson
|
||||
ports:
|
||||
- "8080:8080"
|
||||
runtime: nvidia
|
||||
environment:
|
||||
LOADER_URL: ${LOADER_URL}
|
||||
ANNOTATIONS_URL: ${ANNOTATIONS_URL}
|
||||
env_file: .env
|
||||
volumes:
|
||||
- detections-logs:/app/Logs
|
||||
shm_size: 512m
|
||||
|
||||
volumes:
|
||||
detections-logs:
|
||||
@@ -0,0 +1 @@
|
||||
-r requirements.txt
|
||||
@@ -1,6 +1,7 @@
|
||||
cdef str CONFIG_FILE
|
||||
|
||||
cdef str AI_ONNX_MODEL_FILE
|
||||
cdef str INT8_CALIB_CACHE_FILE
|
||||
|
||||
cdef str CDN_CONFIG
|
||||
cdef str MODELS_FOLDER
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
cdef str CONFIG_FILE = "config.yaml"
|
||||
cdef str AI_ONNX_MODEL_FILE = "azaion.onnx"
|
||||
cdef str INT8_CALIB_CACHE_FILE = "azaion.int8_calib.cache"
|
||||
|
||||
cdef str CDN_CONFIG = "cdn.yaml"
|
||||
cdef str MODELS_FOLDER = "models"
|
||||
|
||||
@@ -4,11 +4,31 @@ import pycuda.driver as cuda # pyright: ignore[reportMissingImports]
|
||||
import pycuda.autoinit # pyright: ignore[reportMissingImports]
|
||||
import pynvml
|
||||
import numpy as np
|
||||
import os
|
||||
cimport constants_inf
|
||||
|
||||
GPU_MEMORY_FRACTION = 0.8
|
||||
|
||||
|
||||
class _CacheCalibrator(trt.IInt8EntropyCalibrator2):
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
self._path = path
|
||||
|
||||
def get_batch_size(self):
|
||||
return 1
|
||||
|
||||
def get_batch(self, names):
|
||||
return None
|
||||
|
||||
def read_calibration_cache(self):
|
||||
with open(self._path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def write_calibration_cache(self, cache):
|
||||
pass
|
||||
|
||||
|
||||
cdef class TensorRTEngine(InferenceEngine):
|
||||
def __init__(self, model_bytes: bytes, max_batch_size: int = 8, **kwargs):
|
||||
InferenceEngine.__init__(self, model_bytes, max_batch_size, engine_name="tensorrt")
|
||||
@@ -80,13 +100,16 @@ cdef class TensorRTEngine(InferenceEngine):
|
||||
return 2 * 1024 * 1024 * 1024 if total_memory is None else total_memory
|
||||
|
||||
@staticmethod
|
||||
def get_engine_filename():
|
||||
def get_engine_filename(str precision="fp16"):
|
||||
try:
|
||||
from engines import tensor_gpu_index
|
||||
device = cuda.Device(max(tensor_gpu_index, 0))
|
||||
sm_count = device.multiprocessor_count
|
||||
cc_major, cc_minor = device.compute_capability()
|
||||
return f"azaion.cc_{cc_major}.{cc_minor}_sm_{sm_count}.engine"
|
||||
base = f"azaion.cc_{cc_major}.{cc_minor}_sm_{sm_count}"
|
||||
if precision == "int8":
|
||||
return f"{base}.int8.engine"
|
||||
return f"{base}.engine"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -96,7 +119,7 @@ cdef class TensorRTEngine(InferenceEngine):
|
||||
return constants_inf.AI_ONNX_MODEL_FILE
|
||||
|
||||
@staticmethod
|
||||
def convert_from_source(bytes onnx_model):
|
||||
def convert_from_source(bytes onnx_model, str calib_cache_path=None):
|
||||
gpu_mem = TensorRTEngine.get_gpu_memory_bytes(0)
|
||||
workspace_bytes = int(gpu_mem * 0.9)
|
||||
|
||||
@@ -130,7 +153,13 @@ cdef class TensorRTEngine(InferenceEngine):
|
||||
)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
if builder.platform_has_fast_fp16:
|
||||
use_int8 = calib_cache_path is not None and os.path.isfile(calib_cache_path)
|
||||
if use_int8:
|
||||
constants_inf.log(<str>'Converting to INT8 with calibration cache')
|
||||
calibrator = _CacheCalibrator(calib_cache_path)
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
config.int8_calibrator = calibrator
|
||||
elif builder.platform_has_fast_fp16:
|
||||
constants_inf.log(<str>'Converting to supported fp16')
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
else:
|
||||
|
||||
+38
-9
@@ -1,4 +1,6 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import av
|
||||
@@ -74,11 +76,11 @@ cdef class Inference:
|
||||
raise Exception(res.err)
|
||||
return <bytes>res.data
|
||||
|
||||
cdef convert_and_upload_model(self, bytes source_bytes, str engine_filename):
|
||||
cdef convert_and_upload_model(self, bytes source_bytes, str engine_filename, str calib_cache_path):
|
||||
try:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.CONVERTING)
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
model_bytes = EngineClass.convert_from_source(source_bytes)
|
||||
model_bytes = EngineClass.convert_from_source(source_bytes, calib_cache_path)
|
||||
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.UPLOADING)
|
||||
res = self.loader_client.upload_big_small_resource(model_bytes, engine_filename, models_dir)
|
||||
@@ -92,6 +94,11 @@ cdef class Inference:
|
||||
self._converted_model_bytes = <bytes>None
|
||||
finally:
|
||||
self.is_building_engine = <bint>False
|
||||
if calib_cache_path is not None:
|
||||
try:
|
||||
os.unlink(calib_cache_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cdef init_ai(self):
|
||||
constants_inf.log(<str> 'init AI...')
|
||||
@@ -112,25 +119,32 @@ cdef class Inference:
|
||||
return
|
||||
|
||||
models_dir = constants_inf.MODELS_FOLDER
|
||||
engine_filename = EngineClass.get_engine_filename()
|
||||
if engine_filename is not None:
|
||||
engine_filename_fp16 = EngineClass.get_engine_filename()
|
||||
if engine_filename_fp16 is not None:
|
||||
engine_filename_int8 = EngineClass.get_engine_filename(<str>"int8")
|
||||
for candidate in [engine_filename_int8, engine_filename_fp16]:
|
||||
try:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.DOWNLOADING)
|
||||
res = self.loader_client.load_big_small_resource(engine_filename, models_dir)
|
||||
res = self.loader_client.load_big_small_resource(candidate, models_dir)
|
||||
if res.err is not None:
|
||||
raise Exception(res.err)
|
||||
self.engine = EngineClass(res.data)
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ENABLED)
|
||||
except Exception as e:
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
source_filename = EngineClass.get_source_filename()
|
||||
if source_filename is None:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>f"Pre-built engine not found: {str(e)}")
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>"Pre-built engine not found and no source available")
|
||||
return
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>str(e))
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.WARNING, <str>"Cached engine not found, converting from source")
|
||||
source_bytes = self.download_model(source_filename)
|
||||
calib_cache_path = self._try_download_calib_cache(models_dir)
|
||||
target_engine_filename = EngineClass.get_engine_filename(<str>"int8") if calib_cache_path is not None else engine_filename_fp16
|
||||
self.is_building_engine = <bint>True
|
||||
|
||||
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, engine_filename))
|
||||
thread = Thread(target=self.convert_and_upload_model, args=(source_bytes, target_engine_filename, calib_cache_path))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return
|
||||
@@ -142,6 +156,21 @@ cdef class Inference:
|
||||
self.ai_availability_status.set_status(AIAvailabilityEnum.ERROR, <str>str(e))
|
||||
self.is_building_engine = <bint>False
|
||||
|
||||
cdef str _try_download_calib_cache(self, str models_dir):
|
||||
try:
|
||||
res = self.loader_client.load_big_small_resource(constants_inf.INT8_CALIB_CACHE_FILE, models_dir)
|
||||
if res.err is not None:
|
||||
constants_inf.log(<str>f"INT8 calibration cache not available: {res.err}")
|
||||
return <str>None
|
||||
fd, path = tempfile.mkstemp(suffix='.cache')
|
||||
with os.fdopen(fd, 'wb') as f:
|
||||
f.write(res.data)
|
||||
constants_inf.log(<str>'INT8 calibration cache downloaded')
|
||||
return <str>path
|
||||
except Exception as e:
|
||||
constants_inf.log(<str>f"INT8 calibration cache download failed: {str(e)}")
|
||||
return <str>None
|
||||
|
||||
cpdef run_detect_image(self, bytes image_bytes, AIRecognitionConfig ai_config, str media_name,
|
||||
object annotation_callback, object status_callback=None):
|
||||
cdef list all_frame_data = []
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import tensorrt # noqa: F401
|
||||
import pycuda.driver # noqa: F401
|
||||
HAS_TENSORRT = True
|
||||
except ImportError:
|
||||
HAS_TENSORRT = False
|
||||
|
||||
requires_tensorrt = pytest.mark.skipif(
|
||||
not HAS_TENSORRT,
|
||||
reason="TensorRT and PyCUDA required (GPU / Jetson environment)",
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_trt():
|
||||
mock_trt = MagicMock()
|
||||
mock_trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH = 0
|
||||
mock_trt.Logger.WARNING = "WARNING"
|
||||
mock_trt.MemoryPoolType.WORKSPACE = "WORKSPACE"
|
||||
mock_trt.BuilderFlag.INT8 = "INT8"
|
||||
mock_trt.BuilderFlag.FP16 = "FP16"
|
||||
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.platform_has_fast_fp16 = True
|
||||
mock_config = MagicMock()
|
||||
mock_network = MagicMock()
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.parse.return_value = True
|
||||
mock_input = MagicMock()
|
||||
mock_input.shape = [1, 3, 640, 640]
|
||||
mock_input.name = "images"
|
||||
mock_network.get_input.return_value = mock_input
|
||||
mock_builder.create_network.return_value.__enter__ = MagicMock(return_value=mock_network)
|
||||
mock_builder.create_network.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_builder.create_builder_config.return_value.__enter__ = MagicMock(return_value=mock_config)
|
||||
mock_builder.create_builder_config.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_builder.__enter__ = MagicMock(return_value=mock_builder)
|
||||
mock_builder.__exit__ = MagicMock(return_value=False)
|
||||
mock_trt.Builder.return_value = mock_builder
|
||||
|
||||
mock_onnx_parser = MagicMock()
|
||||
mock_onnx_parser.__enter__ = MagicMock(return_value=mock_parser)
|
||||
mock_onnx_parser.__exit__ = MagicMock(return_value=False)
|
||||
mock_trt.OnnxParser.return_value = mock_onnx_parser
|
||||
|
||||
mock_trt.IInt8EntropyCalibrator2 = object
|
||||
mock_builder.build_serialized_network.return_value = b"engine_bytes"
|
||||
|
||||
return mock_trt, mock_builder, mock_config
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_convert_from_source_uses_int8_when_cache_provided():
|
||||
# Arrange
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as f:
|
||||
f.write(b"calibration_cache_data")
|
||||
cache_path = f.name
|
||||
|
||||
try:
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
TensorRTEngine.convert_from_source(b"onnx_model", cache_path)
|
||||
|
||||
# Assert
|
||||
mock_config.set_flag.assert_any_call("INT8")
|
||||
assert mock_config.int8_calibrator is not None
|
||||
finally:
|
||||
os.unlink(cache_path)
|
||||
|
||||
|
||||
@requires_tensorrt
|
||||
def test_convert_from_source_uses_fp16_when_no_cache():
|
||||
# Arrange
|
||||
from engines.tensorrt_engine import TensorRTEngine
|
||||
import engines.tensorrt_engine as trt_mod
|
||||
|
||||
mock_trt, mock_builder, mock_config = _make_mock_trt()
|
||||
|
||||
with patch.object(trt_mod, "trt", mock_trt), \
|
||||
patch.object(TensorRTEngine, "get_gpu_memory_bytes", return_value=4 * 1024**3):
|
||||
# Act
|
||||
TensorRTEngine.convert_from_source(b"onnx_model", None)
|
||||
|
||||
# Assert
|
||||
mock_config.set_flag.assert_any_call("FP16")
|
||||
int8_calls = [c for c in mock_config.set_flag.call_args_list if c == call("INT8")]
|
||||
assert len(int8_calls) == 0
|
||||
Reference in New Issue
Block a user