diff --git a/.gitignore b/.gitignore index d5b0e44..d6fa44d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ *images/ datasets/ runs/ +models/ *.pt \ No newline at end of file diff --git a/constants.py b/constants.py index 1f790ff..6734350 100644 --- a/constants.py +++ b/constants.py @@ -7,7 +7,3 @@ current_dataset_dir = os.path.join('datasets', 'zombobase-current') current_images_dir = os.path.join(current_dataset_dir, 'images') current_labels_dir = os.path.join(current_dataset_dir, 'labels') annotation_classes = AnnotationClass.read_json() - - -prefix = 'zombobase-' -today_dataset = os.path.join('datasets', f'{prefix}{datetime.now():%Y-%m-%d}') \ No newline at end of file diff --git a/train.py b/train.py index 7f0e6b8..bd6bc8f 100644 --- a/train.py +++ b/train.py @@ -4,23 +4,18 @@ import shutil from datetime import datetime from pathlib import Path from ultralytics import YOLOv10 -from constants import current_images_dir, current_labels_dir, annotation_classes, prefix +from constants import current_images_dir, current_labels_dir, annotation_classes - -def get_yaml_path(date_str): - date_dataset = path.join('datasets', f'{prefix}{date_str}') - date_yaml_path = abspath(path.join(date_dataset, 'data.yaml')) - return date_dataset, date_yaml_path - - -today_dataset, today_yaml = get_yaml_path(f'{datetime.now():%Y-%m-%d}') +prefix = 'zombobase-' +latest_model = f'models/{prefix}latest.pt' +today_folder = f'{prefix}{datetime.now():%Y-%m-%d}' train_set = 70 valid_set = 20 test_set = 10 def form_dataset(): - makedirs(today_dataset, exist_ok=True) + makedirs(path.join('datasets', today_folder), exist_ok=True) images = listdir(current_images_dir) train_size = int(len(images) * train_set / 100.0) @@ -34,6 +29,7 @@ def form_dataset(): def move_annotations(images, folder): + today_dataset = path.join('datasets', today_folder) destination_images = path.join(today_dataset, folder, 'images') makedirs(destination_images, exist_ok=True) destination_labels = path.join(today_dataset, folder, 'labels') @@ -60,6 +56,7 @@ def create_yaml(): lines.append(f'val: valid/images') lines.append('') + today_yaml = abspath(path.join('datasets', today_folder, 'data.yaml')) with open(today_yaml, 'w', encoding='utf-8') as f: f.writelines([f'{line}\n' for line in lines]) @@ -82,13 +79,12 @@ if __name__ == '__main__': # form_dataset() # create_yaml() - model = YOLOv10('datasets/zombobase-latest.pt' or 'yolov10x.yaml') - _, yaml = get_yaml_path('2024-06-09') - results = model.train(data=yaml, epochs=2, batch=10, imgsz=640) + model = YOLOv10(latest_model or 'yolov10x.yaml') - print(results) - res_model = path.join(results['save_dir'], '/weights/best.pt') - print(res_model) - shutil.copy(res_model, 'datasets/zombobase-latest2.pt') + yaml = abspath(path.join('datasets', today_folder, 'data.yaml')) + results = model.train(data=yaml, epochs=100, batch=10, imgsz=640) - pass + shutil.copy(f'{results.save_dir}/weights/best.pt', latest_model) + shutil.copytree(results.save_dir, f'models/{today_folder}') + shutil.rmtree('runs') + shutil.rmtree('models/zombobase-latest')