1414# ==============================================================================
1515"""Pruning Policy tests."""
1616
17+ import distutils .version as version
18+
1719import tensorflow as tf
1820
1921from tensorflow_model_optimization .python .core .sparsity .keras import prune
2527layers = keras .layers
2628
2729
30+ class CompatGlobalAveragePooling2D (layers .GlobalAveragePooling2D ):
31+ """GlobalAveragePooling2D in tf <= 2.5.0 doesn't support keepdims."""
32+
33+ def __init__ (self , * args , keepdims = False , ** kwargs ):
34+ self ._compat = False
35+ if version .LooseVersion (tf .__version__ ) > version .LooseVersion ('2.5.0' ):
36+ super (CompatGlobalAveragePooling2D , self ).__init__ (
37+ * args , keepdims = keepdims , ** kwargs )
38+ else :
39+ super (CompatGlobalAveragePooling2D , self ).__init__ (* args , ** kwargs )
40+ self ._compat = True
41+ self .keepdims = keepdims
42+
43+ def call (self , inputs ):
44+ if not self ._compat :
45+ return super (CompatGlobalAveragePooling2D , self ).call (inputs )
46+
47+ if self .data_format == 'channels_last' :
48+ return keras .backend .mean (inputs , axis = [1 , 2 ], keepdims = self .keepdims )
49+ else :
50+ return keras .backend .mean (inputs , axis = [2 , 3 ], keepdims = self .keepdims )
51+
52+
2853class PruningPolicyTest (tf .test .TestCase ):
2954 INVALID_TO_PRUNE_START_LAYER_ERROR = (
3055 'Could not find `Conv2D 3x3` layer with stride 2x2, `input filters == 3`'
@@ -69,7 +94,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyNoStartLayer(self):
6994 padding = 'same' ,
7095 )(i )
7196 x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
72- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
97+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
7398 model = keras .Model (inputs = [i ], outputs = [o ])
7499 with self .assertRaises (ValueError ) as e :
75100 _ = prune .prune_low_magnitude (
@@ -89,7 +114,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyNoStopLayer(self):
89114 padding = 'valid' ,
90115 )(x )
91116 x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
92- o = layers . GlobalAveragePooling2D ()(x )
117+ o = CompatGlobalAveragePooling2D ()(x )
93118 model = keras .Model (inputs = [i ], outputs = [o ])
94119 with self .assertRaises (ValueError ) as e :
95120 _ = prune .prune_low_magnitude (
@@ -110,7 +135,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyMiddleLayer(self):
110135 )(x )
111136 x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
112137 x = layers .MaxPooling2D (pool_size = (2 , 2 ), strides = (2 , 2 ))(x )
113- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
138+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
114139 model = keras .Model (inputs = [i ], outputs = [o ])
115140 with self .assertRaises (ValueError ) as e :
116141 _ = prune .prune_low_magnitude (
@@ -135,7 +160,7 @@ def testPruneSequentialModelForLatencyOnXNNPackPolicy(self):
135160 ),
136161 layers .DepthwiseConv2D (kernel_size = (3 , 3 ), padding = 'same' ),
137162 layers .Conv2D (filters = 8 , kernel_size = [1 , 1 ]),
138- layers . GlobalAveragePooling2D (keepdims = True ),
163+ CompatGlobalAveragePooling2D (keepdims = True ),
139164 ])
140165 with self .assertRaises (ValueError ) as e :
141166 _ = prune .prune_low_magnitude (
@@ -156,7 +181,7 @@ def testPruneSequentialModelForLatencyOnXNNPackPolicy(self):
156181 ),
157182 layers .DepthwiseConv2D (kernel_size = (3 , 3 ), padding = 'same' ),
158183 layers .Conv2D (filters = 8 , kernel_size = [1 , 1 ]),
159- layers . GlobalAveragePooling2D (keepdims = True ),
184+ CompatGlobalAveragePooling2D (keepdims = True ),
160185 ])
161186 pruned_model = prune .prune_low_magnitude (
162187 model ,
@@ -178,7 +203,7 @@ def testPruneModelRecursivelyForLatencyOnXNNPackPolicy(self):
178203 layers .Conv2D (filters = 8 , kernel_size = [1 , 1 ]),
179204 layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ]),
180205 ]),
181- layers . GlobalAveragePooling2D (keepdims = True ),
206+ CompatGlobalAveragePooling2D (keepdims = True ),
182207 ])
183208 pruned_model = prune .prune_low_magnitude (
184209 original_model ,
@@ -199,7 +224,7 @@ def testPruneFunctionalModelWithLayerReusedForLatencyOnXNNPackPolicy(self):
199224 conv_layer = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])
200225 x = conv_layer (x )
201226 x = conv_layer (x )
202- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
227+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
203228 model = keras .Model (inputs = [i ], outputs = [o ])
204229 pruned_model = prune .prune_low_magnitude (
205230 model ,
@@ -219,7 +244,7 @@ def testFunctionalModelNoPruningLayersForLatencyOnXNNPackPolicy(self):
219244 )(x )
220245 x = layers .DepthwiseConv2D (kernel_size = (3 , 3 ), padding = 'same' )(x )
221246 x = layers .Activation ('relu' )(x )
222- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
247+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
223248 model = keras .Model (inputs = [i ], outputs = [o ])
224249
225250 pruned_model = prune .prune_low_magnitude (
@@ -256,7 +281,7 @@ def testFunctionalModelForLatencyOnXNNPackPolicy(self):
256281 strides = (2 , 2 ),
257282 padding = 'valid' ,
258283 )(x2 )
259- x2_1 = layers . GlobalAveragePooling2D (keepdims = True )(x2 )
284+ x2_1 = CompatGlobalAveragePooling2D (keepdims = True )(x2 )
260285 x2_1 = layers .Conv2D (filters = 32 , kernel_size = [1 , 1 ])(x2_1 )
261286 x2_1 = layers .Activation ('sigmoid' )(x2_1 )
262287 x2_2 = layers .Conv2D (filters = 32 , kernel_size = [1 , 1 ])(x2 )
@@ -265,7 +290,7 @@ def testFunctionalModelForLatencyOnXNNPackPolicy(self):
265290
266291 x2 = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x2 )
267292 x = layers .Add ()([x1 , x2 ])
268- x = layers . GlobalAveragePooling2D (keepdims = True )(x )
293+ x = CompatGlobalAveragePooling2D (keepdims = True )(x )
269294
270295 o1 = layers .Conv2D (filters = 7 , kernel_size = [1 , 1 ])(x )
271296 o2 = layers .Conv2D (filters = 3 , kernel_size = [1 , 1 ])(x )
@@ -289,7 +314,7 @@ def testPruneFunctionalModelAfterCloneForLatencyOnXNNPackPolicy(self):
289314 )(
290315 x )
291316 x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
292- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
317+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
293318 original_model = keras .Model (inputs = [i ], outputs = [o ])
294319
295320 cloned_model = tf .keras .models .clone_model (
@@ -315,7 +340,7 @@ def testFunctionalModelWithTFOpsForLatencyOnXNNPackPolicy(self):
315340 x = x - residual
316341 x = x * residual
317342 x = tf .identity (x )
318- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
343+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
319344 model = keras .Model (inputs = [i ], outputs = [o ])
320345
321346 pruned_model = prune .prune_low_magnitude (
0 commit comments