diff --git a/.gitignore b/.gitignore index fb9edd7..5f6d828 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea/ *labels/ -*images/ \ No newline at end of file +*images/ +datasets/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..cd04aee --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +1. Install dependencies first +``` + python -m pip install --upgrade pip + pip install --upgrade huggingface_hub + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + pip install -q git+https://github.com/THU-MIG/yolov10.git + pip install albumentations +``` \ No newline at end of file diff --git a/train.py b/train.py index c3d3ad4..b724869 100644 --- a/train.py +++ b/train.py @@ -3,11 +3,10 @@ import shutil from datetime import datetime from pathlib import Path -from ultralytics import YOLO +from ultralytics import YOLOv10 from constants import current_images_dir, current_labels_dir, annotation_classes, today_dataset, prefix -yaml_name = 'data.yaml' -yaml_path = os.path.join(today_dataset, yaml_name) +yaml_path = os.path.join(today_dataset, 'data.yaml') train_set = 70 valid_set = 20 test_set = 10 @@ -48,9 +47,11 @@ def create_yaml(): for c in annotation_classes: lines.append(f'- {annotation_classes[c].name}') lines.append(f'nc: {len(annotation_classes)}') - lines.append(f'test: test/images') - lines.append(f'train: train/images') - lines.append(f'val: valid/images') + main_dir = f'../../{prefix}{datetime.now():%Y-%m-%d}' + + lines.append(f'test: {main_dir}/test/images') + lines.append(f'train: {main_dir}/train/images') + lines.append(f'val: {main_dir}/valid/images') lines.append('') with open(yaml_path, 'w', encoding='utf-8') as f: @@ -66,16 +67,12 @@ def get_recent_model(): continue date = datetime.strptime(date_str, '%Y-%m-%d') for file in os.listdir(os.path.join('datasets', d)): - if file.endswith('pt') and cur_date is None or cur_date < date: + if file.endswith('pt') and (cur_date is None or cur_date < date): cur_model = os.path.join('datasets', d, file) + cur_date = date return cur_model -def retrain(): - model = YOLO(get_recent_model() or 'yolov10x.yaml') - model.train(data=yaml_path, save=True, cache=True) - - def revert_to_current(date): def revert_dir(src_dir, dest_dir): for file in os.listdir(src_dir): @@ -90,6 +87,13 @@ def revert_to_current(date): shutil.rmtree(date_dataset) -form_dataset() -# revert_to_current('2024-06-06') -retrain() +if __name__ == '__main__': + # revert_to_current('2024-06-06') + # form_dataset() + # create_yaml() + + model = get_recent_model() or 'yolov10x.yaml' + model = YOLOv10(model=model, task='detect').to('cuda') + + results = model.train(data=yaml_path, epochs=200, imgsz=1280, save=True, cache=True) + pass