small refactoring

This commit is contained in:
Oleksandr Bezdieniezhnykh
2024-06-16 12:21:38 +03:00
parent 4161078f40
commit b7b8b8fd27
+34 -13
View File
@@ -10,21 +10,32 @@ tag_size = 'size'
tag_object = 'object' tag_object = 'object'
tag_name = 'name' tag_name = 'name'
tag_bndbox = 'bndbox' tag_bndbox = 'bndbox'
# 1 Вантажівка, 2 Машина легкова name_class_map = {'Truck': 1, 'Car': 2, 'Taxi': 2} # 1 Вантажівка, 2 Машина легкова
name_class_map = {'Truck': 1, 'Car': 2, 'Taxi': 2} forbidden_classes = ['Motorcycle']
default_class = 1
def convert_xml(folder): def convert_xml(folder):
for f in os.listdir(folder):
if not f.endswith('.jpg'):
continue
os.makedirs(images_dir, exist_ok=True) os.makedirs(images_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True) os.makedirs(labels_dir, exist_ok=True)
shutil.copy(os.path.join(folder, f), os.path.join(images_dir, f)) for f in os.listdir(folder):
if not f.endswith('.jpg'):
continue
label = f'{Path(f).stem}.xml' label = f'{Path(f).stem}.xml'
lines = read_xml(folder, label)
if not lines:
print(f'Image {f} has only forbidden classes in annotations')
continue
shutil.copy(os.path.join(folder, f), os.path.join(images_dir, f))
with open(os.path.join(labels_dir, f'{Path(label).stem}.txt'), 'w') as label_file:
label_file.writelines(lines)
label_file.close()
print(f'Image {f} has been processed successfully')
def read_xml(folder, label):
tree = et.parse(os.path.join(folder, label)) tree = et.parse(os.path.join(folder, label))
root = tree.getroot() root = tree.getroot()
lines = [] lines = []
@@ -32,9 +43,19 @@ def convert_xml(folder):
width = int(size_dict['width']) width = int(size_dict['width'])
height = int(size_dict['height']) height = int(size_dict['height'])
for node_object in tree.findall(tag_object): for node_object in tree.findall(tag_object):
class_num = default_class
c_x = c_y = c_w = c_h = 0
for node_object_ch in node_object: for node_object_ch in node_object:
if node_object_ch.tag == tag_name: if node_object_ch.tag == tag_name:
class_num = name_class_map[node_object_ch.text] key = node_object_ch.text
if key in name_class_map:
class_num = name_class_map[key]
else:
if key in forbidden_classes:
class_num = -1
continue
else:
class_num = default_class
if node_object_ch.tag == tag_bndbox: if node_object_ch.tag == tag_bndbox:
bbox_dict = {bbox_ch.tag: bbox_ch.text for bbox_ch in node_object_ch} bbox_dict = {bbox_ch.tag: bbox_ch.text for bbox_ch in node_object_ch}
xmin = int(bbox_dict['xmin']) xmin = int(bbox_dict['xmin'])
@@ -45,11 +66,11 @@ def convert_xml(folder):
c_h = (ymax - ymin) / height c_h = (ymax - ymin) / height
c_x = xmin / width + c_w / 2 c_x = xmin / width + c_w / 2
c_y = ymin / height + c_h / 2 c_y = ymin / height + c_h / 2
lines.append(f'{class_num} {c_x} {c_y} {c_w} {c_h}') if class_num == -1:
continue
with open(os.path.join(labels_dir, f'{Path(label).stem}.txt'), 'w') as f: if c_x != 0 and c_y != 0 and c_w != 0 and c_h != 0:
f.writelines(lines) lines.append(f'{class_num} {round(c_x, 5)} {round(c_y, 5)} {round(c_w, 5)} {round(c_h, 5)}\n')
f.close() return lines
if __name__ == '__main__': if __name__ == '__main__':