33import comfy .model_management
44import nodes
55import torch
6- import comfy_extras .nodes_slg
6+ from typing_extensions import override
7+ from comfy_api .latest import ComfyExtension , io
8+ from comfy_extras .nodes_slg import SkipLayerGuidanceDiT
79
810
9- class TripleCLIPLoader :
11+ class TripleCLIPLoader ( io . ComfyNode ) :
1012 @classmethod
11- def INPUT_TYPES (s ):
12- return {"required" : { "clip_name1" : (folder_paths .get_filename_list ("text_encoders" ), ), "clip_name2" : (folder_paths .get_filename_list ("text_encoders" ), ), "clip_name3" : (folder_paths .get_filename_list ("text_encoders" ), )
13- }}
14- RETURN_TYPES = ("CLIP" ,)
15- FUNCTION = "load_clip"
13+ def define_schema (cls ):
14+ return io .Schema (
15+ node_id = "TripleCLIPLoader" ,
16+ category = "advanced/loaders" ,
17+ description = "[Recipes]\n \n sd3: clip-l, clip-g, t5" ,
18+ inputs = [
19+ io .Combo .Input ("clip_name1" , options = folder_paths .get_filename_list ("text_encoders" )),
20+ io .Combo .Input ("clip_name2" , options = folder_paths .get_filename_list ("text_encoders" )),
21+ io .Combo .Input ("clip_name3" , options = folder_paths .get_filename_list ("text_encoders" )),
22+ ],
23+ outputs = [
24+ io .Clip .Output (),
25+ ],
26+ )
1627
17- CATEGORY = "advanced/loaders"
18-
19- DESCRIPTION = "[Recipes]\n \n sd3: clip-l, clip-g, t5"
20-
21- def load_clip (self , clip_name1 , clip_name2 , clip_name3 ):
28+ @classmethod
29+ def execute (cls , clip_name1 , clip_name2 , clip_name3 ) -> io .NodeOutput :
2230 clip_path1 = folder_paths .get_full_path_or_raise ("text_encoders" , clip_name1 )
2331 clip_path2 = folder_paths .get_full_path_or_raise ("text_encoders" , clip_name2 )
2432 clip_path3 = folder_paths .get_full_path_or_raise ("text_encoders" , clip_name3 )
2533 clip = comfy .sd .load_clip (ckpt_paths = [clip_path1 , clip_path2 , clip_path3 ], embedding_directory = folder_paths .get_folder_paths ("embeddings" ))
26- return (clip , )
34+ return io . NodeOutput (clip )
2735
36+ load_clip = execute # TODO: remove
2837
29- class EmptySD3LatentImage :
30- def __init__ (self ):
31- self .device = comfy .model_management .intermediate_device ()
3238
39+ class EmptySD3LatentImage (io .ComfyNode ):
3340 @classmethod
34- def INPUT_TYPES (s ):
35- return {"required" : { "width" : ("INT" , {"default" : 1024 , "min" : 16 , "max" : nodes .MAX_RESOLUTION , "step" : 16 }),
36- "height" : ("INT" , {"default" : 1024 , "min" : 16 , "max" : nodes .MAX_RESOLUTION , "step" : 16 }),
37- "batch_size" : ("INT" , {"default" : 1 , "min" : 1 , "max" : 4096 })}}
38- RETURN_TYPES = ("LATENT" ,)
39- FUNCTION = "generate"
41+ def define_schema (cls ):
42+ return io .Schema (
43+ node_id = "EmptySD3LatentImage" ,
44+ category = "latent/sd3" ,
45+ inputs = [
46+ io .Int .Input ("width" , default = 1024 , min = 16 , max = nodes .MAX_RESOLUTION , step = 16 ),
47+ io .Int .Input ("height" , default = 1024 , min = 16 , max = nodes .MAX_RESOLUTION , step = 16 ),
48+ io .Int .Input ("batch_size" , default = 1 , min = 1 , max = 4096 ),
49+ ],
50+ outputs = [
51+ io .Latent .Output (),
52+ ],
53+ )
4054
41- CATEGORY = "latent/sd3"
55+ @classmethod
56+ def execute (cls , width , height , batch_size = 1 ) -> io .NodeOutput :
57+ latent = torch .zeros ([batch_size , 16 , height // 8 , width // 8 ], device = comfy .model_management .intermediate_device ())
58+ return io .NodeOutput ({"samples" :latent })
4259
43- def generate (self , width , height , batch_size = 1 ):
44- latent = torch .zeros ([batch_size , 16 , height // 8 , width // 8 ], device = self .device )
45- return ({"samples" :latent }, )
60+ generate = execute # TODO: remove
4661
4762
48- class CLIPTextEncodeSD3 :
63+ class CLIPTextEncodeSD3 ( io . ComfyNode ) :
4964 @classmethod
50- def INPUT_TYPES (s ):
51- return {"required" : {
52- "clip" : ("CLIP" , ),
53- "clip_l" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
54- "clip_g" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
55- "t5xxl" : ("STRING" , {"multiline" : True , "dynamicPrompts" : True }),
56- "empty_padding" : (["none" , "empty_prompt" ], )
57- }}
58- RETURN_TYPES = ("CONDITIONING" ,)
59- FUNCTION = "encode"
60-
61- CATEGORY = "advanced/conditioning"
62-
63- def encode (self , clip , clip_l , clip_g , t5xxl , empty_padding ):
65+ def define_schema (cls ):
66+ return io .Schema (
67+ node_id = "CLIPTextEncodeSD3" ,
68+ category = "advanced/conditioning" ,
69+ inputs = [
70+ io .Clip .Input ("clip" ),
71+ io .String .Input ("clip_l" , multiline = True , dynamic_prompts = True ),
72+ io .String .Input ("clip_g" , multiline = True , dynamic_prompts = True ),
73+ io .String .Input ("t5xxl" , multiline = True , dynamic_prompts = True ),
74+ io .Combo .Input ("empty_padding" , options = ["none" , "empty_prompt" ]),
75+ ],
76+ outputs = [
77+ io .Conditioning .Output (),
78+ ],
79+ )
80+
81+ @classmethod
82+ def execute (cls , clip , clip_l , clip_g , t5xxl , empty_padding ) -> io .NodeOutput :
6483 no_padding = empty_padding == "none"
6584
6685 tokens = clip .tokenize (clip_g )
@@ -82,57 +101,112 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
82101 tokens ["l" ] += empty ["l" ]
83102 while len (tokens ["l" ]) > len (tokens ["g" ]):
84103 tokens ["g" ] += empty ["g" ]
85- return (clip .encode_from_tokens_scheduled (tokens ), )
104+ return io .NodeOutput (clip .encode_from_tokens_scheduled (tokens ))
105+
106+ encode = execute # TODO: remove
86107
87108
88- class ControlNetApplySD3 (nodes .ControlNetApplyAdvanced ):
109+ class ControlNetApplySD3 (io .ComfyNode ):
110+ @classmethod
111+ def define_schema (cls ) -> io .Schema :
112+ return io .Schema (
113+ node_id = "ControlNetApplySD3" ,
114+ display_name = "Apply Controlnet with VAE" ,
115+ category = "conditioning/controlnet" ,
116+ inputs = [
117+ io .Conditioning .Input ("positive" ),
118+ io .Conditioning .Input ("negative" ),
119+ io .ControlNet .Input ("control_net" ),
120+ io .Vae .Input ("vae" ),
121+ io .Image .Input ("image" ),
122+ io .Float .Input ("strength" , default = 1.0 , min = 0.0 , max = 10.0 , step = 0.01 ),
123+ io .Float .Input ("start_percent" , default = 0.0 , min = 0.0 , max = 1.0 , step = 0.001 ),
124+ io .Float .Input ("end_percent" , default = 1.0 , min = 0.0 , max = 1.0 , step = 0.001 ),
125+ ],
126+ outputs = [
127+ io .Conditioning .Output (display_name = "positive" ),
128+ io .Conditioning .Output (display_name = "negative" ),
129+ ],
130+ is_deprecated = True ,
131+ )
132+
89133 @classmethod
90- def INPUT_TYPES (s ):
91- return {"required" : {"positive" : ("CONDITIONING" , ),
92- "negative" : ("CONDITIONING" , ),
93- "control_net" : ("CONTROL_NET" , ),
94- "vae" : ("VAE" , ),
95- "image" : ("IMAGE" , ),
96- "strength" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 10.0 , "step" : 0.01 }),
97- "start_percent" : ("FLOAT" , {"default" : 0.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 }),
98- "end_percent" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 })
99- }}
100- CATEGORY = "conditioning/controlnet"
101- DEPRECATED = True
102-
103-
104- class SkipLayerGuidanceSD3 (comfy_extras .nodes_slg .SkipLayerGuidanceDiT ):
134+ def execute (cls , positive , negative , control_net , image , strength , start_percent , end_percent , vae = None ) -> io .NodeOutput :
135+ if strength == 0 :
136+ return io .NodeOutput (positive , negative )
137+
138+ control_hint = image .movedim (- 1 , 1 )
139+ cnets = {}
140+
141+ out = []
142+ for conditioning in [positive , negative ]:
143+ c = []
144+ for t in conditioning :
145+ d = t [1 ].copy ()
146+
147+ prev_cnet = d .get ('control' , None )
148+ if prev_cnet in cnets :
149+ c_net = cnets [prev_cnet ]
150+ else :
151+ c_net = control_net .copy ().set_cond_hint (control_hint , strength , (start_percent , end_percent ),
152+ vae = vae , extra_concat = [])
153+ c_net .set_previous_controlnet (prev_cnet )
154+ cnets [prev_cnet ] = c_net
155+
156+ d ['control' ] = c_net
157+ d ['control_apply_to_uncond' ] = False
158+ n = [t [0 ], d ]
159+ c .append (n )
160+ out .append (c )
161+ return io .NodeOutput (out [0 ], out [1 ])
162+
163+ apply_controlnet = execute # TODO: remove
164+
165+
166+ class SkipLayerGuidanceSD3 (io .ComfyNode ):
105167 '''
106168 Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
107169 Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
108170 Experimental implementation by Dango233@StabilityAI.
109171 '''
172+
173+ @classmethod
174+ def define_schema (cls ):
175+ return io .Schema (
176+ node_id = "SkipLayerGuidanceSD3" ,
177+ category = "advanced/guidance" ,
178+ description = "Generic version of SkipLayerGuidance node that can be used on every DiT model." ,
179+ inputs = [
180+ io .Model .Input ("model" ),
181+ io .String .Input ("layers" , default = "7, 8, 9" , multiline = False ),
182+ io .Float .Input ("scale" , default = 3.0 , min = 0.0 , max = 10.0 , step = 0.1 ),
183+ io .Float .Input ("start_percent" , default = 0.01 , min = 0.0 , max = 1.0 , step = 0.001 ),
184+ io .Float .Input ("end_percent" , default = 0.15 , min = 0.0 , max = 1.0 , step = 0.001 ),
185+ ],
186+ outputs = [
187+ io .Model .Output (),
188+ ],
189+ is_experimental = True ,
190+ )
191+
110192 @classmethod
111- def INPUT_TYPES (s ):
112- return {"required" : {"model" : ("MODEL" , ),
113- "layers" : ("STRING" , {"default" : "7, 8, 9" , "multiline" : False }),
114- "scale" : ("FLOAT" , {"default" : 3.0 , "min" : 0.0 , "max" : 10.0 , "step" : 0.1 }),
115- "start_percent" : ("FLOAT" , {"default" : 0.01 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 }),
116- "end_percent" : ("FLOAT" , {"default" : 0.15 , "min" : 0.0 , "max" : 1.0 , "step" : 0.001 })
117- }}
118- RETURN_TYPES = ("MODEL" ,)
119- FUNCTION = "skip_guidance_sd3"
120-
121- CATEGORY = "advanced/guidance"
122-
123- def skip_guidance_sd3 (self , model , layers , scale , start_percent , end_percent ):
124- return self .skip_guidance (model = model , scale = scale , start_percent = start_percent , end_percent = end_percent , double_layers = layers )
125-
126-
127- NODE_CLASS_MAPPINGS = {
128- "TripleCLIPLoader" : TripleCLIPLoader ,
129- "EmptySD3LatentImage" : EmptySD3LatentImage ,
130- "CLIPTextEncodeSD3" : CLIPTextEncodeSD3 ,
131- "ControlNetApplySD3" : ControlNetApplySD3 ,
132- "SkipLayerGuidanceSD3" : SkipLayerGuidanceSD3 ,
133- }
134-
135- NODE_DISPLAY_NAME_MAPPINGS = {
136- # Sampling
137- "ControlNetApplySD3" : "Apply Controlnet with VAE" ,
138- }
193+ def execute (cls , model , layers , scale , start_percent , end_percent ) -> io .NodeOutput :
194+ return SkipLayerGuidanceDiT ().execute (model = model , scale = scale , start_percent = start_percent , end_percent = end_percent , double_layers = layers )
195+
196+ skip_guidance_sd3 = execute # TODO: remove
197+
198+
199+ class SD3Extension (ComfyExtension ):
200+ @override
201+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
202+ return [
203+ TripleCLIPLoader ,
204+ EmptySD3LatentImage ,
205+ CLIPTextEncodeSD3 ,
206+ ControlNetApplySD3 ,
207+ SkipLayerGuidanceSD3 ,
208+ ]
209+
210+
211+ async def comfy_entrypoint () -> SD3Extension :
212+ return SD3Extension ()
0 commit comments