From 66987f4d952f4c6eafa852df22a08ad29920aadd Mon Sep 17 00:00:00 2001 From: Alex Bezdieniezhnykh Date: Tue, 18 Jun 2024 21:32:15 +0300 Subject: [PATCH] add checkpoints and config system convert from bbox oriented and pascal xml fixes --- config.py | 22 ++++++++++ config.yaml | 3 ++ constants.py | 4 +- convert-annotations.py | 96 ++++++++++++++++++++++++++++++------------ preprocessing.py | 87 ++++++++++++++++++++++++++------------ train.py | 29 +++++++++++-- 6 files changed, 182 insertions(+), 59 deletions(-) create mode 100644 config.py create mode 100644 config.yaml diff --git a/config.py b/config.py new file mode 100644 index 0000000..5aab429 --- /dev/null +++ b/config.py @@ -0,0 +1,22 @@ +import yaml + +config_file = 'config.yaml' + + +class Config: + + def __init__(self): + with open(config_file, 'r') as f: + c = yaml.safe_load(f) + self.checkpoint = c['checkpoint'] + self.images_dir = c['images_dir'] + self.labels_dir = c['labels_dir'] + f.close() + + def write(self): + with open(config_file, 'w') as f: + d = dict(checkpoint=self.checkpoint, + images_dir=self.images_dir, + labels_dir=self.labels_dir) + yaml.safe_dump(d, f) + f.close() diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..b12ff48 --- /dev/null +++ b/config.yaml @@ -0,0 +1,3 @@ +checkpoint: 2024-06-18 19:14:02.080664 +images_dir: E:\images +labels_dir: E:\labels diff --git a/constants.py b/constants.py index 6734350..7f419eb 100644 --- a/constants.py +++ b/constants.py @@ -1,9 +1,9 @@ import os -from datetime import datetime - from dto.annotationClass import AnnotationClass current_dataset_dir = os.path.join('datasets', 'zombobase-current') current_images_dir = os.path.join(current_dataset_dir, 'images') current_labels_dir = os.path.join(current_dataset_dir, 'labels') annotation_classes = AnnotationClass.read_json() +prefix = 'zombobase-' +date_format = '%Y-%m-%d' diff --git a/convert-annotations.py b/convert-annotations.py index 110ead1..f907281 100644 --- a/convert-annotations.py +++ b/convert-annotations.py @@ -2,6 +2,7 @@ import os import shutil import xml.etree.cElementTree as et from pathlib import Path +import cv2 labels_dir = 'labels' images_dir = 'images' @@ -13,36 +14,50 @@ tag_bndbox = 'bndbox' name_class_map = {'Truck': 1, 'Car': 2, 'Taxi': 2} # 1 Вантажівка, 2 Машина легкова forbidden_classes = ['Motorcycle'] default_class = 1 +image_extensions = ['jpg', 'png', 'jpeg'] -def convert_xml(folder): +def convert(folder, read_annotations, ann_format): os.makedirs(images_dir, exist_ok=True) os.makedirs(labels_dir, exist_ok=True) for f in os.listdir(folder): - if not f.endswith('.jpg'): + if not f[-3:] in image_extensions: continue - label = f'{Path(f).stem}.xml' - lines = read_xml(folder, label) - if not lines: - print(f'Image {f} has only forbidden classes in annotations') + im = cv2.imread(os.path.join(folder, f)) + height = im.shape[0] + width = im.shape[1] + + label = f'{Path(f).stem}.{ann_format}' + try: + with open(os.path.join(folder, label), 'r') as label_file: + text = label_file.read() + lines = read_annotations(width, height, text) + except ValueError as val_err: + print(f'Image {f} annotations could not be converted. Error: {val_err}') continue + except Exception as e: + print(f'Error conversion for {f}. Error: {e}') shutil.copy(os.path.join(folder, f), os.path.join(images_dir, f)) - with open(os.path.join(labels_dir, f'{Path(label).stem}.txt'), 'w') as label_file: - label_file.writelines(lines) - label_file.close() + with open(os.path.join(labels_dir, f'{Path(label).stem}.txt'), 'w') as new_label_file: + new_label_file.writelines(lines) + new_label_file.close() print(f'Image {f} has been processed successfully') -def read_xml(folder, label): - tree = et.parse(os.path.join(folder, label)) - root = tree.getroot() +def minmax2yolo(width, height, xmin, xmax, ymin, ymax): + c_w = (xmax - xmin) / width + c_h = (ymax - ymin) / height + c_x = xmin / width + c_w / 2 + c_y = ymin / height + c_h / 2 + return round(c_x, 5), round(c_y, 5), round(c_w, 5), round(c_h, 5) + + +def read_pascal_voc(width, height, s): + root = et.fromstring(s) lines = [] - size_dict = {size_ch.tag: size_ch.text for size_ch in root.findall(f'{tag_size}/*')} - width = int(size_dict['width']) - height = int(size_dict['height']) - for node_object in tree.findall(tag_object): + for node_object in root.findall(tag_object): class_num = default_class c_x = c_y = c_w = c_h = 0 for node_object_ch in node_object: @@ -58,20 +73,47 @@ def read_xml(folder, label): class_num = default_class if node_object_ch.tag == tag_bndbox: bbox_dict = {bbox_ch.tag: bbox_ch.text for bbox_ch in node_object_ch} - xmin = int(bbox_dict['xmin']) - xmax = int(bbox_dict['xmax']) - ymin = int(bbox_dict['ymin']) - ymax = int(bbox_dict['ymax']) - c_w = (xmax - xmin) / width - c_h = (ymax - ymin) / height - c_x = xmin / width + c_w / 2 - c_y = ymin / height + c_h / 2 + c_x, c_y, c_w, c_h = minmax2yolo(width, height, + int(bbox_dict['xmin']), + int(bbox_dict['xmax']), + int(bbox_dict['ymin']), + int(bbox_dict['ymax'])) if class_num == -1: continue - if c_x != 0 and c_y != 0 and c_w != 0 and c_h != 0: - lines.append(f'{class_num} {round(c_x, 5)} {round(c_y, 5)} {round(c_w, 5)} {round(c_h, 5)}\n') + if c_x > 1 or c_y > 1 or c_w > 1 or c_h > 1: + print('Values are out of bounds') + else: + if c_x != 0 and c_y != 0 and c_w != 0 and c_h != 0: + lines.append(f'{class_num} {c_x} {c_y} {c_w} {c_h}\n') return lines +def read_bbox_oriented(width, height, s): + yolo_lines = [] + lines = s.split('\n', ) + for line in lines: + if line == '': + continue + vals = line.split(' ') + if len(vals) != 14: + raise ValueError('wrong format') + xmin = min(int(vals[6]), int(vals[7]), int(vals[8]), int(vals[9])) + xmax = max(int(vals[6]), int(vals[7]), int(vals[8]), int(vals[9])) + ymin = min(int(vals[10]), int(vals[11]), int(vals[12]), int(vals[13])) + ymax = max(int(vals[10]), int(vals[11]), int(vals[12]), int(vals[13])) + c_x, c_y, c_w, c_h = minmax2yolo(width, height, xmin, xmax, ymin, ymax) + if c_x > 1 or c_y > 1 or c_w > 1 or c_h > 1: + print('Values are out of bounds') + else: + yolo_lines.append(f'1 {c_x} {c_y} {c_w} {c_h}\n') + return yolo_lines + + +def rename_images(folder): + for f in os.listdir(folder): + shutil.move(os.path.join(folder, f), os.path.join(folder, f[:-7] + '.png')) + + if __name__ == '__main__': - convert_xml('datasets/others/UAVimages') + convert('datasets/others/UAVHeightImages', read_bbox_oriented, 'txt') + convert('datasets/others/UAVimages', read_pascal_voc, 'xml') diff --git a/preprocessing.py b/preprocessing.py index c5bd525..afccaeb 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -1,13 +1,18 @@ import os.path import time +from datetime import datetime from pathlib import Path import albumentations as A import cv2 -from constants import current_images_dir, current_labels_dir, annotation_classes +import numpy as np +from dateutil.relativedelta import relativedelta + +from config import Config +from constants import current_images_dir, current_labels_dir, annotation_classes, prefix, date_format, \ + current_dataset_dir from dto.imageLabel import ImageLabel -labels_dir = 'labels' -images_dir = 'images' +config = Config() def image_processing(img_ann: ImageLabel) -> [ImageLabel]: @@ -22,7 +27,7 @@ def image_processing(img_ann: ImageLabel) -> [ImageLabel]: A.RandomBrightnessContrast(always_apply=True)], bbox_params=A.BboxParams(format='yolo')), A.Compose([A.ShiftScaleRotate(scale_limit=0.2, always_apply=True), - A.VerticalFlip(always_apply=True),], + A.VerticalFlip(always_apply=True), ], bbox_params=A.BboxParams(format='yolo')), A.Compose([A.ShiftScaleRotate(scale_limit=0.2, always_apply=True)], bbox_params=A.BboxParams(format='yolo')), @@ -36,7 +41,7 @@ def image_processing(img_ann: ImageLabel) -> [ImageLabel]: try: res = transform(image=img_ann.image, bboxes=img_ann.labels) path = Path(img_ann.image_path) - name = f'{path.stem}_{i+1}' + name = f'{path.stem}_{i + 1}' img = ImageLabel( image=res['image'], labels=res['bboxes'], @@ -56,11 +61,12 @@ def write_result(img_ann: ImageLabel, show_image=False): if show_image: img_ann.visualize(annotation_classes) - cv2.imwrite(img_ann.image_path, img_ann.image) + cv2.imencode('.jpg', img_ann.image)[1].tofile(img_ann.image_path) print(f'{img_ann.image_path} written') with open(img_ann.labels_path, 'w') as f: - lines = [f'{ann[4]} {round(ann[0], 5)} {round(ann[1], 5)} {round(ann[2], 5)} {round(ann[3], 5)}\n' for ann in img_ann.labels] + lines = [f'{ann[4]} {round(ann[0], 5)} {round(ann[1], 5)} {round(ann[2], 5)} {round(ann[3], 5)}\n' for ann in + img_ann.labels] f.writelines(lines) f.close() print(f'{img_ann.labels_path} written') @@ -89,40 +95,69 @@ def process_image(img_ann): image_path=os.path.join(current_images_dir, Path(img_ann.image_path).name), labels_path=os.path.join(current_labels_dir, Path(img_ann.labels_path).name) )) - os.remove(img_ann.image_path) - os.remove(img_ann.labels_path) + # os.remove(img_ann.image_path) + # os.remove(img_ann.labels_path) + + +def get_checkpoint(): + if config.checkpoint is not None: + return config.checkpoint + + dates = [] + for directory in os.listdir('models'): + try: + dates.append(datetime.strptime(directory[len(prefix):], date_format)) + except: + continue + if len(dates) == 0: + return datetime.now() - relativedelta(years=1) + else: + return max(dates) def main(): + last_date = checkpoint = get_checkpoint() while True: - images = os.listdir(images_dir) - if len(images) == 0: - time.sleep(5) - continue + images = [] + with os.scandir(config.images_dir) as imd: + for image_file in imd: + if not image_file.is_file(): + continue + mod_time = datetime.fromtimestamp(image_file.stat().st_mtime) + if mod_time > checkpoint: + images.append(image_file) + last_date = max(last_date, mod_time) - for image in images: + for image_file in images: try: - image_path = os.path.join(images_dir, image) - labels_path = os.path.join(labels_dir, f'{Path(image_path).stem}.txt') + image_path = os.path.join(config.images_dir, image_file.name) + labels_path = os.path.join(config.labels_dir, f'{Path(image_path).stem}.txt') + image = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED) process_image(ImageLabel( image_path=image_path, - image=cv2.imread(image_path), + image=image, labels_path=labels_path, labels=read_labels(labels_path) )) except Exception as e: print(f'Error appeared {e}') + if last_date != checkpoint: + checkpoint = config.checkpoint = last_date + config.write() + time.sleep(5) - try: - os.remove(image_path) - except OSError: - pass - try: - os.remove(labels_path) - except OSError: - pass +def check_labels(): + for label in os.listdir(os.path.join(current_dataset_dir, 'labels')): + with open(os.path.join(current_dataset_dir, 'labels', label), 'r') as f: + lines = f.readlines() + for line in lines: + list_c = line.split(' ')[1:] + for l in list_c: + if float(l) > 1: + print('Error!') if __name__ == '__main__': - main() + check_labels() + # main() diff --git a/train.py b/train.py index bd6bc8f..2cd2daa 100644 --- a/train.py +++ b/train.py @@ -4,11 +4,10 @@ import shutil from datetime import datetime from pathlib import Path from ultralytics import YOLOv10 -from constants import current_images_dir, current_labels_dir, annotation_classes +from constants import current_images_dir, current_labels_dir, annotation_classes, prefix, date_format -prefix = 'zombobase-' latest_model = f'models/{prefix}latest.pt' -today_folder = f'{prefix}{datetime.now():%Y-%m-%d}' +today_folder = f'{prefix}{datetime.now():{date_format}}' train_set = 70 valid_set = 20 test_set = 10 @@ -38,13 +37,35 @@ def move_annotations(images, folder): image_path = path.join(current_images_dir, image_name) label_name = f'{Path(image_name).stem}.txt' label_path = path.join(current_labels_dir, label_name) - if not path.exists(label_path): + if not check_label(label_path): remove(image_path) else: replace(image_path, path.join(destination_images, image_name)) replace(label_path, path.join(destination_labels, label_name)) +def check_label(label_path): + lines_edited = False + if not path.exists(label_path): + return False + with open(label_path, 'r') as f: + lines = f.readlines() + for line in lines: + for val in line.split(' ')[1:]: + if float(val) > 1: + lines.remove(line) + lines_edited = True + if len(lines) == 0: + return False + if not lines_edited: + return True + + with open(label_path, 'w') as label_write: + label_write.writelines(lines) + label_write.close() + return True + + def create_yaml(): lines = ['names:'] for c in annotation_classes: