Files
ai-training/train.py
T
Alex Bezdieniezhnykh 70abaded05 small refactoring
simplify learning, improve readme
2024-06-09 18:09:26 +03:00

95 lines
3.1 KiB
Python

from os import path, replace, remove, listdir, makedirs
from os.path import abspath
import shutil
from datetime import datetime
from pathlib import Path
from ultralytics import YOLOv10
from constants import current_images_dir, current_labels_dir, annotation_classes, prefix
def get_yaml_path(date_str):
date_dataset = path.join('datasets', f'{prefix}{date_str}')
date_yaml_path = abspath(path.join(date_dataset, 'data.yaml'))
return date_dataset, date_yaml_path
today_dataset, today_yaml = get_yaml_path(f'{datetime.now():%Y-%m-%d}')
train_set = 70
valid_set = 20
test_set = 10
def form_dataset():
makedirs(today_dataset, exist_ok=True)
images = listdir(current_images_dir)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
move_annotations(images[:train_size], 'train')
move_annotations(images[train_size:train_size + valid_size], 'valid')
move_annotations(images[train_size + valid_size:], 'test')
create_yaml()
def move_annotations(images, folder):
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
for image_name in images:
image_path = path.join(current_images_dir, image_name)
label_name = f'{Path(image_name).stem}.txt'
label_path = path.join(current_labels_dir, label_name)
if not path.exists(label_path):
remove(image_path)
else:
replace(image_path, path.join(destination_images, image_name))
replace(label_path, path.join(destination_labels, label_name))
def create_yaml():
lines = ['names:']
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
lines.append(f'nc: {len(annotation_classes)}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def revert_to_current(date):
def revert_dir(src_dir, dest_dir):
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join('datasets', f'{prefix}{date}')
current_dataset = path.join('datasets', f'{prefix}current')
for subset in ['test', 'train', 'valid']:
revert_dir(path.join(date_dataset, subset, 'images'), path.join(current_dataset, 'images'))
revert_dir(path.join(date_dataset, subset, 'labels'), path.join(current_dataset, 'labels'))
shutil.rmtree(date_dataset)
if __name__ == '__main__':
# form_dataset()
# create_yaml()
model = YOLOv10('datasets/zombobase-latest.pt' or 'yolov10x.yaml')
_, yaml = get_yaml_path('2024-06-09')
results = model.train(data=yaml, epochs=2, batch=10, imgsz=640)
print(results)
res_model = path.join(results['save_dir'], '/weights/best.pt')
print(res_model)
shutil.copy(res_model, 'datasets/zombobase-latest2.pt')
pass