diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d77fb5b2871f..458e52766c82 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -17,6 +17,7 @@ from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torchvision import transforms @@ -522,6 +523,15 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + if is_xformers_available(): + try: + unet.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 704ba00cad29..df394425b815 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -178,6 +178,38 @@ def disable_gradient_checkpointing(self): """ if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) + + + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) def save_pretrained( self, diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 7e759e987fb4..052d2643db9d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -710,3 +710,37 @@ def progress_bar(self, iterable): def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs + + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module)