Skip to content

Commit e5b3d6c

Browse files
bigcat88adlerfaulkner
authored andcommitted
convert nodes_sd3.py and nodes_slg.py to V3 schema (comfyanonymous#10162)
1 parent 79e1d44 commit e5b3d6c

File tree

2 files changed

+228
-130
lines changed

2 files changed

+228
-130
lines changed

comfy_extras/nodes_sd3.py

Lines changed: 160 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,83 @@
33
import comfy.model_management
44
import nodes
55
import torch
6-
import comfy_extras.nodes_slg
6+
from typing_extensions import override
7+
from comfy_api.latest import ComfyExtension, io
8+
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
79

810

9-
class TripleCLIPLoader:
11+
class TripleCLIPLoader(io.ComfyNode):
1012
@classmethod
11-
def INPUT_TYPES(s):
12-
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
13-
}}
14-
RETURN_TYPES = ("CLIP",)
15-
FUNCTION = "load_clip"
13+
def define_schema(cls):
14+
return io.Schema(
15+
node_id="TripleCLIPLoader",
16+
category="advanced/loaders",
17+
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
18+
inputs=[
19+
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
20+
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
21+
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
22+
],
23+
outputs=[
24+
io.Clip.Output(),
25+
],
26+
)
1627

17-
CATEGORY = "advanced/loaders"
18-
19-
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
20-
21-
def load_clip(self, clip_name1, clip_name2, clip_name3):
28+
@classmethod
29+
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
2230
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
2331
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
2432
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
2533
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
26-
return (clip,)
34+
return io.NodeOutput(clip)
2735

36+
load_clip = execute # TODO: remove
2837

29-
class EmptySD3LatentImage:
30-
def __init__(self):
31-
self.device = comfy.model_management.intermediate_device()
3238

39+
class EmptySD3LatentImage(io.ComfyNode):
3340
@classmethod
34-
def INPUT_TYPES(s):
35-
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
36-
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
37-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
38-
RETURN_TYPES = ("LATENT",)
39-
FUNCTION = "generate"
41+
def define_schema(cls):
42+
return io.Schema(
43+
node_id="EmptySD3LatentImage",
44+
category="latent/sd3",
45+
inputs=[
46+
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
47+
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
48+
io.Int.Input("batch_size", default=1, min=1, max=4096),
49+
],
50+
outputs=[
51+
io.Latent.Output(),
52+
],
53+
)
4054

41-
CATEGORY = "latent/sd3"
55+
@classmethod
56+
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
57+
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
58+
return io.NodeOutput({"samples":latent})
4259

43-
def generate(self, width, height, batch_size=1):
44-
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
45-
return ({"samples":latent}, )
60+
generate = execute # TODO: remove
4661

4762

48-
class CLIPTextEncodeSD3:
63+
class CLIPTextEncodeSD3(io.ComfyNode):
4964
@classmethod
50-
def INPUT_TYPES(s):
51-
return {"required": {
52-
"clip": ("CLIP", ),
53-
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
54-
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
55-
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
56-
"empty_padding": (["none", "empty_prompt"], )
57-
}}
58-
RETURN_TYPES = ("CONDITIONING",)
59-
FUNCTION = "encode"
60-
61-
CATEGORY = "advanced/conditioning"
62-
63-
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
65+
def define_schema(cls):
66+
return io.Schema(
67+
node_id="CLIPTextEncodeSD3",
68+
category="advanced/conditioning",
69+
inputs=[
70+
io.Clip.Input("clip"),
71+
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
72+
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
73+
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
74+
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
75+
],
76+
outputs=[
77+
io.Conditioning.Output(),
78+
],
79+
)
80+
81+
@classmethod
82+
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
6483
no_padding = empty_padding == "none"
6584

6685
tokens = clip.tokenize(clip_g)
@@ -82,57 +101,112 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
82101
tokens["l"] += empty["l"]
83102
while len(tokens["l"]) > len(tokens["g"]):
84103
tokens["g"] += empty["g"]
85-
return (clip.encode_from_tokens_scheduled(tokens), )
104+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
105+
106+
encode = execute # TODO: remove
86107

87108

88-
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
109+
class ControlNetApplySD3(io.ComfyNode):
110+
@classmethod
111+
def define_schema(cls) -> io.Schema:
112+
return io.Schema(
113+
node_id="ControlNetApplySD3",
114+
display_name="Apply Controlnet with VAE",
115+
category="conditioning/controlnet",
116+
inputs=[
117+
io.Conditioning.Input("positive"),
118+
io.Conditioning.Input("negative"),
119+
io.ControlNet.Input("control_net"),
120+
io.Vae.Input("vae"),
121+
io.Image.Input("image"),
122+
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
123+
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
124+
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
125+
],
126+
outputs=[
127+
io.Conditioning.Output(display_name="positive"),
128+
io.Conditioning.Output(display_name="negative"),
129+
],
130+
is_deprecated=True,
131+
)
132+
89133
@classmethod
90-
def INPUT_TYPES(s):
91-
return {"required": {"positive": ("CONDITIONING", ),
92-
"negative": ("CONDITIONING", ),
93-
"control_net": ("CONTROL_NET", ),
94-
"vae": ("VAE", ),
95-
"image": ("IMAGE", ),
96-
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
97-
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
98-
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
99-
}}
100-
CATEGORY = "conditioning/controlnet"
101-
DEPRECATED = True
102-
103-
104-
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
134+
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
135+
if strength == 0:
136+
return io.NodeOutput(positive, negative)
137+
138+
control_hint = image.movedim(-1, 1)
139+
cnets = {}
140+
141+
out = []
142+
for conditioning in [positive, negative]:
143+
c = []
144+
for t in conditioning:
145+
d = t[1].copy()
146+
147+
prev_cnet = d.get('control', None)
148+
if prev_cnet in cnets:
149+
c_net = cnets[prev_cnet]
150+
else:
151+
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
152+
vae=vae, extra_concat=[])
153+
c_net.set_previous_controlnet(prev_cnet)
154+
cnets[prev_cnet] = c_net
155+
156+
d['control'] = c_net
157+
d['control_apply_to_uncond'] = False
158+
n = [t[0], d]
159+
c.append(n)
160+
out.append(c)
161+
return io.NodeOutput(out[0], out[1])
162+
163+
apply_controlnet = execute # TODO: remove
164+
165+
166+
class SkipLayerGuidanceSD3(io.ComfyNode):
105167
'''
106168
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
107169
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
108170
Experimental implementation by Dango233@StabilityAI.
109171
'''
172+
173+
@classmethod
174+
def define_schema(cls):
175+
return io.Schema(
176+
node_id="SkipLayerGuidanceSD3",
177+
category="advanced/guidance",
178+
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
179+
inputs=[
180+
io.Model.Input("model"),
181+
io.String.Input("layers", default="7, 8, 9", multiline=False),
182+
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
183+
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
184+
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
185+
],
186+
outputs=[
187+
io.Model.Output(),
188+
],
189+
is_experimental=True,
190+
)
191+
110192
@classmethod
111-
def INPUT_TYPES(s):
112-
return {"required": {"model": ("MODEL", ),
113-
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
114-
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
115-
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
116-
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
117-
}}
118-
RETURN_TYPES = ("MODEL",)
119-
FUNCTION = "skip_guidance_sd3"
120-
121-
CATEGORY = "advanced/guidance"
122-
123-
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
124-
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
125-
126-
127-
NODE_CLASS_MAPPINGS = {
128-
"TripleCLIPLoader": TripleCLIPLoader,
129-
"EmptySD3LatentImage": EmptySD3LatentImage,
130-
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
131-
"ControlNetApplySD3": ControlNetApplySD3,
132-
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
133-
}
134-
135-
NODE_DISPLAY_NAME_MAPPINGS = {
136-
# Sampling
137-
"ControlNetApplySD3": "Apply Controlnet with VAE",
138-
}
193+
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
194+
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
195+
196+
skip_guidance_sd3 = execute # TODO: remove
197+
198+
199+
class SD3Extension(ComfyExtension):
200+
@override
201+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
202+
return [
203+
TripleCLIPLoader,
204+
EmptySD3LatentImage,
205+
CLIPTextEncodeSD3,
206+
ControlNetApplySD3,
207+
SkipLayerGuidanceSD3,
208+
]
209+
210+
211+
async def comfy_entrypoint() -> SD3Extension:
212+
return SD3Extension()

0 commit comments

Comments
 (0)