xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the reconstruction of non-debugger-decorated GraphDefs."""
16import tempfile
17
18from tensorflow.core.framework import graph_pb2
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.core.protobuf import rewriter_config_pb2
21from tensorflow.python.client import session
22from tensorflow.python.debug.lib import debug_data
23from tensorflow.python.debug.lib import debug_graphs
24from tensorflow.python.debug.lib import debug_utils
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.lib.io import file_io
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import test
33from tensorflow.python.training import gradient_descent
34
35
36class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
37
38  _OP_TYPE_DENYLIST = ("_Send", "_Recv", "_HostSend", "_HostRecv", "_Retval")
39
40  def _no_rewrite_session_config(self):
41    rewriter_config = rewriter_config_pb2.RewriterConfig(
42        dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
43        pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
44        min_graph_nodes=-1)
45    graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
46    return config_pb2.ConfigProto(graph_options=graph_options)
47
48  def setUp(self):
49    super(ReconstructNonDebugGraphTest, self).setUp()
50    self._dump_dir = tempfile.mkdtemp()
51    self._debug_url = "file://" + self._dump_dir
52    ops.reset_default_graph()
53
54  def tearDown(self):
55    file_io.delete_recursively(self._dump_dir)
56    super(ReconstructNonDebugGraphTest, self).tearDown()
57
58  def _graphDefWithoutDenylistedNodes(self, graph_def):
59    output_graph_def = graph_pb2.GraphDef()
60    for node in graph_def.node:
61      if node.op not in self._OP_TYPE_DENYLIST:
62        new_node = output_graph_def.node.add()
63        new_node.CopyFrom(node)
64
65        if new_node.op == "Enter":
66          # The debugger sets parallel_iterations attribute of while-loop Enter
67          # nodes to 1 for debugging.
68          for attr_key in new_node.attr:
69            if attr_key == "parallel_iterations":
70              new_node.attr[attr_key].i = 1
71        elif new_node.op == "Switch" or new_node.op == "Identity":
72          # We don't check the inputs to Switch or Identity ops as their inputs
73          # may be Send/Recv nodes.
74          del new_node.input[:]
75
76    return output_graph_def
77
78  def _compareOriginalAndReconstructedGraphDefs(self,
79                                                sess,
80                                                fetches,
81                                                feed_dict=None,
82                                                expected_output=None):
83    run_options = config_pb2.RunOptions(output_partition_graphs=True)
84    run_metadata = config_pb2.RunMetadata()
85    output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
86                      run_metadata=run_metadata)
87    if expected_output is not None:
88      self.assertAllClose(expected_output, output)
89    non_debug_graph_defs = run_metadata.partition_graphs
90
91    debug_utils.watch_graph(
92        run_options, sess.graph, debug_urls=self._debug_url)
93    run_metadata = config_pb2.RunMetadata()
94    output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
95                      run_metadata=run_metadata)
96    if expected_output is not None:
97      self.assertAllClose(expected_output, output)
98
99    dump = debug_data.DebugDumpDir(
100        self._dump_dir, partition_graphs=run_metadata.partition_graphs,
101        validate=True)
102    reconstructed = dump.reconstructed_non_debug_partition_graphs()
103
104    self.assertEqual(len(non_debug_graph_defs), len(reconstructed))
105    for i, non_debug_graph_def in enumerate(non_debug_graph_defs):
106      device_name = debug_graphs._infer_device_name(non_debug_graph_def)
107      test_util.assert_equal_graph_def(
108          self._graphDefWithoutDenylistedNodes(reconstructed[device_name]),
109          self._graphDefWithoutDenylistedNodes(non_debug_graph_def))
110
111      # Test debug_graphs.reconstruct_non_debug_graph_def.
112      reconstructed_again = (
113          debug_graphs.reconstruct_non_debug_graph_def(
114              run_metadata.partition_graphs[i]))
115      test_util.assert_equal_graph_def(
116          self._graphDefWithoutDenylistedNodes(reconstructed_again),
117          self._graphDefWithoutDenylistedNodes(non_debug_graph_def))
118
119  def testReconstructSimpleGraph(self):
120    with session.Session() as sess:
121      u = variables.Variable([12.0], name="u")
122      v = variables.Variable([30.0], name="v")
123      w = math_ops.add(u, v, name="w")
124      self.evaluate(u.initializer)
125      self.evaluate(v.initializer)
126
127      self._compareOriginalAndReconstructedGraphDefs(
128          sess, w, expected_output=[42.0])
129
130  def testReconstructGraphWithControlEdge(self):
131    with session.Session() as sess:
132      a = variables.Variable(10.0, name="a")
133      with ops.control_dependencies([a]):
134        b = math_ops.add(a, a, name="b")
135      with ops.control_dependencies([a, b]):
136        c = math_ops.multiply(b, b, name="c")
137      self.evaluate(a.initializer)
138
139      self._compareOriginalAndReconstructedGraphDefs(
140          sess, c, expected_output=400.0)
141
142  def testReconstructGraphWithCond(self):
143    with session.Session(config=self._no_rewrite_session_config()) as sess:
144      x = variables.Variable(10.0, name="x")
145      y = variables.Variable(20.0, name="y")
146      cond = control_flow_ops.cond(
147          x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
148      self.evaluate(x.initializer)
149      self.evaluate(y.initializer)
150
151      self._compareOriginalAndReconstructedGraphDefs(
152          sess, cond, expected_output=21.0)
153
154  def testReconstructGraphWithWhileLoop(self):
155    with session.Session(config=self._no_rewrite_session_config()) as sess:
156      loop_body = lambda i: math_ops.add(i, 2)
157      loop_cond = lambda i: math_ops.less(i, 16)
158      i = constant_op.constant(10, name="i")
159      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
160
161      self._compareOriginalAndReconstructedGraphDefs(sess, loop)
162
163  def testReconstructGraphWithGradients(self):
164    with session.Session(config=self._no_rewrite_session_config()) as sess:
165      u = variables.Variable(12.0, name="u")
166      v = variables.Variable(30.0, name="v")
167      x = constant_op.constant(1.1, name="x")
168      toy_loss = x * (u - v)
169      train_op = gradient_descent.GradientDescentOptimizer(
170          learning_rate=0.1).minimize(toy_loss, name="train_op")
171      self.evaluate(u.initializer)
172      self.evaluate(v.initializer)
173
174      self._compareOriginalAndReconstructedGraphDefs(sess, train_op)
175
176
177if __name__ == "__main__":
178  test.main()
179