Skip to content

Commit 46623a7

Browse files
committed
conditionner: make t5 optional for chroma
1 parent 51fcd62 commit 46623a7

File tree

1 file changed

+76
-50
lines changed

1 file changed

+76
-50
lines changed

conditioner.hpp

Lines changed: 76 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,6 @@ struct FluxCLIPEmbedder : public Conditioner {
11091109
FluxCLIPEmbedder(ggml_backend_t backend,
11101110
std::map<std::string, enum ggml_type>& tensor_types,
11111111
int clip_skip = -1) {
1112-
11131112
for (auto pair : tensor_types) {
11141113
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
11151114
use_clip_l = true;
@@ -1319,7 +1318,6 @@ struct FluxCLIPEmbedder : public Conditioner {
13191318
ggml_set_f32(chunk_hidden_states, 0.f);
13201319
}
13211320

1322-
13231321
int64_t t1 = ggml_time_ms();
13241322
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
13251323
if (force_zero_embeddings) {
@@ -1328,12 +1326,12 @@ struct FluxCLIPEmbedder : public Conditioner {
13281326
vec[i] = 0;
13291327
}
13301328
}
1331-
1329+
13321330
hidden_states_vec.insert(hidden_states_vec.end(),
1333-
(float*)chunk_hidden_states->data,
1334-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1331+
(float*)chunk_hidden_states->data,
1332+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
13351333
}
1336-
1334+
13371335
if (hidden_states_vec.size() > 0) {
13381336
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
13391337
hidden_states = ggml_reshape_2d(work_ctx,
@@ -1388,35 +1386,54 @@ struct PixArtCLIPEmbedder : public Conditioner {
13881386
bool use_mask = false;
13891387
int mask_pad = 1;
13901388

1389+
bool use_t5 = false;
1390+
13911391
PixArtCLIPEmbedder(ggml_backend_t backend,
13921392
std::map<std::string, enum ggml_type>& tensor_types,
13931393
int clip_skip = -1,
13941394
bool use_mask = false,
13951395
int mask_pad = 1)
13961396
: use_mask(use_mask), mask_pad(mask_pad) {
1397-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1397+
for (auto pair : tensor_types) {
1398+
if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
1399+
use_t5 = true;
1400+
}
1401+
}
1402+
1403+
if (!use_t5) {
1404+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
1405+
return;
1406+
} else {
1407+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1408+
}
13981409
}
13991410

14001411
void set_clip_skip(int clip_skip) {
14011412
}
14021413

14031414
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1404-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1415+
if (use_t5) {
1416+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1417+
}
14051418
}
14061419

14071420
void alloc_params_buffer() {
1408-
t5->alloc_params_buffer();
1421+
if (use_t5) {
1422+
t5->alloc_params_buffer();
1423+
}
14091424
}
14101425

14111426
void free_params_buffer() {
1412-
t5->free_params_buffer();
1427+
if (use_t5) {
1428+
t5->free_params_buffer();
1429+
}
14131430
}
14141431

14151432
size_t get_params_buffer_size() {
14161433
size_t buffer_size = 0;
1417-
1418-
buffer_size += t5->get_params_buffer_size();
1419-
1434+
if (use_t5) {
1435+
buffer_size += t5->get_params_buffer_size();
1436+
}
14201437
return buffer_size;
14211438
}
14221439

@@ -1442,17 +1459,18 @@ struct PixArtCLIPEmbedder : public Conditioner {
14421459
std::vector<int> t5_tokens;
14431460
std::vector<float> t5_weights;
14441461
std::vector<float> t5_mask;
1445-
for (const auto& item : parsed_attention) {
1446-
const std::string& curr_text = item.first;
1447-
float curr_weight = item.second;
1448-
1449-
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1450-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1451-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1452-
}
1462+
if (use_t5) {
1463+
for (const auto& item : parsed_attention) {
1464+
const std::string& curr_text = item.first;
1465+
float curr_weight = item.second;
14531466

1454-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1467+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1468+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1469+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1470+
}
14551471

1472+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1473+
}
14561474
return {t5_tokens, t5_weights, t5_mask};
14571475
}
14581476

@@ -1489,38 +1507,44 @@ struct PixArtCLIPEmbedder : public Conditioner {
14891507
std::vector<float> hidden_states_vec;
14901508

14911509
size_t chunk_count = t5_tokens.size() / chunk_len;
1492-
14931510
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
14941511
// t5
1495-
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1496-
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1497-
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1498-
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1499-
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1500-
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1501-
1502-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1503-
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
1504-
1505-
t5->compute(n_threads,
1506-
input_ids,
1507-
t5_attn_mask_chunk,
1508-
&chunk_hidden_states,
1509-
work_ctx);
1510-
{
1511-
auto tensor = chunk_hidden_states;
1512-
float original_mean = ggml_tensor_mean(tensor);
1513-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1514-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1515-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1516-
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1517-
value *= chunk_weights[i1];
1518-
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1512+
1513+
if (use_t5) {
1514+
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1515+
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1516+
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1517+
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1518+
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1519+
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1520+
1521+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1522+
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
1523+
t5->compute(n_threads,
1524+
input_ids,
1525+
t5_attn_mask_chunk,
1526+
&chunk_hidden_states,
1527+
work_ctx);
1528+
{
1529+
auto tensor = chunk_hidden_states;
1530+
float original_mean = ggml_tensor_mean(tensor);
1531+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1532+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1533+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1534+
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1535+
value *= chunk_weights[i1];
1536+
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1537+
}
15191538
}
15201539
}
1540+
float new_mean = ggml_tensor_mean(tensor);
1541+
ggml_tensor_scale(tensor, (original_mean / new_mean));
15211542
}
1522-
float new_mean = ggml_tensor_mean(tensor);
1523-
ggml_tensor_scale(tensor, (original_mean / new_mean));
1543+
} else {
1544+
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1545+
ggml_set_f32(chunk_hidden_states, 0.f);
1546+
t5_attn_mask = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, chunk_len);
1547+
ggml_set_f32(t5_attn_mask, -HUGE_VALF);
15241548
}
15251549

15261550
int64_t t1 = ggml_time_ms();
@@ -1544,8 +1568,10 @@ struct PixArtCLIPEmbedder : public Conditioner {
15441568
chunk_hidden_states->ne[0],
15451569
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
15461570
} else {
1547-
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1571+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
15481572
ggml_set_f32(hidden_states, 0.f);
1573+
t5_attn_mask = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, chunk_len);
1574+
ggml_set_f32(t5_attn_mask, -HUGE_VALF);
15491575
}
15501576

15511577
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);

0 commit comments

Comments
 (0)