@@ -342,62 +342,45 @@ def execute(self, x):
342
342
343
343
class BatchNorm1d (Module ):
344
344
def __init__ (self , num_features , eps = 1e-5 , momentum = 0.1 , affine = None , is_train = True , sync = True ):
345
- assert affine == None
346
345
self .sync = sync
347
346
self .num_features = num_features
348
347
self .is_train = is_train
349
348
self .eps = eps
350
349
self .momentum = momentum
351
- self .weight = init .constant ((num_features ,), "float32" , 1.0 )
352
- self .bias = init .constant ((num_features ,), "float32" , 0.0 )
350
+ self .affine = affine
351
+ if affine :
352
+ self .weight = init .constant ((num_features ,), "float32" , 1.0 )
353
+ self .bias = init .constant ((num_features ,), "float32" , 0.0 )
353
354
self .running_mean = init .constant ((num_features ,), "float32" , 0.0 ).stop_grad ()
354
355
self .running_var = init .constant ((num_features ,), "float32" , 1.0 ).stop_grad ()
355
356
356
357
def execute (self , x ):
357
358
if len (x .shape ) == 3 :
358
- if self .is_train :
359
- xmean = jt .mean (x , dims = [0 , 2 ], keepdims = 1 )
360
- x2mean = jt .mean (x * x , dims = [0 , 2 ], keepdims = 1 )
361
-
362
- if self .sync and jt .in_mpi :
363
- xmean = xmean .mpi_all_reduce ("mean" )
364
- x2mean = x2mean .mpi_all_reduce ("mean" )
365
-
366
- xvar = x2mean - xmean * xmean
367
- norm_x = (x - xmean )/ jt .sqrt (xvar + self .eps )
368
- self .running_mean .update (self .running_mean +
369
- (xmean .sum ([0 , 2 ])- self .running_mean )* self .momentum )
370
- self .running_var .update (self .running_var +
371
- (xvar .sum ([0 , 2 ])- self .running_var )* self .momentum )
372
- else :
373
- running_mean = self .running_mean .broadcast (x , [0 , 2 ])
374
- running_var = self .running_var .broadcast (x , [0 , 2 ])
375
- norm_x = (x - running_mean )/ jt .sqrt (running_var + self .eps )
376
- w = self .weight .broadcast (x , [0 , 2 ])
377
- b = self .bias .broadcast (x , [0 , 2 ])
378
- else :
379
- if self .is_train :
380
- xmean = jt .mean (x , dims = [0 ], keepdims = 1 )
381
- x2mean = jt .mean (x * x , dims = [0 ], keepdims = 1 )
382
-
383
- if self .sync and jt .in_mpi :
384
- xmean = xmean .mpi_all_reduce ("mean" )
385
- x2mean = x2mean .mpi_all_reduce ("mean" )
386
-
387
- xvar = x2mean - xmean * xmean
388
- norm_x = (x - xmean )/ jt .sqrt (xvar + self .eps )
389
- self .running_mean .update (self .running_mean +
390
- (xmean .sum ([0 ])- self .running_mean )* self .momentum )
391
- self .running_var .update (self .running_var +
392
- (xvar .sum ([0 ])- self .running_var )* self .momentum )
393
- else :
394
- running_mean = self .running_mean .broadcast (x , [0 ])
395
- running_var = self .running_var .broadcast (x , [0 ])
396
- norm_x = (x - running_mean )/ jt .sqrt (running_var + self .eps )
397
- w = self .weight .broadcast (x , [0 ])
398
- b = self .bias .broadcast (x , [0 ])
359
+ dims = [0 , 2 ]
360
+ else :
361
+ dims = [0 ]
362
+ if self .is_train :
363
+ xmean = jt .mean (x , dims = dims , keepdims = 1 )
364
+ x2mean = jt .mean (x * x , dims = dims , keepdims = 1 )
365
+
366
+ if self .sync and jt .in_mpi :
367
+ xmean = xmean .mpi_all_reduce ("mean" )
368
+ x2mean = x2mean .mpi_all_reduce ("mean" )
369
+
370
+ xvar = x2mean - xmean * xmean
371
+ norm_x = (x - xmean )/ jt .sqrt (xvar + self .eps )
372
+ self .running_mean .update (self .running_mean +
373
+ (xmean .sum (dims )- self .running_mean )* self .momentum )
374
+ self .running_var .update (self .running_var +
375
+ (xvar .sum (dims )- self .running_var )* self .momentum )
376
+ else :
377
+ running_mean = self .running_mean .broadcast (x , dims )
378
+ running_var = self .running_var .broadcast (x , dims )
379
+ norm_x = (x - running_mean )/ jt .sqrt (running_var + self .eps )
399
380
if not self .affine :
400
381
return norm_x
382
+ w = self .weight .broadcast (x , dims )
383
+ b = self .bias .broadcast (x , dims )
401
384
return norm_x * w + b
402
385
403
386
class InstanceNorm2d (Module ):
0 commit comments