File tree Expand file tree Collapse file tree 4 files changed +8
-39
lines changed
tensorflow_model_optimization/python
examples/sparsity/keras/mnist Expand file tree Collapse file tree 4 files changed +8
-39
lines changed Original file line number Diff line number Diff line change 1919from __future__ import print_function
2020
2121import collections
22+ import os
2223import weakref
2324
2425import tensorflow as tf
2526
2627
2728def _get_keras_instance ():
28- from pkg_resources import parse_version
29-
30- required_tensorflow_version = '2.16.0'
31- if parse_version (tf .__version__ ) < parse_version (required_tensorflow_version ):
32- return tf .keras
29+ # Keep using keras-2 (tf-keras) rather than keras-3 (keras).
30+ os .environ ['TF_USE_LEGACY_KERAS' ] = '1'
3331
32+ # Use Keras 2.
3433 version_fn = getattr (tf .keras , 'version' , None )
3534 if version_fn and version_fn ().startswith ('3.' ):
36- try :
37- import tf_keras as keras
38- except ImportError :
39- pass
40- return tf .keras
35+ import tf_keras as keras_internal # pylint: disable=g-import-not-at-top,unused-import
36+ else :
37+ keras_internal = tf .keras
38+ return keras_internal
4139
4240
4341keras = _get_keras_instance ()
Original file line number Diff line number Diff line change @@ -326,7 +326,6 @@ py_strict_test(
326326 srcs = ["quantize_models_test.py" ],
327327 flaky = True ,
328328 python_version = "PY3" ,
329- shard_count = 10 ,
330329 deps = [
331330 ":quantize" ,
332331 ":utils" ,
@@ -343,8 +342,6 @@ py_strict_test(
343342 size = "large" ,
344343 srcs = ["quantize_functional_test.py" ],
345344 python_version = "PY3" ,
346- # To match parallel runs of run_all_keras_modes.
347- shard_count = 4 ,
348345 deps = [
349346 ":quantize" ,
350347 ":utils" ,
Original file line number Diff line number Diff line change 2121
2222
2323try :
24- # OSS case.
25- import keras # pylint: disable=g-import-not-at-top
2624 if hasattr (keras , 'src' ):
2725 # Path as seen in pip packages as of TF/Keras 2.13.
2826 from keras .src .engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member
Original file line number Diff line number Diff line change @@ -11,30 +11,6 @@ filegroup(
1111 srcs = glob (["**" ]),
1212)
1313
14- py_strict_binary (
15- name = "mnist_estimator" ,
16- srcs = [
17- "dataset.py" ,
18- "mnist_estimator.py" ,
19- ],
20- python_version = "PY3" ,
21- deps = [
22- # absl/flags dep1,
23- # google/protobuf:use_fast_cpp_protos dep1, # Automatically added
24- # numpy dep1,
25- # six dep1,
26- # tensorflow dep1,
27- # tensorflow:tensorflow_compat_v1_estimator dep1,
28- "//tensorflow_model_optimization/python/core/keras:compat" ,
29- "//tensorflow_model_optimization/python/core/sparsity/keras:estimator_utils" ,
30- "//tensorflow_model_optimization/python/core/sparsity/keras:prune" ,
31- "//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule" ,
32- "//third_party/tensorflow_models/official/common:distribute_utils" ,
33- "//third_party/tensorflow_models/official/r1/utils/logs:hooks_helper" ,
34- "//third_party/tensorflow_models/official/utils" ,
35- ],
36- )
37-
3814py_strict_binary (
3915 name = "mnist_cnn" ,
4016 srcs = [
You can’t perform that action at this time.
0 commit comments