fix ai detection bugs #1

This commit is contained in:
Alex Bezdieniezhnykh
2024-11-04 21:12:45 +02:00
parent addf7ccc11
commit d8f60d7491
10 changed files with 172 additions and 66 deletions
@@ -17,7 +17,6 @@ public class AnnotationControl : Border
private readonly TextBlock _classNameLabel;
private readonly Label _probabilityLabel;
public TimeSpan? Time { get; set; }
public double? Probability { get; set; }
private AnnotationClass _annotationClass = null!;
public AnnotationClass AnnotationClass
@@ -84,7 +83,6 @@ public class AnnotationControl : Border
{
_selectionFrame,
_classNameLabel,
_probabilityLabel,
AddRect("rLT", HorizontalAlignment.Left, VerticalAlignment.Top, Cursors.SizeNWSE),
AddRect("rCT", HorizontalAlignment.Center, VerticalAlignment.Top, Cursors.SizeNS),
AddRect("rRT", HorizontalAlignment.Right, VerticalAlignment.Top, Cursors.SizeNESW),
@@ -95,6 +93,8 @@ public class AnnotationControl : Border
AddRect("rRB", HorizontalAlignment.Right, VerticalAlignment.Bottom, Cursors.SizeNWSE)
}
};
if (probability.HasValue)
_grid.Children.Add(_probabilityLabel);
Child = _grid;
Cursor = Cursors.SizeAll;
AnnotationClass = annotationClass;
+4 -1
View File
@@ -46,6 +46,7 @@ public class AIRecognitionConfig
public double FrameRecognitionSeconds { get; set; }
public double TrackingDistanceConfidence { get; set; }
public double TrackingProbabilityIncrease { get; set; }
public double TrackingIntersectionThreshold { get; set; }
}
public class WindowConfig
@@ -83,6 +84,7 @@ public class FileConfigRepository(ILogger<FileConfigRepository> logger) : IConfi
private const double DEFAULT_FRAME_RECOGNITION_SECONDS = 2;
private const double TRACKING_DISTANCE_CONFIDENCE = 0.15;
private const double TRACKING_PROBABILITY_INCREASE = 15;
private const double TRACKING_INTERSECTION_THRESHOLD = 0.8;
private static readonly Size DefaultWindowSize = new(1280, 720);
private static readonly Point DefaultWindowLocation = new(100, 100);
@@ -131,7 +133,8 @@ public class FileConfigRepository(ILogger<FileConfigRepository> logger) : IConfi
AIModelPath = "azaion.onnx",
FrameRecognitionSeconds = DEFAULT_FRAME_RECOGNITION_SECONDS,
TrackingDistanceConfidence = TRACKING_DISTANCE_CONFIDENCE,
TrackingProbabilityIncrease = TRACKING_PROBABILITY_INCREASE
TrackingProbabilityIncrease = TRACKING_PROBABILITY_INCREASE,
TrackingIntersectionThreshold = TRACKING_INTERSECTION_THRESHOLD
}
};
}
+1
View File
@@ -15,6 +15,7 @@ public class FormState
public Size CurrentVideoSize { get; set; }
public TimeSpan CurrentVideoLength { get; set; }
public bool BackgroundShown { get; set; }
public int CurrentVolume { get; set; } = 100;
public ObservableCollection<AnnotationResult> AnnotationResults { get; set; } = [];
public WindowsEnum ActiveWindow { get; set; }
+20 -2
View File
@@ -1,7 +1,8 @@
using System.Globalization;
using System.Drawing;
using System.Globalization;
using System.IO;
using System.Windows;
using Newtonsoft.Json;
using Size = System.Windows.Size;
namespace Azaion.Annotator.DTO;
@@ -98,6 +99,9 @@ public class YoloLabel : Label
Height = height;
}
public RectangleF ToRectangle() =>
new((float)(CenterX - Width / 2.0), (float)(CenterY - Height / 2.0), (float)Width, (float)Height);
public YoloLabel(CanvasLabel canvasLabel, Size canvasSize, Size videoSize)
{
var cw = canvasSize.Width;
@@ -168,4 +172,18 @@ public class YoloLabel : Label
}
public override string ToString() => $"{ClassNumber} {CenterX:F5} {CenterY:F5} {Width:F5} {Height:F5}".Replace(',', '.');
}
public class Detection : YoloLabel
{
public Detection(YoloLabel label, double? probability = null)
{
ClassNumber = label.ClassNumber;
CenterX = label.CenterX;
CenterY = label.CenterY;
Height = label.Height;
Width = label.Width;
Probability = probability;
}
public double? Probability { get; set; }
}
@@ -0,0 +1,8 @@
using System.Drawing;
namespace Azaion.Annotator.Extensions;
public static class RectangleFExtensions
{
public static double Area(this RectangleF rectangle) => rectangle.Width * rectangle.Height;
}
@@ -36,13 +36,13 @@ public class VLCFrameExtractor(LibVLC libVLC)
public async IAsyncEnumerable<(TimeSpan Time, Stream Stream)> ExtractFrames(string mediaPath,
[EnumeratorCancellation] CancellationToken manualCancellationToken = default)
{
var videoFinishedCancellationToken = new CancellationTokenSource();
var videoFinishedCancellationSource = new CancellationTokenSource();
_mediaPlayer = new MediaPlayer(libVLC);
_mediaPlayer.Stopped += (s, e) => videoFinishedCancellationToken.CancelAfter(1);
_mediaPlayer.Stopped += (s, e) => videoFinishedCancellationSource.CancelAfter(1);
using var media = new Media(libVLC, mediaPath);
await media.Parse(cancellationToken: videoFinishedCancellationToken.Token);
await media.Parse(cancellationToken: videoFinishedCancellationSource.Token);
var videoTrack = media.Tracks.FirstOrDefault(x => x.Data.Video.Width != 0);
_width = videoTrack.Data.Video.Width;
_height = videoTrack.Data.Video.Height;
@@ -58,9 +58,9 @@ public class VLCFrameExtractor(LibVLC libVLC)
_mediaPlayer.Play(media);
_frameCounter = 0;
var surface = SKSurface.Create(new SKImageInfo((int) _width, (int) _height));
var token = videoFinishedCancellationToken.Token;
var videoFinishedCT = videoFinishedCancellationSource.Token;
while (!(FramesQueue.IsEmpty && token.IsCancellationRequested) && !manualCancellationToken.IsCancellationRequested)
while ( !(FramesQueue.IsEmpty && videoFinishedCT.IsCancellationRequested || manualCancellationToken.IsCancellationRequested))
{
if (FramesQueue.TryDequeue(out var frameInfo))
{
@@ -80,9 +80,10 @@ public class VLCFrameExtractor(LibVLC libVLC)
}
else
{
await Task.Delay(TimeSpan.FromSeconds(1), token);
await Task.Delay(TimeSpan.FromSeconds(1), videoFinishedCT);
}
}
FramesQueue.Clear(); //clear queue in case of manual stop
_mediaPlayer.Stop();
_mediaPlayer.Dispose();
}
+59 -40
View File
@@ -18,6 +18,7 @@ using Size = System.Windows.Size;
using IntervalTree;
using Microsoft.Extensions.Logging;
using OpenTK.Graphics.OpenGL;
using ScottPlot.TickGenerators.TimeUnits;
using Serilog;
using MediaPlayer = LibVLCSharp.Shared.MediaPlayer;
@@ -155,6 +156,7 @@ public partial class MainWindow
{
VideoView.MediaPlayer = _mediaPlayer;
//On start playing media
_mediaPlayer.Playing += async (sender, args) =>
{
if (_formState.CurrentMrl == _mediaPlayer.Media?.Mrl)
@@ -167,13 +169,13 @@ public partial class MainWindow
_formState.CurrentVideoLength = TimeSpan.FromMilliseconds(_mediaPlayer.Length);
await Dispatcher.Invoke(async () => await ReloadAnnotations(_cancellationTokenSource.Token));
if (_formState.CurrentMedia?.MediaType != MediaTypes.Image)
return;
//if image show annotations, give 100ms to load the frame and set on pause
await Task.Delay(100);
ShowCurrentAnnotations();
_mediaPlayer.SetPause(true);
if (_formState.CurrentMedia?.MediaType == MediaTypes.Image)
{
await Task.Delay(100); //wait to load the frame and set on pause
ShowTimeAnnotations(TimeSpan.FromMilliseconds(_mediaPlayer.Time));
_mediaPlayer.SetPause(true);
}
};
LvFiles.MouseDoubleClick += async (_, _) => await _mediator.Publish(new PlaybackControlEvent(PlaybackControlEnum.Play));
@@ -203,13 +205,12 @@ public partial class MainWindow
DgAnnotations.MouseDoubleClick += (sender, args) =>
{
Editor.RemoveAllAnns();
var dgRow = ItemsControl.ContainerFromElement((DataGrid)sender, (args.OriginalSource as DependencyObject)!) as DataGridRow;
var res = (AnnotationResult)dgRow!.Item;
_mediaPlayer.SetPause(true);
Editor.RemoveAllAnns();
_mediaPlayer.Time = (long)res.Time.TotalMilliseconds;
ShowTimeAnnotations(res.Time);
ShowTimeAnnotations(res.Time, showImage: true);
};
DgAnnotations.KeyUp += (sender, args) =>
@@ -255,27 +256,47 @@ public partial class MainWindow
}, TimeSpan.FromSeconds(5));
}
public void ShowCurrentAnnotations() => ShowTimeAnnotations(TimeSpan.FromMilliseconds(_mediaPlayer.Time));
private void ShowTimeAnnotations(TimeSpan time)
private void ShowTimeAnnotations(TimeSpan time, bool showImage = false)
{
Dispatcher.Invoke(() => VideoSlider.Value = _mediaPlayer.Position * VideoSlider.Maximum);
Dispatcher.Invoke(() => StatusClock.Text = $"{TimeSpan.FromMilliseconds(_mediaPlayer.Time):mm\\:ss} / {_formState.CurrentVideoLength:mm\\:ss}");
Dispatcher.Invoke(() =>
{
VideoSlider.Value = _mediaPlayer.Position * VideoSlider.Maximum;
StatusClock.Text = $"{TimeSpan.FromMilliseconds(_mediaPlayer.Time):mm\\:ss} / {_formState.CurrentVideoLength:mm\\:ss}";
Editor.ClearExpiredAnnotations(time);
});
Dispatcher.Invoke(() => Editor.ClearExpiredAnnotations(time));
var annotations = Annotations.Query(time).SelectMany(x => x).ToList();
foreach (var ann in annotations)
AddAnnotationToCanvas(time, new CanvasLabel(ann, Editor.RenderSize, _formState.CurrentVideoSize));
var annotations = Annotations.Query(time).SelectMany(x => x).Select(x => new Detection(x));
AddAnnotationsToCanvas(time, annotations, showImage);
}
private void AddAnnotationToCanvas(TimeSpan? time, CanvasLabel canvasLabel)
private void AddAnnotationsToCanvas(TimeSpan? time, IEnumerable<Detection> labels, bool showImage = false)
{
var annClass = _config.AnnotationClasses[canvasLabel.ClassNumber];
Dispatcher.Invoke(() => Editor.CreateAnnotation(annClass, time, canvasLabel));
Dispatcher.Invoke(async () =>
{
var canvasSize = Editor.RenderSize;
var videoSize = _formState.CurrentVideoSize;
if (showImage)
{
var fName = _formState.GetTimeName(time);
var imgPath = Path.Combine(_config.ImagesDirectory, $"{fName}.jpg");
if (File.Exists(imgPath))
{
Editor.Background = new ImageBrush { ImageSource = await imgPath.OpenImage() };
_formState.BackgroundShown = true;
videoSize = Editor.RenderSize;
}
}
foreach (var label in labels)
{
var annClass = _config.AnnotationClasses[label.ClassNumber];
var canvasLabel = new CanvasLabel(label, canvasSize, videoSize, label.Probability);
Editor.CreateAnnotation(annClass, time, canvasLabel);
}
});
}
private async Task ReloadAnnotations(CancellationToken cancellationToken)
private async Task ReloadAnnotations(CancellationToken ct = default)
{
_formState.AnnotationResults.Clear();
Annotations.Clear();
@@ -290,11 +311,11 @@ public partial class MainWindow
{
var name = Path.GetFileNameWithoutExtension(file.Name);
var time = _formState.GetTime(name);
await AddAnnotations(time, await YoloLabel.ReadFromFile(file.FullName, cancellationToken));
await AddAnnotations(time, await YoloLabel.ReadFromFile(file.FullName, ct), ct);
}
}
public async Task AddAnnotations(TimeSpan? time, List<YoloLabel> annotations)
public async Task AddAnnotations(TimeSpan? time, List<YoloLabel> annotations, CancellationToken ct = default)
{
var timeValue = time ?? TimeSpan.FromMinutes(0);
var previousAnnotations = Annotations.Query(timeValue);
@@ -315,7 +336,7 @@ public partial class MainWindow
.FirstOrDefault();
_formState.AnnotationResults.Insert(index, new AnnotationResult(timeValue, _formState.GetTimeName(time), annotations, _config));
await File.WriteAllTextAsync($"{_config.ResultsDirectory}/{_formState.VideoName}.json", JsonConvert.SerializeObject(_formState.AnnotationResults));
await File.WriteAllTextAsync($"{_config.ResultsDirectory}/{_formState.VideoName}.json", JsonConvert.SerializeObject(_formState.AnnotationResults), ct);
}
private void ReloadFiles()
@@ -481,7 +502,7 @@ public partial class MainWindow
LvFilesContextMenu.DataContext = listItem.DataContext;
}
private (TimeSpan Time, List<(YoloLabel Label, float Probability)> Detections)? _previousDetection;
private (TimeSpan Time, List<Detection> Detections)? _previousDetection;
public void AutoDetect(object sender, RoutedEventArgs e)
{
@@ -540,14 +561,14 @@ public partial class MainWindow
await manualCancellationSource.CancelAsync();
}
}
_autoDetectDialog.Close();
Dispatcher.Invoke(() => _autoDetectDialog.Close());
}, token);
_autoDetectDialog.ShowDialog();
Dispatcher.Invoke(() => Editor.Background = new SolidColorBrush(Color.FromArgb(1, 0, 0, 0)));
}
private bool IsValidDetection(TimeSpan time, List<(YoloLabel Label, float Probability)> detections)
private bool IsValidDetection(TimeSpan time, List<Detection> detections)
{
// No AI detection, forbid
if (detections.Count == 0)
@@ -572,12 +593,12 @@ public partial class MainWindow
foreach (var det in detections)
{
var point = new Point(det.Label.CenterX, det.Label.CenterY);
var point = new Point(det.CenterX, det.CenterY);
var closestObject = prev.Detections
.Select(p => new
{
Point = p,
Distance = point.SqrDistance(new Point(p.Label.CenterX, p.Label.CenterY))
Distance = point.SqrDistance(new Point(p.CenterX, p.CenterY))
})
.OrderBy(x => x.Distance)
.First();
@@ -594,7 +615,7 @@ public partial class MainWindow
return false;
}
private async Task ProcessDetection((TimeSpan Time, Stream Stream) timeframe, List<(YoloLabel Label, float Probability)> detections, CancellationToken token = default)
private async Task ProcessDetection((TimeSpan Time, Stream Stream) timeframe, List<Detection> detections, CancellationToken token = default)
{
_previousDetection = (timeframe.Time, detections);
await Dispatcher.Invoke(async () =>
@@ -602,31 +623,29 @@ public partial class MainWindow
try
{
var time = timeframe.Time;
var labels = detections.Select(x => x.Label).ToList();
var fName = _formState.GetTimeName(timeframe.Time);
var imgPath = Path.Combine(_config.ImagesDirectory, $"{fName}.jpg");
var img = System.Drawing.Image.FromStream(timeframe.Stream);
img.Save(imgPath, ImageFormat.Jpeg);
await YoloLabel.WriteToFile(labels, Path.Combine(_config.LabelsDirectory, $"{fName}.txt"), token);
await YoloLabel.WriteToFile(detections, Path.Combine(_config.LabelsDirectory, $"{fName}.txt"), token);
Editor.Background = new ImageBrush { ImageSource = await imgPath.OpenImage() };
Editor.RemoveAllAnns();
foreach (var (label, probability) in detections)
AddAnnotationToCanvas(time, new CanvasLabel(label, Editor.RenderSize, Editor.RenderSize, probability));
await AddAnnotations(timeframe.Time, labels);
AddAnnotationsToCanvas(time, detections, true);
await AddAnnotations(timeframe.Time, detections.Cast<YoloLabel>().ToList(), token);
var log = string.Join(Environment.NewLine, detections.Select(det =>
$"{_config.AnnotationClassesDict[det.Label.ClassNumber].Name}: " +
$"xy=({det.Label.CenterX:F2},{det.Label.CenterY:F2}), " +
$"size=({det.Label.Width:F2}, {det.Label.Height:F2}), " +
$"{_config.AnnotationClassesDict[det.ClassNumber].Name}: " +
$"xy=({det.CenterX:F2},{det.CenterY:F2}), " +
$"size=({det.Width:F2}, {det.Height:F2}), " +
$"prob: {det.Probability:F1}%"));
Dispatcher.Invoke(() => _autoDetectDialog.Log(log));
var thumbnailDto = await _galleryManager.CreateThumbnail(imgPath, token);
if (thumbnailDto != null)
_datasetExplorer.AddThumbnail(thumbnailDto, labels.Select(x => x.ClassNumber));
_datasetExplorer.AddThumbnail(thumbnailDto, detections.Select(x => x.ClassNumber));
}
catch (Exception e)
{
+21 -6
View File
@@ -2,10 +2,12 @@
using System.Windows;
using System.Windows.Controls;
using System.Windows.Input;
using System.Windows.Media;
using Azaion.Annotator.DTO;
using LibVLCSharp.Shared;
using MediatR;
using Microsoft.Extensions.Logging;
using MediaPlayer = LibVLCSharp.Shared.MediaPlayer;
namespace Azaion.Annotator;
@@ -148,6 +150,11 @@ public class MainWindowEventHandler :
_mediaPlayer.Pause();
if (!_mediaPlayer.IsPlaying)
_mainWindow.BlinkHelp(HelpTexts.HelpTextsDict[HelpTextEnum.AnnotationHelp]);
if (_formState.BackgroundShown)
{
_mainWindow.Editor.Background = new SolidColorBrush(Color.FromArgb(1, 0, 0, 0));
_formState.BackgroundShown = false;
}
break;
case PlaybackControlEnum.Stop:
_mediaPlayer.Stop();
@@ -240,16 +247,13 @@ public class MainWindowEventHandler :
return;
var time = TimeSpan.FromMilliseconds(_mediaPlayer.Time);
var fName = _formState.GetTimeName(time);
var fName = _formState.GetTimeName(time);
var currentAnns = _mainWindow.Editor.CurrentAnns
.Select(x => new YoloLabel(x.Info, _mainWindow.Editor.RenderSize, _formState.CurrentVideoSize))
.Select(x => new YoloLabel(x.Info, _mainWindow.Editor.RenderSize, _formState.BackgroundShown ? _mainWindow.Editor.RenderSize : _formState.CurrentVideoSize))
.ToList();
await YoloLabel.WriteToFile(currentAnns, Path.Combine(_config.LabelsDirectory, $"{fName}.txt"));
var resultHeight = (uint)Math.Round(RESULT_WIDTH / _formState.CurrentVideoSize.Width * _formState.CurrentVideoSize.Height);
await _mainWindow.AddAnnotations(time, currentAnns);
_formState.CurrentMedia.HasAnnotations = _mainWindow.Annotations.Count != 0;
@@ -261,7 +265,18 @@ public class MainWindowEventHandler :
_mainWindow.Editor.RemoveAllAnns();
if (isVideo)
{
_mediaPlayer.TakeSnapshot(0, destinationPath, RESULT_WIDTH, resultHeight);
if (_formState.BackgroundShown)
{
//no need to save image, it's already there, just remove background
_mainWindow.Editor.Background = new SolidColorBrush(Color.FromArgb(1, 0, 0, 0));
_formState.BackgroundShown = false;
}
else
{
var resultHeight = (uint)Math.Round(RESULT_WIDTH / _formState.CurrentVideoSize.Width * _formState.CurrentVideoSize.Height);
_mediaPlayer.TakeSnapshot(0, destinationPath, RESULT_WIDTH, resultHeight);
}
_mediaPlayer.Play();
}
else
+47 -7
View File
@@ -1,23 +1,24 @@
using System.Drawing.Imaging;
using System.IO;
using System.IO;
using Azaion.Annotator.DTO;
using Azaion.Annotator.Extensions;
using Compunet.YoloV8;
using Compunet.YoloV8.Data;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.Formats.Jpeg;
using SixLabors.ImageSharp.PixelFormats;
using Detection = Azaion.Annotator.DTO.Detection;
namespace Azaion.Annotator;
public interface IAIDetector
{
List<(YoloLabel Label, float Probability)> Detect(Stream stream);
List<Detection> Detect(Stream stream);
}
public class YOLODetector(Config config) : IAIDetector, IDisposable
{
private readonly YoloPredictor _predictor = new(config.AIRecognitionConfig.AIModelPath);
public List<(YoloLabel Label, float Probability)> Detect(Stream stream)
public List<Detection> Detect(Stream stream)
{
stream.Seek(0, SeekOrigin.Begin);
var image = Image.Load<Rgb24>(stream);
@@ -25,11 +26,50 @@ public class YOLODetector(Config config) : IAIDetector, IDisposable
var imageSize = new System.Windows.Size(image.Width, image.Height);
return result.Select(d =>
var detections = result.Select(d =>
{
var label = new YoloLabel(new CanvasLabel(d.Name.Id, d.Bounds.X, d.Bounds.Y, d.Bounds.Width, d.Bounds.Height), imageSize, imageSize);
return (label, d.Confidence * 100);
return new Detection(label, (double?)d.Confidence * 100);
}).ToList();
return FilterOverlapping(detections);
}
private List<Detection> FilterOverlapping(List<Detection> detections)
{
var k = config.AIRecognitionConfig.TrackingIntersectionThreshold;
var filteredDetections = new List<Detection>();
for (var i = 0; i < detections.Count; i++)
{
var detectionSelected = false;
for (var j = i + 1; j < detections.Count; j++)
{
var intersect = detections[i].ToRectangle();
intersect.Intersect(detections[j].ToRectangle());
var maxArea = Math.Max(detections[i].ToRectangle().Area(), detections[j].ToRectangle().Area());
if (intersect.Area() > k * maxArea)
{
if (detections[i].Probability > detections[j].Probability)
{
filteredDetections.Add(detections[i]);
detections.RemoveAt(j);
}
else
{
filteredDetections.Add(detections[j]);
detections.RemoveAt(i);
}
detectionSelected = true;
break;
}
}
if (!detectionSelected)
filteredDetections.Add(detections[i]);
}
return filteredDetections;
}
public void Dispose() => _predictor.Dispose();
+3 -2
View File
@@ -1,5 +1,5 @@
{
"VideosDirectory": "E:\\Azaion3\\VideosTest",
"VideosDirectory": "E:\\Azaion1\\Videos",
"LabelsDirectory": "E:\\labels",
"ImagesDirectory": "E:\\images",
"ThumbnailsDirectory": "E:\\thumbnails",
@@ -40,6 +40,7 @@
"AIModelPath": "azaion.onnx",
"FrameRecognitionSeconds": 2,
"TrackingDistanceConfidence": 0.15,
"TrackingProbabilityIncrease": 15
"TrackingProbabilityIncrease": 15,
"TrackingIntersectionThreshold": 0.8
}
}