Skip to content

Commit c0af35d

Browse files
committed
convert nodes_sd3.py and nodes_slg.py to V3 schema
1 parent 0e9d172 commit c0af35d

File tree

2 files changed

+214
-131
lines changed

2 files changed

+214
-131
lines changed

comfy_extras/nodes_sd3.py

Lines changed: 151 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,79 @@
33
import comfy.model_management
44
import nodes
55
import torch
6-
import comfy_extras.nodes_slg
6+
from typing_extensions import override
7+
from comfy_api.latest import ComfyExtension, io
8+
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
79

810

9-
class TripleCLIPLoader:
11+
class TripleCLIPLoader(io.ComfyNode):
1012
@classmethod
11-
def INPUT_TYPES(s):
12-
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
13-
}}
14-
RETURN_TYPES = ("CLIP",)
15-
FUNCTION = "load_clip"
13+
def define_schema(cls):
14+
return io.Schema(
15+
node_id="TripleCLIPLoader",
16+
category="advanced/loaders",
17+
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
18+
inputs=[
19+
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
20+
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
21+
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
22+
],
23+
outputs=[
24+
io.Clip.Output(),
25+
],
26+
)
1627

17-
CATEGORY = "advanced/loaders"
18-
19-
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
20-
21-
def load_clip(self, clip_name1, clip_name2, clip_name3):
28+
@classmethod
29+
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
2230
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
2331
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
2432
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
2533
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
26-
return (clip,)
27-
34+
return io.NodeOutput(clip)
2835

29-
class EmptySD3LatentImage:
30-
def __init__(self):
31-
self.device = comfy.model_management.intermediate_device()
3236

37+
class EmptySD3LatentImage(io.ComfyNode):
3338
@classmethod
34-
def INPUT_TYPES(s):
35-
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
36-
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
37-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
38-
RETURN_TYPES = ("LATENT",)
39-
FUNCTION = "generate"
39+
def define_schema(cls):
40+
return io.Schema(
41+
node_id="EmptySD3LatentImage",
42+
category="latent/sd3",
43+
inputs=[
44+
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
45+
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
46+
io.Int.Input("batch_size", default=1, min=1, max=4096),
47+
],
48+
outputs=[
49+
io.Latent.Output(),
50+
],
51+
)
4052

41-
CATEGORY = "latent/sd3"
53+
@classmethod
54+
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
55+
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
56+
return io.NodeOutput({"samples":latent})
4257

43-
def generate(self, width, height, batch_size=1):
44-
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
45-
return ({"samples":latent}, )
4658

59+
class CLIPTextEncodeSD3(io.ComfyNode):
60+
@classmethod
61+
def define_schema(cls):
62+
return io.Schema(
63+
node_id="CLIPTextEncodeSD3",
64+
category="advanced/conditioning",
65+
inputs=[
66+
io.Clip.Input("clip"),
67+
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
68+
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
69+
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
70+
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
71+
],
72+
outputs=[
73+
io.Conditioning.Output(),
74+
],
75+
)
4776

48-
class CLIPTextEncodeSD3:
4977
@classmethod
50-
def INPUT_TYPES(s):
51-
return {"required": {
52-
"clip": ("CLIP", ),
53-
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
54-
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
55-
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
56-
"empty_padding": (["none", "empty_prompt"], )
57-
}}
58-
RETURN_TYPES = ("CONDITIONING",)
59-
FUNCTION = "encode"
60-
61-
CATEGORY = "advanced/conditioning"
62-
63-
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
78+
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
6479
no_padding = empty_padding == "none"
6580

6681
tokens = clip.tokenize(clip_g)
@@ -82,57 +97,106 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
8297
tokens["l"] += empty["l"]
8398
while len(tokens["l"]) > len(tokens["g"]):
8499
tokens["g"] += empty["g"]
85-
return (clip.encode_from_tokens_scheduled(tokens), )
100+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
86101

87102

88-
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
103+
class ControlNetApplySD3(io.ComfyNode):
104+
@classmethod
105+
def define_schema(cls) -> io.Schema:
106+
return io.Schema(
107+
node_id="ControlNetApplySD3",
108+
display_name="Apply Controlnet with VAE",
109+
category="conditioning/controlnet",
110+
inputs=[
111+
io.Conditioning.Input("positive"),
112+
io.Conditioning.Input("negative"),
113+
io.ControlNet.Input("control_net"),
114+
io.Vae.Input("vae"),
115+
io.Image.Input("image"),
116+
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
117+
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
118+
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
119+
],
120+
outputs=[
121+
io.Conditioning.Output(display_name="positive"),
122+
io.Conditioning.Output(display_name="negative"),
123+
],
124+
is_deprecated=True,
125+
)
126+
89127
@classmethod
90-
def INPUT_TYPES(s):
91-
return {"required": {"positive": ("CONDITIONING", ),
92-
"negative": ("CONDITIONING", ),
93-
"control_net": ("CONTROL_NET", ),
94-
"vae": ("VAE", ),
95-
"image": ("IMAGE", ),
96-
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
97-
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
98-
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
99-
}}
100-
CATEGORY = "conditioning/controlnet"
101-
DEPRECATED = True
102-
103-
104-
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
128+
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
129+
if strength == 0:
130+
return io.NodeOutput(positive, negative)
131+
132+
control_hint = image.movedim(-1, 1)
133+
cnets = {}
134+
135+
out = []
136+
for conditioning in [positive, negative]:
137+
c = []
138+
for t in conditioning:
139+
d = t[1].copy()
140+
141+
prev_cnet = d.get('control', None)
142+
if prev_cnet in cnets:
143+
c_net = cnets[prev_cnet]
144+
else:
145+
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
146+
vae=vae, extra_concat=[])
147+
c_net.set_previous_controlnet(prev_cnet)
148+
cnets[prev_cnet] = c_net
149+
150+
d['control'] = c_net
151+
d['control_apply_to_uncond'] = False
152+
n = [t[0], d]
153+
c.append(n)
154+
out.append(c)
155+
return io.NodeOutput(out[0], out[1])
156+
157+
158+
class SkipLayerGuidanceSD3(io.ComfyNode):
105159
'''
106160
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
107161
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
108162
Experimental implementation by Dango233@StabilityAI.
109163
'''
164+
165+
@classmethod
166+
def define_schema(cls):
167+
return io.Schema(
168+
node_id="SkipLayerGuidanceSD3",
169+
category="advanced/guidance",
170+
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
171+
inputs=[
172+
io.Model.Input("model"),
173+
io.String.Input("layers", default="7, 8, 9", multiline=False),
174+
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
175+
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
176+
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
177+
],
178+
outputs=[
179+
io.Model.Output(),
180+
],
181+
is_experimental=True,
182+
)
183+
110184
@classmethod
111-
def INPUT_TYPES(s):
112-
return {"required": {"model": ("MODEL", ),
113-
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
114-
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
115-
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
116-
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
117-
}}
118-
RETURN_TYPES = ("MODEL",)
119-
FUNCTION = "skip_guidance_sd3"
120-
121-
CATEGORY = "advanced/guidance"
122-
123-
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
124-
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
125-
126-
127-
NODE_CLASS_MAPPINGS = {
128-
"TripleCLIPLoader": TripleCLIPLoader,
129-
"EmptySD3LatentImage": EmptySD3LatentImage,
130-
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
131-
"ControlNetApplySD3": ControlNetApplySD3,
132-
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
133-
}
134-
135-
NODE_DISPLAY_NAME_MAPPINGS = {
136-
# Sampling
137-
"ControlNetApplySD3": "Apply Controlnet with VAE",
138-
}
185+
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
186+
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
187+
188+
189+
class SD3Extension(ComfyExtension):
190+
@override
191+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
192+
return [
193+
TripleCLIPLoader,
194+
EmptySD3LatentImage,
195+
CLIPTextEncodeSD3,
196+
ControlNetApplySD3,
197+
SkipLayerGuidanceSD3,
198+
]
199+
200+
201+
async def comfy_entrypoint() -> SD3Extension:
202+
return SD3Extension()

0 commit comments

Comments
 (0)