Files
ai-training/train.py
T
Alex Bezdieniezhnykh 2325fd0916 update gitignore
simplify paths
remove runs folder
2024-06-10 09:07:13 +03:00

91 lines
3.2 KiB
Python

from os import path, replace, remove, listdir, makedirs
from os.path import abspath
import shutil
from datetime import datetime
from pathlib import Path
from ultralytics import YOLOv10
from constants import current_images_dir, current_labels_dir, annotation_classes
prefix = 'zombobase-'
latest_model = f'models/{prefix}latest.pt'
today_folder = f'{prefix}{datetime.now():%Y-%m-%d}'
train_set = 70
valid_set = 20
test_set = 10
def form_dataset():
makedirs(path.join('datasets', today_folder), exist_ok=True)
images = listdir(current_images_dir)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
move_annotations(images[:train_size], 'train')
move_annotations(images[train_size:train_size + valid_size], 'valid')
move_annotations(images[train_size + valid_size:], 'test')
create_yaml()
def move_annotations(images, folder):
today_dataset = path.join('datasets', today_folder)
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
for image_name in images:
image_path = path.join(current_images_dir, image_name)
label_name = f'{Path(image_name).stem}.txt'
label_path = path.join(current_labels_dir, label_name)
if not path.exists(label_path):
remove(image_path)
else:
replace(image_path, path.join(destination_images, image_name))
replace(label_path, path.join(destination_labels, label_name))
def create_yaml():
lines = ['names:']
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
lines.append(f'nc: {len(annotation_classes)}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join('datasets', today_folder, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def revert_to_current(date):
def revert_dir(src_dir, dest_dir):
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join('datasets', f'{prefix}{date}')
current_dataset = path.join('datasets', f'{prefix}current')
for subset in ['test', 'train', 'valid']:
revert_dir(path.join(date_dataset, subset, 'images'), path.join(current_dataset, 'images'))
revert_dir(path.join(date_dataset, subset, 'labels'), path.join(current_dataset, 'labels'))
shutil.rmtree(date_dataset)
if __name__ == '__main__':
# form_dataset()
# create_yaml()
model = YOLOv10(latest_model or 'yolov10x.yaml')
yaml = abspath(path.join('datasets', today_folder, 'data.yaml'))
results = model.train(data=yaml, epochs=100, batch=10, imgsz=640)
shutil.copy(f'{results.save_dir}/weights/best.pt', latest_model)
shutil.copytree(results.save_dir, f'models/{today_folder}')
shutil.rmtree('runs')
shutil.rmtree('models/zombobase-latest')