@@ -66,11 +66,10 @@ const char* rng_type_str[] = {
6666static_assert (std::size(rng_type_str) == RNG_TYPE_COUNT, " rng type mismatch" );
6767
6868const char * prediction_str[] = {
69- " default" ,
7069 " epsilon" ,
7170 " v" ,
7271 " edm_v" ,
73- " sd3_flow " ,
72+ " flow " ,
7473 " flux_flow" ,
7574 " flux2_flow" ,
7675};
@@ -129,6 +128,64 @@ sd_ctx_t* sd_c;
129128scheduler_t scheduler = SCHEDULER_COUNT;
130129sample_method_t sample_method = SAMPLE_METHOD_COUNT;
131130
131+ // Storage for embeddings (needs to persist for the lifetime of ctx_params)
132+ static std::vector<sd_embedding_t > embedding_vec;
133+ // Storage for embedding strings (needs to persist as long as embedding_vec references them)
134+ static std::vector<std::string> embedding_strings;
135+
136+ // Build embeddings vector from directory, similar to upstream CLI
137+ static void build_embedding_vec (const char * embedding_dir) {
138+ embedding_vec.clear ();
139+ embedding_strings.clear ();
140+
141+ if (!embedding_dir || strlen (embedding_dir) == 0 ) {
142+ return ;
143+ }
144+
145+ if (!std::filesystem::exists (embedding_dir) || !std::filesystem::is_directory (embedding_dir)) {
146+ fprintf (stderr, " Embedding directory does not exist or is not a directory: %s\n " , embedding_dir);
147+ return ;
148+ }
149+
150+ static const std::vector<std::string> valid_ext = {" .pt" , " .safetensors" , " .gguf" };
151+
152+ for (const auto & entry : std::filesystem::directory_iterator (embedding_dir)) {
153+ if (!entry.is_regular_file ()) {
154+ continue ;
155+ }
156+
157+ auto path = entry.path ();
158+ std::string ext = path.extension ().string ();
159+
160+ bool valid = false ;
161+ for (const auto & e : valid_ext) {
162+ if (ext == e) {
163+ valid = true ;
164+ break ;
165+ }
166+ }
167+ if (!valid) {
168+ continue ;
169+ }
170+
171+ std::string name = path.stem ().string ();
172+ std::string full_path = path.string ();
173+
174+ // Store strings in persistent storage
175+ embedding_strings.push_back (name);
176+ embedding_strings.push_back (full_path);
177+
178+ sd_embedding_t item;
179+ item.name = embedding_strings[embedding_strings.size () - 2 ].c_str ();
180+ item.path = embedding_strings[embedding_strings.size () - 1 ].c_str ();
181+
182+ embedding_vec.push_back (item);
183+ fprintf (stderr, " Found embedding: %s -> %s\n " , item.name , item.path );
184+ }
185+
186+ fprintf (stderr, " Loaded %zu embeddings from %s\n " , embedding_vec.size (), embedding_dir);
187+ }
188+
132189// Copied from the upstream CLI
133190static void sd_log_cb (enum sd_log_level_t level, const char * log, void * data) {
134191 // SDParams* params = (SDParams*)data;
@@ -196,7 +253,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
196253 enum sd_type_t wtype = SD_TYPE_COUNT;
197254 enum rng_type_t rng_type = CUDA_RNG;
198255 enum rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
199- enum prediction_t prediction = DEFAULT_PRED ;
256+ enum prediction_t prediction = PREDICTION_COUNT ;
200257 enum lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
201258 bool offload_params_to_cpu = false ;
202259 bool keep_clip_on_cpu = false ;
@@ -262,7 +319,19 @@ int load_model(const char *model, char *model_path, char* options[], int threads
262319 if (!strcmp (optname, " high_noise_diffusion_model_path" )) high_noise_diffusion_model_path = strdup (optval);
263320 if (!strcmp (optname, " taesd_path" )) taesd_path = strdup (optval);
264321 if (!strcmp (optname, " control_net_path" )) control_net_path = strdup (optval);
265- if (!strcmp (optname, " embedding_dir" )) embedding_dir = strdup (optval);
322+ if (!strcmp (optname, " embedding_dir" )) {
323+ // Path join with model dir
324+ if (model_path && strlen (model_path) > 0 ) {
325+ std::filesystem::path model_path_str (model_path);
326+ std::filesystem::path embedding_path (optval);
327+ std::filesystem::path full_embedding_path = model_path_str / embedding_path;
328+ embedding_dir = strdup (full_embedding_path.string ().c_str ());
329+ fprintf (stderr, " Embedding dir resolved to: %s\n " , embedding_dir);
330+ } else {
331+ embedding_dir = strdup (optval);
332+ fprintf (stderr, " No model path provided, using embedding dir as-is: %s\n " , embedding_dir);
333+ }
334+ }
266335 if (!strcmp (optname, " photo_maker_path" )) photo_maker_path = strdup (optval);
267336 if (!strcmp (optname, " tensor_type_rules" )) tensor_type_rules = strdup (optval);
268337
@@ -363,6 +432,9 @@ int load_model(const char *model, char *model_path, char* options[], int threads
363432
364433 fprintf (stderr, " parsed options\n " );
365434
435+ // Build embeddings vector from directory if provided
436+ build_embedding_vec (embedding_dir);
437+
366438 fprintf (stderr, " Creating context\n " );
367439 sd_ctx_params_init (&ctx_params);
368440 ctx_params.model_path = model;
@@ -378,7 +450,9 @@ int load_model(const char *model, char *model_path, char* options[], int threads
378450 ctx_params.taesd_path = taesd_path;
379451 ctx_params.control_net_path = control_net_path;
380452 ctx_params.lora_model_dir = lora_dir;
381- ctx_params.embedding_dir = embedding_dir;
453+ // Set embeddings array and count
454+ ctx_params.embeddings = embedding_vec.empty () ? NULL : embedding_vec.data ();
455+ ctx_params.embedding_count = static_cast <uint32_t >(embedding_vec.size ());
382456 ctx_params.photo_maker_path = photo_maker_path;
383457 ctx_params.tensor_type_rules = tensor_type_rules;
384458 ctx_params.vae_decode_only = vae_decode_only;
0 commit comments