Skip to content

Commit 8500830

Browse files
committed
update face detector.
1 parent 3be238a commit 8500830

35 files changed

+4095
-312
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ version.py
88
# *.png
99
# *.jpeg
1010
# *.jpg
11+
*.pt
1112
*.gif
1213
*.pth
1314
*.dat
@@ -121,4 +122,5 @@ venv.bak/
121122
.mypy_cache/
122123

123124
# project
124-
results/
125+
results/
126+
*_old*

README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ S-Lab, Nanyang Technological University
1616

1717
### Updates
1818

19+
- **2022.07.29**: The face detector is upgraded with the family of `['YOLOv5', 'RetinaFace']`. :hugs:
1920
- **2022.07.17**: The Colab demo of CodeFormer is available now. <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
20-
- **2022.07.16**: Test code for face restoration is released. :blush:
21-
- **2022.06.21**: This repo is created.
21+
- **2022.07.16**: Test code for face restoration is released. :blush:
22+
- **2022.06.21**: This repo is created.
2223

2324

2425

@@ -54,16 +55,16 @@ source activate codeformer
5455
# install python dependencies
5556
pip3 install -r requirements.txt
5657
python basicsr/setup.py develop
57-
conda install -c conda-forge dlib
5858
```
59+
<!-- conda install -c conda-forge dlib -->
5960

6061
### Quick Inference
6162

6263
##### Download Pre-trained Models:
63-
Download the dlib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1YCqeuNDGCsJBAm90eGh7M_WWKTt19yIY?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/Em2BaKU2OjhDolr11ngbrUgBu8q6SPn8E0jW-AC7nJF0Ig?e=HkjYrF)] to the `weights/dlib` folder.
64+
Download the facelib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder.
6465
You can download by run the following command OR manually download the pretrained models.
6566
```
66-
python scripts/download_pretrained_models.py dlib
67+
python scripts/download_pretrained_models.py facelib
6768
```
6869

6970
Download the CodeFormer pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder.
@@ -82,8 +83,7 @@ You can put the testing images in the `inputs/TestWhole` folder. If you would li
8283
python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
8384
8485
# For the whole images
85-
# Please set `--upsample_num_times 2` when faces are small and failed detected
86-
python inference_codeformer.py --w 0.7 --upsample_num_times 1 --test_path [input folder]
86+
python inference_codeformer.py --w 0.7 --test_path [input folder]
8787
```
8888

8989
NOTE that *w* is in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result.

basicsr/archs/vqgan_arch.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def forward(self, x):
228228

229229

230230
class Encoder(nn.Module):
231-
def __init__(self, in_channels, nf, out_channels, ch_mult, num_res_blocks, resolution, attn_resolutions):
231+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
232232
super().__init__()
233233
self.nf = nf
234234
self.num_resolutions = len(ch_mult)
@@ -264,7 +264,7 @@ def __init__(self, in_channels, nf, out_channels, ch_mult, num_res_blocks, resol
264264

265265
# normalise and convert to latent size
266266
blocks.append(normalize(block_in_ch))
267-
blocks.append(nn.Conv2d(block_in_ch, out_channels, kernel_size=3, stride=1, padding=1))
267+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
268268
self.blocks = nn.ModuleList(blocks)
269269

270270
def forward(self, x):
@@ -275,7 +275,7 @@ def forward(self, x):
275275

276276

277277
class Generator(nn.Module):
278-
def __init__(self, nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim):
278+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
279279
super().__init__()
280280
self.nf = nf
281281
self.ch_mult = ch_mult
@@ -362,7 +362,14 @@ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, att
362362
self.straight_through,
363363
self.kl_weight
364364
)
365-
self.generator = Generator(nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim)
365+
self.generator = Generator(
366+
self.nf,
367+
self.embed_dim,
368+
self.ch_mult,
369+
self.n_blocks,
370+
self.resolution,
371+
self.attn_resolutions
372+
)
366373

367374
if model_path is not None:
368375
chkpt = torch.load(model_path, map_location='cpu')

basicsr/utils/face_util.py

-220
This file was deleted.

facelib/detection/__init__.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os
2+
import torch
3+
from torch import nn
4+
from copy import deepcopy
5+
6+
from facelib.utils import load_file_from_url
7+
from facelib.utils import download_pretrained_models
8+
from facelib.detection.yolov5face.models.common import Conv
9+
10+
from .retinaface.retinaface import RetinaFace
11+
from .yolov5face.face_detector import YoloDetector
12+
13+
14+
def init_detection_model(model_name, half=False, device='cuda'):
15+
if 'retinaface' in model_name:
16+
model = init_retinaface_model(model_name, half, device)
17+
elif 'YOLOv5' in model_name:
18+
model = init_yolov5face_model(model_name, device)
19+
else:
20+
raise NotImplementedError(f'{model_name} is not implemented.')
21+
22+
return model
23+
24+
25+
def init_retinaface_model(model_name, half=False, device='cuda'):
26+
if model_name == 'retinaface_resnet50':
27+
model = RetinaFace(network_name='resnet50', half=half)
28+
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
29+
elif model_name == 'retinaface_mobile0.25':
30+
model = RetinaFace(network_name='mobile0.25', half=half)
31+
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
32+
else:
33+
raise NotImplementedError(f'{model_name} is not implemented.')
34+
35+
model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
36+
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
37+
# remove unnecessary 'module.'
38+
for k, v in deepcopy(load_net).items():
39+
if k.startswith('module.'):
40+
load_net[k[7:]] = v
41+
load_net.pop(k)
42+
model.load_state_dict(load_net, strict=True)
43+
model.eval()
44+
model = model.to(device)
45+
46+
return model
47+
48+
49+
def init_yolov5face_model(model_name, device='cuda'):
50+
if model_name == 'YOLOv5l':
51+
model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
52+
f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
53+
elif model_name == 'YOLOv5n':
54+
model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
55+
f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
56+
else:
57+
raise NotImplementedError(f'{model_name} is not implemented.')
58+
59+
model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
60+
if not os.path.exists(model_path):
61+
download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
62+
63+
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
64+
model.detector.load_state_dict(load_net, strict=True)
65+
model.detector.eval()
66+
model.detector = model.detector.to(device).float()
67+
68+
for m in model.detector.modules():
69+
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
70+
m.inplace = True # pytorch 1.7.0 compatibility
71+
elif isinstance(m, Conv):
72+
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
73+
74+
return model

0 commit comments

Comments
 (0)