From c1b5b5fee2eeba013be93813fd48ba1a84bd8974 Mon Sep 17 00:00:00 2001 From: Alex Bezdieniezhnykh Date: Mon, 10 Feb 2025 14:55:00 +0200 Subject: [PATCH] use nms in the model itself, simplify and make postprocess faster. make inference in batches, fix c# handling, add overlap handling --- Azaion.Annotator/Annotator.xaml.cs | 51 +++-- Azaion.Annotator/AnnotatorEventHandler.cs | 3 +- Azaion.Common/Constants.cs | 1 + Azaion.Common/DTO/Label.cs | 4 +- Azaion.Common/Database/Annotation.cs | 4 +- Azaion.Common/Services/AnnotationService.cs | 3 - Azaion.Common/Services/InferenceService.cs | 14 +- .../Services/PythonResourceLoader.cs | 2 +- Azaion.CommonSecurity/ZeroMQExtensions.cs | 23 +- Azaion.Inference/README.md | 13 +- Azaion.Inference/annotation.pxd | 9 +- Azaion.Inference/annotation.pyx | 32 ++- Azaion.Inference/constants.pxd | 3 +- Azaion.Inference/constants.pyx | 3 +- Azaion.Inference/inference.pxd | 16 +- Azaion.Inference/inference.pyx | 213 +++++++++++------- Azaion.Inference/remote_command.pyx | 2 +- Azaion.Inference/setup.py | 2 +- Azaion.Inference/token | 1 - 19 files changed, 259 insertions(+), 140 deletions(-) delete mode 100644 Azaion.Inference/token diff --git a/Azaion.Annotator/Annotator.xaml.cs b/Azaion.Annotator/Annotator.xaml.cs index 54e43b7..642e7c6 100644 --- a/Azaion.Annotator/Annotator.xaml.cs +++ b/Azaion.Annotator/Annotator.xaml.cs @@ -7,6 +7,7 @@ using System.Windows.Controls.Primitives; using System.Windows.Input; using System.Windows.Media; using Azaion.Annotator.DTO; +using Azaion.Common; using Azaion.Common.Database; using Azaion.Common.DTO; using Azaion.Common.DTO.Config; @@ -39,11 +40,12 @@ public partial class Annotator private readonly AnnotationService _annotationService; private readonly IDbFactory _dbFactory; private readonly IInferenceService _inferenceService; - private readonly CancellationTokenSource _ctSource = new(); private ObservableCollection AnnotationClasses { get; set; } = new(); private bool _suspendLayout; + public readonly CancellationTokenSource MainCancellationSource = new(); + public CancellationTokenSource DetectionCancellationSource = new(); public bool FollowAI = false; public bool IsInferenceNow = false; @@ -310,7 +312,7 @@ public partial class Annotator var annotations = await _dbFactory.Run(async db => await db.Annotations.LoadWith(x => x.Detections) .Where(x => x.OriginalMediaName == _formState.VideoName) - .ToListAsync(token: _ctSource.Token)); + .ToListAsync(token: MainCancellationSource.Token)); TimedAnnotations.Clear(); _formState.AnnotationResults.Clear(); @@ -395,6 +397,8 @@ public partial class Annotator private void OnFormClosed(object? sender, EventArgs e) { + MainCancellationSource.Cancel(); + DetectionCancellationSource.Cancel(); _mediaPlayer.Stop(); _mediaPlayer.Dispose(); _libVLC.Dispose(); @@ -490,6 +494,20 @@ public partial class Annotator private (TimeSpan Time, List Detections)? _previousDetection; + private List GetLvFiles() + { + return Dispatcher.Invoke(() => + { + var source = LvFiles.ItemsSource as IEnumerable; + var items = source?.Skip(LvFiles.SelectedIndex) + .Take(Constants.DETECTION_BATCH_SIZE) + .Select(x => x.Path) + .ToList(); + + return items ?? new List(); + }); + } + public void AutoDetect(object sender, RoutedEventArgs e) { if (IsInferenceNow) @@ -503,36 +521,25 @@ public partial class Annotator if (LvFiles.SelectedIndex == -1) LvFiles.SelectedIndex = 0; - var mct = new CancellationTokenSource(); - var token = mct.Token; Dispatcher.Invoke(() => Editor.ResetBackground()); IsInferenceNow = true; FollowAI = true; + DetectionCancellationSource = new CancellationTokenSource(); + var ct = DetectionCancellationSource.Token; _ = Task.Run(async () => { - var mediaInfo = Dispatcher.Invoke(() => (MediaFileInfo)LvFiles.SelectedItem); - while (mediaInfo != null && !token.IsCancellationRequested) + var files = GetLvFiles(); + while (files.Any() && !ct.IsCancellationRequested) { await Dispatcher.Invoke(async () => { - await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), token); + await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), ct); await ReloadAnnotations(); }); - await _inferenceService.RunInference(mediaInfo.Path, async annotationImage => - { - annotationImage.OriginalMediaName = mediaInfo.FName; - await ProcessDetection(annotationImage); - }); - - mediaInfo = Dispatcher.Invoke(() => - { - if (LvFiles.SelectedIndex == LvFiles.Items.Count - 1) - return null; - LvFiles.SelectedIndex += 1; - return (MediaFileInfo)LvFiles.SelectedItem; - }); + await _inferenceService.RunInference(files, async annotationImage => await ProcessDetection(annotationImage), ct); + files = GetLvFiles(); Dispatcher.Invoke(() => LvFiles.Items.Refresh()); } Dispatcher.Invoke(() => @@ -541,7 +548,7 @@ public partial class Annotator IsInferenceNow = false; FollowAI = false; }); - }, token); + }); } private async Task ProcessDetection(AnnotationImage annotationImage) @@ -551,6 +558,8 @@ public partial class Annotator try { var annotation = await _annotationService.SaveAnnotation(annotationImage); + if (annotation.OriginalMediaName != _formState.CurrentMedia.FName) + return; AddAnnotation(annotation); if (FollowAI) diff --git a/Azaion.Annotator/AnnotatorEventHandler.cs b/Azaion.Annotator/AnnotatorEventHandler.cs index 7097a0e..addb741 100644 --- a/Azaion.Annotator/AnnotatorEventHandler.cs +++ b/Azaion.Annotator/AnnotatorEventHandler.cs @@ -139,6 +139,7 @@ public class AnnotatorEventHandler( } break; case PlaybackControlEnum.Stop: + await mainWindow.DetectionCancellationSource.CancelAsync(); mediaPlayer.Stop(); break; case PlaybackControlEnum.PreviousFrame: @@ -294,7 +295,7 @@ public class AnnotatorEventHandler( media.HasAnnotations = false; mainWindow.LvFiles.Items.Refresh(); } - } + await Task.CompletedTask; } } diff --git a/Azaion.Common/Constants.cs b/Azaion.Common/Constants.cs index d7161db..67e7d14 100644 --- a/Azaion.Common/Constants.cs +++ b/Azaion.Common/Constants.cs @@ -53,6 +53,7 @@ public class Constants public const double TRACKING_INTERSECTION_THRESHOLD = 0.8; public const int DEFAULT_FRAME_PERIOD_RECOGNITION = 4; + public const int DETECTION_BATCH_SIZE = 4; # endregion AIRecognitionConfig #region Thumbnails diff --git a/Azaion.Common/DTO/Label.cs b/Azaion.Common/DTO/Label.cs index 4d33366..546154d 100644 --- a/Azaion.Common/DTO/Label.cs +++ b/Azaion.Common/DTO/Label.cs @@ -188,8 +188,8 @@ public class YoloLabel : Label [MessagePackObject] public class Detection : YoloLabel { - [IgnoreMember]public string AnnotationName { get; set; } = null!; - [Key("p")] public double? Probability { get; set; } + [Key("an")] public string AnnotationName { get; set; } = null!; + [Key("p")] public double? Probability { get; set; } //For db & serialization public Detection(){} diff --git a/Azaion.Common/Database/Annotation.cs b/Azaion.Common/Database/Annotation.cs index 345be16..00c8dbf 100644 --- a/Azaion.Common/Database/Annotation.cs +++ b/Azaion.Common/Database/Annotation.cs @@ -21,8 +21,8 @@ public class Annotation _thumbDir = config.ThumbnailsDirectory; } - [IgnoreMember]public string Name { get; set; } = null!; - [IgnoreMember]public string OriginalMediaName { get; set; } = null!; + [Key("n")] public string Name { get; set; } = null!; + [Key("mn")] public string OriginalMediaName { get; set; } = null!; [IgnoreMember]public TimeSpan Time { get; set; } [IgnoreMember]public string ImageExtension { get; set; } = null!; [IgnoreMember]public DateTime CreatedDate { get; set; } diff --git a/Azaion.Common/Services/AnnotationService.cs b/Azaion.Common/Services/AnnotationService.cs index 387670f..102b44f 100644 --- a/Azaion.Common/Services/AnnotationService.cs +++ b/Azaion.Common/Services/AnnotationService.cs @@ -105,9 +105,6 @@ public class AnnotationService : INotificationHandler public async Task SaveAnnotation(AnnotationImage a, CancellationToken cancellationToken = default) { a.Time = TimeSpan.FromMilliseconds(a.Milliseconds); - a.Name = a.OriginalMediaName.ToTimeName(a.Time); - foreach (var det in a.Detections) - det.AnnotationName = a.Name; return await SaveAnnotationInner(DateTime.Now, a.OriginalMediaName, a.Time, ".jpg", a.Detections.ToList(), a.Source, new MemoryStream(a.Image), a.CreatedRole, a.CreatedEmail, generateThumbnail: true, cancellationToken); } diff --git a/Azaion.Common/Services/InferenceService.cs b/Azaion.Common/Services/InferenceService.cs index d8e99d6..30164a5 100644 --- a/Azaion.Common/Services/InferenceService.cs +++ b/Azaion.Common/Services/InferenceService.cs @@ -3,23 +3,23 @@ using Azaion.Common.Database; using Azaion.Common.DTO.Config; using Azaion.CommonSecurity; using Azaion.CommonSecurity.DTO.Commands; -using Azaion.CommonSecurity.Services; using MessagePack; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using NetMQ; using NetMQ.Sockets; +using Newtonsoft.Json; namespace Azaion.Common.Services; public interface IInferenceService { - Task RunInference(string mediaPath, Func processAnnotation); + Task RunInference(List mediaPaths, Func processAnnotation, CancellationToken ct = default); } public class PythonInferenceService(ILogger logger, IOptions aiConfigOptions) : IInferenceService { - public async Task RunInference(string mediaPath, Func processAnnotation) + public async Task RunInference(List mediaPaths, Func processAnnotation, CancellationToken ct = default) { using var dealer = new DealerSocket(); var clientId = Guid.NewGuid(); @@ -27,13 +27,14 @@ public class PythonInferenceService(ILogger logger, IOpt dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}"); var data = MessagePackSerializer.Serialize(aiConfigOptions.Value); - dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Inference, mediaPath, data))); + var filename = JsonConvert.SerializeObject(mediaPaths); + dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Inference, filename, data))); - while (true) + while (!ct.IsCancellationRequested) { try { - var annotationStream = dealer.Get(bytes => bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE"); + var annotationStream = dealer.Get(bytes => bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE", ct: ct); if (annotationStream == null) break; @@ -42,6 +43,7 @@ public class PythonInferenceService(ILogger logger, IOpt catch (Exception e) { logger.LogError(e, e.Message); + break; } } } diff --git a/Azaion.CommonSecurity/Services/PythonResourceLoader.cs b/Azaion.CommonSecurity/Services/PythonResourceLoader.cs index f3addea..5fe22a7 100644 --- a/Azaion.CommonSecurity/Services/PythonResourceLoader.cs +++ b/Azaion.CommonSecurity/Services/PythonResourceLoader.cs @@ -82,7 +82,7 @@ public class PythonResourceLoader : IResourceLoader, IAuthProvider { _dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Load, fileName))); - if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes)) + if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromSeconds(3), out var bytes)) throw new Exception($"Unable to receive {fileName}"); return new MemoryStream(bytes); diff --git a/Azaion.CommonSecurity/ZeroMQExtensions.cs b/Azaion.CommonSecurity/ZeroMQExtensions.cs index f63fc94..51cff4a 100644 --- a/Azaion.CommonSecurity/ZeroMQExtensions.cs +++ b/Azaion.CommonSecurity/ZeroMQExtensions.cs @@ -6,12 +6,23 @@ namespace Azaion.CommonSecurity; public static class ZeroMqExtensions { - public static T? Get(this DealerSocket dealer, Func? shouldInterceptFn = null) where T : class + public static T? Get(this DealerSocket dealer, Func? shouldInterceptFn = null, int retries = 24, int tryTimeoutSeconds = 5, CancellationToken ct = default) where T : class { - if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes)) - throw new Exception($"Unable to get {typeof(T).Name}"); - if (shouldInterceptFn != null && shouldInterceptFn(bytes)) - return null; - return MessagePackSerializer.Deserialize(bytes); + var tryNum = 0; + while (!ct.IsCancellationRequested && tryNum++ < retries) + { + if (!dealer.TryReceiveFrameBytes(TimeSpan.FromSeconds(tryTimeoutSeconds), out var bytes)) + continue; + + if (shouldInterceptFn != null && shouldInterceptFn(bytes)) + return null; + + return MessagePackSerializer.Deserialize(bytes); + } + + if (!ct.IsCancellationRequested) + throw new Exception($"Unable to get {typeof(T).Name} after {tryNum} retries, {tryTimeoutSeconds} seconds each"); + + return null; } } \ No newline at end of file diff --git a/Azaion.Inference/README.md b/Azaion.Inference/README.md index e1497a3..5307e74 100644 --- a/Azaion.Inference/README.md +++ b/Azaion.Inference/README.md @@ -13,6 +13,17 @@ Results (file or annotations) is putted to the other queue, or the same socket,

Installation

+Prepare correct onnx model from YOLO: +```python +from ultralytics import YOLO +import netron + +model = YOLO("azaion.pt") +model.export(format="onnx", imgsz=1280, nms=True, batch=4) +netron.start('azaion.onnx') +``` +Read carefully about [export arguments](https://docs.ultralytics.com/modes/export/), you have to use nms=True, and batching with a proper batch size +

Install libs

https://www.python.org/downloads/ @@ -45,7 +56,7 @@ This is crucial for the build because build needs Python.h header and other file ``` python -m pip install --upgrade pip - pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt pyinstaller tensorboard + pip install requirements.txt ``` In case of fbgemm.dll error (Windows specific): diff --git a/Azaion.Inference/annotation.pxd b/Azaion.Inference/annotation.pxd index 796724e..b6c4c6b 100644 --- a/Azaion.Inference/annotation.pxd +++ b/Azaion.Inference/annotation.pxd @@ -1,10 +1,17 @@ cdef class Detection: cdef public double x, y, w, h, confidence + cdef public str annotation_name cdef public int cls + cdef public overlaps(self, Detection det2) + cdef class Annotation: - cdef bytes image + cdef public str name + cdef public str original_media_name cdef long time cdef public list[Detection] detections + cdef public bytes image + + cdef format_time(self, ms) cdef bytes serialize(self) diff --git a/Azaion.Inference/annotation.pyx b/Azaion.Inference/annotation.pyx index 5277895..3b91440 100644 --- a/Azaion.Inference/annotation.pyx +++ b/Azaion.Inference/annotation.pyx @@ -1,7 +1,9 @@ import msgpack +from pathlib import Path cdef class Detection: def __init__(self, double x, double y, double w, double h, int cls, double confidence): + self.annotation_name = None self.x = x self.y = y self.w = w @@ -12,18 +14,44 @@ cdef class Detection: def __str__(self): return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%' + cdef overlaps(self, Detection det2): + cdef double overlap_x = 0.5 * (self.w + det2.w) - abs(self.x - det2.x) + cdef double overlap_y = 0.5 * (self.h + det2.h) - abs(self.y - det2.y) + cdef double overlap_area = max(0.0, overlap_x) * max(0.0, overlap_y) + cdef double min_area = min(self.w * self.h, det2.w * det2.h) + + return overlap_area / min_area > 0.6 + cdef class Annotation: - def __init__(self, long time, list[Detection] detections): - self.time = time + def __init__(self, str name, long ms, list[Detection] detections): + self.original_media_name = Path(name).stem.replace(" ", "") + self.name = f'{self.original_media_name}_{self.format_time(ms)}' + self.time = ms self.detections = detections if detections is not None else [] + for d in self.detections: + d.annotation_name = self.name self.image = b'' + cdef format_time(self, ms): + # Calculate hours, minutes, seconds, and hundreds of milliseconds. + h = ms // 3600000 # Total full hours. + ms_remaining = ms % 3600000 + m = ms_remaining // 60000 # Full minutes. + ms_remaining %= 60000 + s = ms_remaining // 1000 # Full seconds. + f = (ms_remaining % 1000) // 100 # Hundreds of milliseconds. + h = h % 10 + return f"{h}{m:02}{s:02}{f}" + cdef bytes serialize(self): return msgpack.packb({ + "n": self.name, + "mn": self.original_media_name, "i": self.image, # "i" = image "t": self.time, # "t" = time "d": [ # "d" = detections { + "an": det.annotation_name, "x": det.x, "y": det.y, "w": det.w, diff --git a/Azaion.Inference/constants.pxd b/Azaion.Inference/constants.pxd index 1cbf0cc..bce3550 100644 --- a/Azaion.Inference/constants.pxd +++ b/Azaion.Inference/constants.pxd @@ -9,4 +9,5 @@ cdef str TOKEN_FILE # Name of the token file where temporary token wo cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api cdef str AI_MODEL_FILE # AI Model file -cdef bytes DONE_SIGNAL \ No newline at end of file +cdef bytes DONE_SIGNAL +cdef int MODEL_BATCH_SIZE \ No newline at end of file diff --git a/Azaion.Inference/constants.pyx b/Azaion.Inference/constants.pyx index 6e24803..a1cc9f4 100644 --- a/Azaion.Inference/constants.pyx +++ b/Azaion.Inference/constants.pyx @@ -9,4 +9,5 @@ cdef str TOKEN_FILE = "token" cdef str QUEUE_CONFIG_FILENAME = "secured-config.json" cdef str AI_MODEL_FILE = "azaion.onnx" -cdef bytes DONE_SIGNAL = b"DONE" \ No newline at end of file +cdef bytes DONE_SIGNAL = b"DONE" +cdef int MODEL_BATCH_SIZE = 4 \ No newline at end of file diff --git a/Azaion.Inference/inference.pxd b/Azaion.Inference/inference.pxd index 90f64b2..21a3c82 100644 --- a/Azaion.Inference/inference.pxd +++ b/Azaion.Inference/inference.pxd @@ -1,5 +1,5 @@ from remote_command cimport RemoteCommand -from annotation cimport Annotation +from annotation cimport Annotation, Detection from ai_config cimport AIRecognitionConfig cdef class Inference: @@ -14,14 +14,14 @@ cdef class Inference: cdef int model_height cdef bint is_video(self, str filepath) - cdef run_inference(self, RemoteCommand cmd, int batch_size=?) - cdef _process_video(self, RemoteCommand cmd, int batch_size) - cdef _process_image(self, RemoteCommand cmd) + cdef run_inference(self, RemoteCommand cmd) + cdef _process_video(self, RemoteCommand cmd, str video_name) + cdef _process_images(self, RemoteCommand cmd, list[str] image_paths) cdef stop(self) - cdef preprocess(self, frame) - cdef postprocess(self, output, int img_width, int img_height) + cdef preprocess(self, frames) + cdef remove_overlapping_detections(self, list[Detection] detections) + cdef postprocess(self, output) + cdef split_list_extend(self, lst, chunk_size) - - cdef detect_frame(self, frame, long time) cdef bint is_valid_annotation(self, Annotation annotation) diff --git a/Azaion.Inference/inference.pyx b/Azaion.Inference/inference.pyx index 44d74ad..d26a073 100644 --- a/Azaion.Inference/inference.pyx +++ b/Azaion.Inference/inference.pyx @@ -1,3 +1,4 @@ +import json import mimetypes import time @@ -5,6 +6,7 @@ import cv2 import numpy as np import onnxruntime as onnx +cimport constants from remote_command cimport RemoteCommand from annotation cimport Detection, Annotation from ai_config cimport AIRecognitionConfig @@ -26,68 +28,117 @@ cdef class Inference: model_meta = self.session.get_modelmeta() print("Metadata:", model_meta.custom_metadata_map) - cdef preprocess(self, frame): - img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - img = cv2.resize(img, (self.model_width, self.model_height)) - image_data = np.array(img) / 255.0 - image_data = np.transpose(image_data, (2, 0, 1)) # Channel first - image_data = np.expand_dims(image_data, axis=0).astype(np.float32) - return image_data + cdef preprocess(self, frames): + blobs = [cv2.dnn.blobFromImage(frame, + scalefactor=1.0 / 255.0, + size=(self.model_width, self.model_height), + mean=(0, 0, 0), + swapRB=True, + crop=False) + for frame in frames] + return np.vstack(blobs) - cdef postprocess(self, output, int img_width, int img_height): - outputs = np.transpose(np.squeeze(output[0])) - rows = outputs.shape[0] - boxes = [] - scores = [] - class_ids = [] + cdef postprocess(self, output): + cdef list[Detection] detections = [] + cdef int ann_index + cdef float x1, y1, x2, y2, conf, cx, cy, w, h + cdef int class_id + cdef list[list[Detection]] results = [] - x_factor = img_width / self.model_width - y_factor = img_height / self.model_height + for ann_index in range(len(output[0])): + detections.clear() + for det in output[0][ann_index]: + if det[4] == 0: # if confidence is 0 then valid points are over. + break + x1 = det[0] / self.model_width + y1 = det[1] / self.model_height + x2 = det[2] / self.model_width + y2 = det[3] / self.model_height + conf = round(det[4], 2) + class_id = int(det[5]) - for i in range(rows): - classes_scores = outputs[i][4:] - max_score = np.amax(classes_scores) + x = (x1 + x2) / 2 + y = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + detections.append(Detection(x, y, w, h, class_id, conf)) + filtered_detections = self.remove_overlapping_detections(detections) + results.append(filtered_detections) + return results - if max_score >= self.ai_config.probability_threshold: - class_id = np.argmax(classes_scores) - x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3] + cdef remove_overlapping_detections(self, list[Detection] detections): + cdef Detection det1, det2 + filtered_output = [] + filtered_out_indexes = [] - left = int((x - w / 2) * x_factor) - top = int((y - h / 2) * y_factor) - width = int(w * x_factor) - height = int(h * y_factor) - - class_ids.append(class_id) - scores.append(max_score) - boxes.append([left, top, width, height]) - indices = cv2.dnn.NMSBoxes(boxes, scores, self.ai_config.probability_threshold, 0.45) - detections = [] - for i in indices: - x, y, w, h = boxes[i] - detections.append(Detection(x, y, w, h, class_ids[i], scores[i])) - return detections + for det1_index in range(len(detections)): + if det1_index in filtered_out_indexes: + continue + det1 = detections[det1_index] + print(f'det1 size: {det1.w}, {det1.h}') + res = det1_index + for det2_index in range(det1_index + 1, len(detections)): + det2 = detections[det2_index] + print(f'det2 size: {det2.w}, {det2.h}') + if det1.overlaps(det2): + if det1.confidence > det2.confidence or ( + det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id + filtered_out_indexes.append(det2_index) + else: + filtered_out_indexes.append(res) + res = det2_index + filtered_output.append(detections[res]) + filtered_out_indexes.append(res) + return filtered_output cdef bint is_video(self, str filepath): mime_type, _ = mimetypes.guess_type(filepath) return mime_type and mime_type.startswith("video") - cdef run_inference(self, RemoteCommand cmd, int batch_size=8): - print('run inference..') + cdef split_list_extend(self, lst, chunk_size): + chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] + + # If the last chunk is smaller than the desired chunk_size, extend it by duplicating its last element. + last_chunk = chunks[len(chunks) - 1] + if len(last_chunk) < chunk_size: + last_elem = last_chunk[len(last_chunk)-1] + while len(last_chunk) < chunk_size: + last_chunk.append(last_elem) + return chunks + + cdef run_inference(self, RemoteCommand cmd): + cdef list[str] medias = json.loads( cmd.filename) + cdef list[str] videos = [] + cdef list[str] images = [] + self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data) self.stop_signal = False - if self.is_video(cmd.filename): - self._process_video(cmd, batch_size) - else: - self._process_image(cmd) - cdef _process_video(self, RemoteCommand cmd, int batch_size): - frame_count = 0 - batch_frame = [] + for m in medias: + if self.is_video(m): + videos.append(m) + else: + images.append(m) + + # images first, it's faster + if len(images) > 0: + for chunk in self.split_list_extend(images, constants.MODEL_BATCH_SIZE): + print(f'run inference on {" ".join(chunk)}...') + self._process_images(cmd, chunk) + if len(videos) > 0: + for v in videos: + print(f'run inference on {v}...') + self._process_video(cmd, v) + + + cdef _process_video(self, RemoteCommand cmd, str video_name): + cdef int frame_count = 0 + cdef list batch_frames = [] + cdef list[int] batch_timestamps = [] self._previous_annotation = None - self.start_video_time = time.time() - v_input = cv2.VideoCapture(cmd.filename) + v_input = cv2.VideoCapture(video_name) while v_input.isOpened(): ret, frame = v_input.read() if not ret or frame is None: @@ -95,45 +146,45 @@ cdef class Inference: frame_count += 1 if frame_count % self.ai_config.frame_period_recognition == 0: - ms = int(v_input.get(cv2.CAP_PROP_POS_MSEC)) - annotation = self.detect_frame(frame, ms) - if annotation is not None: - self._previous_annotation = annotation - self.on_annotation(annotation) + batch_frames.append(frame) + batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC))) + + if len(batch_frames) == constants.MODEL_BATCH_SIZE: + input_blob = self.preprocess(batch_frames) + outputs = self.session.run(None, {self.model_input: input_blob}) + list_detections = self.postprocess(outputs) + for i in range(len(list_detections)): + detections = list_detections[i] + annotation = Annotation(video_name, batch_timestamps[i], detections) + if self.is_valid_annotation(annotation): + _, image = cv2.imencode('.jpg', frame) + annotation.image = image.tobytes() + self.on_annotation(cmd, annotation) + self._previous_annotation = annotation + + batch_frames.clear() + batch_timestamps.clear() + v_input.release() - cdef detect_frame(self, frame, long time): - cdef Annotation annotation - img_height, img_width = frame.shape[:2] - - start_time = time.time() - img_data = self.preprocess(frame) - preprocess_time = time.time() - outputs = self.session.run(None, {self.model_input: img_data}) - inference_time = time.time() - detections = self.postprocess(outputs, img_width, img_height) - postprocess_time = time.time() - print(f'video time, ms: {time / 1000:.3f}. total time, s : {postprocess_time - self.start_video_time:.3f} ' - f'preprocess time: {preprocess_time - start_time:.3f}, inference time: {inference_time - preprocess_time:.3f},' - f' postprocess time: {postprocess_time - inference_time:.3f}, total time: {postprocess_time - start_time:.3f}') - if len(detections) > 0: - annotation = Annotation(frame, time, detections) - if self.is_valid_annotation(annotation): - _, image = cv2.imencode('.jpg', frame) - annotation.image = image.tobytes() - return annotation - return None - - - cdef _process_image(self, RemoteCommand cmd): + cdef _process_images(self, RemoteCommand cmd, list[str] image_paths): + cdef list frames = [] + cdef list timestamps = [] self._previous_annotation = None - frame = cv2.imread(cmd.filename) - annotation = self.detect_frame(frame, 0) - if annotation is None: - _, image = cv2.imencode('.jpg', frame) - annotation = Annotation(frame, time, []) + for image in image_paths: + frame = cv2.imread(image) + frames.append(frame) + timestamps.append(0) + + input_blob = self.preprocess(frames) + outputs = self.session.run(None, {self.model_input: input_blob}) + list_detections = self.postprocess(outputs) + for i in range(len(list_detections)): + detections = list_detections[i] + annotation = Annotation(image_paths[i], timestamps[i], detections) + _, image = cv2.imencode('.jpg', frames[i]) annotation.image = image.tobytes() - self.on_annotation(cmd, annotation) + self.on_annotation(cmd, annotation) cdef stop(self): diff --git a/Azaion.Inference/remote_command.pyx b/Azaion.Inference/remote_command.pyx index 0ad91a5..1cbc677 100644 --- a/Azaion.Inference/remote_command.pyx +++ b/Azaion.Inference/remote_command.pyx @@ -11,7 +11,7 @@ cdef class RemoteCommand: 10: "GET_USER", 20: "LOAD", 30: "INFERENCE", - 40: "STOP INFERENCE", + 40: "STOP_INFERENCE", 100: "EXIT" } data_str = f'. Data: {len(self.data)} bytes' if self.data else '' diff --git a/Azaion.Inference/setup.py b/Azaion.Inference/setup.py index c8e7451..d9ab7ec 100644 --- a/Azaion.Inference/setup.py +++ b/Azaion.Inference/setup.py @@ -13,7 +13,7 @@ extensions = [ Extension('api_client', ['api_client.pyx']), Extension('secure_model', ['secure_model.pyx']), Extension('ai_config', ['ai_config.pyx']), - Extension('inference', ['inference.pyx']), + Extension('inference', ['inference.pyx'], include_dirs=[np.get_include()]), Extension('main', ['main.pyx']), ] diff --git a/Azaion.Inference/token b/Azaion.Inference/token deleted file mode 100644 index 7a7dfd6..0000000 --- a/Azaion.Inference/token +++ /dev/null @@ -1 +0,0 @@ -eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3Mzg4Mjk0NTMsImV4cCI6MTczODg0Mzg1MywiaWF0IjoxNzM4ODI5NDUzLCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.t6ImX8KkH5IQ4zNNY5IbXESSI6uia4iuzyMhodvM7AA \ No newline at end of file