Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanster committed Nov 14, 2024
1 parent 668733c commit 14b334a
Showing 1 changed file with 42 additions and 38 deletions.
80 changes: 42 additions & 38 deletions iopaint/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]:
model_abs_path,
load_safety_checker=False,
num_in_channels=9,
original_config_file=get_config_files()['v1']
original_config_file=get_config_files()["v1"],
)
model_type = ModelType.DIFFUSERS_SD_INPAINT
except ValueError as e:
Expand Down Expand Up @@ -89,7 +89,7 @@ def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]:
model_abs_path,
load_safety_checker=False,
num_in_channels=9,
original_config_file=get_config_files()['xl'],
original_config_file=get_config_files()["xl"],
)
if model.unet.config.in_channels == 9:
# https://github.com/huggingface/diffusers/issues/6610
Expand Down Expand Up @@ -207,55 +207,59 @@ def scan_diffusers_models() -> List[ModelInfo]:
cache_dir = Path(HF_HUB_CACHE)
# logger.info(f"Scanning diffusers models in {cache_dir}")
diffusers_model_names = []
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
model_index_files = glob.glob(
os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
)
for it in model_index_files:
it = Path(it)
with open(it, "r", encoding="utf-8") as f:
try:
try:
with open(it, "r", encoding="utf-8") as f:
data = json.load(f)
except:
continue
except:
continue

_class_name = data["_class_name"]
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names:
continue
if "PowerPaint" in name:
model_type = ModelType.DIFFUSERS_OTHER
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
elif _class_name in [
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
"KandinskyV22InpaintPipeline",
"AnyText",
]:
model_type = ModelType.DIFFUSERS_OTHER
else:
continue
_class_name = data["_class_name"]
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names:
continue
if "PowerPaint" in name:
model_type = ModelType.DIFFUSERS_OTHER
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
elif _class_name in [
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
"KandinskyV22InpaintPipeline",
"AnyText",
]:
model_type = ModelType.DIFFUSERS_OTHER
else:
continue

diffusers_model_names.append(name)
available_models.append(
ModelInfo(
name=name,
path=name,
model_type=model_type,
)
diffusers_model_names.append(name)
available_models.append(
ModelInfo(
name=name,
path=name,
model_type=model_type,
)
)
return available_models


def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
available_models = []
diffusers_model_names = []
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
model_index_files = glob.glob(
os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
)
for it in model_index_files:
it = Path(it)
with open(it, "r", encoding="utf-8") as f:
Expand Down

0 comments on commit 14b334a

Please sign in to comment.