@@ -772,7 +772,7 @@ def __call__(self, x):
772772 self .handle .device_id = x .device .id ()
773773
774774 y = batchnorm_2d (self .handle , x , self .scale , self .bias ,
775- self .running_mean , self .running_var )
775+ self .running_mean , self .running_var )
776776 return y
777777
778778
@@ -936,3 +936,41 @@ def __init__(self, kernel_size, stride=None, padding=0):
936936 stride = kernel_size
937937 super (MaxPool2d , self ).__init__ (
938938 (1 , kernel_size ), (0 , stride ), (0 , padding ), False )
939+
940+
941+ class _RNN (Operation ):
942+
943+ def __init__ (self , handle ):
944+ self .handle = handle
945+
946+ def forward (self , X , W ):
947+
948+ if self .handle .device_id == - 1 :
949+ raise NotImplementedError
950+ else :
951+ if training :
952+ out , self .cache = singa .GpuRNNForwardTraining (
953+ self .handle , X , W )
954+ else :
955+ out = singa .GpuRNNForwardInference (self .handle , X , W )
956+ return out
957+
958+ def backward (self , dY ):
959+ assert training is True and hasattr (
960+ self , 'cache' ), 'Please set training as True before do BP. '
961+
962+ if dY .device ().id () != self .handle .device_id :
963+ dY .ToDevice (self .inputs [0 ].device ())
964+
965+ if self .handle .device_id == - 1 :
966+ raise NotImplementedError
967+ else :
968+ dX , dW = singa .GpuRNNBackward (self .handle , dY , self .cache )
969+ return dX , dW
970+
971+
972+ def rnn ():
973+ pass
974+
975+
976+ class RNN (Layer ):
0 commit comments