Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanster committed Dec 22, 2023
1 parent eb97641 commit 61d5628
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 15 deletions.
3 changes: 3 additions & 0 deletions lama_cleaner/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ def _scaled_pad_forward(self, image, mask, config: Config):
def set_scheduler(self, config: Config):
scheduler_config = self.model.scheduler.config
sd_sampler = config.sd_sampler
if config.sd_lcm_lora:
sd_sampler = SDSampler.lcm
logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
scheduler = get_scheduler(sd_sampler, scheduler_config)
self.model.scheduler = scheduler

Expand Down
10 changes: 9 additions & 1 deletion lama_cleaner/model/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import cv2
import numpy as np
import torch
from diffusers import ControlNetModel
from diffusers import ControlNetModel, DiffusionPipeline
from loguru import logger

from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
Expand Down Expand Up @@ -69,6 +69,7 @@ def init_model(self, device: torch.device, **kwargs):

use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.torch_dtype = torch_dtype

if model_info.model_type in [
ModelType.DIFFUSERS_SD,
Expand Down Expand Up @@ -131,6 +132,13 @@ def init_model(self, device: torch.device, **kwargs):

self.callback = kwargs.pop("callback", None)

def switch_controlnet_method(self, new_method: str):
self.sd_controlnet_method = new_method
controlnet = ControlNetModel.from_pretrained(
new_method, torch_dtype=self.torch_dtype, resume_download=True
).to(self.model.device)
self.model.controlnet = controlnet

def _get_control_image(self, image, mask):
if "canny" in self.sd_controlnet_method:
control_image = make_canny_control_image(image)
Expand Down
19 changes: 10 additions & 9 deletions lama_cleaner/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,18 @@ def switch_controlnet_method(self, config):
if not self.available_models[self.name].support_controlnet:
return

if self.sd_controlnet != config.controlnet_enabled or (
self.sd_controlnet and self.sd_controlnet_method != config.controlnet_method
if (
self.sd_controlnet
and config.controlnet_method
and self.sd_controlnet_method != config.controlnet_method
):
# 可能关闭/开启 controlnet
# 可能开启了 controlnet,切换 controlnet 的方法
old_sd_controlnet = self.sd_controlnet
old_sd_controlnet_method = self.sd_controlnet_method
self.sd_controlnet_method = config.controlnet_method
self.model.switch_controlnet_method(config.controlnet_method)
logger.info(
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
)
elif self.sd_controlnet != config.controlnet_enabled:
self.sd_controlnet = config.controlnet_enabled
self.sd_controlnet_method = config.controlnet_method

Expand All @@ -120,10 +125,6 @@ def switch_controlnet_method(self, config):
)
if not config.controlnet_enabled:
logger.info(f"Disable controlnet")
elif old_sd_controlnet_method != config.controlnet_method:
logger.info(
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
)
else:
logger.info(f"Enable controlnet: {config.controlnet_method}")

Expand Down
7 changes: 7 additions & 0 deletions lama_cleaner/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def process():
croper_y=form["croperY"],
croper_height=form["croperHeight"],
croper_width=form["croperWidth"],

use_extender=form["useExtender"],
extender_x=form["extenderX"],
extender_y=form["extenderY"],
extender_height=form["extenderHeight"],
extender_width=form["extenderWidth"],

sd_scale=form["sdScale"],
sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"],
Expand Down
9 changes: 5 additions & 4 deletions lama_cleaner/tests/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
save_dir.mkdir(exist_ok=True, parents=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
model_name = "runwayml/stable-diffusion-inpainting"


@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
Expand All @@ -35,7 +36,7 @@ def test_runway_sd_1_5(

sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_local_file_path(sd_device, sampler):

sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
Expand Down Expand Up @@ -121,7 +122,7 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):

sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_controlnet_switch(sd_device, sampler):

sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
name=model_name,
sd_controlnet=True,
device=torch.device(sd_device),
hf_access_token="",
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
wheel
twine
twine
pytest-loguru

0 comments on commit 61d5628

Please sign in to comment.