Skip to content

Commit 0adfe17

Browse files
authored
convert nodes_lora_extract.py to V3 schema (comfyanonymous#10182)
1 parent b879807 commit 0adfe17

File tree

1 file changed

+42
-28
lines changed

1 file changed

+42
-28
lines changed

comfy_extras/nodes_lora_extract.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import os
66
import logging
77
from enum import Enum
8+
from typing_extensions import override
9+
from comfy_api.latest import ComfyExtension, io
810

911
CLAMP_QUANTILE = 0.99
1012

@@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
7173
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
7274
return output_sd
7375

74-
class LoraSave:
75-
def __init__(self):
76-
self.output_dir = folder_paths.get_output_directory()
76+
class LoraSave(io.ComfyNode):
77+
@classmethod
78+
def define_schema(cls):
79+
return io.Schema(
80+
node_id="LoraSave",
81+
display_name="Extract and Save Lora",
82+
category="_for_testing",
83+
inputs=[
84+
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
85+
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
86+
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
87+
io.Boolean.Input("bias_diff", default=True),
88+
io.Model.Input(
89+
"model_diff",
90+
tooltip="The ModelSubtract output to be converted to a lora.",
91+
optional=True,
92+
),
93+
io.Clip.Input(
94+
"text_encoder_diff",
95+
tooltip="The CLIPSubtract output to be converted to a lora.",
96+
optional=True,
97+
),
98+
],
99+
is_experimental=True,
100+
is_output_node=True,
101+
)
77102

78103
@classmethod
79-
def INPUT_TYPES(s):
80-
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
81-
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
82-
"lora_type": (tuple(LORA_TYPES.keys()),),
83-
"bias_diff": ("BOOLEAN", {"default": True}),
84-
},
85-
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
86-
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
87-
}
88-
RETURN_TYPES = ()
89-
FUNCTION = "save"
90-
OUTPUT_NODE = True
91-
92-
CATEGORY = "_for_testing"
93-
94-
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
104+
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
95105
if model_diff is None and text_encoder_diff is None:
96-
return {}
106+
return io.NodeOutput()
97107

98108
lora_type = LORA_TYPES.get(lora_type)
99-
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
109+
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
100110

101111
output_sd = {}
102112
if model_diff is not None:
@@ -108,12 +118,16 @@ def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, tex
108118
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
109119

110120
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
111-
return {}
121+
return io.NodeOutput()
122+
123+
124+
class LoraSaveExtension(ComfyExtension):
125+
@override
126+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
127+
return [
128+
LoraSave,
129+
]
112130

113-
NODE_CLASS_MAPPINGS = {
114-
"LoraSave": LoraSave
115-
}
116131

117-
NODE_DISPLAY_NAME_MAPPINGS = {
118-
"LoraSave": "Extract and Save Lora"
119-
}
132+
async def comfy_entrypoint() -> LoraSaveExtension:
133+
return LoraSaveExtension()

0 commit comments

Comments
 (0)