Skip to content

Commit b176cb4

Browse files
committed
SINGA-386 Implement RNN operation for autograd
- fix bugs in cpp parts, the codes can be made without error.
1 parent 33ddc2d commit b176cb4

File tree

3 files changed

+238
-84
lines changed

3 files changed

+238
-84
lines changed

python/singa/autograd.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ def __call__(self, x):
772772
self.handle.device_id = x.device.id()
773773

774774
y = batchnorm_2d(self.handle, x, self.scale, self.bias,
775-
self.running_mean, self.running_var)
775+
self.running_mean, self.running_var)
776776
return y
777777

778778

@@ -936,3 +936,41 @@ def __init__(self, kernel_size, stride=None, padding=0):
936936
stride = kernel_size
937937
super(MaxPool2d, self).__init__(
938938
(1, kernel_size), (0, stride), (0, padding), False)
939+
940+
941+
class _RNN(Operation):
942+
943+
def __init__(self, handle):
944+
self.handle = handle
945+
946+
def forward(self, X, W):
947+
948+
if self.handle.device_id == -1:
949+
raise NotImplementedError
950+
else:
951+
if training:
952+
out, self.cache = singa.GpuRNNForwardTraining(
953+
self.handle, X, W)
954+
else:
955+
out = singa.GpuRNNForwardInference(self.handle, X, W)
956+
return out
957+
958+
def backward(self, dY):
959+
assert training is True and hasattr(
960+
self, 'cache'), 'Please set training as True before do BP. '
961+
962+
if dY.device().id() != self.handle.device_id:
963+
dY.ToDevice(self.inputs[0].device())
964+
965+
if self.handle.device_id == -1:
966+
raise NotImplementedError
967+
else:
968+
dX, dW = singa.GpuRNNBackward(self.handle, dY, self.cache)
969+
return dX, dW
970+
971+
972+
def rnn():
973+
pass
974+
975+
976+
class RNN(Layer):

0 commit comments

Comments
 (0)