Skip to content

Commit be1a51a

Browse files
authored
Merge pull request torch#849 from jonathantompson/relu6
Added ReLU6 layer, test and doc.
2 parents d23a8f5 + 3113c82 commit be1a51a

File tree

8 files changed

+149
-0
lines changed

8 files changed

+149
-0
lines changed

ReLU6.lua

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
local ReLU6, parent = torch.class('nn.ReLU6', 'nn.Module')
2+
3+
function ReLU6:__init(inplace)
4+
parent.__init(self)
5+
6+
if inplace == nil then
7+
self.inplace = false
8+
else
9+
self.inplace = inplace
10+
end
11+
12+
if (inplace and type(inplace) ~= 'boolean') then
13+
error('in-place flag must be boolean')
14+
end
15+
end
16+
17+
function ReLU6:updateOutput(input)
18+
input.THNN.ReLU6_updateOutput(
19+
input:cdata(),
20+
self.output:cdata(),
21+
self.inplace
22+
)
23+
return self.output
24+
end
25+
26+
function ReLU6:updateGradInput(input, gradOutput)
27+
input.THNN.ReLU6_updateGradInput(
28+
input:cdata(),
29+
gradOutput:cdata(),
30+
self.gradInput:cdata(),
31+
self.inplace
32+
)
33+
return self.gradInput
34+
end
35+

doc/image/relu6.png

19.6 KB
Loading

doc/transfer.md

+23
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,29 @@ gnuplot.grid(true)
261261
```
262262
![](image/relu.png)
263263

264+
<a name="nn.ReLU6"></a>
265+
## ReLU6 ##
266+
267+
Same as `ReLU` except that the rectifying function `f(x)` saturates at `x = 6`. This layer is useful for training networks that do not loose precision (due to FP saturation) when implemented as FP16.
268+
269+
`ReLU6` is defined as `f(x)` = `min(max(0, x), 6)`
270+
271+
Can optionally do its operation in-place without using extra state memory:
272+
```lua
273+
m=nn.ReLU6(true) -- true = in-place, false = keeping separate state.
274+
```
275+
276+
```lua
277+
ii=torch.linspace(-3, 9)
278+
m=nn.ReLU6()
279+
oo=m:forward(ii)
280+
go=torch.ones(100)
281+
gi=m:backward(ii,go)
282+
gnuplot.plot({'f(x)',ii,oo,'+-'},{'df/dx',ii,gi,'+-'})
283+
gnuplot.grid(true)
284+
```
285+
![](image/relu6.png)
286+
264287
<a name="nn.PReLU"></a>
265288
## PReLU ##
266289

init.lua

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ require('nn.HardShrink')
8383
require('nn.SoftShrink')
8484
require('nn.Threshold')
8585
require('nn.ReLU')
86+
require('nn.ReLU6')
8687
require('nn.PReLU')
8788
require('nn.LeakyReLU')
8889
require('nn.SpatialSoftMax')

lib/THNN/generic/ReLU6.c

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#ifndef TH_GENERIC_FILE
2+
#define TH_GENERIC_FILE "generic/ReLU6.c"
3+
#else
4+
5+
void THNN_(ReLU6_updateOutput)(
6+
THNNState *state,
7+
THTensor *input,
8+
THTensor *output,
9+
bool inplace)
10+
{
11+
if (inplace)
12+
{
13+
TH_TENSOR_APPLY(real, input,
14+
if (*input_data <= 0)
15+
*input_data = 0;
16+
else if (*input_data >= 6)
17+
*input_data = 6;
18+
);
19+
THTensor_(set)(output, input);
20+
}
21+
else
22+
{
23+
THTensor_(resizeAs)(output, input);
24+
TH_TENSOR_APPLY2(real, output, real, input,
25+
*output_data =
26+
(*input_data > 0) ? ((*input_data < 6) ? *input_data : 6) : 0;
27+
);
28+
}
29+
}
30+
31+
void THNN_(ReLU6_updateGradInput)(
32+
THNNState *state,
33+
THTensor *input,
34+
THTensor *gradOutput,
35+
THTensor *gradInput,
36+
bool inplace)
37+
{
38+
if (inplace)
39+
{
40+
TH_TENSOR_APPLY2(real, gradOutput, real, input,
41+
if ((*input_data) <= 0 || (*input_data) >= 6)
42+
*gradOutput_data = 0;
43+
);
44+
THTensor_(set)(gradInput, gradOutput);
45+
}
46+
else
47+
{
48+
THTensor_(resizeAs)(gradInput, input);
49+
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
50+
if ((*input_data) > 0 && (*input_data) < 6)
51+
*gradInput_data = *gradOutput_data;
52+
else
53+
*gradInput_data = 0;
54+
);
55+
}
56+
}
57+
58+
#endif

lib/THNN/generic/THNN.h

+12
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,18 @@ TH_API void THNN_(Threshold_updateGradInput)(
472472
real threshold,
473473
bool inplace);
474474

475+
TH_API void THNN_(ReLU6_updateOutput)(
476+
THNNState *state,
477+
THTensor *input,
478+
THTensor *output,
479+
bool inplace);
480+
TH_API void THNN_(ReLU6_updateGradInput)(
481+
THNNState *state,
482+
THTensor *input,
483+
THTensor *gradOutput,
484+
THTensor *gradInput,
485+
bool inplace);
486+
475487
TH_API void THNN_(TemporalConvolution_updateOutput)(
476488
THNNState *state,
477489
THTensor *input,

lib/THNN/init.c

+3
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@
9494
#include "generic/Threshold.c"
9595
#include "THGenerateFloatTypes.h"
9696

97+
#include "generic/ReLU6.c"
98+
#include "THGenerateFloatTypes.h"
99+
97100
#include "generic/TemporalConvolution.c"
98101
#include "THGenerateFloatTypes.h"
99102

test.lua

+17
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,23 @@ function nntest.ReLU()
263263
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput')
264264
end
265265

266+
function nntest.ReLU6()
267+
for inplace = 0, 1 do
268+
local input = torch.randn(3, 4):mul(6)
269+
local gradOutput = torch.randn(3,4)
270+
local module = nn.ReLU6(inplace == 1)
271+
local output = module:forward(input:clone())
272+
local gt = input:clone():gt(input, 0)
273+
local lt = input:clone():lt(input, 6)
274+
local output2 = gt:clone():cmul(lt):cmul(input)
275+
output2:add(6, input:clone():gt(input, 6))
276+
mytester:assertTensorEq(output, output2, 0.000001, 'ReLU6 output')
277+
local gradInput = module:backward(input, gradOutput:clone())
278+
local gradInput2 = gt:clone():cmul(lt):cmul(gradOutput)
279+
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput')
280+
end
281+
end
282+
266283
function nntest.Exp()
267284
local ini = math.random(3,5)
268285
local inj = math.random(3,5)

0 commit comments

Comments
 (0)