Skip to content

Commit

Permalink
Fix LLMGroundedDiffusionPipeline super class arguments (huggingface#5993
Browse files Browse the repository at this point in the history
)

* make `requires_safety_checker` a kwarg instead of a positional argument as it's more future-proof

* apply `make style` formatting edits

* add image_encoder to arguments and pass to super constructor
  • Loading branch information
KristianMischke authored Nov 30, 2023
1 parent f72b28c commit 141cd52
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions examples/community/llm_grounded_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention import Attention, GatedSelfAttentionDense
Expand Down Expand Up @@ -272,10 +272,19 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)

self.register_attn_hooks(unet)
Expand Down

0 comments on commit 141cd52

Please sign in to comment.