From f673166f34f092aaa51d8303303ef05b2ce12dca Mon Sep 17 00:00:00 2001 From: Ftps Date: Thu, 16 Feb 2023 12:48:17 +0900 Subject: [PATCH 1/7] Add cpu-mps_support --- cldm/model.py | 4 +++- gradio_canny2image.py | 18 ++++++++++++++---- ldm/models/diffusion/ddim.py | 9 ++++++--- ldm/modules/encoders/modules.py | 19 ++++++++++++++----- 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/cldm/model.py b/cldm/model.py index fed3c31ac1..387938c101 100644 --- a/cldm/model.py +++ b/cldm/model.py @@ -9,8 +9,10 @@ def get_state_dict(d): return d.get('state_dict', d) -def load_state_dict(ckpt_path, location='cpu'): +def load_state_dict(ckpt_path, location): _, extension = os.path.splitext(ckpt_path) + if str(location) == "mps": + location = "cpu" if extension.lower() == ".safetensors": import safetensors.torch state_dict = safetensors.torch.load_file(ckpt_path, device=location) diff --git a/gradio_canny2image.py b/gradio_canny2image.py index f9b1e9f124..7b86e3ea28 100644 --- a/gradio_canny2image.py +++ b/gradio_canny2image.py @@ -15,12 +15,22 @@ from ldm.models.diffusion.ddim import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + apply_canny = CannyDetector() +device = get_device() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold): @@ -31,7 +41,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = apply_canny(img, low_threshold, high_threshold) detected_map = HWC3(detected_map) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 27ead0ea91..09c0c7f893 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -8,16 +8,19 @@ class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, device, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 4edd5496b9..2cd4e1dc9f 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -8,6 +8,15 @@ from ldm.util import default, count_params +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + class AbstractEncoder(nn.Module): def __init__(self): super().__init__() @@ -42,7 +51,7 @@ def forward(self, batch, key=None, disable_dropout=False): c = self.embedding(c) return c - def get_unconditional_conditioning(self, bs, device="cuda"): + def get_unconditional_conditioning(self, bs, device=get_device()): uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc} @@ -57,7 +66,7 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + def __init__(self, version="google/t5-v1_1-large", device=get_device(), max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) @@ -92,7 +101,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): "pooled", "hidden" ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + def __init__(self, version="openai/clip-vit-large-patch14", device=get_device(), max_length=77, freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -140,7 +149,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): "last", "penultimate" ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77, freeze=True, layer="last"): super().__init__() assert layer in self.LAYERS @@ -194,7 +203,7 @@ def encode(self, text): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=get_device(), clip_max_length=77, t5_max_length=77): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) From 8d2ff6dea0dd3a163acf7804dc3d0473b0e25fc5 Mon Sep 17 00:00:00 2001 From: Ftps Date: Thu, 16 Feb 2023 14:43:09 +0900 Subject: [PATCH 2/7] Add support mps and cpu --- .DS_Store | Bin 0 -> 10244 bytes annotator/.DS_Store | Bin 0 -> 6148 bytes annotator/ckpts/.DS_Store | Bin 0 -> 6148 bytes annotator/hed/__init__.py | 9 +++++---- annotator/midas/__init__.py | 7 ++++--- annotator/mlsd/__init__.py | 4 ++-- gradio_depth2image.py | 20 +++++++++++++++----- gradio_fake_scribble2image.py | 21 ++++++++++++++++----- gradio_hed2image.py | 20 +++++++++++++++----- gradio_hough2image.py | 20 +++++++++++++++----- gradio_normal2image.py | 20 +++++++++++++++----- gradio_pose2image.py | 18 ++++++++++++++---- gradio_scribble2image.py | 17 +++++++++++++---- gradio_scribble2image_interactive.py | 18 ++++++++++++++---- 14 files changed, 128 insertions(+), 46 deletions(-) create mode 100644 .DS_Store create mode 100644 annotator/.DS_Store create mode 100644 annotator/ckpts/.DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..bbb57c8300036ba2657709dd0a6445ef433aadfa GIT binary patch literal 10244 zcmeHMOK;Oa5S~p#>nKna^`TxWStGnA7zX|Y2Jp@1VijX)u7&}_fMMW@ z0nQILD$8;r`>~X%0~fgkfXw5uENEjNAZ_eKmJ`{Jr8EU*yn7IdsYtdMA~wf+o7Ex9 ziR{NRHYX9AlSq0N$qq%x-eG4;brR)RnyX>JFfhmf@7;^otvl4A^VI&`Iq-sd%?oPK z;6=5Ky8c5Jw1^7W>-+cx)J3Zc8EO`v$6M6SqE6Dsms-bbMcq>AtH@2|Z%ofvGgi@h zQ)vcgm2S1$i*~BvE53Uk1YYe|yIOnEYMeQ93%h~et+xC|C=FU}11hgywtP1@+X;HU z8%f(r{K6_&1!r#g;^N_EY0Z9cf9rD1zSxS}k1j6@*5c~D$A`{|*Y<-Ca<(Lr#6!2@ z`>X3mv@*Cjt%SZ8_{Vt5mU(dNkxT6~M`8u<(yZ3Oq)$iGqGMWTuV@zBpK6f@Y5?>| z)BCRJ&EZtWr&HjDvWK?3@1l=$_^(Y~+CIl4SNkbqJT3USNj<8`c&Xw^BW)$4^gE*)u!_crt~xpkZ}@6-~eq)=3IyS;P4Q&Qo}Iegn^eZ+*l}Lq>KTtrpf*cKjh~ zQvK+@<+&e^W08HELYd0SCDL#26~E2U6FQ(>+NWn!p?z(ovS(E7X?Qe1oHc;0=MrXr zsO~k%;i>s+*}Fi8U~!xF=qcnnm~zu>3ih+gkcFsjA|r%D;sR?vrC*To6yeIz)kaKc zK9Dh>_An3K06$M=RTrLfkazgo(M7NPJmc%ZnN@sKVq|$N3gQJ70a62q*?8<-74J6m zbi8=BoEb7k#A`G6R^feVB*I9}WhAJh!RNyi!+>GHFkl!k3{1?xRIF(iA@_&>|9@gU zqi7f~4E%o#i2Pn$%!9^a&L(1cL-0*SyT~yXBvE+hw kIg$NX;t9&%|7U>U9m1^g_qf^rr}n?v?$+%8uigItCw~81t^fc4 literal 0 HcmV?d00001 diff --git a/annotator/.DS_Store b/annotator/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..253418f127a257d4c51370f4e949ca7b1348231e GIT binary patch literal 6148 zcmeHKPjAyO9R0cNXiA0D14z3dMdDhGZljZ?T}l}Tt`xxmP)M4d5dALv1AtDEfv2h#O6~WIrZA8tq z+yM&J$C!K?(-}qNZLn>Fb-+5XZ4U6-ZDX~1G^Goy-9PKsdzQ#?kjMZG-X0z!?fq1N z648!gr*BAc4Ge+#BboaMdVcTuZ1c5XGDrqg!bj?Uy_ zl>7NS8~U@?Tzes968tXv!OJ*Wc%8dXWSaYN8qJg-4kJ+Byo%FME{1ZRhMCfi^aH2m zw7kyla@jxVb=^mY$E&WpJl5r-)vD#(d+_k&xp$sSQu&E(z}APoPX&8fjL)dBOJ?jJ zN3%4M=?Ap3qOSJIu;K`7$}UeRrV#&2)w>@RW`ib_P)3y=Yy7@qkC(igc6k4w-*pfB zM}O(?<#;OoIwUdYBFxVzN*P6Er`3G^Qkc&znu8nDd$4ElGNM^U!SlMPXhM-gkuQJ3|EJgA|H~pnHoS|eHz_Cr9~U^DB$KXu?2U4-`- literal 0 HcmV?d00001 diff --git a/annotator/ckpts/.DS_Store b/annotator/ckpts/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..96f33f277a650111f7334030c53fd4998536f7e7 GIT binary patch literal 6148 zcmeHK%TB{E5FD3E6u2Pu00)rr14RXdxS>$NfioWvnl>$nKv5q8Zn^Ud`~%;|31-(8 z(6l{pK?rswdz_7Tyq?H*48XM}-CbZEK#fJPy2@&f$$Kd?RtcVEqEn6pcNk!V49QGy zbC?3Az;9DP)@~gsS8W%MT(zI;_eOmga_rDM#^qT)=%o1|WrS=O$1FQ$scKVxY(oBrhq9ht-z9hT$l6zbn*RvI>??(0aM^lDd4In z=O-;rDV(h{lasSHV!2=ulYWIq4Ph0IV{OP$yvd@*cab!RalpbOJv93fP#LT;1%6b4 EcZh(97ytkO literal 0 HcmV?d00001 diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py index 56532c374d..f882115f15 100644 --- a/annotator/hed/__init__.py +++ b/annotator/hed/__init__.py @@ -93,20 +93,21 @@ def forward(self, tenInput): return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1)) -class HEDdetector: - def __init__(self): +class HEDdetector(): + def __init__(self, device): + self.device = device remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth" modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth") if not os.path.exists(modelpath): from basicsr.utils.download_util import load_file_from_url load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) - self.netNetwork = Network(modelpath).cuda().eval() + self.netNetwork = Network(modelpath).to(device).eval() def __call__(self, input_image): assert input_image.ndim == 3 input_image = input_image[:, :, ::-1].copy() with torch.no_grad(): - image_hed = torch.from_numpy(input_image).float().cuda() + image_hed = torch.from_numpy(input_image).float().to(self.device) image_hed = image_hed / 255.0 image_hed = rearrange(image_hed, 'h w c -> 1 c h w') edge = self.netNetwork(image_hed)[0] diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py index dc5ac03eea..7aa15e5dbb 100644 --- a/annotator/midas/__init__.py +++ b/annotator/midas/__init__.py @@ -7,14 +7,15 @@ class MidasDetector: - def __init__(self): - self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + def __init__(self, device): + self.device = device + self.model = MiDaSInference(model_type="dpt_hybrid").to(device) def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): assert input_image.ndim == 3 image_depth = input_image with torch.no_grad(): - image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = torch.from_numpy(image_depth).float().to(self.device) image_depth = image_depth / 127.5 - 1.0 image_depth = rearrange(image_depth, 'h w c -> 1 c h w') depth = self.model(image_depth)[0] diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py index 42af28c682..cb447811b5 100644 --- a/annotator/mlsd/__init__.py +++ b/annotator/mlsd/__init__.py @@ -15,14 +15,14 @@ class MLSDdetector: - def __init__(self): + def __init__(self, device): model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") if not os.path.exists(model_path): from basicsr.utils.download_util import load_file_from_url load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) model = MobileV2_MLSD_Large() model.load_state_dict(torch.load(model_path), strict=True) - self.model = model.cuda().eval() + self.model = model.to(device).eval() def __call__(self, input_image, thr_v, thr_d): assert input_image.ndim == 3 diff --git a/gradio_depth2image.py b/gradio_depth2image.py index 6a72de6da3..fbecadf6b9 100644 --- a/gradio_depth2image.py +++ b/gradio_depth2image.py @@ -15,12 +15,22 @@ from ldm.models.diffusion.ddim import DDIMSampler -apply_midas = MidasDetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_midas = MidasDetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_fake_scribble2image.py b/gradio_fake_scribble2image.py index fac13d9ec2..07a16cf230 100644 --- a/gradio_fake_scribble2image.py +++ b/gradio_fake_scribble2image.py @@ -15,12 +15,23 @@ from ldm.models.diffusion.ddim import DDIMSampler -apply_hed = HEDdetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_hed = HEDdetector(device) + model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta): @@ -37,7 +48,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map[detected_map > 4] = 255 detected_map[detected_map < 255] = 0 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_hed2image.py b/gradio_hed2image.py index 13c181548c..89720b9824 100644 --- a/gradio_hed2image.py +++ b/gradio_hed2image.py @@ -15,12 +15,22 @@ from ldm.models.diffusion.ddim import DDIMSampler -apply_hed = HEDdetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_hed = HEDdetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_hough2image.py b/gradio_hough2image.py index 8223cca5d5..f6b94bd778 100644 --- a/gradio_hough2image.py +++ b/gradio_hough2image.py @@ -15,12 +15,22 @@ from ldm.models.diffusion.ddim import DDIMSampler -apply_mlsd = MLSDdetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_mlsd = MLSDdetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta, value_threshold, distance_threshold): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_normal2image.py b/gradio_normal2image.py index e9622e03ee..c5125321c8 100644 --- a/gradio_normal2image.py +++ b/gradio_normal2image.py @@ -15,12 +15,22 @@ from ldm.models.diffusion.ddim import DDIMSampler -apply_midas = MidasDetector() +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() +apply_midas = MidasDetector(device) model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta, bg_threshold): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_pose2image.py b/gradio_pose2image.py index 92384ede79..0d3cdafc43 100644 --- a/gradio_pose2image.py +++ b/gradio_pose2image.py @@ -15,12 +15,22 @@ from ldm.models.diffusion.ddim import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() apply_openpose = OpenposeDetector() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta): @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_scribble2image.py b/gradio_scribble2image.py index 44241d3ebc..4ef3f40cb7 100644 --- a/gradio_scribble2image.py +++ b/gradio_scribble2image.py @@ -14,10 +14,19 @@ from ldm.models.diffusion.ddim import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + +device = get_device() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta): @@ -28,7 +37,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = np.zeros_like(img, dtype=np.uint8) detected_map[np.min(img, axis=2) < 127] = 255 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_scribble2image_interactive.py b/gradio_scribble2image_interactive.py index 97e75148d8..b5587ec9ab 100644 --- a/gradio_scribble2image_interactive.py +++ b/gradio_scribble2image_interactive.py @@ -14,10 +14,20 @@ from ldm.models.diffusion.ddim import DDIMSampler +def get_device(): + if(torch.cuda.is_available()): + return 'cuda' + elif(torch.backends.mps.is_available()): + return 'mps' + else: + return 'cpu' + + +device = get_device() model = create_model('./models/cldm_v15.yaml').cpu() -model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() -ddim_sampler = DDIMSampler(model) +model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location=device)) +model = model.to(device) +ddim_sampler = DDIMSampler(model, device) def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta): @@ -28,7 +38,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = np.zeros_like(img, dtype=np.uint8) detected_map[np.min(img, axis=2) > 127] = 255 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() From a12cc4bfd4d684c4d92dd08f243c1f70165a5cd8 Mon Sep 17 00:00:00 2001 From: Ftps Date: Thu, 16 Feb 2023 15:06:59 +0900 Subject: [PATCH 3/7] Delete unnecessary files --- .DS_Store | Bin 10244 -> 0 bytes annotator/.DS_Store | Bin 6148 -> 0 bytes annotator/ckpts/.DS_Store | Bin 6148 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 annotator/.DS_Store delete mode 100644 annotator/ckpts/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index bbb57c8300036ba2657709dd0a6445ef433aadfa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10244 zcmeHMOK;Oa5S~p#>nKna^`TxWStGnA7zX|Y2Jp@1VijX)u7&}_fMMW@ z0nQILD$8;r`>~X%0~fgkfXw5uENEjNAZ_eKmJ`{Jr8EU*yn7IdsYtdMA~wf+o7Ex9 ziR{NRHYX9AlSq0N$qq%x-eG4;brR)RnyX>JFfhmf@7;^otvl4A^VI&`Iq-sd%?oPK z;6=5Ky8c5Jw1^7W>-+cx)J3Zc8EO`v$6M6SqE6Dsms-bbMcq>AtH@2|Z%ofvGgi@h zQ)vcgm2S1$i*~BvE53Uk1YYe|yIOnEYMeQ93%h~et+xC|C=FU}11hgywtP1@+X;HU z8%f(r{K6_&1!r#g;^N_EY0Z9cf9rD1zSxS}k1j6@*5c~D$A`{|*Y<-Ca<(Lr#6!2@ z`>X3mv@*Cjt%SZ8_{Vt5mU(dNkxT6~M`8u<(yZ3Oq)$iGqGMWTuV@zBpK6f@Y5?>| z)BCRJ&EZtWr&HjDvWK?3@1l=$_^(Y~+CIl4SNkbqJT3USNj<8`c&Xw^BW)$4^gE*)u!_crt~xpkZ}@6-~eq)=3IyS;P4Q&Qo}Iegn^eZ+*l}Lq>KTtrpf*cKjh~ zQvK+@<+&e^W08HELYd0SCDL#26~E2U6FQ(>+NWn!p?z(ovS(E7X?Qe1oHc;0=MrXr zsO~k%;i>s+*}Fi8U~!xF=qcnnm~zu>3ih+gkcFsjA|r%D;sR?vrC*To6yeIz)kaKc zK9Dh>_An3K06$M=RTrLfkazgo(M7NPJmc%ZnN@sKVq|$N3gQJ70a62q*?8<-74J6m zbi8=BoEb7k#A`G6R^feVB*I9}WhAJh!RNyi!+>GHFkl!k3{1?xRIF(iA@_&>|9@gU zqi7f~4E%o#i2Pn$%!9^a&L(1cL-0*SyT~yXBvE+hw kIg$NX;t9&%|7U>U9m1^g_qf^rr}n?v?$+%8uigItCw~81t^fc4 diff --git a/annotator/.DS_Store b/annotator/.DS_Store deleted file mode 100644 index 253418f127a257d4c51370f4e949ca7b1348231e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKPjAyO9R0cNXiA0D14z3dMdDhGZljZ?T}l}Tt`xxmP)M4d5dALv1AtDEfv2h#O6~WIrZA8tq z+yM&J$C!K?(-}qNZLn>Fb-+5XZ4U6-ZDX~1G^Goy-9PKsdzQ#?kjMZG-X0z!?fq1N z648!gr*BAc4Ge+#BboaMdVcTuZ1c5XGDrqg!bj?Uy_ zl>7NS8~U@?Tzes968tXv!OJ*Wc%8dXWSaYN8qJg-4kJ+Byo%FME{1ZRhMCfi^aH2m zw7kyla@jxVb=^mY$E&WpJl5r-)vD#(d+_k&xp$sSQu&E(z}APoPX&8fjL)dBOJ?jJ zN3%4M=?Ap3qOSJIu;K`7$}UeRrV#&2)w>@RW`ib_P)3y=Yy7@qkC(igc6k4w-*pfB zM}O(?<#;OoIwUdYBFxVzN*P6Er`3G^Qkc&znu8nDd$4ElGNM^U!SlMPXhM-gkuQJ3|EJgA|H~pnHoS|eHz_Cr9~U^DB$KXu?2U4-`- diff --git a/annotator/ckpts/.DS_Store b/annotator/ckpts/.DS_Store deleted file mode 100644 index 96f33f277a650111f7334030c53fd4998536f7e7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%TB{E5FD3E6u2Pu00)rr14RXdxS>$NfioWvnl>$nKv5q8Zn^Ud`~%;|31-(8 z(6l{pK?rswdz_7Tyq?H*48XM}-CbZEK#fJPy2@&f$$Kd?RtcVEqEn6pcNk!V49QGy zbC?3Az;9DP)@~gsS8W%MT(zI;_eOmga_rDM#^qT)=%o1|WrS=O$1FQ$scKVxY(oBrhq9ht-z9hT$l6zbn*RvI>??(0aM^lDd4In z=O-;rDV(h{lasSHV!2=ulYWIq4Ph0IV{OP$yvd@*cab!RalpbOJv93fP#LT;1%6b4 EcZh(97ytkO From 2e8ee087e359ed594ab3f11424ccf3543628c2eb Mon Sep 17 00:00:00 2001 From: Ftps <63702646+Tps-F@users.noreply.github.com> Date: Tue, 21 Feb 2023 12:48:03 +0900 Subject: [PATCH 4/7] Add mps-cpu support for readme.md --- README.md | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/README.md b/README.md index f1168264e4..48486c5e45 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,14 @@ All test images can be found at the folder "test_imgs". Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection) +##### CUDA, CPU + python gradio_canny2image.py +##### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_canny2image.py + The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details. Prompt: "bird" @@ -75,8 +81,14 @@ Prompt: "cute dog" Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection) +##### CUDA, CPU + python gradio_hough2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_hough2image.py + The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details. Prompt: "room" @@ -89,8 +101,14 @@ Prompt: "building" Stable Diffusion 1.5 + ControlNet (using soft HED Boundary) +#### CUDA, CPU + python gradio_hed2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_hed2image.py + The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details. Prompt: "oil painting of handsome old man, masterpiece" @@ -103,8 +121,12 @@ Prompt: "Cyberpunk robot" Stable Diffusion 1.5 + ControlNet (using Scribbles) +#### CUDA, CPU + python gradio_scribble2image.py +#### MPS + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_scribble2image.py Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio. Prompt: "turtle" @@ -117,8 +139,14 @@ Prompt: "hot air balloon" We actually provide an interactive interface +#### CUDA, CPU + python gradio_scribble2image_interactive.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_scribble2image.py + However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy. The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase. @@ -130,8 +158,14 @@ Prompt: "dog in a room" Stable Diffusion 1.5 + ControlNet (using fake scribbles) +#### CUDA, CPU + python gradio_fake_scribble2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_fake_scribble2image.py + Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images. Prompt: "bag" @@ -144,8 +178,12 @@ Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still Stable Diffusion 1.5 + ControlNet (using human pose) +#### CUDA, CPU + python gradio_pose2image.py +#### MPS + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_pose2image.py Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you. Prompt: "Chief in the kitchen" @@ -158,8 +196,13 @@ Prompt: "An astronaut on the moon" Stable Diffusion 1.5 + ControlNet (using semantic segmentation) +#### CUDA, CPU + python gradio_seg2image.py +#### MPS + Not Supported (Reason:aten::_slow_conv2d_forward is currently not supported by mps.) + This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details. Prompt: "House" @@ -172,8 +215,14 @@ Prompt: "River" Stable Diffusion 1.5 + ControlNet (using depth map) +#### CUDA, CPU + python gradio_depth2image.py +### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_depth2image.py + Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2). Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map. @@ -187,8 +236,14 @@ Prompt: "Stormtrooper's lecture" Stable Diffusion 1.5 + ControlNet (using normal map) +#### CUDA, CPU + python gradio_normal2image.py +#### MPS + + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_normal2image.py + This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling). Prompt: "Cute toy" From b2b455ff88044e04ddcaaeba6cfe0d10bcd71033 Mon Sep 17 00:00:00 2001 From: Ftps Date: Sat, 25 Feb 2023 10:10:39 +0900 Subject: [PATCH 5/7] Add environment-mps --- .DS_Store | Bin 0 -> 10244 bytes README.md | 7 +++++++ environment-mps.yaml | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+) create mode 100644 .DS_Store create mode 100644 environment-mps.yaml diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f84627cb0dc851a4fc640526808fdacb67867d16 GIT binary patch literal 10244 zcmeHMO>Yx15FNJ>GzvmRg{qfId*HySqWpriw1-N-f!N}JM+e#$c~7Z^F?+ll8DG2I_3TY+$tKsXJ2VM4*F2~lCEgSfyh8)ATkgchzvvq{s#u|o6WY@vNGNy1CfEqz%K?^9};xRb}5%u zYU#jDw*Zh03~R$P_5o^>l(JpQrIl(5OuKt9rKu^m7{;38`-JL{?NTnSwB};0xtQ|J zrreMTH!r%-;G#N1Klz~zPa;k1 z^>1+%GRM;#avE5C4J;bK&n1(wuV;<}y(aVuAm7MH2695r0*lXqMFXe>tY(4U_dstS zClj~w5x7%*!&pDh@s2yhFqYi6-x&n<4fsUQ=DwbC^Tpw&1;0n&WGD;iDz_bs^aCsM z<60d+X9_dw{m<<)io~4Q7sDBH&^3nB_8MdIz0Npx-~<^<=pul|Hz|@$0!%-BP&ct zqs<`sd&RlLy1(+ewO!k%7oSWFRsS8Q3BQc5SRl3-WUP|NmRWNDMVHuzd`e#>Mr;IpQwQ)>jzd zS$l&19-WPPODoj`H$9Gr)Z=)3;Bovjbm~5*WW#g2luIiyg7y#o42b*xxc{%*uW!r# F|0j^4>WTmW literal 0 HcmV?d00001 diff --git a/README.md b/README.md index 7166cb08ef..15ecc40d33 100644 --- a/README.md +++ b/README.md @@ -48,9 +48,16 @@ Note that the way we connect layers is computational efficient. The original SD First create a new conda environment +CUDA, CPU + conda env create -f environment.yaml conda activate control +MPS + + conda env create -f environment-mps.yaml + conda activate control + All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on. We provide 9 Gradio apps with these models. diff --git a/environment-mps.yaml b/environment-mps.yaml new file mode 100644 index 0000000000..ef87ae47b6 --- /dev/null +++ b/environment-mps.yaml @@ -0,0 +1,34 @@ +name: control +channels: + - pytorch + - defaults +dependencies: + - python=3.8 + - pip + - pytorch=1.12.1 + - torchvision=0.13.1 + - numpy=1.23.1 + - pip: + - gradio==3.16.2 + - albumentations==1.3.0 + - opencv-contrib-python + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.5.0 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.3.0 + - transformers==4.19.2 + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.0.2 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.6.0 + - timm==0.6.12 + - addict==2.4.0 + - yapf==0.32.0 + - prettytable==3.6.0 + - safetensors==0.2.7 + - basicsr==1.4.2 From 5ebf07fbf3b24acb52c8ab40cf1fde9433e3ee72 Mon Sep 17 00:00:00 2001 From: Ftps Date: Sat, 25 Feb 2023 10:34:22 +0900 Subject: [PATCH 6/7] Fix attr --- cldm/ddim_hacked.py | 10 +++++++--- ldm/models/diffusion/dpm_solver/sampler.py | 10 +++++++--- ldm/models/diffusion/plms.py | 10 +++++++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py index 6c040b363b..2214964bb7 100644 --- a/cldm/ddim_hacked.py +++ b/cldm/ddim_hacked.py @@ -8,16 +8,20 @@ class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, device, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8cf3..f0509b8a35 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -11,16 +11,20 @@ class DPMSolverSampler(object): - def __init__(self, model, **kwargs): + def __init__(self, model, device, **kwargs): super().__init__() self.model = model to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) @torch.no_grad() diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 7002a365d2..20be051b7b 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -9,17 +9,21 @@ from ldm.models.diffusion.sampling_util import norm_thresholding -class PLMSSampler(object): +class PLMSSampler(object, device): def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + if str(self.device) == 'mps': + attr = attr.to(self.device, torch.float32) + else: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): From 9fadc8684a7e1ccb938416297d90ece1a66f1cea Mon Sep 17 00:00:00 2001 From: Ftps Date: Sat, 18 Mar 2023 20:19:47 +0900 Subject: [PATCH 7/7] Support annotator(mlsd) --- README.md | 7 ++++--- annotator/mlsd/__init__.py | 3 ++- annotator/mlsd/utils.py | 4 ++-- annotator/uniformer/__init__.py | 4 ++-- gradio_annotator.py | 19 +++++++++++++++---- 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index e6b38aa2c6..287e059864 100644 --- a/README.md +++ b/README.md @@ -50,12 +50,12 @@ Note that the way we connect layers is computational efficient. The original SD First create a new conda environment -CUDA, CPU +#### CUDA, CPU conda env create -f environment.yaml conda activate control -MPS +#### MPS conda env create -f environment-mps.yaml conda activate control @@ -194,7 +194,8 @@ Stable Diffusion 1.5 + ControlNet (using human pose) python gradio_pose2image.py #### MPS - PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_pose2image.py + PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_pose2image.py + Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you. Prompt: "Chief in the kitchen" diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py index cb447811b5..3cf840d464 100644 --- a/annotator/mlsd/__init__.py +++ b/annotator/mlsd/__init__.py @@ -23,6 +23,7 @@ def __init__(self, device): model = MobileV2_MLSD_Large() model.load_state_dict(torch.load(model_path), strict=True) self.model = model.to(device).eval() + self.device = device def __call__(self, input_image, thr_v, thr_d): assert input_image.ndim == 3 @@ -30,7 +31,7 @@ def __call__(self, input_image, thr_v, thr_d): img_output = np.zeros_like(img) try: with torch.no_grad(): - lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + lines = pred_lines(img, self.model, self.device, [img.shape[0], img.shape[1]], thr_v, thr_d) for line in lines: x_start, y_start, x_end, y_end = [int(val) for val in line] cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) diff --git a/annotator/mlsd/utils.py b/annotator/mlsd/utils.py index ae3cf9420a..b3ecdfb8b8 100644 --- a/annotator/mlsd/utils.py +++ b/annotator/mlsd/utils.py @@ -44,7 +44,7 @@ def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): return ptss, scores, displacement -def pred_lines(image, model, +def pred_lines(image, model, device, input_shape=[512, 512], score_thr=0.10, dist_thr=20.0): @@ -58,7 +58,7 @@ def pred_lines(image, model, batch_image = np.expand_dims(resized_image, axis=0).astype('float32') batch_image = (batch_image / 127.5) - 1.0 - batch_image = torch.from_numpy(batch_image).float().cuda() + batch_image = torch.from_numpy(batch_image).float().to(device) outputs = model(batch_image) pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) start = vmap[:, :, :2] diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py index 6be429542e..d1f2d00bca 100644 --- a/annotator/uniformer/__init__.py +++ b/annotator/uniformer/__init__.py @@ -9,13 +9,13 @@ class UniformerDetector: - def __init__(self): + def __init__(self, device): modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth") if not os.path.exists(modelpath): from basicsr.utils.download_util import load_file_from_url load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path) config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py") - self.model = init_segmentor(config_file, modelpath).cuda() + self.model = init_segmentor(config_file, modelpath).to(device) def __call__(self, img): result = inference_segmentor(self.model, img) diff --git a/gradio_annotator.py b/gradio_annotator.py index 2b1a29ebbe..2d33b52919 100644 --- a/gradio_annotator.py +++ b/gradio_annotator.py @@ -1,9 +1,20 @@ import gradio as gr +import torch from annotator.util import resize_image, HWC3 +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + model_canny = None +device = get_device() def canny(img, res, l, h): @@ -24,7 +35,7 @@ def hed(img, res): global model_hed if model_hed is None: from annotator.hed import HEDdetector - model_hed = HEDdetector() + model_hed = HEDdetector(device) result = model_hed(img) return [result] @@ -37,7 +48,7 @@ def mlsd(img, res, thr_v, thr_d): global model_mlsd if model_mlsd is None: from annotator.mlsd import MLSDdetector - model_mlsd = MLSDdetector() + model_mlsd = MLSDdetector(device) result = model_mlsd(img, thr_v, thr_d) return [result] @@ -50,7 +61,7 @@ def midas(img, res, a): global model_midas if model_midas is None: from annotator.midas import MidasDetector - model_midas = MidasDetector() + model_midas = MidasDetector(device) results = model_midas(img, a) return results @@ -76,7 +87,7 @@ def uniformer(img, res): global model_uniformer if model_uniformer is None: from annotator.uniformer import UniformerDetector - model_uniformer = UniformerDetector() + model_uniformer = UniformerDetector(device) result = model_uniformer(img) return [result]