Skip to content

Commit 19fbfd8

Browse files
authored
feat: override text encoders for unet models (leejet#682)
1 parent 76c7262 commit 19fbfd8

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

model.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,15 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
15391539
return true;
15401540
}
15411541

1542+
bool ModelLoader::model_is_unet() {
1543+
for (auto& tensor_storage : tensor_storages) {
1544+
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
1545+
return true;
1546+
}
1547+
}
1548+
return false;
1549+
}
1550+
15421551
SDVersion ModelLoader::get_sd_version() {
15431552
TensorStorage token_embedding_weight, input_block_weight;
15441553
bool input_block_checked = false;

model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class ModelLoader {
210210
std::map<std::string, enum ggml_type> tensor_storages_types;
211211

212212
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
213+
bool model_is_unet();
213214
SDVersion get_sd_version();
214215
ggml_type get_sd_wtype();
215216
ggml_type get_conditioner_wtype();

stable-diffusion.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,25 @@ class StableDiffusionGGML {
213213
}
214214
}
215215

216+
if (diffusion_model_path.size() > 0) {
217+
LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str());
218+
if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) {
219+
LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str());
220+
}
221+
}
222+
223+
bool is_unet = model_loader.model_is_unet();
224+
216225
if (clip_l_path.size() > 0) {
217226
LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str());
218-
if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) {
227+
if (!model_loader.init_from_file(clip_l_path, is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.")) {
219228
LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str());
220229
}
221230
}
222231

223232
if (clip_g_path.size() > 0) {
224233
LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str());
225-
if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) {
234+
if (!model_loader.init_from_file(clip_g_path, is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.")) {
226235
LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str());
227236
}
228237
}
@@ -234,13 +243,6 @@ class StableDiffusionGGML {
234243
}
235244
}
236245

237-
if (diffusion_model_path.size() > 0) {
238-
LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str());
239-
if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) {
240-
LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str());
241-
}
242-
}
243-
244246
if (vae_path.size() > 0) {
245247
LOG_INFO("loading vae from '%s'", vae_path.c_str());
246248
if (!model_loader.init_from_file(vae_path, "vae.")) {

0 commit comments

Comments
 (0)