Skip to content

Commit e31e85f

Browse files
committed
apg: add experimental threshold smoothing parameter
1 parent 204a51b commit e31e85f

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

examples/cli/main.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct SDParams {
133133
float apg_eta = 1.0f;
134134
float apg_momentum = 0.0f;
135135
float apg_norm_threshold = 0.0f;
136+
float apg_norm_smoothing = 0.0f;
136137
};
137138

138139
void print_params(SDParams params) {
@@ -220,6 +221,8 @@ void print_usage(int argc, const char* argv[]) {
220221
printf(" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n");
221222
printf(" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n");
222223
printf(" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n");
224+
printf(" --apg-nt-smoothing VALUE EXPERIMENTAL! Norm threshold smoothing for APG (default: 0 = disabled)\n");
225+
printf(" (replaces saturation with a smooth approximation)\n");
223226
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
224227
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
225228
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
@@ -654,6 +657,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
654657
break;
655658
}
656659
params.apg_norm_threshold = std::stof(argv[i]);
660+
} else if (arg == "--apg-nt-smoothing") {
661+
if (++i >= argc) {
662+
invalid_arg = true;
663+
break;
664+
}
665+
params.apg_norm_smoothing = std::stof(argv[i]);
657666
} else {
658667
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
659668
print_usage(argc, argv);
@@ -752,6 +761,9 @@ std::string get_image_params(SDParams params, int64_t seed) {
752761
}
753762
if (params.apg_norm_threshold != 0) {
754763
parameter_string += "CFG normalization threshold: " + std::to_string(params.apg_norm_threshold) + ", ";
764+
if (params.apg_norm_smoothing != 0) {
765+
parameter_string += "CFG normalization threshold: " + std::to_string(params.apg_norm_smoothing) + ", ";
766+
}
755767
}
756768
if (params.slg_scale != 0 && params.skip_layers.size() != 0) {
757769
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", ";
@@ -1004,7 +1016,8 @@ int main(int argc, const char* argv[]) {
10041016
params.skip_layer_end},
10051017
sd_apg_params_t{params.apg_eta,
10061018
params.apg_momentum,
1007-
params.apg_norm_threshold});
1019+
params.apg_norm_threshold,
1020+
params.apg_norm_smoothing});
10081021
} else {
10091022
sd_image_t input_image = {(uint32_t)params.width,
10101023
(uint32_t)params.height,
@@ -1076,7 +1089,8 @@ int main(int argc, const char* argv[]) {
10761089
params.skip_layer_end},
10771090
sd_apg_params_t{params.apg_eta,
10781091
params.apg_momentum,
1079-
params.apg_norm_threshold});
1092+
params.apg_norm_threshold,
1093+
params.apg_norm_smoothing});
10801094
}
10811095
}
10821096

stable-diffusion.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ class StableDiffusionGGML {
801801
int start_merge_step,
802802
SDCondition id_cond,
803803
sd_slg_params_t slg_params = {NULL, 0, 0, 0, 0},
804-
sd_apg_params_t apg_params = {1, 0, 0},
804+
sd_apg_params_t apg_params = {1, 0, 0, 0},
805805
ggml_tensor* noise_mask = nullptr) {
806806
std::vector<int> skip_layers(slg_params.skip_layers, slg_params.skip_layers + slg_params.skip_layers_count);
807807

@@ -980,8 +980,14 @@ class StableDiffusionGGML {
980980
deltas[i] = delta;
981981
}
982982
if (apg_params.norm_treshold > 0) {
983-
diff_norm = sqrtf(diff_norm);
984-
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
983+
diff_norm = sqrtf(diff_norm);
984+
if (apg_params.norm_treshold_smoothing <= 0) {
985+
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
986+
} else {
987+
// Experimental: smooth saturate
988+
float x = apg_params.norm_treshold / diff_norm;
989+
apg_scale_factor = x / std::pow(1 + std::pow(x, 1.0 / apg_params.norm_treshold_smoothing), apg_params.norm_treshold_smoothing);
990+
}
985991
}
986992
if (apg_params.eta != 1.0f) {
987993
dot *= apg_scale_factor;

stable-diffusion.h

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ typedef struct {
131131
float eta;
132132
float momentum;
133133
float norm_treshold;
134+
float norm_treshold_smoothing;
134135
} sd_apg_params_t;
135136

136137
typedef struct {

0 commit comments

Comments
 (0)