Skip to content

Commit ac74eef

Browse files
committed
inplace HardTanh, subclass ReLU6
1 parent be1a51a commit ac74eef

File tree

8 files changed

+98
-123
lines changed

8 files changed

+98
-123
lines changed

HardTanh.lua

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
local HardTanh, parent = torch.class('nn.HardTanh', 'nn.Module')
22

3-
function HardTanh:__init(min_value, max_value)
3+
function HardTanh:__init(min_value, max_value, inplace)
44
parent.__init(self)
55
self.min_val = min_value or -1
66
self.max_val = max_value or 1
7+
self.inplace = inplace or false
8+
if (inplace and type(inplace) ~= 'boolean') then
9+
error('in-place flag must be boolean')
10+
end
711
assert(self.max_val>self.min_val, 'max_value must be larger than min_value')
812
end
913

@@ -14,7 +18,8 @@ function HardTanh:updateOutput(input)
1418
input:cdata(),
1519
self.output:cdata(),
1620
self.min_val,
17-
self.max_val
21+
self.max_val,
22+
self.inplace or false
1823
)
1924
return self.output
2025
end
@@ -25,7 +30,8 @@ function HardTanh:updateGradInput(input, gradOutput)
2530
gradOutput:cdata(),
2631
self.gradInput:cdata(),
2732
self.min_val,
28-
self.max_val
33+
self.max_val,
34+
self.inplace or false
2935
)
3036
return self.gradInput
3137
end

ReLU6.lua

+4-7
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,18 @@ function ReLU6:__init(inplace)
1515
end
1616

1717
function ReLU6:updateOutput(input)
18-
input.THNN.ReLU6_updateOutput(
18+
input.THNN.HardTanh_updateOutput(
1919
input:cdata(),
2020
self.output:cdata(),
21-
self.inplace
22-
)
21+
0, 6, self.inplace)
2322
return self.output
2423
end
2524

2625
function ReLU6:updateGradInput(input, gradOutput)
27-
input.THNN.ReLU6_updateGradInput(
26+
input.THNN.HardTanh_updateGradInput(
2827
input:cdata(),
2928
gradOutput:cdata(),
3029
self.gradInput:cdata(),
31-
self.inplace
32-
)
30+
0, 6, self.inplace)
3331
return self.gradInput
3432
end
35-

doc/transfer.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ thus outputting a Tensor of the same dimension.
1515
* `f(x)` = `x,` `otherwise.`
1616

1717
The range of the linear region `[-1 1]` can be adjusted by specifying arguments in declaration, for example `nn.HardTanh(min_value, max_value)`.
18-
Otherwise, `[min_value max_value]` is set to `[-1 1]` by default.
18+
Otherwise, `[min_value max_value]` is set to `[-1 1]` by default. In-place operation defined by third argument boolean.
1919

2020

2121
```lua

lib/THNN/generic/HardTanh.c

+78-35
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,59 @@ void THNN_(HardTanh_updateOutput)(
77
THTensor *input,
88
THTensor *output,
99
real min_val,
10-
real max_val)
10+
real max_val,
11+
bool inplace)
1112
{
12-
THTensor_(resizeAs)(output, input);
13+
if (inplace)
14+
THTensor_(set)(output, input);
15+
else
16+
THTensor_(resizeAs)(output, input);
1317

1418
if (input->nDimension == 1 || !THTensor_(isContiguous)(input) || !THTensor_(isContiguous)(output))
1519
{
16-
TH_TENSOR_APPLY2(real, output, real, input,
17-
if (*input_data < min_val)
18-
*output_data = min_val;
19-
else if (*input_data <= max_val)
20-
*output_data = *input_data;
21-
else
22-
*output_data = max_val;
23-
);
20+
if (inplace)
21+
TH_TENSOR_APPLY(real, input,
22+
if (*input_data < min_val)
23+
*input_data = min_val;
24+
else if (*input_data > max_val)
25+
*input_data = max_val;
26+
);
27+
TH_TENSOR_APPLY2(real, output, real, input,
28+
if (*input_data < min_val)
29+
*output_data = min_val;
30+
else if (*input_data <= max_val)
31+
*output_data = *input_data;
32+
else
33+
*output_data = max_val;
34+
);
2435
}
2536
else
2637
{
27-
real* ptr_output = THTensor_(data)(output);
2838
real* ptr_input = THTensor_(data)(input);
39+
real* ptr_output = THTensor_(data)(output);
2940
long i;
41+
long n = THTensor_(nElement)(input);
3042

43+
if (inplace)
3144
#pragma omp parallel for private(i)
32-
for (i = 0; i < THTensor_(nElement)(input); i++)
33-
{
34-
if (ptr_input[i] < min_val)
35-
ptr_output[i] = min_val;
36-
else if (ptr_input[i] <= max_val)
37-
ptr_output[i] = ptr_input[i];
38-
else
39-
ptr_output[i] = max_val;
40-
}
45+
for (i = 0; i < n; i++)
46+
{
47+
if (ptr_input[i] < min_val)
48+
ptr_input[i] = min_val;
49+
else if (ptr_input[i] > max_val)
50+
ptr_input[i] = max_val;
51+
}
52+
else
53+
#pragma omp parallel for private(i)
54+
for (i = 0; i < n; i++)
55+
{
56+
if (ptr_input[i] < min_val)
57+
ptr_output[i] = min_val;
58+
else if (ptr_input[i] <= max_val)
59+
ptr_output[i] = ptr_input[i];
60+
else
61+
ptr_output[i] = max_val;
62+
}
4163
}
4264
}
4365

@@ -47,37 +69,58 @@ void THNN_(HardTanh_updateGradInput)(
4769
THTensor *gradOutput,
4870
THTensor *gradInput,
4971
real min_val,
50-
real max_val)
72+
real max_val,
73+
bool inplace)
5174
{
52-
THTensor_(resizeAs)(gradInput, input);
75+
if (inplace)
76+
THTensor_(set)(gradInput, gradOutput);
77+
else
78+
THTensor_(resizeAs)(gradInput, input);
5379

5480
if (input->nDimension == 1 ||
5581
!THTensor_(isContiguous)(input) ||
5682
!THTensor_(isContiguous)(gradOutput) ||
5783
!THTensor_(isContiguous)(gradInput))
5884
{
59-
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
60-
if (*input_data < min_val || *input_data > max_val)
61-
*gradInput_data = 0;
62-
else
63-
*gradInput_data = *gradOutput_data;
64-
);
85+
if (inplace)
86+
{
87+
TH_TENSOR_APPLY2(real, gradOutput, real, input,
88+
if (*input_data < min_val || *input_data > max_val)
89+
*gradOutput_data = 0;
90+
);
91+
}
92+
else
93+
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
94+
if (*input_data < min_val || *input_data > max_val)
95+
*gradInput_data = 0;
96+
else
97+
*gradInput_data = *gradOutput_data;
98+
);
6599
}
66100
else
67101
{
68102
real* ptr_gradOutput = THTensor_(data)(gradOutput);
69103
real* ptr_gradInput = THTensor_(data)(gradInput);
70104
real* ptr_input = THTensor_(data)(input);
71105
long i;
106+
long n = THTensor_(nElement)(input);
72107

108+
if (inplace)
73109
#pragma omp parallel for private(i)
74-
for (i = 0; i < THTensor_(nElement)(input); i++)
75-
{
76-
if (ptr_input[i] < min_val || ptr_input[i] > max_val)
77-
ptr_gradInput[i] = 0;
78-
else
79-
ptr_gradInput[i] = ptr_gradOutput[i];
80-
}
110+
for (i = 0; i < n; i++)
111+
{
112+
if (ptr_input[i] <= min_val || ptr_input[i] >= max_val)
113+
ptr_gradInput[i] = 0;
114+
}
115+
else
116+
#pragma omp parallel for private(i)
117+
for (i = 0; i < n; i++)
118+
{
119+
if (ptr_input[i] < min_val || ptr_input[i] > max_val)
120+
ptr_gradInput[i] = 0;
121+
else
122+
ptr_gradInput[i] = ptr_gradOutput[i];
123+
}
81124
}
82125
}
83126

lib/THNN/generic/ReLU6.c

-58
This file was deleted.

lib/THNN/generic/THNN.h

+4-14
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,16 @@ TH_API void THNN_(HardTanh_updateOutput)(
106106
THTensor *input, // input tensor
107107
THTensor *output, // [OUT] output tensor
108108
real min_val, // lower threshold
109-
real max_val); // upper threshold
109+
real max_val,
110+
bool inplace); // upper threshold
110111
TH_API void THNN_(HardTanh_updateGradInput)(
111112
THNNState *state, // library's state
112113
THTensor *input, // input tensor
113114
THTensor *gradOutput, // gradient w.r.t. module's output
114115
THTensor *gradInput, // [OUT] gradient w.r.t. the input
115116
real min_val, // lower threshold
116-
real max_val); // upper threshold
117+
real max_val,
118+
bool inplace); // upper threshold
117119

118120
TH_API void THNN_(L1Cost_updateOutput)(
119121
THNNState *state, // library's state
@@ -472,18 +474,6 @@ TH_API void THNN_(Threshold_updateGradInput)(
472474
real threshold,
473475
bool inplace);
474476

475-
TH_API void THNN_(ReLU6_updateOutput)(
476-
THNNState *state,
477-
THTensor *input,
478-
THTensor *output,
479-
bool inplace);
480-
TH_API void THNN_(ReLU6_updateGradInput)(
481-
THNNState *state,
482-
THTensor *input,
483-
THTensor *gradOutput,
484-
THTensor *gradInput,
485-
bool inplace);
486-
487477
TH_API void THNN_(TemporalConvolution_updateOutput)(
488478
THNNState *state,
489479
THTensor *input,

lib/THNN/init.c

-3
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@
9494
#include "generic/Threshold.c"
9595
#include "THGenerateFloatTypes.h"
9696

97-
#include "generic/ReLU6.c"
98-
#include "THGenerateFloatTypes.h"
99-
10097
#include "generic/TemporalConvolution.c"
10198
#include "THGenerateFloatTypes.h"
10299

test.lua

+2-2
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ function nntest.ReLU6()
273273
local lt = input:clone():lt(input, 6)
274274
local output2 = gt:clone():cmul(lt):cmul(input)
275275
output2:add(6, input:clone():gt(input, 6))
276-
mytester:assertTensorEq(output, output2, 0.000001, 'ReLU6 output')
276+
mytester:assertTensorEq(output, output2, 0.000001, 'ReLU6 output '..(inplace and '(inplace)' or '') )
277277
local gradInput = module:backward(input, gradOutput:clone())
278278
local gradInput2 = gt:clone():cmul(lt):cmul(gradOutput)
279-
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput')
279+
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput '..(inplace and '(inplace)' or '') )
280280
end
281281
end
282282

0 commit comments

Comments
 (0)