mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 22:46:30 +00:00
use nms in the model itself, simplify and make postprocess faster.
make inference in batches, fix c# handling, add overlap handling
This commit is contained in:
@@ -7,6 +7,7 @@ using System.Windows.Controls.Primitives;
|
|||||||
using System.Windows.Input;
|
using System.Windows.Input;
|
||||||
using System.Windows.Media;
|
using System.Windows.Media;
|
||||||
using Azaion.Annotator.DTO;
|
using Azaion.Annotator.DTO;
|
||||||
|
using Azaion.Common;
|
||||||
using Azaion.Common.Database;
|
using Azaion.Common.Database;
|
||||||
using Azaion.Common.DTO;
|
using Azaion.Common.DTO;
|
||||||
using Azaion.Common.DTO.Config;
|
using Azaion.Common.DTO.Config;
|
||||||
@@ -39,11 +40,12 @@ public partial class Annotator
|
|||||||
private readonly AnnotationService _annotationService;
|
private readonly AnnotationService _annotationService;
|
||||||
private readonly IDbFactory _dbFactory;
|
private readonly IDbFactory _dbFactory;
|
||||||
private readonly IInferenceService _inferenceService;
|
private readonly IInferenceService _inferenceService;
|
||||||
private readonly CancellationTokenSource _ctSource = new();
|
|
||||||
|
|
||||||
private ObservableCollection<DetectionClass> AnnotationClasses { get; set; } = new();
|
private ObservableCollection<DetectionClass> AnnotationClasses { get; set; } = new();
|
||||||
private bool _suspendLayout;
|
private bool _suspendLayout;
|
||||||
|
|
||||||
|
public readonly CancellationTokenSource MainCancellationSource = new();
|
||||||
|
public CancellationTokenSource DetectionCancellationSource = new();
|
||||||
public bool FollowAI = false;
|
public bool FollowAI = false;
|
||||||
public bool IsInferenceNow = false;
|
public bool IsInferenceNow = false;
|
||||||
|
|
||||||
@@ -310,7 +312,7 @@ public partial class Annotator
|
|||||||
var annotations = await _dbFactory.Run(async db =>
|
var annotations = await _dbFactory.Run(async db =>
|
||||||
await db.Annotations.LoadWith(x => x.Detections)
|
await db.Annotations.LoadWith(x => x.Detections)
|
||||||
.Where(x => x.OriginalMediaName == _formState.VideoName)
|
.Where(x => x.OriginalMediaName == _formState.VideoName)
|
||||||
.ToListAsync(token: _ctSource.Token));
|
.ToListAsync(token: MainCancellationSource.Token));
|
||||||
|
|
||||||
TimedAnnotations.Clear();
|
TimedAnnotations.Clear();
|
||||||
_formState.AnnotationResults.Clear();
|
_formState.AnnotationResults.Clear();
|
||||||
@@ -395,6 +397,8 @@ public partial class Annotator
|
|||||||
|
|
||||||
private void OnFormClosed(object? sender, EventArgs e)
|
private void OnFormClosed(object? sender, EventArgs e)
|
||||||
{
|
{
|
||||||
|
MainCancellationSource.Cancel();
|
||||||
|
DetectionCancellationSource.Cancel();
|
||||||
_mediaPlayer.Stop();
|
_mediaPlayer.Stop();
|
||||||
_mediaPlayer.Dispose();
|
_mediaPlayer.Dispose();
|
||||||
_libVLC.Dispose();
|
_libVLC.Dispose();
|
||||||
@@ -490,6 +494,20 @@ public partial class Annotator
|
|||||||
|
|
||||||
private (TimeSpan Time, List<Detection> Detections)? _previousDetection;
|
private (TimeSpan Time, List<Detection> Detections)? _previousDetection;
|
||||||
|
|
||||||
|
private List<string> GetLvFiles()
|
||||||
|
{
|
||||||
|
return Dispatcher.Invoke(() =>
|
||||||
|
{
|
||||||
|
var source = LvFiles.ItemsSource as IEnumerable<MediaFileInfo>;
|
||||||
|
var items = source?.Skip(LvFiles.SelectedIndex)
|
||||||
|
.Take(Constants.DETECTION_BATCH_SIZE)
|
||||||
|
.Select(x => x.Path)
|
||||||
|
.ToList();
|
||||||
|
|
||||||
|
return items ?? new List<string>();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
public void AutoDetect(object sender, RoutedEventArgs e)
|
public void AutoDetect(object sender, RoutedEventArgs e)
|
||||||
{
|
{
|
||||||
if (IsInferenceNow)
|
if (IsInferenceNow)
|
||||||
@@ -503,36 +521,25 @@ public partial class Annotator
|
|||||||
if (LvFiles.SelectedIndex == -1)
|
if (LvFiles.SelectedIndex == -1)
|
||||||
LvFiles.SelectedIndex = 0;
|
LvFiles.SelectedIndex = 0;
|
||||||
|
|
||||||
var mct = new CancellationTokenSource();
|
|
||||||
var token = mct.Token;
|
|
||||||
Dispatcher.Invoke(() => Editor.ResetBackground());
|
Dispatcher.Invoke(() => Editor.ResetBackground());
|
||||||
|
|
||||||
IsInferenceNow = true;
|
IsInferenceNow = true;
|
||||||
FollowAI = true;
|
FollowAI = true;
|
||||||
|
DetectionCancellationSource = new CancellationTokenSource();
|
||||||
|
var ct = DetectionCancellationSource.Token;
|
||||||
_ = Task.Run(async () =>
|
_ = Task.Run(async () =>
|
||||||
{
|
{
|
||||||
var mediaInfo = Dispatcher.Invoke(() => (MediaFileInfo)LvFiles.SelectedItem);
|
var files = GetLvFiles();
|
||||||
while (mediaInfo != null && !token.IsCancellationRequested)
|
while (files.Any() && !ct.IsCancellationRequested)
|
||||||
{
|
{
|
||||||
await Dispatcher.Invoke(async () =>
|
await Dispatcher.Invoke(async () =>
|
||||||
{
|
{
|
||||||
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), token);
|
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), ct);
|
||||||
await ReloadAnnotations();
|
await ReloadAnnotations();
|
||||||
});
|
});
|
||||||
|
|
||||||
await _inferenceService.RunInference(mediaInfo.Path, async annotationImage =>
|
await _inferenceService.RunInference(files, async annotationImage => await ProcessDetection(annotationImage), ct);
|
||||||
{
|
files = GetLvFiles();
|
||||||
annotationImage.OriginalMediaName = mediaInfo.FName;
|
|
||||||
await ProcessDetection(annotationImage);
|
|
||||||
});
|
|
||||||
|
|
||||||
mediaInfo = Dispatcher.Invoke(() =>
|
|
||||||
{
|
|
||||||
if (LvFiles.SelectedIndex == LvFiles.Items.Count - 1)
|
|
||||||
return null;
|
|
||||||
LvFiles.SelectedIndex += 1;
|
|
||||||
return (MediaFileInfo)LvFiles.SelectedItem;
|
|
||||||
});
|
|
||||||
Dispatcher.Invoke(() => LvFiles.Items.Refresh());
|
Dispatcher.Invoke(() => LvFiles.Items.Refresh());
|
||||||
}
|
}
|
||||||
Dispatcher.Invoke(() =>
|
Dispatcher.Invoke(() =>
|
||||||
@@ -541,7 +548,7 @@ public partial class Annotator
|
|||||||
IsInferenceNow = false;
|
IsInferenceNow = false;
|
||||||
FollowAI = false;
|
FollowAI = false;
|
||||||
});
|
});
|
||||||
}, token);
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task ProcessDetection(AnnotationImage annotationImage)
|
private async Task ProcessDetection(AnnotationImage annotationImage)
|
||||||
@@ -551,6 +558,8 @@ public partial class Annotator
|
|||||||
try
|
try
|
||||||
{
|
{
|
||||||
var annotation = await _annotationService.SaveAnnotation(annotationImage);
|
var annotation = await _annotationService.SaveAnnotation(annotationImage);
|
||||||
|
if (annotation.OriginalMediaName != _formState.CurrentMedia.FName)
|
||||||
|
return;
|
||||||
AddAnnotation(annotation);
|
AddAnnotation(annotation);
|
||||||
|
|
||||||
if (FollowAI)
|
if (FollowAI)
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ public class AnnotatorEventHandler(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case PlaybackControlEnum.Stop:
|
case PlaybackControlEnum.Stop:
|
||||||
|
await mainWindow.DetectionCancellationSource.CancelAsync();
|
||||||
mediaPlayer.Stop();
|
mediaPlayer.Stop();
|
||||||
break;
|
break;
|
||||||
case PlaybackControlEnum.PreviousFrame:
|
case PlaybackControlEnum.PreviousFrame:
|
||||||
@@ -294,7 +295,7 @@ public class AnnotatorEventHandler(
|
|||||||
media.HasAnnotations = false;
|
media.HasAnnotations = false;
|
||||||
mainWindow.LvFiles.Items.Refresh();
|
mainWindow.LvFiles.Items.Refresh();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
await Task.CompletedTask;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ public class Constants
|
|||||||
public const double TRACKING_INTERSECTION_THRESHOLD = 0.8;
|
public const double TRACKING_INTERSECTION_THRESHOLD = 0.8;
|
||||||
public const int DEFAULT_FRAME_PERIOD_RECOGNITION = 4;
|
public const int DEFAULT_FRAME_PERIOD_RECOGNITION = 4;
|
||||||
|
|
||||||
|
public const int DETECTION_BATCH_SIZE = 4;
|
||||||
# endregion AIRecognitionConfig
|
# endregion AIRecognitionConfig
|
||||||
|
|
||||||
#region Thumbnails
|
#region Thumbnails
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ public class YoloLabel : Label
|
|||||||
[MessagePackObject]
|
[MessagePackObject]
|
||||||
public class Detection : YoloLabel
|
public class Detection : YoloLabel
|
||||||
{
|
{
|
||||||
[IgnoreMember]public string AnnotationName { get; set; } = null!;
|
[Key("an")] public string AnnotationName { get; set; } = null!;
|
||||||
[Key("p")] public double? Probability { get; set; }
|
[Key("p")] public double? Probability { get; set; }
|
||||||
|
|
||||||
//For db & serialization
|
//For db & serialization
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ public class Annotation
|
|||||||
_thumbDir = config.ThumbnailsDirectory;
|
_thumbDir = config.ThumbnailsDirectory;
|
||||||
}
|
}
|
||||||
|
|
||||||
[IgnoreMember]public string Name { get; set; } = null!;
|
[Key("n")] public string Name { get; set; } = null!;
|
||||||
[IgnoreMember]public string OriginalMediaName { get; set; } = null!;
|
[Key("mn")] public string OriginalMediaName { get; set; } = null!;
|
||||||
[IgnoreMember]public TimeSpan Time { get; set; }
|
[IgnoreMember]public TimeSpan Time { get; set; }
|
||||||
[IgnoreMember]public string ImageExtension { get; set; } = null!;
|
[IgnoreMember]public string ImageExtension { get; set; } = null!;
|
||||||
[IgnoreMember]public DateTime CreatedDate { get; set; }
|
[IgnoreMember]public DateTime CreatedDate { get; set; }
|
||||||
|
|||||||
@@ -105,9 +105,6 @@ public class AnnotationService : INotificationHandler<AnnotationsDeletedEvent>
|
|||||||
public async Task<Annotation> SaveAnnotation(AnnotationImage a, CancellationToken cancellationToken = default)
|
public async Task<Annotation> SaveAnnotation(AnnotationImage a, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
a.Time = TimeSpan.FromMilliseconds(a.Milliseconds);
|
a.Time = TimeSpan.FromMilliseconds(a.Milliseconds);
|
||||||
a.Name = a.OriginalMediaName.ToTimeName(a.Time);
|
|
||||||
foreach (var det in a.Detections)
|
|
||||||
det.AnnotationName = a.Name;
|
|
||||||
return await SaveAnnotationInner(DateTime.Now, a.OriginalMediaName, a.Time, ".jpg", a.Detections.ToList(),
|
return await SaveAnnotationInner(DateTime.Now, a.OriginalMediaName, a.Time, ".jpg", a.Detections.ToList(),
|
||||||
a.Source, new MemoryStream(a.Image), a.CreatedRole, a.CreatedEmail, generateThumbnail: true, cancellationToken);
|
a.Source, new MemoryStream(a.Image), a.CreatedRole, a.CreatedEmail, generateThumbnail: true, cancellationToken);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,23 +3,23 @@ using Azaion.Common.Database;
|
|||||||
using Azaion.Common.DTO.Config;
|
using Azaion.Common.DTO.Config;
|
||||||
using Azaion.CommonSecurity;
|
using Azaion.CommonSecurity;
|
||||||
using Azaion.CommonSecurity.DTO.Commands;
|
using Azaion.CommonSecurity.DTO.Commands;
|
||||||
using Azaion.CommonSecurity.Services;
|
|
||||||
using MessagePack;
|
using MessagePack;
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using NetMQ;
|
using NetMQ;
|
||||||
using NetMQ.Sockets;
|
using NetMQ.Sockets;
|
||||||
|
using Newtonsoft.Json;
|
||||||
|
|
||||||
namespace Azaion.Common.Services;
|
namespace Azaion.Common.Services;
|
||||||
|
|
||||||
public interface IInferenceService
|
public interface IInferenceService
|
||||||
{
|
{
|
||||||
Task RunInference(string mediaPath, Func<AnnotationImage, Task> processAnnotation);
|
Task RunInference(List<string> mediaPaths, Func<AnnotationImage, Task> processAnnotation, CancellationToken ct = default);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOptions<AIRecognitionConfig> aiConfigOptions) : IInferenceService
|
public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOptions<AIRecognitionConfig> aiConfigOptions) : IInferenceService
|
||||||
{
|
{
|
||||||
public async Task RunInference(string mediaPath, Func<AnnotationImage, Task> processAnnotation)
|
public async Task RunInference(List<string> mediaPaths, Func<AnnotationImage, Task> processAnnotation, CancellationToken ct = default)
|
||||||
{
|
{
|
||||||
using var dealer = new DealerSocket();
|
using var dealer = new DealerSocket();
|
||||||
var clientId = Guid.NewGuid();
|
var clientId = Guid.NewGuid();
|
||||||
@@ -27,13 +27,14 @@ public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOpt
|
|||||||
dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
|
dealer.Connect($"tcp://{SecurityConstants.ZMQ_HOST}:{SecurityConstants.ZMQ_PORT}");
|
||||||
|
|
||||||
var data = MessagePackSerializer.Serialize(aiConfigOptions.Value);
|
var data = MessagePackSerializer.Serialize(aiConfigOptions.Value);
|
||||||
dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Inference, mediaPath, data)));
|
var filename = JsonConvert.SerializeObject(mediaPaths);
|
||||||
|
dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Inference, filename, data)));
|
||||||
|
|
||||||
while (true)
|
while (!ct.IsCancellationRequested)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var annotationStream = dealer.Get<AnnotationImage>(bytes => bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE");
|
var annotationStream = dealer.Get<AnnotationImage>(bytes => bytes.Length == 4 && Encoding.UTF8.GetString(bytes) == "DONE", ct: ct);
|
||||||
if (annotationStream == null)
|
if (annotationStream == null)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@@ -42,6 +43,7 @@ public class PythonInferenceService(ILogger<PythonInferenceService> logger, IOpt
|
|||||||
catch (Exception e)
|
catch (Exception e)
|
||||||
{
|
{
|
||||||
logger.LogError(e, e.Message);
|
logger.LogError(e, e.Message);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ public class PythonResourceLoader : IResourceLoader, IAuthProvider
|
|||||||
{
|
{
|
||||||
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Load, fileName)));
|
_dealer.SendFrame(MessagePackSerializer.Serialize(new RemoteCommand(CommandType.Load, fileName)));
|
||||||
|
|
||||||
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromMilliseconds(1000), out var bytes))
|
if (!_dealer.TryReceiveFrameBytes(TimeSpan.FromSeconds(3), out var bytes))
|
||||||
throw new Exception($"Unable to receive {fileName}");
|
throw new Exception($"Unable to receive {fileName}");
|
||||||
|
|
||||||
return new MemoryStream(bytes);
|
return new MemoryStream(bytes);
|
||||||
|
|||||||
@@ -6,12 +6,23 @@ namespace Azaion.CommonSecurity;
|
|||||||
|
|
||||||
public static class ZeroMqExtensions
|
public static class ZeroMqExtensions
|
||||||
{
|
{
|
||||||
public static T? Get<T>(this DealerSocket dealer, Func<byte[], bool>? shouldInterceptFn = null) where T : class
|
public static T? Get<T>(this DealerSocket dealer, Func<byte[], bool>? shouldInterceptFn = null, int retries = 24, int tryTimeoutSeconds = 5, CancellationToken ct = default) where T : class
|
||||||
{
|
{
|
||||||
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromMinutes(2), out var bytes))
|
var tryNum = 0;
|
||||||
throw new Exception($"Unable to get {typeof(T).Name}");
|
while (!ct.IsCancellationRequested && tryNum++ < retries)
|
||||||
|
{
|
||||||
|
if (!dealer.TryReceiveFrameBytes(TimeSpan.FromSeconds(tryTimeoutSeconds), out var bytes))
|
||||||
|
continue;
|
||||||
|
|
||||||
if (shouldInterceptFn != null && shouldInterceptFn(bytes))
|
if (shouldInterceptFn != null && shouldInterceptFn(bytes))
|
||||||
return null;
|
return null;
|
||||||
|
|
||||||
return MessagePackSerializer.Deserialize<T>(bytes);
|
return MessagePackSerializer.Deserialize<T>(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!ct.IsCancellationRequested)
|
||||||
|
throw new Exception($"Unable to get {typeof(T).Name} after {tryNum} retries, {tryTimeoutSeconds} seconds each");
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -13,6 +13,17 @@ Results (file or annotations) is putted to the other queue, or the same socket,
|
|||||||
|
|
||||||
<h2>Installation</h2>
|
<h2>Installation</h2>
|
||||||
|
|
||||||
|
Prepare correct onnx model from YOLO:
|
||||||
|
```python
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import netron
|
||||||
|
|
||||||
|
model = YOLO("azaion.pt")
|
||||||
|
model.export(format="onnx", imgsz=1280, nms=True, batch=4)
|
||||||
|
netron.start('azaion.onnx')
|
||||||
|
```
|
||||||
|
Read carefully about [export arguments](https://docs.ultralytics.com/modes/export/), you have to use nms=True, and batching with a proper batch size
|
||||||
|
|
||||||
<h3>Install libs</h3>
|
<h3>Install libs</h3>
|
||||||
https://www.python.org/downloads/
|
https://www.python.org/downloads/
|
||||||
|
|
||||||
@@ -45,7 +56,7 @@ This is crucial for the build because build needs Python.h header and other file
|
|||||||
|
|
||||||
```
|
```
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt pyinstaller tensorboard
|
pip install requirements.txt
|
||||||
```
|
```
|
||||||
In case of fbgemm.dll error (Windows specific):
|
In case of fbgemm.dll error (Windows specific):
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
cdef class Detection:
|
cdef class Detection:
|
||||||
cdef public double x, y, w, h, confidence
|
cdef public double x, y, w, h, confidence
|
||||||
|
cdef public str annotation_name
|
||||||
cdef public int cls
|
cdef public int cls
|
||||||
|
|
||||||
|
cdef public overlaps(self, Detection det2)
|
||||||
|
|
||||||
cdef class Annotation:
|
cdef class Annotation:
|
||||||
cdef bytes image
|
cdef public str name
|
||||||
|
cdef public str original_media_name
|
||||||
cdef long time
|
cdef long time
|
||||||
cdef public list[Detection] detections
|
cdef public list[Detection] detections
|
||||||
|
cdef public bytes image
|
||||||
|
|
||||||
|
cdef format_time(self, ms)
|
||||||
cdef bytes serialize(self)
|
cdef bytes serialize(self)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import msgpack
|
import msgpack
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
cdef class Detection:
|
cdef class Detection:
|
||||||
def __init__(self, double x, double y, double w, double h, int cls, double confidence):
|
def __init__(self, double x, double y, double w, double h, int cls, double confidence):
|
||||||
|
self.annotation_name = None
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
self.w = w
|
self.w = w
|
||||||
@@ -12,18 +14,44 @@ cdef class Detection:
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%'
|
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%'
|
||||||
|
|
||||||
|
cdef overlaps(self, Detection det2):
|
||||||
|
cdef double overlap_x = 0.5 * (self.w + det2.w) - abs(self.x - det2.x)
|
||||||
|
cdef double overlap_y = 0.5 * (self.h + det2.h) - abs(self.y - det2.y)
|
||||||
|
cdef double overlap_area = max(0.0, overlap_x) * max(0.0, overlap_y)
|
||||||
|
cdef double min_area = min(self.w * self.h, det2.w * det2.h)
|
||||||
|
|
||||||
|
return overlap_area / min_area > 0.6
|
||||||
|
|
||||||
cdef class Annotation:
|
cdef class Annotation:
|
||||||
def __init__(self, long time, list[Detection] detections):
|
def __init__(self, str name, long ms, list[Detection] detections):
|
||||||
self.time = time
|
self.original_media_name = Path(<str>name).stem.replace(" ", "")
|
||||||
|
self.name = f'{self.original_media_name}_{self.format_time(ms)}'
|
||||||
|
self.time = ms
|
||||||
self.detections = detections if detections is not None else []
|
self.detections = detections if detections is not None else []
|
||||||
|
for d in self.detections:
|
||||||
|
d.annotation_name = self.name
|
||||||
self.image = b''
|
self.image = b''
|
||||||
|
|
||||||
|
cdef format_time(self, ms):
|
||||||
|
# Calculate hours, minutes, seconds, and hundreds of milliseconds.
|
||||||
|
h = ms // 3600000 # Total full hours.
|
||||||
|
ms_remaining = ms % 3600000
|
||||||
|
m = ms_remaining // 60000 # Full minutes.
|
||||||
|
ms_remaining %= 60000
|
||||||
|
s = ms_remaining // 1000 # Full seconds.
|
||||||
|
f = (ms_remaining % 1000) // 100 # Hundreds of milliseconds.
|
||||||
|
h = h % 10
|
||||||
|
return f"{h}{m:02}{s:02}{f}"
|
||||||
|
|
||||||
cdef bytes serialize(self):
|
cdef bytes serialize(self):
|
||||||
return msgpack.packb({
|
return msgpack.packb({
|
||||||
|
"n": self.name,
|
||||||
|
"mn": self.original_media_name,
|
||||||
"i": self.image, # "i" = image
|
"i": self.image, # "i" = image
|
||||||
"t": self.time, # "t" = time
|
"t": self.time, # "t" = time
|
||||||
"d": [ # "d" = detections
|
"d": [ # "d" = detections
|
||||||
{
|
{
|
||||||
|
"an": det.annotation_name,
|
||||||
"x": det.x,
|
"x": det.x,
|
||||||
"y": det.y,
|
"y": det.y,
|
||||||
"w": det.w,
|
"w": det.w,
|
||||||
|
|||||||
@@ -10,3 +10,4 @@ cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api
|
|||||||
cdef str AI_MODEL_FILE # AI Model file
|
cdef str AI_MODEL_FILE # AI Model file
|
||||||
|
|
||||||
cdef bytes DONE_SIGNAL
|
cdef bytes DONE_SIGNAL
|
||||||
|
cdef int MODEL_BATCH_SIZE
|
||||||
@@ -10,3 +10,4 @@ cdef str QUEUE_CONFIG_FILENAME = "secured-config.json"
|
|||||||
cdef str AI_MODEL_FILE = "azaion.onnx"
|
cdef str AI_MODEL_FILE = "azaion.onnx"
|
||||||
|
|
||||||
cdef bytes DONE_SIGNAL = b"DONE"
|
cdef bytes DONE_SIGNAL = b"DONE"
|
||||||
|
cdef int MODEL_BATCH_SIZE = 4
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from remote_command cimport RemoteCommand
|
from remote_command cimport RemoteCommand
|
||||||
from annotation cimport Annotation
|
from annotation cimport Annotation, Detection
|
||||||
from ai_config cimport AIRecognitionConfig
|
from ai_config cimport AIRecognitionConfig
|
||||||
|
|
||||||
cdef class Inference:
|
cdef class Inference:
|
||||||
@@ -14,14 +14,14 @@ cdef class Inference:
|
|||||||
cdef int model_height
|
cdef int model_height
|
||||||
|
|
||||||
cdef bint is_video(self, str filepath)
|
cdef bint is_video(self, str filepath)
|
||||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
|
cdef run_inference(self, RemoteCommand cmd)
|
||||||
cdef _process_video(self, RemoteCommand cmd, int batch_size)
|
cdef _process_video(self, RemoteCommand cmd, str video_name)
|
||||||
cdef _process_image(self, RemoteCommand cmd)
|
cdef _process_images(self, RemoteCommand cmd, list[str] image_paths)
|
||||||
cdef stop(self)
|
cdef stop(self)
|
||||||
|
|
||||||
cdef preprocess(self, frame)
|
cdef preprocess(self, frames)
|
||||||
cdef postprocess(self, output, int img_width, int img_height)
|
cdef remove_overlapping_detections(self, list[Detection] detections)
|
||||||
|
cdef postprocess(self, output)
|
||||||
|
cdef split_list_extend(self, lst, chunk_size)
|
||||||
|
|
||||||
|
|
||||||
cdef detect_frame(self, frame, long time)
|
|
||||||
cdef bint is_valid_annotation(self, Annotation annotation)
|
cdef bint is_valid_annotation(self, Annotation annotation)
|
||||||
|
|||||||
+126
-75
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -5,6 +6,7 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as onnx
|
import onnxruntime as onnx
|
||||||
|
|
||||||
|
cimport constants
|
||||||
from remote_command cimport RemoteCommand
|
from remote_command cimport RemoteCommand
|
||||||
from annotation cimport Detection, Annotation
|
from annotation cimport Detection, Annotation
|
||||||
from ai_config cimport AIRecognitionConfig
|
from ai_config cimport AIRecognitionConfig
|
||||||
@@ -26,68 +28,117 @@ cdef class Inference:
|
|||||||
model_meta = self.session.get_modelmeta()
|
model_meta = self.session.get_modelmeta()
|
||||||
print("Metadata:", model_meta.custom_metadata_map)
|
print("Metadata:", model_meta.custom_metadata_map)
|
||||||
|
|
||||||
cdef preprocess(self, frame):
|
cdef preprocess(self, frames):
|
||||||
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
blobs = [cv2.dnn.blobFromImage(frame,
|
||||||
img = cv2.resize(img, (self.model_width, self.model_height))
|
scalefactor=1.0 / 255.0,
|
||||||
image_data = np.array(img) / 255.0
|
size=(self.model_width, self.model_height),
|
||||||
image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
|
mean=(0, 0, 0),
|
||||||
image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
|
swapRB=True,
|
||||||
return image_data
|
crop=False)
|
||||||
|
for frame in frames]
|
||||||
|
return np.vstack(blobs)
|
||||||
|
|
||||||
cdef postprocess(self, output, int img_width, int img_height):
|
|
||||||
outputs = np.transpose(np.squeeze(output[0]))
|
|
||||||
rows = outputs.shape[0]
|
|
||||||
|
|
||||||
boxes = []
|
cdef postprocess(self, output):
|
||||||
scores = []
|
cdef list[Detection] detections = []
|
||||||
class_ids = []
|
cdef int ann_index
|
||||||
|
cdef float x1, y1, x2, y2, conf, cx, cy, w, h
|
||||||
|
cdef int class_id
|
||||||
|
cdef list[list[Detection]] results = []
|
||||||
|
|
||||||
x_factor = img_width / self.model_width
|
for ann_index in range(len(output[0])):
|
||||||
y_factor = img_height / self.model_height
|
detections.clear()
|
||||||
|
for det in output[0][ann_index]:
|
||||||
|
if det[4] == 0: # if confidence is 0 then valid points are over.
|
||||||
|
break
|
||||||
|
x1 = det[0] / self.model_width
|
||||||
|
y1 = det[1] / self.model_height
|
||||||
|
x2 = det[2] / self.model_width
|
||||||
|
y2 = det[3] / self.model_height
|
||||||
|
conf = round(det[4], 2)
|
||||||
|
class_id = int(det[5])
|
||||||
|
|
||||||
for i in range(rows):
|
x = (x1 + x2) / 2
|
||||||
classes_scores = outputs[i][4:]
|
y = (y1 + y2) / 2
|
||||||
max_score = np.amax(classes_scores)
|
w = x2 - x1
|
||||||
|
h = y2 - y1
|
||||||
|
detections.append(Detection(x, y, w, h, class_id, conf))
|
||||||
|
filtered_detections = self.remove_overlapping_detections(detections)
|
||||||
|
results.append(filtered_detections)
|
||||||
|
return results
|
||||||
|
|
||||||
if max_score >= self.ai_config.probability_threshold:
|
cdef remove_overlapping_detections(self, list[Detection] detections):
|
||||||
class_id = np.argmax(classes_scores)
|
cdef Detection det1, det2
|
||||||
x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
|
filtered_output = []
|
||||||
|
filtered_out_indexes = []
|
||||||
|
|
||||||
left = int((x - w / 2) * x_factor)
|
for det1_index in range(len(detections)):
|
||||||
top = int((y - h / 2) * y_factor)
|
if det1_index in filtered_out_indexes:
|
||||||
width = int(w * x_factor)
|
continue
|
||||||
height = int(h * y_factor)
|
det1 = detections[det1_index]
|
||||||
|
print(f'det1 size: {det1.w}, {det1.h}')
|
||||||
class_ids.append(class_id)
|
res = det1_index
|
||||||
scores.append(max_score)
|
for det2_index in range(det1_index + 1, len(detections)):
|
||||||
boxes.append([left, top, width, height])
|
det2 = detections[det2_index]
|
||||||
indices = cv2.dnn.NMSBoxes(boxes, scores, self.ai_config.probability_threshold, 0.45)
|
print(f'det2 size: {det2.w}, {det2.h}')
|
||||||
detections = []
|
if det1.overlaps(det2):
|
||||||
for i in indices:
|
if det1.confidence > det2.confidence or (
|
||||||
x, y, w, h = boxes[i]
|
det1.confidence == det2.confidence and det1.cls < det2.cls): # det1 has higher confidence or lower class_id
|
||||||
detections.append(Detection(x, y, w, h, class_ids[i], scores[i]))
|
filtered_out_indexes.append(det2_index)
|
||||||
return detections
|
else:
|
||||||
|
filtered_out_indexes.append(res)
|
||||||
|
res = det2_index
|
||||||
|
filtered_output.append(detections[res])
|
||||||
|
filtered_out_indexes.append(res)
|
||||||
|
return filtered_output
|
||||||
|
|
||||||
cdef bint is_video(self, str filepath):
|
cdef bint is_video(self, str filepath):
|
||||||
mime_type, _ = mimetypes.guess_type(<str>filepath)
|
mime_type, _ = mimetypes.guess_type(<str>filepath)
|
||||||
return mime_type and mime_type.startswith("video")
|
return mime_type and mime_type.startswith("video")
|
||||||
|
|
||||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
|
cdef split_list_extend(self, lst, chunk_size):
|
||||||
print('run inference..')
|
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||||
|
|
||||||
|
# If the last chunk is smaller than the desired chunk_size, extend it by duplicating its last element.
|
||||||
|
last_chunk = chunks[len(chunks) - 1]
|
||||||
|
if len(last_chunk) < chunk_size:
|
||||||
|
last_elem = last_chunk[len(last_chunk)-1]
|
||||||
|
while len(last_chunk) < chunk_size:
|
||||||
|
last_chunk.append(last_elem)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
cdef run_inference(self, RemoteCommand cmd):
|
||||||
|
cdef list[str] medias = json.loads(<str> cmd.filename)
|
||||||
|
cdef list[str] videos = []
|
||||||
|
cdef list[str] images = []
|
||||||
|
|
||||||
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
||||||
self.stop_signal = False
|
self.stop_signal = False
|
||||||
if self.is_video(cmd.filename):
|
|
||||||
self._process_video(cmd, batch_size)
|
for m in medias:
|
||||||
|
if self.is_video(m):
|
||||||
|
videos.append(m)
|
||||||
else:
|
else:
|
||||||
self._process_image(cmd)
|
images.append(m)
|
||||||
|
|
||||||
cdef _process_video(self, RemoteCommand cmd, int batch_size):
|
# images first, it's faster
|
||||||
frame_count = 0
|
if len(images) > 0:
|
||||||
batch_frame = []
|
for chunk in self.split_list_extend(images, constants.MODEL_BATCH_SIZE):
|
||||||
|
print(f'run inference on {" ".join(chunk)}...')
|
||||||
|
self._process_images(cmd, chunk)
|
||||||
|
if len(videos) > 0:
|
||||||
|
for v in videos:
|
||||||
|
print(f'run inference on {v}...')
|
||||||
|
self._process_video(cmd, v)
|
||||||
|
|
||||||
|
|
||||||
|
cdef _process_video(self, RemoteCommand cmd, str video_name):
|
||||||
|
cdef int frame_count = 0
|
||||||
|
cdef list batch_frames = []
|
||||||
|
cdef list[int] batch_timestamps = []
|
||||||
self._previous_annotation = None
|
self._previous_annotation = None
|
||||||
self.start_video_time = time.time()
|
|
||||||
|
|
||||||
v_input = cv2.VideoCapture(<str>cmd.filename)
|
v_input = cv2.VideoCapture(<str>video_name)
|
||||||
while v_input.isOpened():
|
while v_input.isOpened():
|
||||||
ret, frame = v_input.read()
|
ret, frame = v_input.read()
|
||||||
if not ret or frame is None:
|
if not ret or frame is None:
|
||||||
@@ -95,43 +146,43 @@ cdef class Inference:
|
|||||||
|
|
||||||
frame_count += 1
|
frame_count += 1
|
||||||
if frame_count % self.ai_config.frame_period_recognition == 0:
|
if frame_count % self.ai_config.frame_period_recognition == 0:
|
||||||
ms = int(v_input.get(cv2.CAP_PROP_POS_MSEC))
|
batch_frames.append(frame)
|
||||||
annotation = self.detect_frame(frame, ms)
|
batch_timestamps.append(int(v_input.get(cv2.CAP_PROP_POS_MSEC)))
|
||||||
if annotation is not None:
|
|
||||||
self._previous_annotation = annotation
|
|
||||||
self.on_annotation(annotation)
|
|
||||||
|
|
||||||
|
if len(batch_frames) == constants.MODEL_BATCH_SIZE:
|
||||||
cdef detect_frame(self, frame, long time):
|
input_blob = self.preprocess(batch_frames)
|
||||||
cdef Annotation annotation
|
outputs = self.session.run(None, {self.model_input: input_blob})
|
||||||
img_height, img_width = frame.shape[:2]
|
list_detections = self.postprocess(outputs)
|
||||||
|
for i in range(len(list_detections)):
|
||||||
start_time = time.time()
|
detections = list_detections[i]
|
||||||
img_data = self.preprocess(frame)
|
annotation = Annotation(video_name, batch_timestamps[i], detections)
|
||||||
preprocess_time = time.time()
|
|
||||||
outputs = self.session.run(None, {self.model_input: img_data})
|
|
||||||
inference_time = time.time()
|
|
||||||
detections = self.postprocess(outputs, img_width, img_height)
|
|
||||||
postprocess_time = time.time()
|
|
||||||
print(f'video time, ms: {time / 1000:.3f}. total time, s : {postprocess_time - self.start_video_time:.3f} '
|
|
||||||
f'preprocess time: {preprocess_time - start_time:.3f}, inference time: {inference_time - preprocess_time:.3f},'
|
|
||||||
f' postprocess time: {postprocess_time - inference_time:.3f}, total time: {postprocess_time - start_time:.3f}')
|
|
||||||
if len(detections) > 0:
|
|
||||||
annotation = Annotation(frame, time, detections)
|
|
||||||
if self.is_valid_annotation(annotation):
|
if self.is_valid_annotation(annotation):
|
||||||
_, image = cv2.imencode('.jpg', frame)
|
_, image = cv2.imencode('.jpg', frame)
|
||||||
annotation.image = image.tobytes()
|
annotation.image = image.tobytes()
|
||||||
return annotation
|
self.on_annotation(cmd, annotation)
|
||||||
return None
|
self._previous_annotation = annotation
|
||||||
|
|
||||||
|
batch_frames.clear()
|
||||||
|
batch_timestamps.clear()
|
||||||
|
v_input.release()
|
||||||
|
|
||||||
|
|
||||||
cdef _process_image(self, RemoteCommand cmd):
|
cdef _process_images(self, RemoteCommand cmd, list[str] image_paths):
|
||||||
|
cdef list frames = []
|
||||||
|
cdef list timestamps = []
|
||||||
self._previous_annotation = None
|
self._previous_annotation = None
|
||||||
frame = cv2.imread(<str>cmd.filename)
|
for image in image_paths:
|
||||||
annotation = self.detect_frame(frame, 0)
|
frame = cv2.imread(image)
|
||||||
if annotation is None:
|
frames.append(frame)
|
||||||
_, image = cv2.imencode('.jpg', frame)
|
timestamps.append(0)
|
||||||
annotation = Annotation(frame, time, [])
|
|
||||||
|
input_blob = self.preprocess(frames)
|
||||||
|
outputs = self.session.run(None, {self.model_input: input_blob})
|
||||||
|
list_detections = self.postprocess(outputs)
|
||||||
|
for i in range(len(list_detections)):
|
||||||
|
detections = list_detections[i]
|
||||||
|
annotation = Annotation(image_paths[i], timestamps[i], detections)
|
||||||
|
_, image = cv2.imencode('.jpg', frames[i])
|
||||||
annotation.image = image.tobytes()
|
annotation.image = image.tobytes()
|
||||||
self.on_annotation(cmd, annotation)
|
self.on_annotation(cmd, annotation)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ cdef class RemoteCommand:
|
|||||||
10: "GET_USER",
|
10: "GET_USER",
|
||||||
20: "LOAD",
|
20: "LOAD",
|
||||||
30: "INFERENCE",
|
30: "INFERENCE",
|
||||||
40: "STOP INFERENCE",
|
40: "STOP_INFERENCE",
|
||||||
100: "EXIT"
|
100: "EXIT"
|
||||||
}
|
}
|
||||||
data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
|
data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ extensions = [
|
|||||||
Extension('api_client', ['api_client.pyx']),
|
Extension('api_client', ['api_client.pyx']),
|
||||||
Extension('secure_model', ['secure_model.pyx']),
|
Extension('secure_model', ['secure_model.pyx']),
|
||||||
Extension('ai_config', ['ai_config.pyx']),
|
Extension('ai_config', ['ai_config.pyx']),
|
||||||
Extension('inference', ['inference.pyx']),
|
Extension('inference', ['inference.pyx'], include_dirs=[np.get_include()]),
|
||||||
Extension('main', ['main.pyx']),
|
Extension('main', ['main.pyx']),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3Mzg4Mjk0NTMsImV4cCI6MTczODg0Mzg1MywiaWF0IjoxNzM4ODI5NDUzLCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.t6ImX8KkH5IQ4zNNY5IbXESSI6uia4iuzyMhodvM7AA
|
|
||||||
Reference in New Issue
Block a user