Skip to content

Commit f26d82e

Browse files
authored
Add yolo models (#228)
* add yolo v8 models * add yolo v11 and v12 models * add a selection of segmentation models
1 parent a20adbe commit f26d82e

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

.github/workflows/models.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
python-version: '3.10'
2222
- run: |
2323
python -m pip install --upgrade pip
24-
pip install torch torchvision boto3
24+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
25+
pip install boto3 ultralytics
2526
- run: |
2627
python tools/convert-models.py

tools/convert-models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import boto3
55
from botocore.exceptions import ClientError
6+
from ultralytics import settings
67

78

89
def upload_blob(bucket_name, source_file_name, destination_blob_name):
@@ -110,9 +111,34 @@ def blob_exist(bucket_name, blob_name):
110111
'fasterrcnn_resnet50_v2': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth',
111112
'fasterrcnn_mobilenet_v3_large': 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth',
112113
'fasterrcnn_mobilenet_v3_large_320': 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth',
114+
'yolo_v8_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8l.pt',
115+
'yolo_v8_l_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8l-seg.pt',
116+
'yolo_v8_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8m.pt',
117+
'yolo_v8_m_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8m-seg.pt',
118+
'yolo_v8_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt',
119+
'yolo_v8_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s.pt',
120+
'yolo_v8_s_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8s-seg.pt',
121+
'yolo_v8_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8x.pt',
122+
'yolo_v11_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11l.pt',
123+
'yolo_v11_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11m.pt',
124+
'yolo_v11_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt',
125+
'yolo_v11_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s.pt',
126+
'yolo_v11_s_cls': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-cls.pt',
127+
'yolo_v11_s_obb': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-obb.pt',
128+
'yolo_v11_s_pose': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-pose.pt',
129+
'yolo_v11_s_seg': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11s-seg.pt',
130+
'yolo_v11_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11x.pt',
131+
'yolo_v12_l': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12l.pt',
132+
'yolo_v12_m': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12m.pt',
133+
'yolo_v12_n': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12n.pt',
134+
'yolo_v12_s': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12s.pt',
135+
'yolo_v12_x': 'https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo12x.pt',
113136
}
114137

115138
os.makedirs("models", exist_ok=True)
139+
# yolo specifics
140+
os.makedirs("runs", exist_ok=True)
141+
settings.update({"runs_dir": "runs/", "weights_dir": "models/", "sync": False})
116142

117143
for name, url in models.items():
118144
fpath = "models/" + name + ".pth"
@@ -124,6 +150,11 @@ def blob_exist(bucket_name, blob_name):
124150
# download from url, convert and upload the converted weights
125151
m = load_state_dict_from_url(url, progress=False)
126152
converted = {}
153+
154+
# yolo models weights are embedded in a BaseModel per https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/tasks.py#L309
155+
if name.startswith("yolo_"):
156+
m = m["model"].model.float().state_dict()
157+
127158
for nm, par in m.items():
128159
converted.update([(nm, par.clone())])
129160
torch.save(converted, fpath, _use_new_zipfile_serialization=True)

0 commit comments

Comments
 (0)