Skip to content

Commit c7f75eb

Browse files
Add some Tensorflow graph traversal utility functions.
PiperOrigin-RevId: 517108819
1 parent 043e8b5 commit c7f75eb

File tree

4 files changed

+164
-2
lines changed

4 files changed

+164
-2
lines changed

tensorflow_privacy/privacy/fast_gradient_clipping/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test")
22

33
package(default_visibility = ["//visibility:public"])
44

5+
py_library(
6+
name = "tensorflow_graph_utils",
7+
srcs = ["tensorflow_graph_utils.py"],
8+
srcs_version = "PY3",
9+
)
10+
11+
py_test(
12+
name = "tensorflow_graph_utils_test",
13+
srcs = ["tensorflow_graph_utils_test.py"],
14+
python_version = "PY3",
15+
srcs_version = "PY3",
16+
deps = [":tensorflow_graph_utils"],
17+
)
18+
519
py_library(
620
name = "gradient_clipping_utils",
721
srcs = ["gradient_clipping_utils.py"],

tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
2424

25-
GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
25+
GeneratorFunction = Callable[[Any, Tuple, Dict], Tuple[Any, Any]]
2626

2727

2828
def has_internal_compute_graph(input_object: Any):
@@ -52,7 +52,7 @@ def _get_internal_layers(
5252
def model_forward_pass(
5353
input_model: tf.keras.Model,
5454
inputs: InputTensor,
55-
generator_fn: GeneratorFunction = None,
55+
generator_fn: Optional[GeneratorFunction] = None,
5656
) -> Tuple[tf.Tensor, List[Any]]:
5757
"""Does a forward pass of a model and returns useful intermediates.
5858
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2022, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utility functions that help in traversing Tensorflow graphs."""
15+
16+
from typing import Any, Callable, Dict, Iterable, Optional, Set, Text, Union
17+
18+
import tensorflow as tf
19+
20+
PackedTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
21+
22+
LayerFunction = Callable[[tf.keras.layers.Layer], None]
23+
24+
25+
def depth_first_backward_pass(
26+
outputs: PackedTensor, layer_function: Optional[LayerFunction] = None
27+
):
28+
"""Performs a depth-first traversal on a given set of model outputs.
29+
30+
This function is simplified version of
31+
`tf.keras.engine.functional._build_map()` that allows additional side-effects
32+
performed by an (optional) layer function.
33+
34+
Args:
35+
outputs: A `PackedTensor` that should be generated by calling a
36+
`tf.keras.Model` on a set of non-eager inputs.
37+
layer_function: A callable that consumes a `tf.keras.layers.Layer`. This
38+
callable is applied to every layer in the DAG that generates `outputs`.
39+
"""
40+
41+
# Helper function that performs the traversal.
42+
def graph_crawler(
43+
tensor: tf.Tensor, finished_nodes: Set[Any], nodes_in_progress: Set[Any]
44+
):
45+
layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access
46+
node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
47+
# Avoid duplicating work on shared subgraphs.
48+
if node in finished_nodes:
49+
return
50+
# Check if we encountered a cycle.
51+
if node in nodes_in_progress:
52+
raise ValueError(
53+
f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.'
54+
)
55+
# Apply side-effects and go to the next node (pre-order traversal).
56+
if layer_function is not None:
57+
layer_function(layer)
58+
nodes_in_progress.add(node)
59+
if not node.is_input:
60+
for tensor in node.keras_inputs:
61+
graph_crawler(tensor, finished_nodes, nodes_in_progress)
62+
finished_nodes.add(node)
63+
nodes_in_progress.remove(node)
64+
65+
# Traverse over the outputs.
66+
finished_nodes = set()
67+
nodes_in_progress = set()
68+
for output in tf.nest.flatten(outputs):
69+
graph_crawler(output, finished_nodes, nodes_in_progress)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2023, The TensorFlow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import parameterized
16+
import tensorflow as tf
17+
18+
from tensorflow_privacy.privacy.fast_gradient_clipping import tensorflow_graph_utils
19+
20+
21+
# ==============================================================================
22+
# Main tests.
23+
# ==============================================================================
24+
class DepthFirstBackwardPassTest(tf.test.TestCase, parameterized.TestCase):
25+
26+
@parameterized.product(
27+
input_packing_type=[None, tuple, list, dict],
28+
output_packing_type=[None, tuple, list, dict],
29+
)
30+
def test_layer_function(self, input_packing_type, output_packing_type):
31+
num_dims = 3
32+
num_inputs = 1 if input_packing_type is None else 2
33+
num_outputs = 1 if output_packing_type is None else 2
34+
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
35+
temp_sum = tf.stack(sample_inputs, axis=0)
36+
sample_sum = [
37+
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
38+
]
39+
sample_outputs = [tf.keras.layers.Dense(3)(t) for t in sample_sum]
40+
41+
# Pack inputs.
42+
if input_packing_type is None:
43+
inputs = sample_inputs[0]
44+
elif input_packing_type is not dict:
45+
inputs = input_packing_type(sample_inputs)
46+
else:
47+
inputs = {}
48+
keys = [str(i) for i in range(len(sample_inputs))]
49+
for k, v in zip(keys, sample_inputs):
50+
inputs[k] = v
51+
52+
# Pack outputs.
53+
if output_packing_type is None:
54+
outputs = sample_outputs[0]
55+
elif output_packing_type is not dict:
56+
outputs = output_packing_type(sample_outputs)
57+
else:
58+
outputs = {}
59+
keys = [str(i) for i in range(len(sample_outputs))]
60+
for k, v in zip(keys, sample_outputs):
61+
outputs[k] = v
62+
63+
# Append the trainable layers into a list.
64+
layer_list = []
65+
66+
def layer_function(layer):
67+
if layer.trainable_variables:
68+
layer_list.append(layer)
69+
70+
# Run the traversal and verify the outputs that are relevant to
71+
# the above layer function.
72+
tensorflow_graph_utils.depth_first_backward_pass(outputs, layer_function)
73+
self.assertLen(layer_list, num_outputs)
74+
for l in layer_list:
75+
self.assertIsInstance(l, tf.keras.layers.Dense)
76+
77+
78+
if __name__ == '__main__':
79+
tf.test.main()

0 commit comments

Comments
 (0)