@@ -338,6 +338,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
338338 {" to_v" , " v" },
339339 {" to_out_0" , " proj_out" },
340340 {" group_norm" , " norm" },
341+ {" key" , " k" },
342+ {" query" , " q" },
343+ {" value" , " v" },
344+ {" proj_attn" , " proj_out" },
341345 },
342346 },
343347 {
@@ -362,6 +366,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
362366 {" to_v" , " v" },
363367 {" to_out.0" , " proj_out" },
364368 {" group_norm" , " norm" },
369+ {" key" , " k" },
370+ {" query" , " q" },
371+ {" value" , " v" },
372+ {" proj_attn" , " proj_out" },
365373 },
366374 },
367375 {
@@ -433,6 +441,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
433441 return format (" model%cdiffusion_model%ctime_embed%c" , seq, seq, seq) + std::to_string (std::stoi (m[0 ]) * 2 - 2 ) + m[1 ];
434442 }
435443
444+ if (match (m, std::regex (format (" unet%cadd_embedding%clinear_(\\ d+)(.*)" , seq, seq)), key)) {
445+ return format (" model%cdiffusion_model%clabel_emb%c0%c" , seq, seq, seq, seq) + std::to_string (std::stoi (m[0 ]) * 2 - 2 ) + m[1 ];
446+ }
447+
436448 if (match (m, std::regex (format (" unet%cdown_blocks%c(\\ d+)%c(attentions|resnets)%c(\\ d+)%c(.+)" , seq, seq, seq, seq, seq)), key)) {
437449 std::string suffix = get_converted_suffix (m[1 ], m[3 ]);
438450 // LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
@@ -470,6 +482,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
470482 return format (" cond_stage_model%ctransformer%ctext_model" , seq, seq) + m[0 ];
471483 }
472484
485+ // clip-g
486+ if (match (m, std::regex (format (" te%c1%ctext_model%cencoder%clayers%c(\\ d+)%c(.+)" , seq, seq, seq, seq, seq, seq)), key)) {
487+ return format (" cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c" , seq, seq, seq, seq, seq, seq) + m[0 ] + seq + m[1 ];
488+ }
489+
490+ if (match (m, std::regex (format (" te%c1%ctext_model(.*)" , seq, seq)), key)) {
491+ return format (" cond_stage_model%c1%ctransformer%ctext_model" , seq, seq, seq) + m[0 ];
492+ }
493+
494+ if (match (m, std::regex (format (" te%c1%ctext_projection" , seq, seq)), key)) {
495+ return format (" cond_stage_model%c1%ctransformer%ctext_model%ctext_projection" , seq, seq, seq, seq);
496+ }
497+
473498 // vae
474499 if (match (m, std::regex (format (" vae%c(.*)%cconv_norm_out(.*)" , seq, seq)), key)) {
475500 return format (" first_stage_model%c%s%cnorm_out%s" , seq, m[0 ].c_str (), seq, m[1 ].c_str ());
@@ -606,6 +631,8 @@ std::string convert_tensor_name(std::string name) {
606631 std::string new_key = convert_diffusers_name_to_compvis (name_without_network_parts, ' .' );
607632 if (new_key.empty ()) {
608633 new_name = name;
634+ } else if (new_key == " cond_stage_model.1.transformer.text_model.text_projection" ) {
635+ new_name = new_key;
609636 } else {
610637 new_name = new_key + " ." + network_part;
611638 }
@@ -1029,10 +1056,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
10291056 ttype = GGML_TYPE_F32;
10301057 } else if (dtype == " F32" ) {
10311058 ttype = GGML_TYPE_F32;
1059+ } else if (dtype == " F64" ) {
1060+ ttype = GGML_TYPE_F64;
10321061 } else if (dtype == " F8_E4M3" ) {
10331062 ttype = GGML_TYPE_F16;
10341063 } else if (dtype == " F8_E5M2" ) {
10351064 ttype = GGML_TYPE_F16;
1065+ } else if (dtype == " I64" ) {
1066+ ttype = GGML_TYPE_I64;
10361067 }
10371068 return ttype;
10381069}
@@ -1045,6 +1076,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10451076 std::ifstream file (file_path, std::ios::binary);
10461077 if (!file.is_open ()) {
10471078 LOG_ERROR (" failed to open '%s'" , file_path.c_str ());
1079+ file_paths_.pop_back ();
10481080 return false ;
10491081 }
10501082
@@ -1056,6 +1088,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10561088 // read header size
10571089 if (file_size_ <= ST_HEADER_SIZE_LEN) {
10581090 LOG_ERROR (" invalid safetensor file '%s'" , file_path.c_str ());
1091+ file_paths_.pop_back ();
10591092 return false ;
10601093 }
10611094
@@ -1069,6 +1102,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10691102 size_t header_size_ = read_u64 (header_size_buf);
10701103 if (header_size_ >= file_size_) {
10711104 LOG_ERROR (" invalid safetensor file '%s'" , file_path.c_str ());
1105+ file_paths_.pop_back ();
10721106 return false ;
10731107 }
10741108
@@ -1079,6 +1113,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10791113 file.read (header_buf.data (), header_size_);
10801114 if (!file) {
10811115 LOG_ERROR (" read safetensors header failed: '%s'" , file_path.c_str ());
1116+ file_paths_.pop_back ();
10821117 return false ;
10831118 }
10841119
@@ -1134,6 +1169,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11341169 n_dims = 1 ;
11351170 }
11361171
1172+
11371173 TensorStorage tensor_storage (prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
11381174 tensor_storage.reverse_ne ();
11391175
@@ -1166,18 +1202,45 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11661202/* ================================================= DiffusersModelLoader ==================================================*/
11671203
11681204bool ModelLoader::init_from_diffusers_file (const std::string& file_path, const std::string& prefix) {
1169- std::string unet_path = path_join (file_path, " unet/diffusion_pytorch_model.safetensors" );
1170- std::string vae_path = path_join (file_path, " vae/diffusion_pytorch_model.safetensors" );
1171- std::string clip_path = path_join (file_path, " text_encoder/model.safetensors" );
1205+ std::string unet_path = path_join (file_path, " unet/diffusion_pytorch_model.safetensors" );
1206+ std::string vae_path = path_join (file_path, " vae/diffusion_pytorch_model.safetensors" );
1207+ std::string clip_path = path_join (file_path, " text_encoder/model.safetensors" );
1208+ std::string clip_g_path = path_join (file_path, " text_encoder_2/model.safetensors" );
11721209
11731210 if (!init_from_safetensors_file (unet_path, " unet." )) {
11741211 return false ;
11751212 }
1213+ for (auto ts : tensor_storages) {
1214+ if (ts.name .find (" add_embedding" ) != std::string::npos || ts.name .find (" label_emb" ) != std::string::npos) {
1215+ // probably SDXL
1216+ LOG_DEBUG (" Fixing name for SDXL output blocks.2.2" );
1217+ for (auto & tensor_storage : tensor_storages) {
1218+ int len = 34 ;
1219+ auto pos = tensor_storage.name .find (" unet.up_blocks.0.upsamplers.0.conv" );
1220+ if (pos == std::string::npos) {
1221+ len = 44 ;
1222+ pos = tensor_storage.name .find (" model.diffusion_model.output_blocks.2.1.conv" );
1223+ }
1224+ if (pos != std::string::npos) {
1225+ tensor_storage.name = " model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name .substr (len);
1226+ LOG_DEBUG (" NEW NAME: %s" , tensor_storage.name .c_str ());
1227+ add_preprocess_tensor_storage_types (tensor_storages_types, tensor_storage.name , tensor_storage.type );
1228+ }
1229+ }
1230+ break ;
1231+ }
1232+ }
1233+
11761234 if (!init_from_safetensors_file (vae_path, " vae." )) {
1177- return false ;
1235+ LOG_WARN (" Couldn't find working VAE in %s" , file_path.c_str ());
1236+ // return false;
11781237 }
11791238 if (!init_from_safetensors_file (clip_path, " te." )) {
1180- return false ;
1239+ LOG_WARN (" Couldn't find working text encoder in %s" , file_path.c_str ());
1240+ // return false;
1241+ }
1242+ if (!init_from_safetensors_file (clip_g_path, " te.1." )) {
1243+ LOG_DEBUG (" Couldn't find working second text encoder in %s" , file_path.c_str ());
11811244 }
11821245 return true ;
11831246}
@@ -1571,7 +1634,7 @@ SDVersion ModelLoader::get_sd_version() {
15711634 if (tensor_storage.name .find (" model.diffusion_model.joint_blocks." ) != std::string::npos) {
15721635 return VERSION_SD3;
15731636 }
1574- if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos) {
1637+ if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos || tensor_storage. name . find ( " unet.down_blocks. " ) != std::string::npos ) {
15751638 is_unet = true ;
15761639 if (has_multiple_encoders) {
15771640 is_xl = true ;
@@ -1580,7 +1643,7 @@ SDVersion ModelLoader::get_sd_version() {
15801643 }
15811644 }
15821645 }
1583- if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
1646+ if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos || tensor_storage. name . find ( " te.1 " ) != std::string::npos ) {
15841647 has_multiple_encoders = true ;
15851648 if (is_unet) {
15861649 is_xl = true ;
@@ -1602,7 +1665,7 @@ SDVersion ModelLoader::get_sd_version() {
16021665 token_embedding_weight = tensor_storage;
16031666 // break;
16041667 }
1605- if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == " model.diffusion_model.img_in.weight" ) {
1668+ if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == " model.diffusion_model.img_in.weight" || tensor_storage. name == " unet.conv_in.weight " ) {
16061669 input_block_weight = tensor_storage;
16071670 input_block_checked = true ;
16081671 if (found_family) {
@@ -1687,7 +1750,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
16871750 continue ;
16881751 }
16891752
1690- if (tensor_storage.name .find (" model.diffusion_model." ) == std::string::npos) {
1753+ if (tensor_storage.name .find (" model.diffusion_model." ) == std::string::npos && tensor_storage. name . find ( " unet. " ) == std::string::npos ) {
16911754 continue ;
16921755 }
16931756
0 commit comments