Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ protoc:
echo "Unsupported OS: $$OS_NAME"; exit 1; \
fi; \
URL=https://github.com/protocolbuffers/protobuf/releases/download/v31.1/$$FILE; \
curl -L -s $$URL -o protoc.zip && \
curl -L $$URL -o protoc.zip && \
unzip -j -d $(CURDIR) protoc.zip bin/protoc && rm protoc.zip

.PHONY: protogen-go
Expand Down
2 changes: 1 addition & 1 deletion backend/go/stablediffusion-ggml/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)

# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=5865b5e7034801af1a288a9584631730b25272c6
STABLEDIFFUSION_GGML_VERSION?=8823dc48bcc1598eb9671da7b69e45338d0cc5a5

CMAKE_ARGS+=-DGGML_MAX_NAME=128

Expand Down
84 changes: 79 additions & 5 deletions backend/go/stablediffusion-ggml/gosd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ const char* rng_type_str[] = {
static_assert(std::size(rng_type_str) == RNG_TYPE_COUNT, "rng type mismatch");

const char* prediction_str[] = {
"default",
"epsilon",
"v",
"edm_v",
"sd3_flow",
"flow",
"flux_flow",
"flux2_flow",
};
Expand Down Expand Up @@ -129,6 +128,64 @@ sd_ctx_t* sd_c;
scheduler_t scheduler = SCHEDULER_COUNT;
sample_method_t sample_method = SAMPLE_METHOD_COUNT;

// Storage for embeddings (needs to persist for the lifetime of ctx_params)
static std::vector<sd_embedding_t> embedding_vec;
// Storage for embedding strings (needs to persist as long as embedding_vec references them)
static std::vector<std::string> embedding_strings;

// Build embeddings vector from directory, similar to upstream CLI
static void build_embedding_vec(const char* embedding_dir) {
embedding_vec.clear();
embedding_strings.clear();

if (!embedding_dir || strlen(embedding_dir) == 0) {
return;
}

if (!std::filesystem::exists(embedding_dir) || !std::filesystem::is_directory(embedding_dir)) {
fprintf(stderr, "Embedding directory does not exist or is not a directory: %s\n", embedding_dir);
return;
}

static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};

for (const auto& entry : std::filesystem::directory_iterator(embedding_dir)) {
if (!entry.is_regular_file()) {
continue;
}

auto path = entry.path();
std::string ext = path.extension().string();

bool valid = false;
for (const auto& e : valid_ext) {
if (ext == e) {
valid = true;
break;
}
}
if (!valid) {
continue;
}

std::string name = path.stem().string();
std::string full_path = path.string();

// Store strings in persistent storage
embedding_strings.push_back(name);
embedding_strings.push_back(full_path);

sd_embedding_t item;
item.name = embedding_strings[embedding_strings.size() - 2].c_str();
item.path = embedding_strings[embedding_strings.size() - 1].c_str();

embedding_vec.push_back(item);
fprintf(stderr, "Found embedding: %s -> %s\n", item.name, item.path);
}

fprintf(stderr, "Loaded %zu embeddings from %s\n", embedding_vec.size(), embedding_dir);
}

// Copied from the upstream CLI
static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
//SDParams* params = (SDParams*)data;
Expand Down Expand Up @@ -196,7 +253,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
enum sd_type_t wtype = SD_TYPE_COUNT;
enum rng_type_t rng_type = CUDA_RNG;
enum rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
enum prediction_t prediction = DEFAULT_PRED;
enum prediction_t prediction = PREDICTION_COUNT;
enum lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
bool offload_params_to_cpu = false;
bool keep_clip_on_cpu = false;
Expand Down Expand Up @@ -262,7 +319,19 @@ int load_model(const char *model, char *model_path, char* options[], int threads
if (!strcmp(optname, "high_noise_diffusion_model_path")) high_noise_diffusion_model_path = strdup(optval);
if (!strcmp(optname, "taesd_path")) taesd_path = strdup(optval);
if (!strcmp(optname, "control_net_path")) control_net_path = strdup(optval);
if (!strcmp(optname, "embedding_dir")) embedding_dir = strdup(optval);
if (!strcmp(optname, "embedding_dir")) {
// Path join with model dir
if (model_path && strlen(model_path) > 0) {
std::filesystem::path model_path_str(model_path);
std::filesystem::path embedding_path(optval);
std::filesystem::path full_embedding_path = model_path_str / embedding_path;
embedding_dir = strdup(full_embedding_path.string().c_str());
fprintf(stderr, "Embedding dir resolved to: %s\n", embedding_dir);
} else {
embedding_dir = strdup(optval);
fprintf(stderr, "No model path provided, using embedding dir as-is: %s\n", embedding_dir);
}
}
if (!strcmp(optname, "photo_maker_path")) photo_maker_path = strdup(optval);
if (!strcmp(optname, "tensor_type_rules")) tensor_type_rules = strdup(optval);

Expand Down Expand Up @@ -363,6 +432,9 @@ int load_model(const char *model, char *model_path, char* options[], int threads

fprintf(stderr, "parsed options\n");

// Build embeddings vector from directory if provided
build_embedding_vec(embedding_dir);

fprintf (stderr, "Creating context\n");
sd_ctx_params_init(&ctx_params);
ctx_params.model_path = model;
Expand All @@ -378,7 +450,9 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.taesd_path = taesd_path;
ctx_params.control_net_path = control_net_path;
ctx_params.lora_model_dir = lora_dir;
ctx_params.embedding_dir = embedding_dir;
// Set embeddings array and count
ctx_params.embeddings = embedding_vec.empty() ? NULL : embedding_vec.data();
ctx_params.embedding_count = static_cast<uint32_t>(embedding_vec.size());
ctx_params.photo_maker_path = photo_maker_path;
ctx_params.tensor_type_rules = tensor_type_rules;
ctx_params.vae_decode_only = vae_decode_only;
Expand Down
Loading