55import os
66import logging
77from enum import Enum
8+ from typing_extensions import override
9+ from comfy_api .latest import ComfyExtension , io
810
911CLAMP_QUANTILE = 0.99
1012
@@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
7173 output_sd ["{}{}.diff_b" .format (prefix_lora , k [len (prefix_model ):- 5 ])] = sd [k ].contiguous ().half ().cpu ()
7274 return output_sd
7375
74- class LoraSave :
75- def __init__ (self ):
76- self .output_dir = folder_paths .get_output_directory ()
76+ class LoraSave (io .ComfyNode ):
77+ @classmethod
78+ def define_schema (cls ):
79+ return io .Schema (
80+ node_id = "LoraSave" ,
81+ display_name = "Extract and Save Lora" ,
82+ category = "_for_testing" ,
83+ inputs = [
84+ io .String .Input ("filename_prefix" , default = "loras/ComfyUI_extracted_lora" ),
85+ io .Int .Input ("rank" , default = 8 , min = 1 , max = 4096 , step = 1 ),
86+ io .Combo .Input ("lora_type" , options = tuple (LORA_TYPES .keys ())),
87+ io .Boolean .Input ("bias_diff" , default = True ),
88+ io .Model .Input (
89+ "model_diff" ,
90+ tooltip = "The ModelSubtract output to be converted to a lora." ,
91+ optional = True ,
92+ ),
93+ io .Clip .Input (
94+ "text_encoder_diff" ,
95+ tooltip = "The CLIPSubtract output to be converted to a lora." ,
96+ optional = True ,
97+ ),
98+ ],
99+ is_experimental = True ,
100+ is_output_node = True ,
101+ )
77102
78103 @classmethod
79- def INPUT_TYPES (s ):
80- return {"required" : {"filename_prefix" : ("STRING" , {"default" : "loras/ComfyUI_extracted_lora" }),
81- "rank" : ("INT" , {"default" : 8 , "min" : 1 , "max" : 4096 , "step" : 1 }),
82- "lora_type" : (tuple (LORA_TYPES .keys ()),),
83- "bias_diff" : ("BOOLEAN" , {"default" : True }),
84- },
85- "optional" : {"model_diff" : ("MODEL" , {"tooltip" : "The ModelSubtract output to be converted to a lora." }),
86- "text_encoder_diff" : ("CLIP" , {"tooltip" : "The CLIPSubtract output to be converted to a lora." })},
87- }
88- RETURN_TYPES = ()
89- FUNCTION = "save"
90- OUTPUT_NODE = True
91-
92- CATEGORY = "_for_testing"
93-
94- def save (self , filename_prefix , rank , lora_type , bias_diff , model_diff = None , text_encoder_diff = None ):
104+ def execute (cls , filename_prefix , rank , lora_type , bias_diff , model_diff = None , text_encoder_diff = None ) -> io .NodeOutput :
95105 if model_diff is None and text_encoder_diff is None :
96- return {}
106+ return io . NodeOutput ()
97107
98108 lora_type = LORA_TYPES .get (lora_type )
99- full_output_folder , filename , counter , subfolder , filename_prefix = folder_paths .get_save_image_path (filename_prefix , self . output_dir )
109+ full_output_folder , filename , counter , subfolder , filename_prefix = folder_paths .get_save_image_path (filename_prefix , folder_paths . get_output_directory () )
100110
101111 output_sd = {}
102112 if model_diff is not None :
@@ -108,12 +118,16 @@ def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, tex
108118 output_checkpoint = os .path .join (full_output_folder , output_checkpoint )
109119
110120 comfy .utils .save_torch_file (output_sd , output_checkpoint , metadata = None )
111- return {}
121+ return io .NodeOutput ()
122+
123+
124+ class LoraSaveExtension (ComfyExtension ):
125+ @override
126+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
127+ return [
128+ LoraSave ,
129+ ]
112130
113- NODE_CLASS_MAPPINGS = {
114- "LoraSave" : LoraSave
115- }
116131
117- NODE_DISPLAY_NAME_MAPPINGS = {
118- "LoraSave" : "Extract and Save Lora"
119- }
132+ async def comfy_entrypoint () -> LoraSaveExtension :
133+ return LoraSaveExtension ()
0 commit comments