@@ -26,26 +26,24 @@ __global__ void LRNFillScale(const int nthreads, const Dtype* in,
26
26
Dtype accum_scale = 0 ;
27
27
// fill the scale at [n, :, h, w]
28
28
// accumulate values
29
- while (head < post_pad) {
29
+ while (head < post_pad && head < channels ) {
30
30
accum_scale += in[head * step] * in[head * step];
31
31
++head;
32
32
}
33
- // until we reach size, nothing needs to be subtracted
34
- while (head < size) {
35
- accum_scale += in[head * step] * in[head * step];
36
- scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
37
- ++head;
38
- }
39
33
// both add and subtract
40
34
while (head < channels) {
41
35
accum_scale += in[head * step] * in[head * step];
42
- accum_scale -= in[(head - size) * step] * in[(head - size) * step];
36
+ if (head - size >= 0 ) {
37
+ accum_scale -= in[(head - size) * step] * in[(head - size) * step];
38
+ }
43
39
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
44
40
++head;
45
41
}
46
42
// subtract only
47
43
while (head < channels + post_pad) {
48
- accum_scale -= in[(head - size) * step] * in[(head - size) * step];
44
+ if (head - size >= 0 ) {
45
+ accum_scale -= in[(head - size) * step] * in[(head - size) * step];
46
+ }
49
47
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
50
48
++head;
51
49
}
@@ -143,35 +141,30 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
143
141
int post_pad = size - pre_pad - 1 ;
144
142
Dtype accum_ratio = 0 ;
145
143
// accumulate values
146
- while (head < post_pad) {
144
+ while (head < post_pad && head < channels ) {
147
145
accum_ratio += top_diff[head * step] * top_data[head * step] /
148
146
scale[head * step];
149
147
++head;
150
148
}
151
- // until we reach size, nothing needs to be subtracted
152
- while (head < size) {
153
- accum_ratio += top_diff[head * step] * top_data[head * step] /
154
- scale[head * step];
155
- bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
156
- * pow (scale[(head - post_pad) * step], negative_beta) - cache_ratio *
157
- bottom_data[(head - post_pad) * step] * accum_ratio;
158
- ++head;
159
- }
160
149
// both add and subtract
161
150
while (head < channels) {
162
151
accum_ratio += top_diff[head * step] * top_data[head * step] /
163
152
scale[head * step];
164
- accum_ratio -= top_diff[(head - size) * step] *
165
- top_data[(head - size) * step] / scale[(head - size) * step];
153
+ if (head - size >= 0 ) {
154
+ accum_ratio -= top_diff[(head - size) * step] *
155
+ top_data[(head - size) * step] / scale[(head - size) * step];
156
+ }
166
157
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
167
158
* pow (scale[(head - post_pad) * step], negative_beta) - cache_ratio *
168
159
bottom_data[(head - post_pad) * step] * accum_ratio;
169
160
++head;
170
161
}
171
162
// subtract only
172
163
while (head < channels + post_pad) {
173
- accum_ratio -= top_diff[(head - size) * step] *
174
- top_data[(head - size) * step] / scale[(head - size) * step];
164
+ if (head - size >= 0 ) {
165
+ accum_ratio -= top_diff[(head - size) * step] *
166
+ top_data[(head - size) * step] / scale[(head - size) * step];
167
+ }
175
168
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
176
169
* pow (scale[(head - post_pad) * step], negative_beta) - cache_ratio *
177
170
bottom_data[(head - post_pad) * step] * accum_ratio;
0 commit comments