@@ -291,14 +291,18 @@ def forward(self, x):
291291 return x
292292
293293
294+ def is_stem_deep (stem_type ):
295+ return any ([s in stem_type for s in ('deep' , 'tiered' )])
296+
297+
294298def create_resnetv2_stem (
295299 in_chs , out_chs = 64 , stem_type = '' , preact = True ,
296300 conv_layer = StdConv2d , norm_layer = partial (GroupNormAct , num_groups = 32 )):
297301 stem = OrderedDict ()
298302 assert stem_type in ('' , 'fixed' , 'same' , 'deep' , 'deep_fixed' , 'deep_same' , 'tiered' )
299303
300304 # NOTE conv padding mode can be changed by overriding the conv_layer def
301- if any ([ s in stem_type for s in ( 'deep' , 'tiered' )] ):
305+ if is_stem_deep ( stem_type ):
302306 # A 3 deep 3x3 conv stack as in ResNet V1D models
303307 if 'tiered' in stem_type :
304308 stem_chs = (3 * out_chs // 8 , out_chs // 2 ) # 'T' resnets in resnet.py
@@ -350,7 +354,7 @@ def __init__(
350354 stem_chs = make_div (stem_chs * wf )
351355 self .stem = create_resnetv2_stem (
352356 in_chans , stem_chs , stem_type , preact , conv_layer = conv_layer , norm_layer = norm_layer )
353- stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv' ) if preact else 'stem.norm'
357+ stem_feat = ('stem.conv3' if is_stem_deep ( stem_type ) else 'stem.conv' ) if preact else 'stem.norm'
354358 self .feature_info .append (dict (num_chs = stem_chs , reduction = 2 , module = stem_feat ))
355359
356360 prev_chs = stem_chs
0 commit comments