@@ -802,7 +802,7 @@ class StableDiffusionGGML {
802
802
SDCondition id_cond,
803
803
sd_slg_params_t slg_params = {NULL , 0 , 0 , 0 , 0 },
804
804
sd_apg_params_t apg_params = {1 , 0 , 0 },
805
- ggml_tensor* noise_mask = nullptr ) {
805
+ ggml_tensor* noise_mask = nullptr ) {
806
806
std::vector<int > skip_layers (slg_params.skip_layers , slg_params.skip_layers + slg_params.skip_layers_count );
807
807
808
808
LOG_DEBUG (" Sample" );
@@ -963,39 +963,41 @@ class StableDiffusionGGML {
963
963
float diff_norm = 0 ;
964
964
float cond_norm_sq = 0 ;
965
965
float dot = 0 ;
966
- for (int i = 0 ; i < ne_elements; i++) {
967
- float delta = positive_data[i] - negative_data[i];
968
- if (apg_params.momentum != 0 ) {
969
- delta += apg_params.momentum * apg_momentum_buffer[i];
970
- apg_momentum_buffer[i] = delta;
966
+ if (has_unconditioned) {
967
+ for (int i = 0 ; i < ne_elements; i++) {
968
+ float delta = positive_data[i] - negative_data[i];
969
+ if (apg_params.momentum != 0 ) {
970
+ delta += apg_params.momentum * apg_momentum_buffer[i];
971
+ apg_momentum_buffer[i] = delta;
972
+ }
973
+ if (apg_params.norm_treshold > 0 ) {
974
+ diff_norm += delta * delta;
975
+ }
976
+ if (apg_params.eta != 1 .0f ) {
977
+ cond_norm_sq += positive_data[i] * positive_data[i];
978
+ dot += positive_data[i] * delta;
979
+ }
980
+ deltas[i] = delta;
971
981
}
972
982
if (apg_params.norm_treshold > 0 ) {
973
- diff_norm += delta * delta;
983
+ diff_norm = std::sqrtf (diff_norm);
984
+ apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
974
985
}
975
986
if (apg_params.eta != 1 .0f ) {
976
- cond_norm_sq += positive_data[i] * positive_data[i];
977
- dot += positive_data[i] * delta;
987
+ dot *= apg_scale_factor;
988
+ // pre-normalize (avoids one square root and ne_elements extra divs)
989
+ dot /= cond_norm_sq;
978
990
}
979
- deltas[i] = delta;
980
- }
981
- if (apg_params.norm_treshold > 0 ) {
982
- diff_norm = std::sqrtf (diff_norm);
983
- apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
984
- }
985
- if (apg_params.eta != 1 .0f ) {
986
- dot *= apg_scale_factor;
987
- // pre-normalize (avoids one square root and ne_elements extra divs)
988
- dot /= cond_norm_sq;
989
- }
990
991
991
- for (int i = 0 ; i < ne_elements; i++) {
992
- deltas[i] *= apg_scale_factor;
993
- if (apg_params.eta != 1 .0f ) {
994
- float apg_parallel = dot * positive_data[i];
995
- float apg_orthogonal = deltas[i] - apg_parallel;
992
+ for (int i = 0 ; i < ne_elements; i++) {
993
+ deltas[i] *= apg_scale_factor;
994
+ if (apg_params.eta != 1 .0f ) {
995
+ float apg_parallel = dot * positive_data[i];
996
+ float apg_orthogonal = deltas[i] - apg_parallel;
996
997
997
- // tweak deltas
998
- deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
998
+ // tweak deltas
999
+ deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
1000
+ }
999
1001
}
1000
1002
}
1001
1003
0 commit comments