55import  os 
66import  logging 
77from  enum  import  Enum 
8+ from  typing_extensions  import  override 
9+ from  comfy_api .latest  import  ComfyExtension , io 
810
911CLAMP_QUANTILE  =  0.99 
1012
@@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
7173            output_sd ["{}{}.diff_b" .format (prefix_lora , k [len (prefix_model ):- 5 ])] =  sd [k ].contiguous ().half ().cpu ()
7274    return  output_sd 
7375
74- class  LoraSave :
75-     def  __init__ (self ):
76-         self .output_dir  =  folder_paths .get_output_directory ()
76+ class  LoraSave (io .ComfyNode ):
77+     @classmethod  
78+     def  define_schema (cls ):
79+         return  io .Schema (
80+             node_id = "LoraSave" ,
81+             display_name = "Extract and Save Lora" ,
82+             category = "_for_testing" ,
83+             inputs = [
84+                 io .String .Input ("filename_prefix" , default = "loras/ComfyUI_extracted_lora" ),
85+                 io .Int .Input ("rank" , default = 8 , min = 1 , max = 4096 , step = 1 ),
86+                 io .Combo .Input ("lora_type" , options = tuple (LORA_TYPES .keys ())),
87+                 io .Boolean .Input ("bias_diff" , default = True ),
88+                 io .Model .Input (
89+                     "model_diff" ,
90+                     tooltip = "The ModelSubtract output to be converted to a lora." ,
91+                     optional = True ,
92+                 ),
93+                 io .Clip .Input (
94+                   "text_encoder_diff" ,
95+                     tooltip = "The CLIPSubtract output to be converted to a lora." ,
96+                     optional = True ,
97+                 ),
98+             ],
99+             is_experimental = True ,
100+             is_output_node = True ,
101+         )
77102
78103    @classmethod  
79-     def  INPUT_TYPES (s ):
80-         return  {"required" : {"filename_prefix" : ("STRING" , {"default" : "loras/ComfyUI_extracted_lora" }),
81-                               "rank" : ("INT" , {"default" : 8 , "min" : 1 , "max" : 4096 , "step" : 1 }),
82-                               "lora_type" : (tuple (LORA_TYPES .keys ()),),
83-                               "bias_diff" : ("BOOLEAN" , {"default" : True }),
84-                             },
85-                 "optional" : {"model_diff" : ("MODEL" , {"tooltip" : "The ModelSubtract output to be converted to a lora." }),
86-                              "text_encoder_diff" : ("CLIP" , {"tooltip" : "The CLIPSubtract output to be converted to a lora." })},
87-     }
88-     RETURN_TYPES  =  ()
89-     FUNCTION  =  "save" 
90-     OUTPUT_NODE  =  True 
91- 
92-     CATEGORY  =  "_for_testing" 
93- 
94-     def  save (self , filename_prefix , rank , lora_type , bias_diff , model_diff = None , text_encoder_diff = None ):
104+     def  execute (cls , filename_prefix , rank , lora_type , bias_diff , model_diff = None , text_encoder_diff = None ) ->  io .NodeOutput :
95105        if  model_diff  is  None  and  text_encoder_diff  is  None :
96-             return  {} 
106+             return  io . NodeOutput () 
97107
98108        lora_type  =  LORA_TYPES .get (lora_type )
99-         full_output_folder , filename , counter , subfolder , filename_prefix  =  folder_paths .get_save_image_path (filename_prefix , self . output_dir )
109+         full_output_folder , filename , counter , subfolder , filename_prefix  =  folder_paths .get_save_image_path (filename_prefix , folder_paths . get_output_directory () )
100110
101111        output_sd  =  {}
102112        if  model_diff  is  not   None :
@@ -108,12 +118,16 @@ def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, tex
108118        output_checkpoint  =  os .path .join (full_output_folder , output_checkpoint )
109119
110120        comfy .utils .save_torch_file (output_sd , output_checkpoint , metadata = None )
111-         return  {}
121+         return  io .NodeOutput ()
122+ 
123+ 
124+ class  LoraSaveExtension (ComfyExtension ):
125+     @override  
126+     async  def  get_node_list (self ) ->  list [type [io .ComfyNode ]]:
127+         return  [
128+             LoraSave ,
129+         ]
112130
113- NODE_CLASS_MAPPINGS  =  {
114-     "LoraSave" : LoraSave 
115- }
116131
117- NODE_DISPLAY_NAME_MAPPINGS  =  {
118-     "LoraSave" : "Extract and Save Lora" 
119- }
132+ async  def  comfy_entrypoint () ->  LoraSaveExtension :
133+     return  LoraSaveExtension ()
0 commit comments