33import comfy .model_management
44import nodes
55import torch
6- import comfy_extras .nodes_slg
6+ from typing_extensions import override
7+ from comfy_api .latest import ComfyExtension , io
8+ from comfy_extras .nodes_slg import SkipLayerGuidanceDiT
79
810
9- class TripleCLIPLoader :
11+ class TripleCLIPLoader ( io . ComfyNode ) :
1012 @classmethod
11- def INPUT_TYPES (s ):
12- return {"required" : { "clip_name1" : (folder_paths .get_filename_list ("text_encoders" ), ), "clip_name2" : (folder_paths .get_filename_list ("text_encoders" ), ), "clip_name3" : (folder_paths .get_filename_list ("text_encoders" ), )
13- }}
14- RETURN_TYPES = ("CLIP" ,)
15- FUNCTION = "load_clip"
13+ def define_schema (cls ):
14+ return io .Schema (
15+ node_id = "TripleCLIPLoader" ,
16+ category = "advanced/loaders" ,
17+ description = "[Recipes]\n \n sd3: clip-l, clip-g, t5" ,
18+ inputs = [
19+ io .Combo .Input ("clip_name1" , options = folder_paths .get_filename_list ("text_encoders" )),
20+ io .Combo .Input ("clip_name2" , options = folder_paths .get_filename_list ("text_encoders" )),
21+ io .Combo .Input ("clip_name3" , options = folder_paths .get_filename_list ("text_encoders" )),
22+ ],
23+ outputs = [
24+ io .Clip .Output (),
25+ ],
26+ )
1627
17- CATEGORY = "advanced/loaders"
18-
19- DESCRIPTION = "[Recipes]\n \n sd3: clip-l, clip-g, t5"
20-
21- def load_clip (self , clip_name1 , clip_name2 , clip_name3 ):
28+ @classmethod
29+ def execute (cls , clip_name1 , clip_name2 , clip_name3 ) -> io .NodeOutput :
2230 clip_path1 = folder_paths .get_full_path_or_raise ("text_encoders" , clip_name1 )
2331 clip_path2 = folder_paths .get_full_path_or_raise ("text_encoders" , clip_name2 )
2432 clip_path3 = folder_paths .get_full_path_or_raise ("text_encoders" , clip_name3 )
2533 clip = comfy .sd .load_clip (ckpt_paths = [clip_path1 , clip_path2 , clip_path3 ], embedding_directory = folder_paths .get_folder_paths ("embeddings" ))
26- return (clip ,)
27-
34+ return io .NodeOutput (clip )
2835
29- class EmptySD3LatentImage :
30- def __init__ (self ):
31- self .device = comfy .model_management .intermediate_device ()
3236
37+ class EmptySD3LatentImage (io .ComfyNode ):
3338 @classmethod
34- def INPUT_TYPES (s ):
35- return {"required" : { "width" : ("INT" , {"default" : 1024 , "min" : 16 , "max" : nodes .MAX_RESOLUTION , "step" : 16 }),
36- "height" : ("INT" , {"default" : 1024 , "min" : 16 , "max" : nodes .MAX_RESOLUTION , "step" : 16 }),
37- "batch_size" : ("INT" , {"default" : 1 , "min" : 1 , "max" : 4096 })}}
38- RETURN_TYPES = ("LATENT" ,)
39- FUNCTION = "generate"
39+ def define_schema (cls ):
40+ return io .Schema (
41+ node_id = "EmptySD3LatentImage" ,
42+ category = "latent/sd3" ,
43+ inputs = [
44+ io .Int .Input ("width" , default = 1024 , min = 16 , max = nodes .MAX_RESOLUTION , step = 16 ),
45+ io .Int .Input ("height" , default = 1024 , min = 16 , max = nodes .MAX_RESOLUTION , step = 16 ),
46+ io .Int .Input ("batch_size" , default = 1 , min = 1 , max = 4096 ),
47+ ],
48+ outputs = [
49+ io .Latent .Output (),
50+ ],
51+ )
4052
41- CATEGORY = "latent/sd3"
53+ @classmethod
54+ def execute (cls , width , height , batch_size = 1 ) -> io .NodeOutput :
55+ latent = torch .zeros ([batch_size , 16 , height // 8 , width // 8 ], device = comfy .model_management .intermediate_device ())
56+ return io .NodeOutput ({"samples" :latent })
4257
43- def generate (self , width , height , batch_size = 1 ):
44- latent = torch .zeros ([batch_size , 16 , height // 8 , width // 8 ], device = self .device )
45- return ({"samples" :latent }, )
4658
59+ class CLIPTextEncodeSD3 (io .ComfyNode ):
60+ @classmethod
61+ def define_schema (cls ):
62+ return io .Schema (
63+ node_id = "CLIPTextEncodeSD3" ,
64+ category = "advanced/conditioning" ,
65+ inputs = [
66+ io .Clip .Input ("clip" ),
67+ io .String .Input ("clip_l" , multiline = True , dynamic_prompts = True ),
68+ io .String .Input ("clip_g" , multiline = True , dynamic_prompts = True ),
69+ io .String .Input ("t5xxl" , multiline = True , dynamic_prompts = True ),
70+ io .Combo .Input ("empty_padding" , options = ["none" , "empty_prompt" ]),
71+ ],
72+ outputs = [
73+ io .Conditioning .Output (),
74+ ],
75+ )
4776
48- class CLIPTextEncodeSD3 :
4977 @classmethod
50- def INPUT_TYPES (s ):
51- return {"required" : {
52- "clip" : ("CLIP" , ),
53- "clip_l" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
54- "clip_g" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
55- "t5xxl" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
56- "empty_padding" : (["none" , "empty_prompt" ], )
57- }}
58- RETURN_TYPES = ("CONDITIONING" ,)
59- FUNCTION = "encode"
60-
61- CATEGORY = "advanced/conditioning"
62-
63- def encode (self , clip , clip_l , clip_g , t5xxl , empty_padding ):
78+ def execute (cls , clip , clip_l , clip_g , t5xxl , empty_padding ) -> io .NodeOutput :
6479 no_padding = empty_padding == "none"
6580
6681 tokens = clip .tokenize (clip_g )
@@ -82,57 +97,106 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
8297 tokens ["l" ] += empty ["l" ]
8398 while len (tokens ["l" ]) > len (tokens ["g" ]):
8499 tokens ["g" ] += empty ["g" ]
85- return (clip .encode_from_tokens_scheduled (tokens ), )
100+ return io . NodeOutput (clip .encode_from_tokens_scheduled (tokens ))
86101
87102
88- class ControlNetApplySD3 (nodes .ControlNetApplyAdvanced ):
103+ class ControlNetApplySD3 (io .ComfyNode ):
104+ @classmethod
105+ def define_schema (cls ) -> io .Schema :
106+ return io .Schema (
107+ node_id = "ControlNetApplySD3" ,
108+ display_name = "Apply Controlnet with VAE" ,
109+ category = "conditioning/controlnet" ,
110+ inputs = [
111+ io .Conditioning .Input ("positive" ),
112+ io .Conditioning .Input ("negative" ),
113+ io .ControlNet .Input ("control_net" ),
114+ io .Vae .Input ("vae" ),
115+ io .Image .Input ("image" ),
116+ io .Float .Input ("strength" , default = 1.0 , min = 0.0 , max = 10.0 , step = 0.01 ),
117+ io .Float .Input ("start_percent" , default = 0.0 , min = 0.0 , max = 1.0 , step = 0.001 ),
118+ io .Float .Input ("end_percent" , default = 1.0 , min = 0.0 , max = 1.0 , step = 0.001 ),
119+ ],
120+ outputs = [
121+ io .Conditioning .Output (display_name = "positive" ),
122+ io .Conditioning .Output (display_name = "negative" ),
123+ ],
124+ is_deprecated = True ,
125+ )
126+
89127 @classmethod
90- def INPUT_TYPES (s ):
91- return {"required" : {"positive" : ("CONDITIONING" , ),
92- "negative" : ("CONDITIONING" , ),
93- "control_net" : ("CONTROL_NET" , ),
94- "vae" : ("VAE" , ),
95- "image" : ("IMAGE" , ),
96- "strength" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
97- "start_percent" : ("FLOAT" , {"default" : 0.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 }),
98- "end_percent" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 })
99- }}
100- CATEGORY = "conditioning/controlnet"
101- DEPRECATED = True
102-
103-
104- class SkipLayerGuidanceSD3 (comfy_extras .nodes_slg .SkipLayerGuidanceDiT ):
128+ def execute (cls , positive , negative , control_net , image , strength , start_percent , end_percent , vae = None ) -> io .NodeOutput :
129+ if strength == 0 :
130+ return io .NodeOutput (positive , negative )
131+
132+ control_hint = image .movedim (- 1 , 1 )
133+ cnets = {}
134+
135+ out = []
136+ for conditioning in [positive , negative ]:
137+ c = []
138+ for t in conditioning :
139+ d = t [1 ].copy ()
140+
141+ prev_cnet = d .get ('control' , None )
142+ if prev_cnet in cnets :
143+ c_net = cnets [prev_cnet ]
144+ else :
145+ c_net = control_net .copy ().set_cond_hint (control_hint , strength , (start_percent , end_percent ),
146+ vae = vae , extra_concat = [])
147+ c_net .set_previous_controlnet (prev_cnet )
148+ cnets [prev_cnet ] = c_net
149+
150+ d ['control' ] = c_net
151+ d ['control_apply_to_uncond' ] = False
152+ n = [t [0 ], d ]
153+ c .append (n )
154+ out .append (c )
155+ return io .NodeOutput (out [0 ], out [1 ])
156+
157+
158+ class SkipLayerGuidanceSD3 (io .ComfyNode ):
105159 '''
106160 Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
107161 Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
108162 Experimental implementation by Dango233@StabilityAI.
109163 '''
164+
165+ @classmethod
166+ def define_schema (cls ):
167+ return io .Schema (
168+ node_id = "SkipLayerGuidanceSD3" ,
169+ category = "advanced/guidance" ,
170+ description = "Generic version of SkipLayerGuidance node that can be used on every DiT model." ,
171+ inputs = [
172+ io .Model .Input ("model" ),
173+ io .String .Input ("layers" , default = "7, 8, 9" , multiline = False ),
174+ io .Float .Input ("scale" , default = 3.0 , min = 0.0 , max = 10.0 , step = 0.1 ),
175+ io .Float .Input ("start_percent" , default = 0.01 , min = 0.0 , max = 1.0 , step = 0.001 ),
176+ io .Float .Input ("end_percent" , default = 0.15 , min = 0.0 , max = 1.0 , step = 0.001 ),
177+ ],
178+ outputs = [
179+ io .Model .Output (),
180+ ],
181+ is_experimental = True ,
182+ )
183+
110184 @classmethod
111- def INPUT_TYPES (s ):
112- return {"required" : {"model" : ("MODEL" , ),
113- "layers" : ("STRING" , {"default" : "7, 8, 9" , "multiline" : False }),
114- "scale" : ("FLOAT" , {"default" : 3.0 , "min" : 0.0 , "max" : 10.0 , "step" : 0.1 }),
115- "start_percent" : ("FLOAT" , {"default" : 0.01 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 }),
116- "end_percent" : ("FLOAT" , {"default" : 0.15 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 })
117- }}
118- RETURN_TYPES = ("MODEL" ,)
119- FUNCTION = "skip_guidance_sd3"
120-
121- CATEGORY = "advanced/guidance"
122-
123- def skip_guidance_sd3 (self , model , layers , scale , start_percent , end_percent ):
124- return self .skip_guidance (model = model , scale = scale , start_percent = start_percent , end_percent = end_percent , double_layers = layers )
125-
126-
127- NODE_CLASS_MAPPINGS = {
128- "TripleCLIPLoader" : TripleCLIPLoader ,
129- "EmptySD3LatentImage" : EmptySD3LatentImage ,
130- "CLIPTextEncodeSD3" : CLIPTextEncodeSD3 ,
131- "ControlNetApplySD3" : ControlNetApplySD3 ,
132- "SkipLayerGuidanceSD3" : SkipLayerGuidanceSD3 ,
133- }
134-
135- NODE_DISPLAY_NAME_MAPPINGS = {
136- # Sampling
137- "ControlNetApplySD3" : "Apply Controlnet with VAE" ,
138- }
185+ def execute (cls , model , layers , scale , start_percent , end_percent ) -> io .NodeOutput :
186+ return SkipLayerGuidanceDiT ().execute (model = model , scale = scale , start_percent = start_percent , end_percent = end_percent , double_layers = layers )
187+
188+
189+ class SD3Extension (ComfyExtension ):
190+ @override
191+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
192+ return [
193+ TripleCLIPLoader ,
194+ EmptySD3LatentImage ,
195+ CLIPTextEncodeSD3 ,
196+ ControlNetApplySD3 ,
197+ SkipLayerGuidanceSD3 ,
198+ ]
199+
200+
201+ async def comfy_entrypoint () -> SD3Extension :
202+ return SD3Extension ()
0 commit comments