Skip to content

Commit

Permalink
.load_ip_adapter in StableDiffusionXLAdapterPipeline (huggingface#6246)
Browse files Browse the repository at this point in the history
* Added testing notebook and .load_ip_adapter to XLAdapterPipeline

* Added annotations

* deleted testing notebook

* Update src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py

Co-authored-by: YiYi Xu <[email protected]>

* code clean up

* Add feature_extractor and image_encoder to components

---------

Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
jquintanilla4 and yiyixuxu authored Jan 11, 2024
1 parent 17cece0 commit da843b3
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,22 @@
import numpy as np
import PIL.Image
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
Expand Down Expand Up @@ -169,7 +180,11 @@ def retrieve_timesteps(


class StableDiffusionXLAdapterPipeline(
DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
DiffusionPipeline,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
Expand All @@ -183,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
Expand Down Expand Up @@ -211,8 +227,15 @@ class StableDiffusionXLAdapterPipeline(
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""

model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
_optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"feature_extractor",
"image_encoder",
]

def __init__(
self,
Expand All @@ -225,6 +248,8 @@ def __init__(
adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()

Expand All @@ -237,6 +262,8 @@ def __init__(
unet=unet,
adapter=adapter,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
Expand Down Expand Up @@ -511,6 +538,31 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
Expand Down Expand Up @@ -768,7 +820,7 @@ def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
Expand All @@ -785,6 +837,7 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
Expand Down Expand Up @@ -876,6 +929,7 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand Down Expand Up @@ -991,7 +1045,7 @@ def __call__(

device = self._execution_device

# 3. Encode input prompt
# 3.1 Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
Expand All @@ -1012,6 +1066,15 @@ def __call__(
clip_skip=clip_skip,
)

# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)

Expand All @@ -1028,10 +1091,10 @@ def __call__(
latents,
)

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 6.5 Optionally get Guidance Scale Embedding
# 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
Expand Down Expand Up @@ -1090,8 +1153,7 @@ def __call__(

# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

# 7.1 Apply denoising_end
# Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
Expand All @@ -1109,9 +1171,12 @@ def __call__(

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds

# predict the noise residual
if i < int(num_inference_steps * adapter_conditioning_factor):
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
else:
Expand All @@ -1123,9 +1188,9 @@ def __call__(
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)[0]

# perform guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_di
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
# "safety_checker": None,
# "feature_extractor": None,
"feature_extractor": None,
"image_encoder": None,
}
return components

Expand Down Expand Up @@ -265,7 +266,8 @@ def get_dummy_components_with_full_downscaling(self, adapter_type="full_adapter_
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
# "safety_checker": None,
# "feature_extractor": None,
"feature_extractor": None,
"image_encoder": None,
}
return components

Expand Down

0 comments on commit da843b3

Please sign in to comment.