Skip to content

Commit c587a43

Browse files
stduhpfleejet
andauthored
feat: support incrementing ref image index (omni-kontext) (leejet#755)
* kontext: support ref images indices * lora: support x_embedder * update help message * Support for negative indices * support for OmniControl (offsets at index 0) * c++11 compat * add --increase-ref-index option * simplify the logic and fix some issues * update README.md * remove unused variable --------- Co-authored-by: leejet <[email protected]>
1 parent f8fe4e7 commit c587a43

File tree

8 files changed

+48
-12
lines changed

8 files changed

+48
-12
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ arguments:
319319
-i, --end-img [IMAGE] path to the end image, required by flf2v
320320
--control-image [IMAGE] path to image condition, control net
321321
-r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times)
322+
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
322323
-o, --output OUTPUT path to write result image to (default: ./output.png)
323324
-p, --prompt [PROMPT] the prompt to render
324325
-n, --negative-prompt PROMPT the negative prompt (default: "")

diffusion_model.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct DiffusionModel {
1616
struct ggml_tensor* y,
1717
struct ggml_tensor* guidance,
1818
std::vector<ggml_tensor*> ref_latents = {},
19+
bool increase_ref_index = false,
1920
int num_video_frames = -1,
2021
std::vector<struct ggml_tensor*> controls = {},
2122
float control_strength = 0.f,
@@ -77,6 +78,7 @@ struct UNetModel : public DiffusionModel {
7778
struct ggml_tensor* y,
7879
struct ggml_tensor* guidance,
7980
std::vector<ggml_tensor*> ref_latents = {},
81+
bool increase_ref_index = false,
8082
int num_video_frames = -1,
8183
std::vector<struct ggml_tensor*> controls = {},
8284
float control_strength = 0.f,
@@ -133,6 +135,7 @@ struct MMDiTModel : public DiffusionModel {
133135
struct ggml_tensor* y,
134136
struct ggml_tensor* guidance,
135137
std::vector<ggml_tensor*> ref_latents = {},
138+
bool increase_ref_index = false,
136139
int num_video_frames = -1,
137140
std::vector<struct ggml_tensor*> controls = {},
138141
float control_strength = 0.f,
@@ -191,13 +194,14 @@ struct FluxModel : public DiffusionModel {
191194
struct ggml_tensor* y,
192195
struct ggml_tensor* guidance,
193196
std::vector<ggml_tensor*> ref_latents = {},
197+
bool increase_ref_index = false,
194198
int num_video_frames = -1,
195199
std::vector<struct ggml_tensor*> controls = {},
196200
float control_strength = 0.f,
197201
struct ggml_tensor** output = NULL,
198202
struct ggml_context* output_ctx = NULL,
199203
std::vector<int> skip_layers = std::vector<int>()) {
200-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers);
204+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, output, output_ctx, skip_layers);
201205
}
202206
};
203207

@@ -250,6 +254,7 @@ struct WanModel : public DiffusionModel {
250254
struct ggml_tensor* y,
251255
struct ggml_tensor* guidance,
252256
std::vector<ggml_tensor*> ref_latents = {},
257+
bool increase_ref_index = false,
253258
int num_video_frames = -1,
254259
std::vector<struct ggml_tensor*> controls = {},
255260
float control_strength = 0.f,

examples/cli/main.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct SDParams {
7474
std::string mask_image_path;
7575
std::string control_image_path;
7676
std::vector<std::string> ref_image_paths;
77+
bool increase_ref_index = false;
7778

7879
std::string prompt;
7980
std::string negative_prompt;
@@ -156,6 +157,7 @@ void print_params(SDParams params) {
156157
for (auto& path : params.ref_image_paths) {
157158
printf(" %s\n", path.c_str());
158159
};
160+
printf(" increase_ref_index: %s\n", params.increase_ref_index ? "true" : "false");
159161
printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false");
160162
printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false");
161163
printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false");
@@ -222,6 +224,7 @@ void print_usage(int argc, const char* argv[]) {
222224
printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n");
223225
printf(" --control-image [IMAGE] path to image condition, control net\n");
224226
printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
227+
printf(" --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).\n");
225228
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
226229
printf(" -p, --prompt [PROMPT] the prompt to render\n");
227230
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
@@ -536,6 +539,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
536539
{"", "--color", "", true, &params.color},
537540
{"", "--chroma-disable-dit-mask", "", false, &params.chroma_use_dit_mask},
538541
{"", "--chroma-enable-t5-mask", "", true, &params.chroma_use_t5_mask},
542+
{"", "--increase-ref-index", "", true, &params.increase_ref_index},
539543
};
540544

541545
auto on_mode_arg = [&](int argc, const char** argv, int index) {
@@ -1207,6 +1211,7 @@ int main(int argc, const char* argv[]) {
12071211
init_image,
12081212
ref_images.data(),
12091213
(int)ref_images.size(),
1214+
params.increase_ref_index,
12101215
mask_image,
12111216
params.width,
12121217
params.height,

flux.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ namespace Flux {
960960
struct ggml_tensor* y,
961961
struct ggml_tensor* guidance,
962962
std::vector<ggml_tensor*> ref_latents = {},
963+
bool increase_ref_index = false,
963964
std::vector<int> skip_layers = {}) {
964965
GGML_ASSERT(x->ne[3] == 1);
965966
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
@@ -999,6 +1000,7 @@ namespace Flux {
9991000
x->ne[3],
10001001
context->ne[1],
10011002
ref_latents,
1003+
increase_ref_index,
10021004
flux_params.theta,
10031005
flux_params.axes_dim);
10041006
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
@@ -1035,6 +1037,7 @@ namespace Flux {
10351037
struct ggml_tensor* y,
10361038
struct ggml_tensor* guidance,
10371039
std::vector<ggml_tensor*> ref_latents = {},
1040+
bool increase_ref_index = false,
10381041
struct ggml_tensor** output = NULL,
10391042
struct ggml_context* output_ctx = NULL,
10401043
std::vector<int> skip_layers = std::vector<int>()) {
@@ -1044,7 +1047,7 @@ namespace Flux {
10441047
// y: [N, adm_in_channels] or [1, adm_in_channels]
10451048
// guidance: [N, ]
10461049
auto get_graph = [&]() -> struct ggml_cgraph* {
1047-
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, skip_layers);
1050+
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
10481051
};
10491052

10501053
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@@ -1084,7 +1087,7 @@ namespace Flux {
10841087
struct ggml_tensor* out = NULL;
10851088

10861089
int t0 = ggml_time_ms();
1087-
compute(8, x, timesteps, context, NULL, y, guidance, {}, &out, work_ctx);
1090+
compute(8, x, timesteps, context, NULL, y, guidance, {}, false, &out, work_ctx);
10881091
int t1 = ggml_time_ms();
10891092

10901093
print_ggml_tensor(out);

lora.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ struct LoraModel : public GGMLRunner {
5858
{"x_block.attn.proj", "attn.to_out.0"},
5959
{"x_block.attn2.proj", "attn2.to_out.0"},
6060
// flux
61+
{"img_in", "x_embedder"},
6162
// singlestream
6263
{"linear2", "proj_out"},
6364
{"modulation.lin", "norm.linear"},

rope.hpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,25 +156,33 @@ struct Rope {
156156
int patch_size,
157157
int bs,
158158
int context_len,
159-
std::vector<ggml_tensor*> ref_latents) {
159+
std::vector<ggml_tensor*> ref_latents,
160+
bool increase_ref_index) {
160161
auto txt_ids = gen_txt_ids(bs, context_len);
161162
auto img_ids = gen_img_ids(h, w, patch_size, bs);
162163

163164
auto ids = concat_ids(txt_ids, img_ids, bs);
164165
uint64_t curr_h_offset = 0;
165166
uint64_t curr_w_offset = 0;
167+
int index = 1;
166168
for (ggml_tensor* ref : ref_latents) {
167169
uint64_t h_offset = 0;
168170
uint64_t w_offset = 0;
169-
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
170-
w_offset = curr_w_offset;
171-
} else {
172-
h_offset = curr_h_offset;
171+
if (!increase_ref_index) {
172+
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
173+
w_offset = curr_w_offset;
174+
} else {
175+
h_offset = curr_h_offset;
176+
}
173177
}
174178

175-
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
179+
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, index, h_offset, w_offset);
176180
ids = concat_ids(ids, ref_ids, bs);
177181

182+
if (increase_ref_index) {
183+
index++;
184+
}
185+
178186
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
179187
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
180188
}
@@ -188,9 +196,10 @@ struct Rope {
188196
int bs,
189197
int context_len,
190198
std::vector<ggml_tensor*> ref_latents,
199+
bool increase_ref_index,
191200
int theta,
192201
const std::vector<int>& axes_dim) {
193-
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents);
202+
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
194203
return embed_nd(ids, bs, theta, axes_dim);
195204
}
196205

stable-diffusion.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ class StableDiffusionGGML {
775775

776776
int64_t t0 = ggml_time_ms();
777777
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
778-
diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, -1, {}, 0.f, &out);
778+
diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, false, -1, {}, 0.f, &out);
779779
diffusion_model->free_compute_buffer();
780780

781781
double result = 0.f;
@@ -1032,6 +1032,7 @@ class StableDiffusionGGML {
10321032
int start_merge_step,
10331033
SDCondition id_cond,
10341034
std::vector<ggml_tensor*> ref_latents = {},
1035+
bool increase_ref_index = false,
10351036
ggml_tensor* denoise_mask = nullptr) {
10361037
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
10371038

@@ -1126,6 +1127,7 @@ class StableDiffusionGGML {
11261127
cond.c_vector,
11271128
guidance_tensor,
11281129
ref_latents,
1130+
increase_ref_index,
11291131
-1,
11301132
controls,
11311133
control_strength,
@@ -1139,6 +1141,7 @@ class StableDiffusionGGML {
11391141
id_cond.c_vector,
11401142
guidance_tensor,
11411143
ref_latents,
1144+
increase_ref_index,
11421145
-1,
11431146
controls,
11441147
control_strength,
@@ -1160,6 +1163,7 @@ class StableDiffusionGGML {
11601163
uncond.c_vector,
11611164
guidance_tensor,
11621165
ref_latents,
1166+
increase_ref_index,
11631167
-1,
11641168
controls,
11651169
control_strength,
@@ -1177,6 +1181,7 @@ class StableDiffusionGGML {
11771181
img_cond.c_vector,
11781182
guidance_tensor,
11791183
ref_latents,
1184+
increase_ref_index,
11801185
-1,
11811186
controls,
11821187
control_strength,
@@ -1198,6 +1203,7 @@ class StableDiffusionGGML {
11981203
cond.c_vector,
11991204
guidance_tensor,
12001205
ref_latents,
1206+
increase_ref_index,
12011207
-1,
12021208
controls,
12031209
control_strength,
@@ -1710,6 +1716,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
17101716
"\n"
17111717
"batch_count: %d\n"
17121718
"ref_images_count: %d\n"
1719+
"increase_ref_index: %s\n"
17131720
"control_strength: %.2f\n"
17141721
"style_strength: %.2f\n"
17151722
"normalize_input: %s\n"
@@ -1724,6 +1731,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
17241731
sd_img_gen_params->seed,
17251732
sd_img_gen_params->batch_count,
17261733
sd_img_gen_params->ref_images_count,
1734+
BOOL_STR(sd_img_gen_params->increase_ref_index),
17271735
sd_img_gen_params->control_strength,
17281736
sd_img_gen_params->style_strength,
17291737
BOOL_STR(sd_img_gen_params->normalize_input),
@@ -1797,6 +1805,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17971805
bool normalize_input,
17981806
std::string input_id_images_path,
17991807
std::vector<ggml_tensor*> ref_latents,
1808+
bool increase_ref_index,
18001809
ggml_tensor* concat_latent = NULL,
18011810
ggml_tensor* denoise_mask = NULL) {
18021811
if (seed < 0) {
@@ -2054,6 +2063,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
20542063
start_merge_step,
20552064
id_cond,
20562065
ref_latents,
2066+
increase_ref_index,
20572067
denoise_mask);
20582068
// print_ggml_tensor(x_0);
20592069
int64_t sampling_end = ggml_time_ms();
@@ -2304,7 +2314,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23042314
LOG_INFO("EDIT mode");
23052315
}
23062316

2307-
std::vector<struct ggml_tensor*> ref_latents;
2317+
std::vector<ggml_tensor*> ref_latents;
23082318
for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) {
23092319
ggml_tensor* img = ggml_new_tensor_4d(work_ctx,
23102320
GGML_TYPE_F32,
@@ -2359,6 +2369,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23592369
sd_img_gen_params->normalize_input,
23602370
sd_img_gen_params->input_id_images_path,
23612371
ref_latents,
2372+
sd_img_gen_params->increase_ref_index,
23622373
concat_latent,
23632374
denoise_mask);
23642375

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ typedef struct {
182182
sd_image_t init_image;
183183
sd_image_t* ref_images;
184184
int ref_images_count;
185+
bool increase_ref_index;
185186
sd_image_t mask_image;
186187
int width;
187188
int height;

0 commit comments

Comments
 (0)