mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 11:06:35 +00:00
reorganizing files
add train some catches
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
import json
|
||||
from os.path import dirname, join
|
||||
|
||||
|
||||
class AnnotationClass:
|
||||
def __init__(self, id, name, color):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.color = color
|
||||
|
||||
@staticmethod
|
||||
def read_json():
|
||||
classes_path = join(dirname(dirname(__file__)), 'classes.json')
|
||||
with open(classes_path, 'r', encoding='utf-8') as f:
|
||||
j = json.loads(f.read())
|
||||
return {cl['Id']: AnnotationClass(id=cl['Id'], name=cl['Name'], color=cl['Color']) for cl in j}
|
||||
|
||||
@property
|
||||
def color_tuple(self):
|
||||
color = self.color[3:]
|
||||
lv = len(color)
|
||||
xx = range(0, lv, lv // 3)
|
||||
return tuple(int(color[i:i + lv // 3], 16) for i in xx)
|
||||
@@ -0,0 +1,32 @@
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
class ImageLabel:
|
||||
def __init__(self, image_path, image, labels_path, labels):
|
||||
self.image_path = image_path
|
||||
self.image = image
|
||||
self.labels_path = labels_path
|
||||
self.labels = labels
|
||||
|
||||
def visualize(self, annotation_classes):
|
||||
img = cv2.cvtColor(self.image.copy(), cv2.COLOR_BGR2RGB)
|
||||
height, width, channels = img.shape
|
||||
for label in self.labels:
|
||||
class_num = int(label[-1])
|
||||
x_c = float(label[0])
|
||||
y_c = float(label[1])
|
||||
w = float(label[2])
|
||||
h = float(label[3])
|
||||
x_min = x_c - w / 2
|
||||
y_min = y_c - h / 2
|
||||
x_max = x_min + w
|
||||
y_max = y_min + h
|
||||
color = annotation_classes[class_num].color_tuple
|
||||
|
||||
cv2.rectangle(img, (int(x_min * width), int(y_min * height)), (int(x_max * width), int(y_max * height)),
|
||||
color=color, thickness=3)
|
||||
plt.figure(figsize=(12, 12))
|
||||
plt.axis('off')
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user