mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 09:16:36 +00:00
add readme
update train.py with yolov10 fix generation of data.yaml
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
.idea/
|
||||
*labels/
|
||||
*images/
|
||||
datasets/
|
||||
@@ -0,0 +1,8 @@
|
||||
1. Install dependencies first
|
||||
```
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade huggingface_hub
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install -q git+https://github.com/THU-MIG/yolov10.git
|
||||
pip install albumentations
|
||||
```
|
||||
@@ -3,11 +3,10 @@ import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics import YOLOv10
|
||||
from constants import current_images_dir, current_labels_dir, annotation_classes, today_dataset, prefix
|
||||
|
||||
yaml_name = 'data.yaml'
|
||||
yaml_path = os.path.join(today_dataset, yaml_name)
|
||||
yaml_path = os.path.join(today_dataset, 'data.yaml')
|
||||
train_set = 70
|
||||
valid_set = 20
|
||||
test_set = 10
|
||||
@@ -48,9 +47,11 @@ def create_yaml():
|
||||
for c in annotation_classes:
|
||||
lines.append(f'- {annotation_classes[c].name}')
|
||||
lines.append(f'nc: {len(annotation_classes)}')
|
||||
lines.append(f'test: test/images')
|
||||
lines.append(f'train: train/images')
|
||||
lines.append(f'val: valid/images')
|
||||
main_dir = f'../../{prefix}{datetime.now():%Y-%m-%d}'
|
||||
|
||||
lines.append(f'test: {main_dir}/test/images')
|
||||
lines.append(f'train: {main_dir}/train/images')
|
||||
lines.append(f'val: {main_dir}/valid/images')
|
||||
lines.append('')
|
||||
|
||||
with open(yaml_path, 'w', encoding='utf-8') as f:
|
||||
@@ -66,16 +67,12 @@ def get_recent_model():
|
||||
continue
|
||||
date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
for file in os.listdir(os.path.join('datasets', d)):
|
||||
if file.endswith('pt') and cur_date is None or cur_date < date:
|
||||
if file.endswith('pt') and (cur_date is None or cur_date < date):
|
||||
cur_model = os.path.join('datasets', d, file)
|
||||
cur_date = date
|
||||
return cur_model
|
||||
|
||||
|
||||
def retrain():
|
||||
model = YOLO(get_recent_model() or 'yolov10x.yaml')
|
||||
model.train(data=yaml_path, save=True, cache=True)
|
||||
|
||||
|
||||
def revert_to_current(date):
|
||||
def revert_dir(src_dir, dest_dir):
|
||||
for file in os.listdir(src_dir):
|
||||
@@ -90,6 +87,13 @@ def revert_to_current(date):
|
||||
shutil.rmtree(date_dataset)
|
||||
|
||||
|
||||
form_dataset()
|
||||
# revert_to_current('2024-06-06')
|
||||
retrain()
|
||||
if __name__ == '__main__':
|
||||
# revert_to_current('2024-06-06')
|
||||
# form_dataset()
|
||||
# create_yaml()
|
||||
|
||||
model = get_recent_model() or 'yolov10x.yaml'
|
||||
model = YOLOv10(model=model, task='detect').to('cuda')
|
||||
|
||||
results = model.train(data=yaml_path, epochs=200, imgsz=1280, save=True, cache=True)
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user