Skip to content

Commit 0bb8869

Browse files
authored
Merge pull request #191 from huggingface/main
Merge changes
2 parents 127fe9b + 069186f commit 0bb8869

15 files changed

+2546
-215
lines changed

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
3030
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
3131

3232
There are three official CogVideoX checkpoints for text-to-video and video-to-video.
33+
3334
| checkpoints | recommended inference dtype |
34-
|---|---|
35+
|:---:|:---:|
3536
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
3637
| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
3738
| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
3839

3940
There are two official CogVideoX checkpoints available for image-to-video.
41+
4042
| checkpoints | recommended inference dtype |
41-
|---|---|
43+
|:---:|:---:|
4244
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
4345
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
4446

@@ -48,8 +50,9 @@ For the CogVideoX 1.5 series:
4850
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
4951

5052
There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
53+
5154
| checkpoints | recommended inference dtype |
52-
|---|---|
55+
|:---:|:---:|
5356
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
5457
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
5558

docs/source/en/api/pipelines/flux.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,15 @@ image.save("flux-fp8-dev.png")
333333
[[autodoc]] FluxControlImg2ImgPipeline
334334
- all
335335
- __call__
336+
337+
## FluxPriorReduxPipeline
338+
339+
[[autodoc]] FluxPriorReduxPipeline
340+
- all
341+
- __call__
342+
343+
## FluxFillPipeline
344+
345+
[[autodoc]] FluxFillPipeline
346+
- all
347+
- __call__

examples/community/README.md

Lines changed: 114 additions & 14 deletions
Large diffs are not rendered by default.

examples/community/regional_prompting_stable_diffusion.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33

44
import torch
55
import torchvision.transforms.functional as FF
6-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
6+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
77

88
from diffusers import StableDiffusionPipeline
99
from diffusers.models import AutoencoderKL, UNet2DConditionModel
1010
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1111
from diffusers.schedulers import KarrasDiffusionSchedulers
12-
from diffusers.utils import USE_PEFT_BACKEND
1312

1413

1514
try:
1615
from compel import Compel
1716
except ImportError:
1817
Compel = None
1918

19+
KBASE = "ADDBASE"
2020
KCOMM = "ADDCOMM"
2121
KBRK = "BREAK"
2222

@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
3434
3535
Optional
3636
rp_args["save_mask"]: True/False (save masks in prompt mode)
37+
rp_args["power"]: int (power for attention maps in prompt mode)
38+
rp_args["base_ratio"]:
39+
float (Sets the ratio of the base prompt)
40+
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
41+
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
3742
3843
Pipeline for text-to-image generation using Stable Diffusion.
3944
@@ -70,6 +75,7 @@ def __init__(
7075
scheduler: KarrasDiffusionSchedulers,
7176
safety_checker: StableDiffusionSafetyChecker,
7277
feature_extractor: CLIPImageProcessor,
78+
image_encoder: CLIPVisionModelWithProjection = None,
7379
requires_safety_checker: bool = True,
7480
):
7581
super().__init__(
@@ -80,6 +86,7 @@ def __init__(
8086
scheduler,
8187
safety_checker,
8288
feature_extractor,
89+
image_encoder,
8390
requires_safety_checker,
8491
)
8592
self.register_modules(
@@ -90,6 +97,7 @@ def __init__(
9097
scheduler=scheduler,
9198
safety_checker=safety_checker,
9299
feature_extractor=feature_extractor,
100+
image_encoder=image_encoder,
93101
)
94102

95103
@torch.no_grad()
@@ -110,17 +118,40 @@ def __call__(
110118
rp_args: Dict[str, str] = None,
111119
):
112120
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
121+
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
113122
if negative_prompt is None:
114123
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
115124

116125
device = self._execution_device
117126
regions = 0
118127

128+
self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
119129
self.power = int(rp_args["power"]) if "power" in rp_args else 1
120130

121131
prompts = prompt if isinstance(prompt, list) else [prompt]
122-
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
132+
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
123133
self.batch = batch = num_images_per_prompt * len(prompts)
134+
135+
if use_base:
136+
bases = prompts.copy()
137+
n_bases = n_prompts.copy()
138+
139+
for i, prompt in enumerate(prompts):
140+
parts = prompt.split(KBASE)
141+
if len(parts) == 2:
142+
bases[i], prompts[i] = parts
143+
elif len(parts) > 2:
144+
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
145+
for i, prompt in enumerate(n_prompts):
146+
n_parts = prompt.split(KBASE)
147+
if len(n_parts) == 2:
148+
n_bases[i], n_prompts[i] = n_parts
149+
elif len(n_parts) > 2:
150+
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")
151+
152+
all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
153+
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)
154+
124155
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
125156
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
126157

@@ -137,8 +168,16 @@ def getcompelembs(prps):
137168

138169
conds = getcompelembs(all_prompts_cn)
139170
unconds = getcompelembs(all_n_prompts_cn)
140-
embs = getcompelembs(prompts)
141-
n_embs = getcompelembs(n_prompts)
171+
base_embs = getcompelembs(all_bases_cn) if use_base else None
172+
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
173+
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
174+
embs = getcompelembs(prompts) if not use_base else base_embs
175+
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs
176+
177+
if use_base and self.base_ratio > 0:
178+
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
179+
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
180+
142181
prompt = negative_prompt = None
143182
else:
144183
conds = self.encode_prompt(prompts, device, 1, True)[0]
@@ -147,6 +186,18 @@ def getcompelembs(prps):
147186
if equal
148187
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
149188
)
189+
190+
if use_base and self.base_ratio > 0:
191+
base_embs = self.encode_prompt(bases, device, 1, True)[0]
192+
base_n_embs = (
193+
self.encode_prompt(n_bases, device, 1, True)[0]
194+
if equal
195+
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
196+
)
197+
198+
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
199+
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
200+
150201
embs = n_embs = None
151202

152203
if not active:
@@ -225,8 +276,6 @@ def forward(
225276

226277
residual = hidden_states
227278

228-
args = () if USE_PEFT_BACKEND else (scale,)
229-
230279
if attn.spatial_norm is not None:
231280
hidden_states = attn.spatial_norm(hidden_states, temb)
232281

@@ -247,16 +296,15 @@ def forward(
247296
if attn.group_norm is not None:
248297
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
249298

250-
args = () if USE_PEFT_BACKEND else (scale,)
251-
query = attn.to_q(hidden_states, *args)
299+
query = attn.to_q(hidden_states)
252300

253301
if encoder_hidden_states is None:
254302
encoder_hidden_states = hidden_states
255303
elif attn.norm_cross:
256304
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
257305

258-
key = attn.to_k(encoder_hidden_states, *args)
259-
value = attn.to_v(encoder_hidden_states, *args)
306+
key = attn.to_k(encoder_hidden_states)
307+
value = attn.to_v(encoder_hidden_states)
260308

261309
inner_dim = key.shape[-1]
262310
head_dim = inner_dim // attn.heads
@@ -283,7 +331,7 @@ def forward(
283331
hidden_states = hidden_states.to(query.dtype)
284332

285333
# linear proj
286-
hidden_states = attn.to_out[0](hidden_states, *args)
334+
hidden_states = attn.to_out[0](hidden_states)
287335
# dropout
288336
hidden_states = attn.to_out[1](hidden_states)
289337

@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
410458
add = ""
411459
if KCOMM in prompt:
412460
add, prompt = prompt.split(KCOMM)
413-
add = add + " "
414-
prompts = prompt.split(KBRK)
415-
out_p.append([add + p for p in prompts])
461+
add = add.strip() + " "
462+
prompts = [p.strip() for p in prompt.split(KBRK)]
463+
out_p.append([add + p for i, p in enumerate(prompts)])
416464
out = [None] * batch * len(out_p[0]) * len(out_p)
417465
for p, prs in enumerate(out_p): # inputs prompts
418466
for r, pr in enumerate(prs): # prompts for regions
@@ -449,7 +497,6 @@ def startend(cells, array):
449497
add = []
450498
startend(add, inratios[1:])
451499
icells.append(add)
452-
453500
return ocells, icells, sum(len(cell) for cell in icells)
454501

455502

0 commit comments

Comments
 (0)