1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the reconstruction of non-debugger-decorated GraphDefs.""" 16import tempfile 17 18from tensorflow.core.framework import graph_pb2 19from tensorflow.core.protobuf import config_pb2 20from tensorflow.core.protobuf import rewriter_config_pb2 21from tensorflow.python.client import session 22from tensorflow.python.debug.lib import debug_data 23from tensorflow.python.debug.lib import debug_graphs 24from tensorflow.python.debug.lib import debug_utils 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.lib.io import file_io 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33from tensorflow.python.training import gradient_descent 34 35 36class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase): 37 38 _OP_TYPE_DENYLIST = ("_Send", "_Recv", "_HostSend", "_HostRecv", "_Retval") 39 40 def _no_rewrite_session_config(self): 41 rewriter_config = rewriter_config_pb2.RewriterConfig( 42 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 43 pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF, 44 min_graph_nodes=-1) 45 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 46 return config_pb2.ConfigProto(graph_options=graph_options) 47 48 def setUp(self): 49 super(ReconstructNonDebugGraphTest, self).setUp() 50 self._dump_dir = tempfile.mkdtemp() 51 self._debug_url = "file://" + self._dump_dir 52 ops.reset_default_graph() 53 54 def tearDown(self): 55 file_io.delete_recursively(self._dump_dir) 56 super(ReconstructNonDebugGraphTest, self).tearDown() 57 58 def _graphDefWithoutDenylistedNodes(self, graph_def): 59 output_graph_def = graph_pb2.GraphDef() 60 for node in graph_def.node: 61 if node.op not in self._OP_TYPE_DENYLIST: 62 new_node = output_graph_def.node.add() 63 new_node.CopyFrom(node) 64 65 if new_node.op == "Enter": 66 # The debugger sets parallel_iterations attribute of while-loop Enter 67 # nodes to 1 for debugging. 68 for attr_key in new_node.attr: 69 if attr_key == "parallel_iterations": 70 new_node.attr[attr_key].i = 1 71 elif new_node.op == "Switch" or new_node.op == "Identity": 72 # We don't check the inputs to Switch or Identity ops as their inputs 73 # may be Send/Recv nodes. 74 del new_node.input[:] 75 76 return output_graph_def 77 78 def _compareOriginalAndReconstructedGraphDefs(self, 79 sess, 80 fetches, 81 feed_dict=None, 82 expected_output=None): 83 run_options = config_pb2.RunOptions(output_partition_graphs=True) 84 run_metadata = config_pb2.RunMetadata() 85 output = sess.run(fetches, feed_dict=feed_dict, options=run_options, 86 run_metadata=run_metadata) 87 if expected_output is not None: 88 self.assertAllClose(expected_output, output) 89 non_debug_graph_defs = run_metadata.partition_graphs 90 91 debug_utils.watch_graph( 92 run_options, sess.graph, debug_urls=self._debug_url) 93 run_metadata = config_pb2.RunMetadata() 94 output = sess.run(fetches, feed_dict=feed_dict, options=run_options, 95 run_metadata=run_metadata) 96 if expected_output is not None: 97 self.assertAllClose(expected_output, output) 98 99 dump = debug_data.DebugDumpDir( 100 self._dump_dir, partition_graphs=run_metadata.partition_graphs, 101 validate=True) 102 reconstructed = dump.reconstructed_non_debug_partition_graphs() 103 104 self.assertEqual(len(non_debug_graph_defs), len(reconstructed)) 105 for i, non_debug_graph_def in enumerate(non_debug_graph_defs): 106 device_name = debug_graphs._infer_device_name(non_debug_graph_def) 107 test_util.assert_equal_graph_def( 108 self._graphDefWithoutDenylistedNodes(reconstructed[device_name]), 109 self._graphDefWithoutDenylistedNodes(non_debug_graph_def)) 110 111 # Test debug_graphs.reconstruct_non_debug_graph_def. 112 reconstructed_again = ( 113 debug_graphs.reconstruct_non_debug_graph_def( 114 run_metadata.partition_graphs[i])) 115 test_util.assert_equal_graph_def( 116 self._graphDefWithoutDenylistedNodes(reconstructed_again), 117 self._graphDefWithoutDenylistedNodes(non_debug_graph_def)) 118 119 def testReconstructSimpleGraph(self): 120 with session.Session() as sess: 121 u = variables.Variable([12.0], name="u") 122 v = variables.Variable([30.0], name="v") 123 w = math_ops.add(u, v, name="w") 124 self.evaluate(u.initializer) 125 self.evaluate(v.initializer) 126 127 self._compareOriginalAndReconstructedGraphDefs( 128 sess, w, expected_output=[42.0]) 129 130 def testReconstructGraphWithControlEdge(self): 131 with session.Session() as sess: 132 a = variables.Variable(10.0, name="a") 133 with ops.control_dependencies([a]): 134 b = math_ops.add(a, a, name="b") 135 with ops.control_dependencies([a, b]): 136 c = math_ops.multiply(b, b, name="c") 137 self.evaluate(a.initializer) 138 139 self._compareOriginalAndReconstructedGraphDefs( 140 sess, c, expected_output=400.0) 141 142 def testReconstructGraphWithCond(self): 143 with session.Session(config=self._no_rewrite_session_config()) as sess: 144 x = variables.Variable(10.0, name="x") 145 y = variables.Variable(20.0, name="y") 146 cond = control_flow_ops.cond( 147 x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1)) 148 self.evaluate(x.initializer) 149 self.evaluate(y.initializer) 150 151 self._compareOriginalAndReconstructedGraphDefs( 152 sess, cond, expected_output=21.0) 153 154 def testReconstructGraphWithWhileLoop(self): 155 with session.Session(config=self._no_rewrite_session_config()) as sess: 156 loop_body = lambda i: math_ops.add(i, 2) 157 loop_cond = lambda i: math_ops.less(i, 16) 158 i = constant_op.constant(10, name="i") 159 loop = control_flow_ops.while_loop(loop_cond, loop_body, [i]) 160 161 self._compareOriginalAndReconstructedGraphDefs(sess, loop) 162 163 def testReconstructGraphWithGradients(self): 164 with session.Session(config=self._no_rewrite_session_config()) as sess: 165 u = variables.Variable(12.0, name="u") 166 v = variables.Variable(30.0, name="v") 167 x = constant_op.constant(1.1, name="x") 168 toy_loss = x * (u - v) 169 train_op = gradient_descent.GradientDescentOptimizer( 170 learning_rate=0.1).minimize(toy_loss, name="train_op") 171 self.evaluate(u.initializer) 172 self.evaluate(v.initializer) 173 174 self._compareOriginalAndReconstructedGraphDefs(sess, train_op) 175 176 177if __name__ == "__main__": 178 test.main() 179