diff --git a/constants.py b/constants.py index 36700a1..8f0c1e8 100644 --- a/constants.py +++ b/constants.py @@ -1,21 +1,26 @@ import os from dto.annotationClass import AnnotationClass +azaion = '/azaion' prefix = 'azaion-' images = 'images' labels = 'labels' -data_dir = '/azaion/data/raw' +data_dir = os.path.join(azaion, 'data') data_images_dir = os.path.join(data_dir, images) data_labels_dir = os.path.join(data_dir, labels) -processed_dir = '/azaion/data/processed' +processed_dir = os.path.join(azaion, 'data-processed') processed_images_dir = os.path.join(processed_dir, images) processed_labels_dir = os.path.join(processed_dir, labels) +corrupted_dir = os.path.join(azaion, 'data-corrupted') +corrupted_images_dir = os.path.join(corrupted_dir, images) +corrupted_labels_dir = os.path.join(corrupted_dir, labels) -datasets_dir = '/azaion/datasets' -models_dir = '/azaion/models' + +datasets_dir = os.path.join(azaion, 'datasets') +models_dir = os.path.join(azaion, 'models') annotation_classes = AnnotationClass.read_json() date_format = '%Y-%m-%d' diff --git a/dataset-visualiser.py b/dataset-visualiser.py index 04b8229..6f81f6e 100644 --- a/dataset-visualiser.py +++ b/dataset-visualiser.py @@ -6,25 +6,47 @@ from dto.imageLabel import ImageLabel from preprocessing import read_labels from matplotlib import pyplot as plt -from constants import datasets_dir, prefix - - +from constants import datasets_dir, prefix, processed_images_dir, processed_labels_dir annotation_classes = AnnotationClass.read_json() -cur_dataset = os.path.join(datasets_dir, f'{prefix}2024-06-18', 'train') -images_dir = os.path.join(cur_dataset, 'images') -labels_dir = os.path.join(cur_dataset, 'labels') -for f in os.listdir(images_dir)[35247:]: - image_path = os.path.join(images_dir, f) - labels_path = os.path.join(labels_dir, f'{Path(f).stem}.txt') - img = ImageLabel( - image_path=image_path, - image=cv2.imread(image_path), - labels_path=labels_path, - labels=read_labels(labels_path) - ) - img.visualize(annotation_classes) - print(f'visualizing {image_path}') - plt.close() - key = input('Press any key to continue') +def visualise_dataset(): + cur_dataset = os.path.join(datasets_dir, f'{prefix}2024-06-18', 'train') + images_dir = os.path.join(cur_dataset, 'images') + labels_dir = os.path.join(cur_dataset, 'labels') + + for f in os.listdir(images_dir)[35247:]: + image_path = os.path.join(images_dir, f) + labels_path = os.path.join(labels_dir, f'{Path(f).stem}.txt') + img = ImageLabel( + image_path=image_path, + image=cv2.imread(image_path), + labels_path=labels_path, + labels=read_labels(labels_path) + ) + img.visualize(annotation_classes) + print(f'visualizing {image_path}') + plt.close() + key = input('Press any key to continue') + + +def visualise_processed_folder(): + + def show_image(img): + image_path = os.path.join(processed_images_dir, img) + labels_path = os.path.join(processed_labels_dir, f'{Path(img).stem}.txt') + img = ImageLabel( + image_path=image_path, + image=cv2.imread(image_path), + labels_path=labels_path, + labels=read_labels(labels_path) + ) + img.visualize(annotation_classes) + images = os.listdir(processed_images_dir) + cur = 0 + show_image(images[cur]) + pass + + +if __name__ == '__main__': + visualise_processed_folder() diff --git a/preprocessing.py b/preprocessing.py index 860ef59..672c9fd 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -50,13 +50,10 @@ def image_processing(img_ann: ImageLabel) -> [ImageLabel]: return results -def write_result(img_ann: ImageLabel, show_image=False): +def write_result(img_ann: ImageLabel): os.makedirs(os.path.dirname(img_ann.image_path), exist_ok=True) os.makedirs(os.path.dirname(img_ann.labels_path), exist_ok=True) - if show_image: - img_ann.visualize(annotation_classes) - cv2.imencode('.jpg', img_ann.image)[1].tofile(img_ann.image_path) print(f'{img_ann.image_path} written') @@ -92,28 +89,16 @@ def process_image(img_ann): image_path=os.path.join(processed_images_dir, Path(img_ann.image_path).name), labels_path=os.path.join(processed_labels_dir, Path(img_ann.labels_path).name) )) - # os.remove(img_ann.image_path) - # os.remove(img_ann.labels_path) def main(): - checkpoint = datetime.now() - timedelta(days=720) - try: - with open(checkpoint_file, 'r') as f: - checkpoint = datetime.strptime(f.read(), checkpoint_date_format) - except: - pass - last_date = checkpoint while True: + processed_images = set(f.name for f in os.scandir(processed_images_dir)) images = [] with os.scandir(data_images_dir) as imd: for image_file in imd: - if not image_file.is_file(): - continue - mod_time = datetime.fromtimestamp(image_file.stat().st_mtime) - if mod_time > checkpoint: + if image_file.is_file() and image_file.name not in processed_images: images.append(image_file) - last_date = max(last_date, mod_time) for image_file in images: try: @@ -128,14 +113,8 @@ def main(): )) except Exception as e: print(f'Error appeared {e}') - if last_date != checkpoint: - checkpoint = last_date - try: - with open(checkpoint_file, 'w') as f: - f.write(datetime.strftime(checkpoint, checkpoint_date_format)) - except: - pass - time.sleep(5) + print('All processed, waiting for 2 minutes...') + time.sleep(120) if __name__ == '__main__': diff --git a/train.py b/train.py index 696c4cf..534bd3a 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,17 @@ -from os import path, replace, remove, listdir, makedirs +import random +from os import path, replace, remove, listdir, makedirs, scandir from os.path import abspath import shutil from datetime import datetime from pathlib import Path from ultralytics import YOLO -from constants import processed_images_dir, processed_labels_dir, annotation_classes, prefix, date_format, datasets_dir, models_dir +from constants import (processed_images_dir, + processed_labels_dir, + annotation_classes, + prefix, date_format, + datasets_dir, models_dir, + corrupted_images_dir, corrupted_labels_dir) -latest_model = path.join(models_dir, f'{prefix}latest.pt') today_folder = f'{prefix}{datetime.now():{date_format}}' today_dataset = path.join(datasets_dir, today_folder) train_set = 70 @@ -14,38 +19,61 @@ valid_set = 20 test_set = 10 -def form_dataset(): +def form_dataset(set_date: datetime): makedirs(today_dataset, exist_ok=True) - images = listdir(processed_images_dir) + images = [] + with scandir(processed_images_dir) as imd: + for image_file in imd: + if not image_file.is_file(): + continue + mod_time = datetime.fromtimestamp(image_file.stat().st_mtime) + if set_date is None: + images.append(image_file) + elif mod_time > set_date: + images.append(image_file) + + print('shuffling images') + random.shuffle(images) train_size = int(len(images) * train_set / 100.0) valid_size = int(len(images) * valid_set / 100.0) - move_annotations(images[:train_size], 'train') - move_annotations(images[train_size:train_size + valid_size], 'valid') - move_annotations(images[train_size + valid_size:], 'test') + print(f'copy train dataset, size: {train_size} annotations') + copy_annotations(images[:train_size], 'train') + print(f'copy valid set, size: {valid_size} annotations') + copy_annotations(images[train_size:train_size + valid_size], 'valid') + + print(f'copy test set, size: {len(images) - train_size - valid_size} annotations') + copy_annotations(images[train_size + valid_size:], 'test') + + print('creating yaml...') create_yaml() -def move_annotations(images, folder): +def copy_annotations(images, folder): destination_images = path.join(today_dataset, folder, 'images') makedirs(destination_images, exist_ok=True) + destination_labels = path.join(today_dataset, folder, 'labels') makedirs(destination_labels, exist_ok=True) - for image_name in images: - image_path = path.join(processed_images_dir, image_name) - label_name = f'{Path(image_name).stem}.txt' + + makedirs(corrupted_images_dir, exist_ok=True) + makedirs(corrupted_labels_dir, exist_ok=True) + + for image in images: + label_name = f'{Path(image.path).stem}.txt' label_path = path.join(processed_labels_dir, label_name) - if not check_label(label_path): - remove(image_path) + if check_label(label_path): + shutil.copy(image.path, path.join(destination_images, image.name)) + shutil.copy(label_path, path.join(destination_labels, label_name)) else: - replace(image_path, path.join(destination_images, image_name)) - replace(label_path, path.join(destination_labels, label_name)) + shutil.copy(image.path, path.join(corrupted_images_dir, image.name)) + shutil.copy(label_path, path.join(corrupted_labels_dir, label_name)) + print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})') def check_label(label_path): - lines_edited = False if not path.exists(label_path): return False with open(label_path, 'r') as f: @@ -53,16 +81,7 @@ def check_label(label_path): for line in lines: for val in line.split(' ')[1:]: if float(val) > 1: - lines.remove(line) - lines_edited = True - if len(lines) == 0: - return False - if not lines_edited: - return True - - with open(label_path, 'w') as label_write: - label_write.writelines(lines) - label_write.close() + return False return True @@ -97,10 +116,25 @@ def revert_to_processed_data(date): shutil.rmtree(date_dataset) -if __name__ == '__main__': - # form_dataset() +def get_latest_model(): + def convert(d: str): + dir_date = datetime.strptime(d.replace(prefix, ''), '%Y-%m-%d') + dir_model_path = path.join(models_dir, d, 'weights', 'best.pt') + return {'date': dir_date, 'path': dir_model_path} - model_name = latest_model if path.isfile(latest_model) else 'yolov8m.yaml' + dates = [convert(d) for d in listdir(models_dir)] + sorted_dates = list(sorted(dates, key=lambda x: x['date'])) + if len(sorted_dates) == 0: + return None, None + last_model = sorted_dates[-1] + return last_model['date'], last_model['path'] + + +if __name__ == '__main__': + latest_date, latest_model = get_latest_model() + # form_dataset(latest_date) + + model_name = latest_model if latest_model is not None and path.isfile(latest_model) else 'yolov8m.yaml' print(f'Initial model: {model_name}') model = YOLO(model_name) @@ -108,7 +142,7 @@ if __name__ == '__main__': cur_folder = today_dataset yaml = abspath(path.join(cur_folder, 'data.yaml')) - results = model.train(data=yaml, epochs=100, batch=55, imgsz=640, save_period=1) + results = model.train(data=yaml, epochs=100, batch=60, imgsz=640, save_period=1) shutil.copy(f'{results.save_dir}/weights/best.pt', latest_model) shutil.copytree(results.save_dir, path.join(models_dir, cur_folder))