fix inference

fix small issues
This commit is contained in:
Alex Bezdieniezhnykh
2025-02-14 09:00:04 +02:00
parent cfd5483a18
commit 961d2499de
15 changed files with 42 additions and 29 deletions
+11 -6
View File
@@ -520,21 +520,26 @@ public partial class Annotator
var files = new List<string>(); var files = new List<string>();
await Dispatcher.Invoke(async () => await Dispatcher.Invoke(async () =>
{ {
//Take not annotated medias
files = (LvFiles.ItemsSource as IEnumerable<MediaFileInfo>)?.Skip(LvFiles.SelectedIndex) files = (LvFiles.ItemsSource as IEnumerable<MediaFileInfo>)?.Skip(LvFiles.SelectedIndex)
.Take(Constants.DETECTION_BATCH_SIZE) .Take(Constants.DETECTION_BATCH_SIZE)
.Where(x => !x.HasAnnotations)
.Select(x => x.Path) .Select(x => x.Path)
.ToList(); .ToList() ?? [];
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), ct); if (files.Count != 0)
await ReloadAnnotations(); {
await _mediator.Publish(new AnnotatorControlEvent(PlaybackControlEnum.Play), ct);
await ReloadAnnotations();
}
}); });
if (files.Count == 0)
break;
await _inferenceService.RunInference(files, async annotationImage => await ProcessDetection(annotationImage), ct); await _inferenceService.RunInference(files, async annotationImage => await ProcessDetection(annotationImage), ct);
Dispatcher.Invoke(() => Dispatcher.Invoke(() =>
{ {
if (LvFiles.SelectedIndex + files.Count >= LvFiles.Items.Count)
DetectionCancellationSource.Cancel();
LvFiles.SelectedIndex += files.Count; LvFiles.SelectedIndex += files.Count;
LvFiles.Items.Refresh(); LvFiles.Items.Refresh();
}); });
@@ -13,6 +13,6 @@ public class AIRecognitionConfig
[Key(nameof(TrackingProbabilityIncrease))] public double TrackingProbabilityIncrease { get; set; } [Key(nameof(TrackingProbabilityIncrease))] public double TrackingProbabilityIncrease { get; set; }
[Key(nameof(TrackingIntersectionThreshold))] public double TrackingIntersectionThreshold { get; set; } [Key(nameof(TrackingIntersectionThreshold))] public double TrackingIntersectionThreshold { get; set; }
[Key(nameof(Data))] public byte[] Data { get; set; } [Key(nameof(Data))] public byte[] Data { get; set; } = null!;
[Key(nameof(Paths))] public List<string> Paths { get; set; } [Key(nameof(Paths))] public List<string> Paths { get; set; } = null!;
} }
+1 -1
View File
@@ -49,7 +49,7 @@ public class Annotation
[MessagePackObject] [MessagePackObject]
public class AnnotationImage : Annotation public class AnnotationImage : Annotation
{ {
[Key("i")] public byte[] Image { get; set; } [Key("i")] public byte[] Image { get; set; } = null!;
} }
public enum AnnotationStatus public enum AnnotationStatus
+1
View File
@@ -105,6 +105,7 @@ public class DbFactory : IDbFactory
await db.Detections.DeleteAsync(x => names.Contains(x.AnnotationName), token: cancellationToken); await db.Detections.DeleteAsync(x => names.Contains(x.AnnotationName), token: cancellationToken);
await db.Annotations.DeleteAsync(x => names.Contains(x.Name), token: cancellationToken); await db.Annotations.DeleteAsync(x => names.Contains(x.Name), token: cancellationToken);
}); });
SaveToDisk();
} }
} }
@@ -30,6 +30,7 @@ public static class ThrottleExt
await func(); await func();
_throttleRunAfter = false; _throttleRunAfter = false;
}, cancellationToken); }, cancellationToken);
await Task.CompletedTask;
} }
} }
@@ -25,7 +25,7 @@ public class PythonResourceLoader : IResourceLoader, IAuthProvider
private readonly DealerSocket _dealer = new(); private readonly DealerSocket _dealer = new();
private readonly Guid _clientId = Guid.NewGuid(); private readonly Guid _clientId = Guid.NewGuid();
public User CurrentUser { get; set; } public User CurrentUser { get; set; } = null!;
public PythonResourceLoader() public PythonResourceLoader()
{ {
+3 -1
View File
@@ -22,7 +22,7 @@ public partial class DatasetExplorer
private readonly AnnotationConfig _annotationConfig; private readonly AnnotationConfig _annotationConfig;
private readonly DirectoriesConfig _directoriesConfig; private readonly DirectoriesConfig _directoriesConfig;
private Dictionary<int, List<Annotation>> _annotationsDict; private Dictionary<int, List<Annotation>> _annotationsDict = new();
private readonly CancellationTokenSource _cts = new(); private readonly CancellationTokenSource _cts = new();
public ObservableCollection<DetectionClass> AllDetectionClasses { get; set; } = new(); public ObservableCollection<DetectionClass> AllDetectionClasses { get; set; } = new();
@@ -191,6 +191,7 @@ public partial class DatasetExplorer
ClassDistribution.Plot.FigureBackground.Color = new("#888888"); ClassDistribution.Plot.FigureBackground.Color = new("#888888");
ClassDistribution.Refresh(); ClassDistribution.Refresh();
await Task.CompletedTask;
} }
private async void RefreshThumbnailsBtnClick(object sender, RoutedEventArgs e) private async void RefreshThumbnailsBtnClick(object sender, RoutedEventArgs e)
@@ -294,6 +295,7 @@ public partial class DatasetExplorer
SelectedAnnotations.Add(annThumb); SelectedAnnotations.Add(annThumb);
SelectedAnnotationDict.Add(annThumb.Annotation.Name, annThumb); SelectedAnnotationDict.Add(annThumb.Annotation.Name, annThumb);
} }
await Task.CompletedTask;
} }
private async void ValidateAnnotationsClick(object sender, RoutedEventArgs e) private async void ValidateAnnotationsClick(object sender, RoutedEventArgs e)
@@ -100,6 +100,7 @@ public class DatasetExplorerEventHandler(
datasetExplorer.SelectedAnnotations.Add(annThumb); datasetExplorer.SelectedAnnotations.Add(annThumb);
datasetExplorer.SelectedAnnotationDict.Add(annThumb.Annotation.Name, annThumb); datasetExplorer.SelectedAnnotationDict.Add(annThumb.Annotation.Name, annThumb);
} }
await Task.CompletedTask;
} }
public async Task Handle(AnnotationsDeletedEvent notification, CancellationToken cancellationToken) public async Task Handle(AnnotationsDeletedEvent notification, CancellationToken cancellationToken)
@@ -114,5 +115,6 @@ public class DatasetExplorerEventHandler(
datasetExplorer.SelectedAnnotations.Remove(annThumb); datasetExplorer.SelectedAnnotations.Remove(annThumb);
datasetExplorer.SelectedAnnotationDict.Remove(annThumb.Annotation.Name); datasetExplorer.SelectedAnnotationDict.Remove(annThumb.Annotation.Name);
} }
await Task.CompletedTask;
} }
} }
+1
View File
@@ -8,6 +8,7 @@ cdef class AIRecognitionConfig:
cdef public double tracking_intersection_threshold cdef public double tracking_intersection_threshold
cdef public bytes file_data cdef public bytes file_data
cdef public list[str] paths
@staticmethod @staticmethod
cdef from_msgpack(bytes data) cdef from_msgpack(bytes data)
+10 -4
View File
@@ -10,7 +10,8 @@ cdef class AIRecognitionConfig:
tracking_probability_increase, tracking_probability_increase,
tracking_intersection_threshold, tracking_intersection_threshold,
file_data file_data,
paths
): ):
self.frame_period_recognition = frame_period_recognition self.frame_period_recognition = frame_period_recognition
self.frame_recognition_seconds = frame_recognition_seconds self.frame_recognition_seconds = frame_recognition_seconds
@@ -21,10 +22,14 @@ cdef class AIRecognitionConfig:
self.tracking_intersection_threshold = tracking_intersection_threshold self.tracking_intersection_threshold = tracking_intersection_threshold
self.file_data = file_data self.file_data = file_data
self.paths = paths
def __str__(self): def __str__(self):
return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, ' return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, '
f'probability_increase : {self.tracking_probability_increase}, intersection_threshold : {self.tracking_intersection_threshold}, frame_period_recognition : {self.frame_period_recognition}') f'probability_increase : {self.tracking_probability_increase}, '
f'intersection_threshold : {self.tracking_intersection_threshold}, '
f'frame_period_recognition : {self.frame_period_recognition}, '
f'paths: {self.paths}')
@staticmethod @staticmethod
cdef from_msgpack(bytes data): cdef from_msgpack(bytes data):
@@ -38,5 +43,6 @@ cdef class AIRecognitionConfig:
unpacked.get("TrackingProbabilityIncrease", 0.0), unpacked.get("TrackingProbabilityIncrease", 0.0),
unpacked.get("TrackingIntersectionThreshold", 0.0), unpacked.get("TrackingIntersectionThreshold", 0.0),
unpacked.get("Data", b''),
unpacked.get("Data", b'')) unpacked.get("Paths", []),
)
+1 -1
View File
@@ -12,7 +12,7 @@ cdef class ApiClient:
cdef set_token(self, str token) cdef set_token(self, str token)
cdef get_user(self) cdef get_user(self)
cdef load_bytes(self, FileData file_data) cdef load_bytes(self, str filename, str folder=*)
cdef load_ai_model(self) cdef load_ai_model(self)
cdef load_queue_config(self) cdef load_queue_config(self)
+4 -4
View File
@@ -60,8 +60,8 @@ cdef class ApiClient:
self.login() self.login()
return self.user return self.user
cdef load_bytes(self, FileData file_data): cdef load_bytes(self, str filename, str folder=None):
folder = file_data.folder or self.credentials.folder folder = folder or self.credentials.folder
hardware_service = HardwareService() hardware_service = HardwareService()
cdef HardwareInfo hardware = hardware_service.get_hardware_info() cdef HardwareInfo hardware = hardware_service.get_hardware_info()
@@ -78,7 +78,7 @@ cdef class ApiClient:
{ {
"password": self.credentials.password, "password": self.credentials.password,
"hardware": hardware.to_json_object(), "hardware": hardware.to_json_object(),
"fileName": file_data.filename "fileName": filename
}, indent=4) }, indent=4)
response = requests.post(url, data=payload, headers=headers, stream=True) response = requests.post(url, data=payload, headers=headers, stream=True)
@@ -97,7 +97,7 @@ cdef class ApiClient:
stream = BytesIO(response.raw.read()) stream = BytesIO(response.raw.read())
data = Security.decrypt_to(stream, key) data = Security.decrypt_to(stream, key)
print(f'loaded file: {file_data.filename}, {len(data)} bytes') print(f'loaded file: {filename}, {len(data)} bytes')
return data return data
cdef load_ai_model(self): cdef load_ai_model(self):
+2 -3
View File
@@ -20,7 +20,7 @@ cdef class Inference:
self.model_width = 0 self.model_width = 0
self.model_height = 0 self.model_height = 0
self.class_names = None self.class_names = None
self.ai_config = AIRecognitionConfig(4, 2, 0.25, 0.15, 15, 0.8, b'') self.ai_config = AIRecognitionConfig(4, 2, 0.25, 0.15, 15, 0.8, b'', [])
def init_ai(self): def init_ai(self):
model_bytes = self.api_client.load_ai_model() model_bytes = self.api_client.load_ai_model()
@@ -114,7 +114,6 @@ cdef class Inference:
return chunks return chunks
cdef run_inference(self, RemoteCommand cmd): cdef run_inference(self, RemoteCommand cmd):
cdef list[str] medias = json.loads(<str> cmd.filename)
cdef list[str] videos = [] cdef list[str] videos = []
cdef list[str] images = [] cdef list[str] images = []
@@ -123,7 +122,7 @@ cdef class Inference:
if self.session is None: if self.session is None:
self.init_ai() self.init_ai()
for m in medias: for m in self.ai_config.paths:
if self.is_video(m): if self.is_video(m):
videos.append(m) videos.append(m)
else: else:
+2 -1
View File
@@ -66,7 +66,8 @@ cdef class CommandProcessor:
self.remote_handler.send(command.client_id, user.serialize()) self.remote_handler.send(command.client_id, user.serialize())
cdef load_file(self, RemoteCommand command): cdef load_file(self, RemoteCommand command):
response = self.api_client.load_bytes(FileData.from_msgpack(command.data)) cdef FileData file_data = FileData.from_msgpack(command.data)
response = self.api_client.load_bytes(file_data.filename, file_data.folder)
self.remote_handler.send(command.client_id, response) self.remote_handler.send(command.client_id, response)
cdef on_annotation(self, RemoteCommand cmd, Annotation annotation): cdef on_annotation(self, RemoteCommand cmd, Annotation annotation):
-5
View File
@@ -1,5 +0,0 @@
namespace Azaion.Annotator;
public interface IAIDetector;
public class YOLODetector : IAIDetector;