@@ -1107,31 +1107,52 @@ struct FluxCLIPEmbedder : public Conditioner {
11071107 std::shared_ptr<T5Runner> t5;
11081108 size_t chunk_len = 256 ;
11091109
1110+ bool use_clip_l = false ;
1111+ bool use_t5 = false ;
1112+
11101113 FluxCLIPEmbedder (ggml_backend_t backend,
11111114 bool offload_params_to_cpu,
11121115 const String2GGMLType& tensor_types = {}) {
1113- clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, true );
1114- t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.t5xxl.transformer" );
1116+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, true );
1117+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
11151118 }
11161119
1120+
11171121 void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
1118- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1119- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1122+ if (use_clip_l) {
1123+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1124+ }
1125+ if (use_t5) {
1126+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1127+ }
11201128 }
11211129
11221130 void alloc_params_buffer () {
1123- clip_l->alloc_params_buffer ();
1124- t5->alloc_params_buffer ();
1131+ if (use_clip_l) {
1132+ clip_l->alloc_params_buffer ();
1133+ }
1134+ if (use_t5) {
1135+ t5->alloc_params_buffer ();
1136+ }
11251137 }
11261138
11271139 void free_params_buffer () {
1128- clip_l->free_params_buffer ();
1129- t5->free_params_buffer ();
1140+ if (use_clip_l) {
1141+ clip_l->free_params_buffer ();
1142+ }
1143+ if (use_t5) {
1144+ t5->free_params_buffer ();
1145+ }
11301146 }
11311147
11321148 size_t get_params_buffer_size () {
1133- size_t buffer_size = clip_l->get_params_buffer_size ();
1134- buffer_size += t5->get_params_buffer_size ();
1149+ size_t buffer_size = 0 ;
1150+ if (use_clip_l) {
1151+ buffer_size += clip_l->get_params_buffer_size ();
1152+ }
1153+ if (use_t5) {
1154+ buffer_size += t5->get_params_buffer_size ();
1155+ }
11351156 return buffer_size;
11361157 }
11371158
@@ -1161,18 +1182,23 @@ struct FluxCLIPEmbedder : public Conditioner {
11611182 for (const auto & item : parsed_attention) {
11621183 const std::string& curr_text = item.first ;
11631184 float curr_weight = item.second ;
1164-
1165- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1166- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1167- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1168-
1169- curr_tokens = t5_tokenizer.Encode (curr_text, true );
1170- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1171- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1185+ if (use_clip_l) {
1186+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1187+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1188+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1189+ }
1190+ if (use_t5) {
1191+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
1192+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1193+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1194+ }
1195+ }
1196+ if (use_clip_l) {
1197+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1198+ }
1199+ if (use_t5) {
1200+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
11721201 }
1173-
1174- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1175- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
11761202
11771203 // for (int i = 0; i < clip_l_tokens.size(); i++) {
11781204 // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1207,35 +1233,37 @@ struct FluxCLIPEmbedder : public Conditioner {
12071233 struct ggml_tensor * pooled = NULL ; // [768,]
12081234 std::vector<float > hidden_states_vec;
12091235
1210- size_t chunk_count = t5_tokens.size () / chunk_len;
1236+ size_t chunk_count = std::max (clip_l_tokens. size () > 0 ? chunk_len : 0 , t5_tokens.size () ) / chunk_len;
12111237 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
12121238 // clip_l
12131239 if (chunk_idx == 0 ) {
1214- size_t chunk_len_l = 77 ;
1215- std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1216- clip_l_tokens.begin () + chunk_len_l);
1217- std::vector<float > chunk_weights (clip_l_weights.begin (),
1218- clip_l_weights.begin () + chunk_len_l);
1240+ if (use_clip_l) {
1241+ size_t chunk_len_l = 77 ;
1242+ std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1243+ clip_l_tokens.begin () + chunk_len_l);
1244+ std::vector<float > chunk_weights (clip_l_weights.begin (),
1245+ clip_l_weights.begin () + chunk_len_l);
12191246
1220- auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1221- size_t max_token_idx = 0 ;
1247+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1248+ size_t max_token_idx = 0 ;
12221249
1223- auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1224- max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
1250+ auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1251+ max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
12251252
1226- clip_l->compute (n_threads,
1227- input_ids,
1228- 0 ,
1229- NULL ,
1230- max_token_idx,
1231- true ,
1232- clip_skip,
1233- &pooled,
1234- work_ctx);
1253+ clip_l->compute (n_threads,
1254+ input_ids,
1255+ 0 ,
1256+ NULL ,
1257+ max_token_idx,
1258+ true ,
1259+ clip_skip,
1260+ &pooled,
1261+ work_ctx);
1262+ }
12351263 }
12361264
12371265 // t5
1238- {
1266+ if (use_t5) {
12391267 std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
12401268 t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
12411269 std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -1263,8 +1291,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12631291 float new_mean = ggml_tensor_mean (tensor);
12641292 ggml_tensor_scale (tensor, (original_mean / new_mean));
12651293 }
1294+ } else {
1295+ chunk_hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , chunk_len);
1296+ ggml_set_f32 (chunk_hidden_states, 0 .f );
12661297 }
12671298
1299+
12681300 int64_t t1 = ggml_time_ms ();
12691301 LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
12701302 if (zero_out_masked) {
@@ -1273,17 +1305,26 @@ struct FluxCLIPEmbedder : public Conditioner {
12731305 vec[i] = 0 ;
12741306 }
12751307 }
1276-
1308+
12771309 hidden_states_vec.insert (hidden_states_vec.end (),
1278- (float *)chunk_hidden_states->data ,
1279- ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1310+ (float *)chunk_hidden_states->data ,
1311+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1312+ }
1313+
1314+ if (hidden_states_vec.size () > 0 ) {
1315+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1316+ hidden_states = ggml_reshape_2d (work_ctx,
1317+ hidden_states,
1318+ chunk_hidden_states->ne [0 ],
1319+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1320+ } else {
1321+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 256 );
1322+ ggml_set_f32 (hidden_states, 0 .f );
1323+ }
1324+ if (pooled == NULL ) {
1325+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
1326+ ggml_set_f32 (pooled, 0 .f );
12801327 }
1281-
1282- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1283- hidden_states = ggml_reshape_2d (work_ctx,
1284- hidden_states,
1285- chunk_hidden_states->ne [0 ],
1286- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
12871328 return SDCondition (hidden_states, pooled, NULL );
12881329 }
12891330
0 commit comments