From b2bb359d4730136c8c7328d9f3db4c1cc02fee5f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 16 Jan 2025 02:30:28 +0000 Subject: [PATCH] Update the model loading logic for several of the large FLUX-related models to ensure that the model is initialized on the meta device prior to loading the state dict into it. This helps to keep peak memory down. --- .../model_manager/load/model_loaders/flux.py | 55 ++++++++++--------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py index edf14ec48cc..d44cc014431 100644 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -80,19 +80,19 @@ def _load_model( raise ValueError("Only VAECheckpointConfig models are currently supported here.") model_path = Path(config.path) - with SilenceWarnings(): + with accelerate.init_empty_weights(): model = AutoEncoder(ae_params[config.config_path]) - sd = load_file(model_path) - model.load_state_dict(sd, assign=True) - # VAE is broken in float16, which mps defaults to - if self._torch_dtype == torch.float16: - try: - vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype - except TypeError: - vae_dtype = torch.float32 - else: - vae_dtype = self._torch_dtype - model.to(vae_dtype) + sd = load_file(model_path) + model.load_state_dict(sd, assign=True) + # VAE is broken in float16, which mps defaults to + if self._torch_dtype == torch.float16: + try: + vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype + except TypeError: + vae_dtype = torch.float32 + else: + vae_dtype = self._torch_dtype + model.to(vae_dtype) return model @@ -183,7 +183,9 @@ def _load_model( case SubModelType.Tokenizer2 | SubModelType.Tokenizer3: return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) case SubModelType.TextEncoder2 | SubModelType.TextEncoder3: - return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto") + return T5EncoderModel.from_pretrained( + Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True + ) raise ValueError( f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" @@ -217,17 +219,18 @@ def _load_from_singlefile( assert isinstance(config, MainCheckpointConfig) model_path = Path(config.path) - with SilenceWarnings(): + with accelerate.init_empty_weights(): model = Flux(params[config.config_path]) - sd = load_file(model_path) - if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: - sd = convert_bundle_to_flux_transformer_checkpoint(sd) - new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) - self._ram_cache.make_room(new_sd_size) - for k in sd.keys(): - # We need to cast to bfloat16 due to it being the only currently supported dtype for inference - sd[k] = sd[k].to(torch.bfloat16) - model.load_state_dict(sd, assign=True) + + sd = load_file(model_path) + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: + sd = convert_bundle_to_flux_transformer_checkpoint(sd) + new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) + self._ram_cache.make_room(new_sd_size) + for k in sd.keys(): + # We need to cast to bfloat16 due to it being the only currently supported dtype for inference + sd[k] = sd[k].to(torch.bfloat16) + model.load_state_dict(sd, assign=True) return model @@ -258,11 +261,11 @@ def _load_from_singlefile( assert isinstance(config, MainGGUFCheckpointConfig) model_path = Path(config.path) - with SilenceWarnings(): + with accelerate.init_empty_weights(): model = Flux(params[config.config_path]) - # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. - sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) + # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. + sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) # HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight. # We override the shape here to fix the issue.