Skip to content

Commit 01e15d0

Browse files
committed
Merge pull request BVLC#3069 from timmeinhardt/argmax
Add argmax_param "axis" to maximise output along the specified axis
2 parents 942df00 + def3d3c commit 01e15d0

File tree

4 files changed

+205
-28
lines changed

4 files changed

+205
-28
lines changed

include/caffe/common_layers.hpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace caffe {
2121
*
2222
* Intended for use after a classification layer to produce a prediction.
2323
* If parameter out_max_val is set to true, output is a vector of pairs
24-
* (max_ind, max_val) for each image.
24+
* (max_ind, max_val) for each image. The axis parameter specifies an axis
25+
* along which to maximise.
2526
*
2627
* NOTE: does not implement Backwards operation.
2728
*/
@@ -34,7 +35,11 @@ class ArgMaxLayer : public Layer<Dtype> {
3435
* - top_k (\b optional uint, default 1).
3536
* the number @f$ K @f$ of maximal items to output.
3637
* - out_max_val (\b optional bool, default false).
37-
* if set, output a vector of pairs (max_ind, max_val) for each image.
38+
* if set, output a vector of pairs (max_ind, max_val) unless axis is set then
39+
* output max_val along the specified axis.
40+
* - axis (\b optional int).
41+
* if set, maximise along the specified axis else maximise the flattened
42+
* trailing dimensions for each index of the first / num dimension.
3843
*/
3944
explicit ArgMaxLayer(const LayerParameter& param)
4045
: Layer<Dtype>(param) {}
@@ -54,7 +59,8 @@ class ArgMaxLayer : public Layer<Dtype> {
5459
* the inputs @f$ x @f$
5560
* @param top output Blob vector (length 1)
5661
* -# @f$ (N \times 1 \times K \times 1) @f$ or, if out_max_val
57-
* @f$ (N \times 2 \times K \times 1) @f$
62+
* @f$ (N \times 2 \times K \times 1) @f$ unless axis set than e.g.
63+
* @f$ (N \times K \times H \times W) @f$ if axis == 1
5864
* the computed outputs @f$
5965
* y_n = \arg\max\limits_i x_{ni}
6066
* @f$ (for @f$ K = 1 @f$).
@@ -68,6 +74,8 @@ class ArgMaxLayer : public Layer<Dtype> {
6874
}
6975
bool out_max_val_;
7076
size_t top_k_;
77+
bool has_axis_;
78+
int axis_;
7179
};
7280

7381
/**

src/caffe/layers/argmax_layer.cpp

+57-20
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,84 @@ namespace caffe {
1111
template <typename Dtype>
1212
void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
1313
const vector<Blob<Dtype>*>& top) {
14-
out_max_val_ = this->layer_param_.argmax_param().out_max_val();
15-
top_k_ = this->layer_param_.argmax_param().top_k();
16-
CHECK_GE(top_k_, 1) << " top k must not be less than 1.";
17-
CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
18-
<< "top_k must be less than or equal to the number of classes.";
14+
const ArgMaxParameter& argmax_param = this->layer_param_.argmax_param();
15+
out_max_val_ = argmax_param.out_max_val();
16+
top_k_ = argmax_param.top_k();
17+
has_axis_ = argmax_param.has_axis();
18+
CHECK_GE(top_k_, 1) << "top k must not be less than 1.";
19+
if (has_axis_) {
20+
axis_ = bottom[0]->CanonicalAxisIndex(argmax_param.axis());
21+
CHECK_GE(axis_, 0) << "axis must not be less than 0.";
22+
CHECK_LE(axis_, bottom[0]->num_axes()) <<
23+
"axis must be less than or equal to the number of axis.";
24+
CHECK_LE(top_k_, bottom[0]->shape(axis_))
25+
<< "top_k must be less than or equal to the dimension of the axis.";
26+
} else {
27+
CHECK_LE(top_k_, bottom[0]->count(1))
28+
<< "top_k must be less than or equal to"
29+
" the dimension of the flattened bottom blob per instance.";
30+
}
1931
}
2032

2133
template <typename Dtype>
2234
void ArgMaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
2335
const vector<Blob<Dtype>*>& top) {
24-
if (out_max_val_) {
25-
// Produces max_ind and max_val
26-
top[0]->Reshape(bottom[0]->num(), 2, top_k_, 1);
36+
std::vector<int> shape(bottom[0]->num_axes(), 1);
37+
if (has_axis_) {
38+
// Produces max_ind or max_val per axis
39+
shape = bottom[0]->shape();
40+
shape[axis_] = top_k_;
2741
} else {
28-
// Produces only max_ind
29-
top[0]->Reshape(bottom[0]->num(), 1, top_k_, 1);
42+
shape[0] = bottom[0]->shape(0);
43+
// Produces max_ind
44+
shape[2] = top_k_;
45+
if (out_max_val_) {
46+
// Produces max_ind and max_val
47+
shape[1] = 2;
48+
}
3049
}
50+
top[0]->Reshape(shape);
3151
}
3252

3353
template <typename Dtype>
3454
void ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
3555
const vector<Blob<Dtype>*>& top) {
3656
const Dtype* bottom_data = bottom[0]->cpu_data();
3757
Dtype* top_data = top[0]->mutable_cpu_data();
38-
int num = bottom[0]->num();
39-
int dim = bottom[0]->count() / bottom[0]->num();
58+
int dim, axis_dist;
59+
if (has_axis_) {
60+
dim = bottom[0]->shape(axis_);
61+
// Distance between values of axis in blob
62+
axis_dist = bottom[0]->count(axis_) / dim;
63+
} else {
64+
dim = bottom[0]->count(1);
65+
axis_dist = 1;
66+
}
67+
int num = bottom[0]->count() / dim;
68+
std::vector<std::pair<Dtype, int> > bottom_data_vector(dim);
4069
for (int i = 0; i < num; ++i) {
41-
std::vector<std::pair<Dtype, int> > bottom_data_vector;
4270
for (int j = 0; j < dim; ++j) {
43-
bottom_data_vector.push_back(
44-
std::make_pair(bottom_data[i * dim + j], j));
71+
bottom_data_vector[j] = std::make_pair(
72+
bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
4573
}
4674
std::partial_sort(
4775
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
4876
bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
4977
for (int j = 0; j < top_k_; ++j) {
50-
top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second;
51-
}
52-
if (out_max_val_) {
53-
for (int j = 0; j < top_k_; ++j) {
54-
top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first;
78+
if (out_max_val_) {
79+
if (has_axis_) {
80+
// Produces max_val per axis
81+
top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
82+
= bottom_data_vector[j].first;
83+
} else {
84+
// Produces max_ind and max_val
85+
top_data[2 * i * top_k_ + j] = bottom_data_vector[j].second;
86+
top_data[2 * i * top_k_ + top_k_ + j] = bottom_data_vector[j].first;
87+
}
88+
} else {
89+
// Produces max_ind per axis
90+
top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
91+
= bottom_data_vector[j].second;
5592
}
5693
}
5794
}

src/caffe/proto/caffe.proto

+5
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,11 @@ message ArgMaxParameter {
443443
// If true produce pairs (argmax, maxval)
444444
optional bool out_max_val = 1 [default = false];
445445
optional uint32 top_k = 2 [default = 1];
446+
// The axis along which to maximise -- may be negative to index from the
447+
// end (e.g., -1 for the last axis).
448+
// By default ArgMaxLayer maximizes over the flattened trailing dimensions
449+
// for each index of the first / num dimension.
450+
optional int32 axis = 3;
446451
}
447452

448453
message ConcatParameter {

src/caffe/test/test_argmax_layer.cpp

+132-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ template <typename Dtype>
1616
class ArgMaxLayerTest : public CPUDeviceTest<Dtype> {
1717
protected:
1818
ArgMaxLayerTest()
19-
: blob_bottom_(new Blob<Dtype>(10, 20, 1, 1)),
19+
: blob_bottom_(new Blob<Dtype>(10, 10, 20, 20)),
2020
blob_top_(new Blob<Dtype>()),
2121
top_k_(5) {
2222
Caffe::set_random_seed(1701);
@@ -55,6 +55,43 @@ TYPED_TEST(ArgMaxLayerTest, TestSetupMaxVal) {
5555
EXPECT_EQ(this->blob_top_->channels(), 2);
5656
}
5757

58+
TYPED_TEST(ArgMaxLayerTest, TestSetupAxis) {
59+
LayerParameter layer_param;
60+
ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
61+
argmax_param->set_axis(0);
62+
ArgMaxLayer<TypeParam> layer(layer_param);
63+
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
64+
EXPECT_EQ(this->blob_top_->shape(0), argmax_param->top_k());
65+
EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(0));
66+
EXPECT_EQ(this->blob_top_->shape(2), this->blob_bottom_->shape(2));
67+
EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
68+
}
69+
70+
TYPED_TEST(ArgMaxLayerTest, TestSetupAxisNegativeIndexing) {
71+
LayerParameter layer_param;
72+
ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
73+
argmax_param->set_axis(-2);
74+
ArgMaxLayer<TypeParam> layer(layer_param);
75+
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
76+
EXPECT_EQ(this->blob_top_->shape(0), this->blob_bottom_->shape(0));
77+
EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(1));
78+
EXPECT_EQ(this->blob_top_->shape(2), argmax_param->top_k());
79+
EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
80+
}
81+
82+
TYPED_TEST(ArgMaxLayerTest, TestSetupAxisMaxVal) {
83+
LayerParameter layer_param;
84+
ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
85+
argmax_param->set_axis(2);
86+
argmax_param->set_out_max_val(true);
87+
ArgMaxLayer<TypeParam> layer(layer_param);
88+
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
89+
EXPECT_EQ(this->blob_top_->shape(0), this->blob_bottom_->shape(0));
90+
EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(1));
91+
EXPECT_EQ(this->blob_top_->shape(2), argmax_param->top_k());
92+
EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
93+
}
94+
5895
TYPED_TEST(ArgMaxLayerTest, TestCPU) {
5996
LayerParameter layer_param;
6097
ArgMaxLayer<TypeParam> layer(layer_param);
@@ -112,6 +149,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
112149
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
113150
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
114151
// Now, check values
152+
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
115153
int max_ind;
116154
TypeParam max_val;
117155
int num = this->blob_bottom_->num();
@@ -121,10 +159,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
121159
EXPECT_LE(this->blob_top_->data_at(i, 0, 0, 0), dim);
122160
for (int j = 0; j < this->top_k_; ++j) {
123161
max_ind = this->blob_top_->data_at(i, 0, j, 0);
124-
max_val = this->blob_bottom_->data_at(i, max_ind, 0, 0);
162+
max_val = bottom_data[i * dim + max_ind];
125163
int count = 0;
126164
for (int k = 0; k < dim; ++k) {
127-
if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
165+
if (bottom_data[i * dim + k] > max_val) {
128166
++count;
129167
}
130168
}
@@ -142,6 +180,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
142180
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
143181
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
144182
// Now, check values
183+
const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
145184
int max_ind;
146185
TypeParam max_val;
147186
int num = this->blob_bottom_->num();
@@ -152,10 +191,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
152191
for (int j = 0; j < this->top_k_; ++j) {
153192
max_ind = this->blob_top_->data_at(i, 0, j, 0);
154193
max_val = this->blob_top_->data_at(i, 1, j, 0);
155-
EXPECT_EQ(this->blob_bottom_->data_at(i, max_ind, 0, 0), max_val);
194+
EXPECT_EQ(bottom_data[i * dim + max_ind], max_val);
156195
int count = 0;
157196
for (int k = 0; k < dim; ++k) {
158-
if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
197+
if (bottom_data[i * dim + k] > max_val) {
159198
++count;
160199
}
161200
}
@@ -164,5 +203,93 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
164203
}
165204
}
166205

206+
TYPED_TEST(ArgMaxLayerTest, TestCPUAxis) {
207+
LayerParameter layer_param;
208+
ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
209+
argmax_param->set_axis(0);
210+
ArgMaxLayer<TypeParam> layer(layer_param);
211+
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
212+
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
213+
// Now, check values
214+
int max_ind;
215+
TypeParam max_val;
216+
std::vector<int> shape = this->blob_bottom_->shape();
217+
for (int i = 0; i < shape[1]; ++i) {
218+
for (int j = 0; j < shape[2]; ++j) {
219+
for (int k = 0; k < shape[3]; ++k) {
220+
max_ind = this->blob_top_->data_at(0, i, j, k);
221+
max_val = this->blob_bottom_->data_at(max_ind, i, j, k);
222+
EXPECT_GE(max_ind, 0);
223+
EXPECT_LE(max_ind, shape[0]);
224+
for (int l = 0; l < shape[0]; ++l) {
225+
EXPECT_LE(this->blob_bottom_->data_at(l, i, j, k), max_val);
226+
}
227+
}
228+
}
229+
}
230+
}
231+
232+
TYPED_TEST(ArgMaxLayerTest, TestCPUAxisTopK) {
233+
LayerParameter layer_param;
234+
ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
235+
argmax_param->set_axis(2);
236+
argmax_param->set_top_k(this->top_k_);
237+
ArgMaxLayer<TypeParam> layer(layer_param);
238+
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
239+
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
240+
// Now, check values
241+
int max_ind;
242+
TypeParam max_val;
243+
std::vector<int> shape = this->blob_bottom_->shape();
244+
for (int i = 0; i < shape[0]; ++i) {
245+
for (int j = 0; j < shape[1]; ++j) {
246+
for (int k = 0; k < shape[3]; ++k) {
247+
for (int m = 0; m < this->top_k_; ++m) {
248+
max_ind = this->blob_top_->data_at(i, j, m, k);
249+
max_val = this->blob_bottom_->data_at(i, j, max_ind, k);
250+
EXPECT_GE(max_ind, 0);
251+
EXPECT_LE(max_ind, shape[2]);
252+
int count = 0;
253+
for (int l = 0; l < shape[2]; ++l) {
254+
if (this->blob_bottom_->data_at(i, j, l, k) > max_val) {
255+
++count;
256+
}
257+
}
258+
EXPECT_EQ(m, count);
259+
}
260+
}
261+
}
262+
}
263+
}
264+
265+
TYPED_TEST(ArgMaxLayerTest, TestCPUAxisMaxValTopK) {
266+
LayerParameter layer_param;
267+
ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
268+
argmax_param->set_axis(-1);
269+
argmax_param->set_top_k(this->top_k_);
270+
argmax_param->set_out_max_val(true);
271+
ArgMaxLayer<TypeParam> layer(layer_param);
272+
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
273+
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
274+
// Now, check values
275+
TypeParam max_val;
276+
std::vector<int> shape = this->blob_bottom_->shape();
277+
for (int i = 0; i < shape[0]; ++i) {
278+
for (int j = 0; j < shape[1]; ++j) {
279+
for (int k = 0; k < shape[2]; ++k) {
280+
for (int m = 0; m < this->top_k_; ++m) {
281+
max_val = this->blob_top_->data_at(i, j, k, m);
282+
int count = 0;
283+
for (int l = 0; l < shape[3]; ++l) {
284+
if (this->blob_bottom_->data_at(i, j, k, l) > max_val) {
285+
++count;
286+
}
287+
}
288+
EXPECT_EQ(m, count);
289+
}
290+
}
291+
}
292+
}
293+
}
167294

168295
} // namespace caffe

0 commit comments

Comments
 (0)