mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 21:46:35 +00:00
118 lines
4.2 KiB
Python
118 lines
4.2 KiB
Python
import datetime
|
|
import os
|
|
import msgpack
|
|
import json
|
|
import datetime
|
|
|
|
|
|
class Detection:
|
|
def __init__(self, annotation_name, cls, x, y, w, h, confidence=None):
|
|
self.annotation_name = annotation_name
|
|
self.cls = cls
|
|
self.x = x
|
|
self.y = y
|
|
self.w = w
|
|
self.h = h
|
|
self.confidence = confidence
|
|
|
|
def __str__(self):
|
|
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence * 100):.1f}%'
|
|
|
|
|
|
class AnnotationCreatedMessageNarrow:
|
|
def __init__(self, msgpack_bytes):
|
|
unpacked_data = msgpack.unpackb(msgpack_bytes, strict_map_key=False)
|
|
self.name = unpacked_data.get(1)
|
|
self.createdEmail = unpacked_data.get(2)
|
|
|
|
|
|
class AnnotationCreatedMessage:
|
|
last_offset = None
|
|
|
|
def __init__(self, msgpack_bytes):
|
|
unpacked_data = self.read_rabbit(msgpack_bytes)
|
|
ts = unpacked_data[0]
|
|
self.createdDate = datetime.datetime.utcfromtimestamp(ts.seconds) + datetime.timedelta(microseconds=ts.nanoseconds/1000)
|
|
self.name = unpacked_data[1]
|
|
self.originalMediaName = unpacked_data[2]
|
|
self.time = datetime.timedelta(microseconds=unpacked_data[3]/10)
|
|
self.imageExtension = unpacked_data[4]
|
|
detections_json_str = unpacked_data[5]
|
|
self.detections = self._parse_detections(detections_json_str)
|
|
self.image = unpacked_data[6]
|
|
self.createdRole = unpacked_data[7]
|
|
self.createdEmail = unpacked_data[8]
|
|
self.source = unpacked_data[9]
|
|
self.status = unpacked_data[10]
|
|
|
|
@staticmethod
|
|
def read_rabbit(message_bytes):
|
|
if AnnotationCreatedMessage.last_offset is not None:
|
|
try:
|
|
unpacked_data = msgpack.unpackb(message_bytes[AnnotationCreatedMessage.last_offset:], raw=False, strict_map_key=False)
|
|
return unpacked_data
|
|
except Exception:
|
|
pass
|
|
|
|
for offset in range(3, 15):
|
|
try:
|
|
unpacked_data = msgpack.unpackb(message_bytes[offset:], raw=False, strict_map_key=False)
|
|
AnnotationCreatedMessage.last_offset = offset
|
|
return unpacked_data
|
|
except Exception:
|
|
pass
|
|
raise Exception(f'Cannot read rabbit message! Bytes: {message_bytes}')
|
|
|
|
def __str__(self):
|
|
if self.detections:
|
|
detections_str_list = [str(detection) for detection in self.detections]
|
|
detections_str = ", ".join(detections_str_list)
|
|
return f'{self.name}: [{detections_str}]'
|
|
else:
|
|
return f'{self.name}: [Empty]'
|
|
|
|
@staticmethod
|
|
def _parse_detections(detections_json_str):
|
|
if detections_json_str:
|
|
detections_list = json.loads(detections_json_str)
|
|
return [Detection(
|
|
d.get('an'),
|
|
d.get('cl'),
|
|
d.get('x'),
|
|
d.get('y'),
|
|
d.get('w'),
|
|
d.get('h'),
|
|
d.get('p')
|
|
) for d in detections_list]
|
|
return []
|
|
|
|
def save_annotation(self, save_folder):
|
|
image_folder = os.path.join(save_folder, 'images')
|
|
labels_folder = os.path.join(save_folder, 'labels')
|
|
|
|
os.makedirs(image_folder, exist_ok=True)
|
|
os.makedirs(labels_folder, exist_ok=True)
|
|
|
|
image_path = os.path.join(image_folder, f"{self.name}.{self.imageExtension}")
|
|
label_path = os.path.join(labels_folder, f"{self.name}.txt")
|
|
|
|
try:
|
|
with open(image_path, 'wb') as image_file:
|
|
image_file.write(self.image)
|
|
print(f"Image saved to: {image_path}")
|
|
except IOError as e:
|
|
print(f"Error saving image: {e}")
|
|
|
|
try:
|
|
with open(label_path, 'w') as label_file:
|
|
if self.detections:
|
|
label_file.writelines([
|
|
f'{detection.cls} {detection.x} {detection.y} {detection.w} {detection.h}'
|
|
for detection in self.detections
|
|
])
|
|
else:
|
|
label_file.write('')
|
|
print(f'Label saved to: {label_path}')
|
|
except IOError as e:
|
|
print(f"Error saving label: {e}")
|