Skip to content

Commit 07c8cc6

Browse files
committed
support MPS (Mac M1) device; Happy Lunar New Year!
1 parent 7a584fd commit 07c8cc6

File tree

10 files changed

+75
-41
lines changed

10 files changed

+75
-41
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,4 @@ This project is licensed under <a rel="license" href="https://github.com/sczhou/
146146
This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
147147

148148
### Contact
149-
If you have any question, please feel free to reach me out at `[email protected]`.
149+
If you have any questions, please feel free to reach me out at `[email protected]`.

basicsr/setup.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import subprocess
77
import sys
88
import time
9-
import torch
109
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
10+
from utils.misc import gpu_is_available
1111

1212
version_file = './basicsr/version.py'
1313

@@ -87,7 +87,8 @@ def make_cuda_ext(name, module, sources, sources_cuda=None):
8787
define_macros = []
8888
extra_compile_args = {'cxx': []}
8989

90-
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
90+
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
91+
if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1':
9192
define_macros += [('WITH_CUDA', None)]
9293
extension = CUDAExtension
9394
extra_compile_args['nvcc'] = [

basicsr/utils/misc.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
1-
import numpy as np
21
import os
2+
import re
33
import random
44
import time
55
import torch
6+
import numpy as np
67
from os import path as osp
78

89
from .dist_util import master_only
910
from .logger import get_root_logger
1011

12+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
13+
torch.__version__)[0][:3])] >= [1, 12, 0]
14+
15+
def gpu_is_available():
16+
if IS_HIGH_VERSION:
17+
if torch.backends.mps.is_available():
18+
return True
19+
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
20+
21+
def get_device(gpu_id=None):
22+
if gpu_id is None:
23+
gpu_str = ''
24+
elif isinstance(gpu_id, int):
25+
gpu_str = f':{gpu_id}'
26+
else:
27+
raise TypeError('Input should be int value.')
28+
29+
if IS_HIGH_VERSION:
30+
if torch.backends.mps.is_available():
31+
return torch.device('mps'+gpu_str)
32+
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
33+
1134

1235
def set_random_seed(seed):
1336
"""Set random seeds."""

basicsr/utils/realesrgan_utils.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import queue
66
import threading
77
import torch
8-
from basicsr.utils.download_util import load_file_from_url
98
from torch.nn import functional as F
9+
from basicsr.utils.download_util import load_file_from_url
10+
from basicsr.utils.misc import get_device
1011

1112
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1213

13-
1414
class RealESRGANer():
1515
"""A helper class for upsampling images with RealESRGAN.
1616
@@ -44,11 +44,14 @@ def __init__(self,
4444
self.half = half
4545

4646
# initialize model
47-
if gpu_id:
48-
self.device = torch.device(
49-
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
50-
else:
51-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
47+
# if gpu_id:
48+
# self.device = torch.device(
49+
# f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
50+
# else:
51+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
52+
53+
self.device = get_device(gpu_id) if device is None else device
54+
5255
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
5356
if model_path.startswith('https://'):
5457
model_path = load_file_from_url(

facelib/detection/retinaface/retinaface.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
1212
py_cpu_nms)
1313

14-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14+
from basicsr.utils.misc import get_device
15+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16+
device = get_device()
1517

1618

1719
def generate_config(network_name):
18-
20+
1921
cfg_mnet = {
2022
'name': 'mobilenet0.25',
2123
'min_sizes': [[16, 32], [64, 128], [256, 512]],

facelib/detection/yolov5face/face_detector.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import copy
2-
import os
3-
from pathlib import Path
4-
51
import cv2
6-
import numpy as np
2+
import copy
3+
import re
74
import torch
8-
from torch import nn
5+
import numpy as np
96

10-
from facelib.detection.yolov5face.models.common import Conv
7+
from pathlib import Path
118
from facelib.detection.yolov5face.models.yolo import Model
129
from facelib.detection.yolov5face.utils.datasets import letterbox
1310
from facelib.detection.yolov5face.utils.general import (
@@ -17,7 +14,9 @@
1714
scale_coords_landmarks,
1815
)
1916

20-
IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9, 0)
17+
# IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9)
18+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
19+
torch.__version__)[0][:3])] >= [1, 9, 0]
2120

2221

2322
def isListempty(inList):

facelib/utils/face_restoration_helper.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from facelib.detection import init_detection_model
88
from facelib.parsing import init_parsing_model
99
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
10+
from basicsr.utils.misc import get_device
1011

1112

1213
def get_largest_face(det_faces, h, w):
@@ -97,7 +98,8 @@ def __init__(self,
9798
self.pad_input_imgs = []
9899

99100
if device is None:
100-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102+
self.device = get_device()
101103
else:
102104
self.device = device
103105

inference_codeformer.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from torchvision.transforms.functional import normalize
77
from basicsr.utils import imwrite, img2tensor, tensor2img
88
from basicsr.utils.download_util import load_file_from_url
9+
from basicsr.utils.misc import gpu_is_available, get_device
910
from facelib.utils.face_restoration_helper import FaceRestoreHelper
1011
from facelib.utils.misc import is_gray
11-
import torch.nn.functional as F
1212

1313
from basicsr.utils.registry import ARCH_REGISTRY
1414

@@ -19,9 +19,7 @@
1919
def set_realesrgan():
2020
from basicsr.archs.rrdbnet_arch import RRDBNet
2121
from basicsr.utils.realesrgan_utils import RealESRGANer
22-
23-
cuda_is_available = torch.cuda.is_available()
24-
half = True if cuda_is_available else False
22+
2523
model = RRDBNet(
2624
num_in_ch=3,
2725
num_out_ch=3,
@@ -37,10 +35,10 @@ def set_realesrgan():
3735
tile=args.bg_tile,
3836
tile_pad=40,
3937
pre_pad=0,
40-
half=half, # need to set False in CPU mode
38+
half=torch.cuda.is_available(), # need to set False in CPU/MPS mode
4139
)
4240

43-
if not cuda_is_available: # CPU
41+
if not gpu_is_available(): # CPU
4442
import warnings
4543
warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
4644
'The unoptimized RealESRGAN is slow on CPU. '
@@ -49,7 +47,8 @@ def set_realesrgan():
4947
return upsampler
5048

5149
if __name__ == '__main__':
52-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51+
device = get_device()
5352
parser = argparse.ArgumentParser()
5453

5554
parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
@@ -79,10 +78,10 @@ def set_realesrgan():
7978
# ------------------------ input & output ------------------------
8079
w = args.fidelity_weight
8180
input_video = False
82-
if args.input_path.endswith(('jpg', 'png')): # input single img path
81+
if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
8382
input_img_list = [args.input_path]
8483
result_root = f'results/test_img_{w}'
85-
elif args.input_path.endswith(('mp4', 'mov', 'avi')): # input video path
84+
elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
8685
from basicsr.utils.video_util import VideoReader, VideoWriter
8786
input_img_list = []
8887
vidreader = VideoReader(args.input_path)
@@ -100,7 +99,7 @@ def set_realesrgan():
10099
if args.input_path.endswith('/'): # solve when path ends with /
101100
args.input_path = args.input_path[:-1]
102101
# scan all the jpg and png images
103-
input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jp][pn]g')))
102+
input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
104103
result_root = f'results/{os.path.basename(args.input_path)}_{w}'
105104

106105
if not args.output_path is None: # set output path

web-demos/hugging_face/app.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313

1414
from torchvision.transforms.functional import normalize
1515

16+
from basicsr.archs.rrdbnet_arch import RRDBNet
1617
from basicsr.utils import imwrite, img2tensor, tensor2img
1718
from basicsr.utils.download_util import load_file_from_url
18-
from facelib.utils.face_restoration_helper import FaceRestoreHelper
19-
from facelib.utils.misc import is_gray
20-
from basicsr.archs.rrdbnet_arch import RRDBNet
19+
from basicsr.utils.misc import gpu_is_available, get_device
2120
from basicsr.utils.realesrgan_utils import RealESRGANer
22-
2321
from basicsr.utils.registry import ARCH_REGISTRY
2422

23+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
24+
from facelib.utils.misc import is_gray
25+
2526

2627
os.system("pip freeze")
2728

@@ -65,7 +66,8 @@ def imread(img_path):
6566

6667
# set enhancer with RealESRGAN
6768
def set_realesrgan():
68-
half = True if torch.cuda.is_available() else False
69+
# half = True if torch.cuda.is_available() else False
70+
half = True if gpu_is_available() else False
6971
model = RRDBNet(
7072
num_in_ch=3,
7173
num_out_ch=3,
@@ -86,7 +88,8 @@ def set_realesrgan():
8688
return upsampler
8789

8890
upsampler = set_realesrgan()
89-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92+
device = get_device()
9093
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
9194
dim_embd=512,
9295
codebook_size=1024,

web-demos/replicate/predict.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
except Exception:
1515
print('please install cog package')
1616

17-
from basicsr.utils import imwrite, img2tensor, tensor2img
1817
from basicsr.archs.rrdbnet_arch import RRDBNet
18+
from basicsr.utils import imwrite, img2tensor, tensor2img
1919
from basicsr.utils.realesrgan_utils import RealESRGANer
20+
from basicsr.utils.misc import gpu_is_available
2021
from basicsr.utils.registry import ARCH_REGISTRY
21-
from facelib.utils.face_restoration_helper import FaceRestoreHelper
2222

23+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
2324

2425
class Predictor(BasePredictor):
2526
def setup(self):
@@ -159,7 +160,8 @@ def imread(img_path):
159160

160161

161162
def set_realesrgan():
162-
if not torch.cuda.is_available(): # CPU
163+
# if not torch.cuda.is_available(): # CPU
164+
if not gpu_is_available(): # CPU
163165
import warnings
164166

165167
warnings.warn(

0 commit comments

Comments
 (0)