3
3
4
4
import torch
5
5
import torchvision .transforms .functional as FF
6
- from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
6
+ from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
7
7
8
8
from diffusers import StableDiffusionPipeline
9
9
from diffusers .models import AutoencoderKL , UNet2DConditionModel
10
10
from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
11
11
from diffusers .schedulers import KarrasDiffusionSchedulers
12
- from diffusers .utils import USE_PEFT_BACKEND
13
12
14
13
15
14
try :
16
15
from compel import Compel
17
16
except ImportError :
18
17
Compel = None
19
18
19
+ KBASE = "ADDBASE"
20
20
KCOMM = "ADDCOMM"
21
21
KBRK = "BREAK"
22
22
@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
34
34
35
35
Optional
36
36
rp_args["save_mask"]: True/False (save masks in prompt mode)
37
+ rp_args["power"]: int (power for attention maps in prompt mode)
38
+ rp_args["base_ratio"]:
39
+ float (Sets the ratio of the base prompt)
40
+ ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
41
+ [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
37
42
38
43
Pipeline for text-to-image generation using Stable Diffusion.
39
44
@@ -70,6 +75,7 @@ def __init__(
70
75
scheduler : KarrasDiffusionSchedulers ,
71
76
safety_checker : StableDiffusionSafetyChecker ,
72
77
feature_extractor : CLIPImageProcessor ,
78
+ image_encoder : CLIPVisionModelWithProjection = None ,
73
79
requires_safety_checker : bool = True ,
74
80
):
75
81
super ().__init__ (
@@ -80,6 +86,7 @@ def __init__(
80
86
scheduler ,
81
87
safety_checker ,
82
88
feature_extractor ,
89
+ image_encoder ,
83
90
requires_safety_checker ,
84
91
)
85
92
self .register_modules (
@@ -90,6 +97,7 @@ def __init__(
90
97
scheduler = scheduler ,
91
98
safety_checker = safety_checker ,
92
99
feature_extractor = feature_extractor ,
100
+ image_encoder = image_encoder ,
93
101
)
94
102
95
103
@torch .no_grad ()
@@ -110,17 +118,40 @@ def __call__(
110
118
rp_args : Dict [str , str ] = None ,
111
119
):
112
120
active = KBRK in prompt [0 ] if isinstance (prompt , list ) else KBRK in prompt
121
+ use_base = KBASE in prompt [0 ] if isinstance (prompt , list ) else KBASE in prompt
113
122
if negative_prompt is None :
114
123
negative_prompt = "" if isinstance (prompt , str ) else ["" ] * len (prompt )
115
124
116
125
device = self ._execution_device
117
126
regions = 0
118
127
128
+ self .base_ratio = float (rp_args ["base_ratio" ]) if "base_ratio" in rp_args else 0.0
119
129
self .power = int (rp_args ["power" ]) if "power" in rp_args else 1
120
130
121
131
prompts = prompt if isinstance (prompt , list ) else [prompt ]
122
- n_prompts = negative_prompt if isinstance (prompt , str ) else [negative_prompt ]
132
+ n_prompts = negative_prompt if isinstance (prompt , list ) else [negative_prompt ]
123
133
self .batch = batch = num_images_per_prompt * len (prompts )
134
+
135
+ if use_base :
136
+ bases = prompts .copy ()
137
+ n_bases = n_prompts .copy ()
138
+
139
+ for i , prompt in enumerate (prompts ):
140
+ parts = prompt .split (KBASE )
141
+ if len (parts ) == 2 :
142
+ bases [i ], prompts [i ] = parts
143
+ elif len (parts ) > 2 :
144
+ raise ValueError (f"Multiple instances of { KBASE } found in prompt: { prompt } " )
145
+ for i , prompt in enumerate (n_prompts ):
146
+ n_parts = prompt .split (KBASE )
147
+ if len (n_parts ) == 2 :
148
+ n_bases [i ], n_prompts [i ] = n_parts
149
+ elif len (n_parts ) > 2 :
150
+ raise ValueError (f"Multiple instances of { KBASE } found in negative prompt: { prompt } " )
151
+
152
+ all_bases_cn , _ = promptsmaker (bases , num_images_per_prompt )
153
+ all_n_bases_cn , _ = promptsmaker (n_bases , num_images_per_prompt )
154
+
124
155
all_prompts_cn , all_prompts_p = promptsmaker (prompts , num_images_per_prompt )
125
156
all_n_prompts_cn , _ = promptsmaker (n_prompts , num_images_per_prompt )
126
157
@@ -137,8 +168,16 @@ def getcompelembs(prps):
137
168
138
169
conds = getcompelembs (all_prompts_cn )
139
170
unconds = getcompelembs (all_n_prompts_cn )
140
- embs = getcompelembs (prompts )
141
- n_embs = getcompelembs (n_prompts )
171
+ base_embs = getcompelembs (all_bases_cn ) if use_base else None
172
+ base_n_embs = getcompelembs (all_n_bases_cn ) if use_base else None
173
+ # When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
174
+ embs = getcompelembs (prompts ) if not use_base else base_embs
175
+ n_embs = getcompelembs (n_prompts ) if not use_base else base_n_embs
176
+
177
+ if use_base and self .base_ratio > 0 :
178
+ conds = self .base_ratio * base_embs + (1 - self .base_ratio ) * conds
179
+ unconds = self .base_ratio * base_n_embs + (1 - self .base_ratio ) * unconds
180
+
142
181
prompt = negative_prompt = None
143
182
else :
144
183
conds = self .encode_prompt (prompts , device , 1 , True )[0 ]
@@ -147,6 +186,18 @@ def getcompelembs(prps):
147
186
if equal
148
187
else self .encode_prompt (all_n_prompts_cn , device , 1 , True )[0 ]
149
188
)
189
+
190
+ if use_base and self .base_ratio > 0 :
191
+ base_embs = self .encode_prompt (bases , device , 1 , True )[0 ]
192
+ base_n_embs = (
193
+ self .encode_prompt (n_bases , device , 1 , True )[0 ]
194
+ if equal
195
+ else self .encode_prompt (all_n_bases_cn , device , 1 , True )[0 ]
196
+ )
197
+
198
+ conds = self .base_ratio * base_embs + (1 - self .base_ratio ) * conds
199
+ unconds = self .base_ratio * base_n_embs + (1 - self .base_ratio ) * unconds
200
+
150
201
embs = n_embs = None
151
202
152
203
if not active :
@@ -225,8 +276,6 @@ def forward(
225
276
226
277
residual = hidden_states
227
278
228
- args = () if USE_PEFT_BACKEND else (scale ,)
229
-
230
279
if attn .spatial_norm is not None :
231
280
hidden_states = attn .spatial_norm (hidden_states , temb )
232
281
@@ -247,16 +296,15 @@ def forward(
247
296
if attn .group_norm is not None :
248
297
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
249
298
250
- args = () if USE_PEFT_BACKEND else (scale ,)
251
- query = attn .to_q (hidden_states , * args )
299
+ query = attn .to_q (hidden_states )
252
300
253
301
if encoder_hidden_states is None :
254
302
encoder_hidden_states = hidden_states
255
303
elif attn .norm_cross :
256
304
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
257
305
258
- key = attn .to_k (encoder_hidden_states , * args )
259
- value = attn .to_v (encoder_hidden_states , * args )
306
+ key = attn .to_k (encoder_hidden_states )
307
+ value = attn .to_v (encoder_hidden_states )
260
308
261
309
inner_dim = key .shape [- 1 ]
262
310
head_dim = inner_dim // attn .heads
@@ -283,7 +331,7 @@ def forward(
283
331
hidden_states = hidden_states .to (query .dtype )
284
332
285
333
# linear proj
286
- hidden_states = attn .to_out [0 ](hidden_states , * args )
334
+ hidden_states = attn .to_out [0 ](hidden_states )
287
335
# dropout
288
336
hidden_states = attn .to_out [1 ](hidden_states )
289
337
@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
410
458
add = ""
411
459
if KCOMM in prompt :
412
460
add , prompt = prompt .split (KCOMM )
413
- add = add + " "
414
- prompts = prompt .split (KBRK )
415
- out_p .append ([add + p for p in prompts ])
461
+ add = add . strip () + " "
462
+ prompts = [ p . strip () for p in prompt .split (KBRK )]
463
+ out_p .append ([add + p for i , p in enumerate ( prompts ) ])
416
464
out = [None ] * batch * len (out_p [0 ]) * len (out_p )
417
465
for p , prs in enumerate (out_p ): # inputs prompts
418
466
for r , pr in enumerate (prs ): # prompts for regions
@@ -449,7 +497,6 @@ def startend(cells, array):
449
497
add = []
450
498
startend (add , inratios [1 :])
451
499
icells .append (add )
452
-
453
500
return ocells , icells , sum (len (cell ) for cell in icells )
454
501
455
502
0 commit comments