diff --git a/iopaint/download.py b/iopaint/download.py index b2ed07da..7908d534 100644 --- a/iopaint/download.py +++ b/iopaint/download.py @@ -61,7 +61,7 @@ def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]: model_abs_path, load_safety_checker=False, num_in_channels=9, - original_config_file=get_config_files()['v1'] + original_config_file=get_config_files()["v1"], ) model_type = ModelType.DIFFUSERS_SD_INPAINT except ValueError as e: @@ -89,7 +89,7 @@ def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]: model_abs_path, load_safety_checker=False, num_in_channels=9, - original_config_file=get_config_files()['xl'], + original_config_file=get_config_files()["xl"], ) if model.unet.config.in_channels == 9: # https://github.com/huggingface/diffusers/issues/6610 @@ -207,47 +207,49 @@ def scan_diffusers_models() -> List[ModelInfo]: cache_dir = Path(HF_HUB_CACHE) # logger.info(f"Scanning diffusers models in {cache_dir}") diffusers_model_names = [] - model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True) + model_index_files = glob.glob( + os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True + ) for it in model_index_files: it = Path(it) - with open(it, "r", encoding="utf-8") as f: - try: + try: + with open(it, "r", encoding="utf-8") as f: data = json.load(f) - except: - continue + except: + continue - _class_name = data["_class_name"] - name = folder_name_to_show_name(it.parent.parent.parent.name) - if name in diffusers_model_names: - continue - if "PowerPaint" in name: - model_type = ModelType.DIFFUSERS_OTHER - elif _class_name == DIFFUSERS_SD_CLASS_NAME: - model_type = ModelType.DIFFUSERS_SD - elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: - model_type = ModelType.DIFFUSERS_SD_INPAINT - elif _class_name == DIFFUSERS_SDXL_CLASS_NAME: - model_type = ModelType.DIFFUSERS_SDXL - elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME: - model_type = ModelType.DIFFUSERS_SDXL_INPAINT - elif _class_name in [ - "StableDiffusionInstructPix2PixPipeline", - "PaintByExamplePipeline", - "KandinskyV22InpaintPipeline", - "AnyText", - ]: - model_type = ModelType.DIFFUSERS_OTHER - else: - continue + _class_name = data["_class_name"] + name = folder_name_to_show_name(it.parent.parent.parent.name) + if name in diffusers_model_names: + continue + if "PowerPaint" in name: + model_type = ModelType.DIFFUSERS_OTHER + elif _class_name == DIFFUSERS_SD_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD + elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SD_INPAINT + elif _class_name == DIFFUSERS_SDXL_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL + elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME: + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + elif _class_name in [ + "StableDiffusionInstructPix2PixPipeline", + "PaintByExamplePipeline", + "KandinskyV22InpaintPipeline", + "AnyText", + ]: + model_type = ModelType.DIFFUSERS_OTHER + else: + continue - diffusers_model_names.append(name) - available_models.append( - ModelInfo( - name=name, - path=name, - model_type=model_type, - ) + diffusers_model_names.append(name) + available_models.append( + ModelInfo( + name=name, + path=name, + model_type=model_type, ) + ) return available_models @@ -255,7 +257,9 @@ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]: cache_dir = Path(cache_dir) available_models = [] diffusers_model_names = [] - model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True) + model_index_files = glob.glob( + os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True + ) for it in model_index_files: it = Path(it) with open(it, "r", encoding="utf-8") as f: