small refactoring

simplify learning, improve readme
This commit is contained in:
Alex Bezdieniezhnykh
2024-06-09 18:09:26 +03:00
parent a9b777acc4
commit 70abaded05
2 changed files with 56 additions and 50 deletions
+42 -48
View File
@@ -1,21 +1,27 @@
import os
from os import path, replace, remove, listdir, makedirs
from os.path import abspath
import shutil
from datetime import datetime
from pathlib import Path
from ultralytics import YOLOv10
from constants import current_images_dir, current_labels_dir, annotation_classes, prefix
today_dataset = os.path.join('datasets', f'{prefix}{datetime.now():%Y-%m-%d}')
yaml_path = os.path.join(today_dataset, 'data.yaml')
def get_yaml_path(date_str):
date_dataset = path.join('datasets', f'{prefix}{date_str}')
date_yaml_path = abspath(path.join(date_dataset, 'data.yaml'))
return date_dataset, date_yaml_path
today_dataset, today_yaml = get_yaml_path(f'{datetime.now():%Y-%m-%d}')
train_set = 70
valid_set = 20
test_set = 10
def form_dataset():
os.makedirs(today_dataset, exist_ok=True)
images = os.listdir(current_images_dir)
makedirs(today_dataset, exist_ok=True)
images = listdir(current_images_dir)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
@@ -28,19 +34,19 @@ def form_dataset():
def move_annotations(images, folder):
destination_images = os.path.join(today_dataset, folder, 'images')
os.makedirs(destination_images, exist_ok=True)
destination_labels = os.path.join(today_dataset, folder, 'labels')
os.makedirs(destination_labels, exist_ok=True)
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
for image_name in images:
image_path = os.path.join(current_images_dir, image_name)
image_path = path.join(current_images_dir, image_name)
label_name = f'{Path(image_name).stem}.txt'
label_path = os.path.join(current_labels_dir, label_name)
if not os.path.exists(label_path):
os.remove(image_path)
label_path = path.join(current_labels_dir, label_name)
if not path.exists(label_path):
remove(image_path)
else:
os.replace(image_path, os.path.join(destination_images, image_name))
os.replace(label_path, os.path.join(destination_labels, label_name))
replace(image_path, path.join(destination_images, image_name))
replace(label_path, path.join(destination_labels, label_name))
def create_yaml():
@@ -48,53 +54,41 @@ def create_yaml():
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
lines.append(f'nc: {len(annotation_classes)}')
main_dir = f'../../{prefix}{datetime.now():%Y-%m-%d}'
lines.append(f'test: {main_dir}/test/images')
lines.append(f'train: {main_dir}/train/images')
lines.append(f'val: {main_dir}/valid/images')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
with open(yaml_path, 'w', encoding='utf-8') as f:
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def get_recent_model():
cur_model = None
cur_date = None
for d in os.listdir('datasets'):
date_str = d.replace(prefix, '')
if date_str == 'current' or date_str == f'{datetime.now():%Y-%m-%d}':
continue
date = datetime.strptime(date_str, '%Y-%m-%d')
for file in os.listdir(os.path.join('datasets', d)):
if file.endswith('pt') and (cur_date is None or cur_date < date):
cur_model = os.path.join('datasets', d, file)
cur_date = date
return cur_model
def revert_to_current(date):
def revert_dir(src_dir, dest_dir):
for file in os.listdir(src_dir):
s = os.path.join(src_dir, file)
d = os.path.join(dest_dir, file)
os.replace(s, d)
date_dataset = os.path.join('datasets', f'{prefix}{date}')
current_dataset = os.path.join('datasets', f'{prefix}current')
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join('datasets', f'{prefix}{date}')
current_dataset = path.join('datasets', f'{prefix}current')
for subset in ['test', 'train', 'valid']:
revert_dir(os.path.join(date_dataset, subset, 'images'), os.path.join(current_dataset, 'images'))
revert_dir(os.path.join(date_dataset, subset, 'labels'), os.path.join(current_dataset, 'labels'))
revert_dir(path.join(date_dataset, subset, 'images'), path.join(current_dataset, 'images'))
revert_dir(path.join(date_dataset, subset, 'labels'), path.join(current_dataset, 'labels'))
shutil.rmtree(date_dataset)
if __name__ == '__main__':
# revert_to_current('2024-06-06')
# form_dataset()
# create_yaml()
model = get_recent_model() or 'yolov10x.yaml'
model = YOLOv10(model=model, task='detect').to('cuda')
model = YOLOv10('datasets/zombobase-latest.pt' or 'yolov10x.yaml')
_, yaml = get_yaml_path('2024-06-09')
results = model.train(data=yaml, epochs=2, batch=10, imgsz=640)
print(results)
res_model = path.join(results['save_dir'], '/weights/best.pt')
print(res_model)
shutil.copy(res_model, 'datasets/zombobase-latest2.pt')
results = model.train(data=yaml_path, epochs=200, imgsz=1280, save=True, cache=True)
pass