mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 21:46:35 +00:00
149 lines
5.2 KiB
Python
149 lines
5.2 KiB
Python
import random
|
|
from os import path, replace, remove, listdir, makedirs, scandir
|
|
from os.path import abspath
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from ultralytics import YOLO
|
|
from constants import (processed_images_dir,
|
|
processed_labels_dir,
|
|
annotation_classes,
|
|
prefix, date_format,
|
|
datasets_dir, models_dir,
|
|
corrupted_images_dir, corrupted_labels_dir)
|
|
|
|
today_folder = f'{prefix}{datetime.now():{date_format}}'
|
|
today_dataset = path.join(datasets_dir, today_folder)
|
|
train_set = 70
|
|
valid_set = 20
|
|
test_set = 10
|
|
|
|
|
|
def form_dataset(from_date: datetime):
|
|
makedirs(today_dataset, exist_ok=True)
|
|
images = []
|
|
with scandir(processed_images_dir) as imd:
|
|
for image_file in imd:
|
|
if not image_file.is_file():
|
|
continue
|
|
mod_time = datetime.fromtimestamp(image_file.stat().st_mtime)
|
|
if from_date is None:
|
|
images.append(image_file)
|
|
elif mod_time > from_date:
|
|
images.append(image_file)
|
|
|
|
print('shuffling images')
|
|
random.shuffle(images)
|
|
|
|
train_size = int(len(images) * train_set / 100.0)
|
|
valid_size = int(len(images) * valid_set / 100.0)
|
|
|
|
print(f'copy train dataset, size: {train_size} annotations')
|
|
copy_annotations(images[:train_size], 'train')
|
|
|
|
print(f'copy valid set, size: {valid_size} annotations')
|
|
copy_annotations(images[train_size:train_size + valid_size], 'valid')
|
|
|
|
print(f'copy test set, size: {len(images) - train_size - valid_size} annotations')
|
|
copy_annotations(images[train_size + valid_size:], 'test')
|
|
|
|
print('creating yaml...')
|
|
create_yaml()
|
|
|
|
|
|
def copy_annotations(images, folder):
|
|
destination_images = path.join(today_dataset, folder, 'images')
|
|
makedirs(destination_images, exist_ok=True)
|
|
|
|
destination_labels = path.join(today_dataset, folder, 'labels')
|
|
makedirs(destination_labels, exist_ok=True)
|
|
|
|
makedirs(corrupted_images_dir, exist_ok=True)
|
|
makedirs(corrupted_labels_dir, exist_ok=True)
|
|
|
|
for image in images:
|
|
label_name = f'{Path(image.path).stem}.txt'
|
|
label_path = path.join(processed_labels_dir, label_name)
|
|
if check_label(label_path):
|
|
shutil.copy(image.path, path.join(destination_images, image.name))
|
|
shutil.copy(label_path, path.join(destination_labels, label_name))
|
|
else:
|
|
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
|
|
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
|
|
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
|
|
|
|
|
|
def check_label(label_path):
|
|
if not path.exists(label_path):
|
|
return False
|
|
with open(label_path, 'r') as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
for val in line.split(' ')[1:]:
|
|
if float(val) > 1:
|
|
return False
|
|
return True
|
|
|
|
|
|
def create_yaml():
|
|
lines = ['names:']
|
|
for c in annotation_classes:
|
|
lines.append(f'- {annotation_classes[c].name}')
|
|
lines.append(f'nc: {len(annotation_classes)}')
|
|
|
|
lines.append(f'test: test/images')
|
|
lines.append(f'train: train/images')
|
|
lines.append(f'val: valid/images')
|
|
lines.append('')
|
|
|
|
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
|
|
with open(today_yaml, 'w', encoding='utf-8') as f:
|
|
f.writelines([f'{line}\n' for line in lines])
|
|
|
|
|
|
def revert_to_processed_data(date):
|
|
def revert_dir(src_dir, dest_dir):
|
|
for file in listdir(src_dir):
|
|
s = path.join(src_dir, file)
|
|
d = path.join(dest_dir, file)
|
|
replace(s, d)
|
|
date_dataset = path.join(datasets_dir, f'{prefix}{date}')
|
|
makedirs(processed_images_dir, exist_ok=True)
|
|
makedirs(processed_labels_dir, exist_ok=True)
|
|
for subset in ['test', 'train', 'valid']:
|
|
revert_dir(path.join(date_dataset, subset, 'images'), processed_images_dir)
|
|
revert_dir(path.join(date_dataset, subset, 'labels'), processed_labels_dir)
|
|
shutil.rmtree(date_dataset)
|
|
|
|
|
|
def get_latest_model():
|
|
def convert(d: str):
|
|
dir_date = datetime.strptime(d.replace(prefix, ''), '%Y-%m-%d')
|
|
dir_model_path = path.join(models_dir, d, 'weights', 'best.pt')
|
|
return {'date': dir_date, 'path': dir_model_path}
|
|
|
|
dates = [convert(d) for d in listdir(models_dir)]
|
|
sorted_dates = list(sorted(dates, key=lambda x: x['date']))
|
|
if len(sorted_dates) == 0:
|
|
return None, None
|
|
last_model = sorted_dates[-1]
|
|
return last_model['date'], last_model['path']
|
|
|
|
|
|
if __name__ == '__main__':
|
|
latest_date, latest_model = get_latest_model()
|
|
# form_dataset(latest_date)
|
|
|
|
model_name = latest_model if latest_model is not None and path.isfile(latest_model) else 'yolov8m.yaml'
|
|
print(f'Initial model: {model_name}')
|
|
model = YOLO(model_name)
|
|
|
|
# cur_folder = path.join(datasets_dir, f'{prefix}2024-06-18')
|
|
|
|
cur_folder = today_dataset
|
|
yaml = abspath(path.join(cur_folder, 'data.yaml'))
|
|
results = model.train(data=yaml, epochs=100, batch=60, imgsz=640)
|
|
|
|
shutil.copytree(results.save_dir, path.join(models_dir, today_folder))
|
|
shutil.rmtree('runs')
|