Skip to content

Commit 83078c2

Browse files
committed
Fix cfg 1 crash
1 parent 34e7b93 commit 83078c2

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

stable-diffusion.cpp

+29-27
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ class StableDiffusionGGML {
802802
SDCondition id_cond,
803803
sd_slg_params_t slg_params = {NULL, 0, 0, 0, 0},
804804
sd_apg_params_t apg_params = {1, 0, 0},
805-
ggml_tensor* noise_mask = nullptr) {
805+
ggml_tensor* noise_mask = nullptr) {
806806
std::vector<int> skip_layers(slg_params.skip_layers, slg_params.skip_layers + slg_params.skip_layers_count);
807807

808808
LOG_DEBUG("Sample");
@@ -963,39 +963,41 @@ class StableDiffusionGGML {
963963
float diff_norm = 0;
964964
float cond_norm_sq = 0;
965965
float dot = 0;
966-
for (int i = 0; i < ne_elements; i++) {
967-
float delta = positive_data[i] - negative_data[i];
968-
if (apg_params.momentum != 0) {
969-
delta += apg_params.momentum * apg_momentum_buffer[i];
970-
apg_momentum_buffer[i] = delta;
966+
if (has_unconditioned) {
967+
for (int i = 0; i < ne_elements; i++) {
968+
float delta = positive_data[i] - negative_data[i];
969+
if (apg_params.momentum != 0) {
970+
delta += apg_params.momentum * apg_momentum_buffer[i];
971+
apg_momentum_buffer[i] = delta;
972+
}
973+
if (apg_params.norm_treshold > 0) {
974+
diff_norm += delta * delta;
975+
}
976+
if (apg_params.eta != 1.0f) {
977+
cond_norm_sq += positive_data[i] * positive_data[i];
978+
dot += positive_data[i] * delta;
979+
}
980+
deltas[i] = delta;
971981
}
972982
if (apg_params.norm_treshold > 0) {
973-
diff_norm += delta * delta;
983+
diff_norm = std::sqrtf(diff_norm);
984+
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
974985
}
975986
if (apg_params.eta != 1.0f) {
976-
cond_norm_sq += positive_data[i] * positive_data[i];
977-
dot += positive_data[i] * delta;
987+
dot *= apg_scale_factor;
988+
// pre-normalize (avoids one square root and ne_elements extra divs)
989+
dot /= cond_norm_sq;
978990
}
979-
deltas[i] = delta;
980-
}
981-
if (apg_params.norm_treshold > 0) {
982-
diff_norm = std::sqrtf(diff_norm);
983-
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
984-
}
985-
if (apg_params.eta != 1.0f) {
986-
dot *= apg_scale_factor;
987-
// pre-normalize (avoids one square root and ne_elements extra divs)
988-
dot /= cond_norm_sq;
989-
}
990991

991-
for (int i = 0; i < ne_elements; i++) {
992-
deltas[i] *= apg_scale_factor;
993-
if (apg_params.eta != 1.0f) {
994-
float apg_parallel = dot * positive_data[i];
995-
float apg_orthogonal = deltas[i] - apg_parallel;
992+
for (int i = 0; i < ne_elements; i++) {
993+
deltas[i] *= apg_scale_factor;
994+
if (apg_params.eta != 1.0f) {
995+
float apg_parallel = dot * positive_data[i];
996+
float apg_orthogonal = deltas[i] - apg_parallel;
996997

997-
// tweak deltas
998-
deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
998+
// tweak deltas
999+
deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
1000+
}
9991001
}
10001002
}
10011003

0 commit comments

Comments
 (0)