Skip to content

Commit c61cfec

Browse files
committed
conditionner: make text encoders optional for Flux
1 parent e0d0edb commit c61cfec

File tree

2 files changed

+92
-50
lines changed

2 files changed

+92
-50
lines changed

conditioner.hpp

Lines changed: 91 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,31 +1107,52 @@ struct FluxCLIPEmbedder : public Conditioner {
11071107
std::shared_ptr<T5Runner> t5;
11081108
size_t chunk_len = 256;
11091109

1110+
bool use_clip_l = false;
1111+
bool use_t5 = false;
1112+
11101113
FluxCLIPEmbedder(ggml_backend_t backend,
11111114
bool offload_params_to_cpu,
11121115
const String2GGMLType& tensor_types = {}) {
1113-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1114-
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
1116+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
1117+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
11151118
}
11161119

1120+
11171121
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1118-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1119-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1122+
if (use_clip_l) {
1123+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1124+
}
1125+
if (use_t5) {
1126+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1127+
}
11201128
}
11211129

11221130
void alloc_params_buffer() {
1123-
clip_l->alloc_params_buffer();
1124-
t5->alloc_params_buffer();
1131+
if (use_clip_l) {
1132+
clip_l->alloc_params_buffer();
1133+
}
1134+
if (use_t5) {
1135+
t5->alloc_params_buffer();
1136+
}
11251137
}
11261138

11271139
void free_params_buffer() {
1128-
clip_l->free_params_buffer();
1129-
t5->free_params_buffer();
1140+
if (use_clip_l) {
1141+
clip_l->free_params_buffer();
1142+
}
1143+
if (use_t5) {
1144+
t5->free_params_buffer();
1145+
}
11301146
}
11311147

11321148
size_t get_params_buffer_size() {
1133-
size_t buffer_size = clip_l->get_params_buffer_size();
1134-
buffer_size += t5->get_params_buffer_size();
1149+
size_t buffer_size = 0;
1150+
if (use_clip_l) {
1151+
buffer_size += clip_l->get_params_buffer_size();
1152+
}
1153+
if (use_t5) {
1154+
buffer_size += t5->get_params_buffer_size();
1155+
}
11351156
return buffer_size;
11361157
}
11371158

@@ -1161,18 +1182,23 @@ struct FluxCLIPEmbedder : public Conditioner {
11611182
for (const auto& item : parsed_attention) {
11621183
const std::string& curr_text = item.first;
11631184
float curr_weight = item.second;
1164-
1165-
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1166-
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1167-
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1168-
1169-
curr_tokens = t5_tokenizer.Encode(curr_text, true);
1170-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1171-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1185+
if (use_clip_l) {
1186+
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1187+
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1188+
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1189+
}
1190+
if (use_t5) {
1191+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1192+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1193+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1194+
}
1195+
}
1196+
if (use_clip_l) {
1197+
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1198+
}
1199+
if (use_t5) {
1200+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
11721201
}
1173-
1174-
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1175-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
11761202

11771203
// for (int i = 0; i < clip_l_tokens.size(); i++) {
11781204
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1207,35 +1233,37 @@ struct FluxCLIPEmbedder : public Conditioner {
12071233
struct ggml_tensor* pooled = NULL; // [768,]
12081234
std::vector<float> hidden_states_vec;
12091235

1210-
size_t chunk_count = t5_tokens.size() / chunk_len;
1236+
size_t chunk_count = std::max(clip_l_tokens.size() > 0 ? chunk_len : 0, t5_tokens.size()) / chunk_len;
12111237
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
12121238
// clip_l
12131239
if (chunk_idx == 0) {
1214-
size_t chunk_len_l = 77;
1215-
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1216-
clip_l_tokens.begin() + chunk_len_l);
1217-
std::vector<float> chunk_weights(clip_l_weights.begin(),
1218-
clip_l_weights.begin() + chunk_len_l);
1240+
if (use_clip_l) {
1241+
size_t chunk_len_l = 77;
1242+
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1243+
clip_l_tokens.begin() + chunk_len_l);
1244+
std::vector<float> chunk_weights(clip_l_weights.begin(),
1245+
clip_l_weights.begin() + chunk_len_l);
12191246

1220-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1221-
size_t max_token_idx = 0;
1247+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1248+
size_t max_token_idx = 0;
12221249

1223-
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1224-
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1250+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1251+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
12251252

1226-
clip_l->compute(n_threads,
1227-
input_ids,
1228-
0,
1229-
NULL,
1230-
max_token_idx,
1231-
true,
1232-
clip_skip,
1233-
&pooled,
1234-
work_ctx);
1253+
clip_l->compute(n_threads,
1254+
input_ids,
1255+
0,
1256+
NULL,
1257+
max_token_idx,
1258+
true,
1259+
clip_skip,
1260+
&pooled,
1261+
work_ctx);
1262+
}
12351263
}
12361264

12371265
// t5
1238-
{
1266+
if (use_t5) {
12391267
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
12401268
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
12411269
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
@@ -1263,8 +1291,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12631291
float new_mean = ggml_tensor_mean(tensor);
12641292
ggml_tensor_scale(tensor, (original_mean / new_mean));
12651293
}
1294+
} else {
1295+
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1296+
ggml_set_f32(chunk_hidden_states, 0.f);
12661297
}
12671298

1299+
12681300
int64_t t1 = ggml_time_ms();
12691301
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
12701302
if (zero_out_masked) {
@@ -1273,17 +1305,26 @@ struct FluxCLIPEmbedder : public Conditioner {
12731305
vec[i] = 0;
12741306
}
12751307
}
1276-
1308+
12771309
hidden_states_vec.insert(hidden_states_vec.end(),
1278-
(float*)chunk_hidden_states->data,
1279-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1310+
(float*)chunk_hidden_states->data,
1311+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1312+
}
1313+
1314+
if (hidden_states_vec.size() > 0) {
1315+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1316+
hidden_states = ggml_reshape_2d(work_ctx,
1317+
hidden_states,
1318+
chunk_hidden_states->ne[0],
1319+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1320+
} else {
1321+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1322+
ggml_set_f32(hidden_states, 0.f);
1323+
}
1324+
if (pooled == NULL) {
1325+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1326+
ggml_set_f32(pooled, 0.f);
12801327
}
1281-
1282-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1283-
hidden_states = ggml_reshape_2d(work_ctx,
1284-
hidden_states,
1285-
chunk_hidden_states->ne[0],
1286-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
12871328
return SDCondition(hidden_states, pooled, NULL);
12881329
}
12891330

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ class StableDiffusionGGML {
326326
clip_backend = backend;
327327
bool use_t5xxl = false;
328328
if (sd_version_is_dit(version)) {
329+
// TODO: check if t5 is actually loaded?
329330
use_t5xxl = true;
330331
}
331332
if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {

0 commit comments

Comments
 (0)