1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for source_remote.""" 16 17import os 18import traceback 19 20import grpc 21 22from tensorflow.core.debug import debug_service_pb2 23from tensorflow.python.client import session 24from tensorflow.python.debug.lib import grpc_debug_test_server 25from tensorflow.python.debug.lib import source_remote 26from tensorflow.python.debug.lib import source_utils 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import math_ops 30# Import resource_variable_ops for the variables-to-tensor implicit conversion. 31from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import googletest 34from tensorflow.python.platform import test 35from tensorflow.python.util import tf_inspect 36 37 38def line_number_above(): 39 return tf_inspect.stack()[1][2] - 1 40 41 42class SendTracebacksTest(test_util.TensorFlowTestCase): 43 44 @classmethod 45 def setUpClass(cls): 46 test_util.TensorFlowTestCase.setUpClass() 47 (cls._server_port, cls._debug_server_url, cls._server_dump_dir, 48 cls._server_thread, 49 cls._server) = grpc_debug_test_server.start_server_on_separate_thread( 50 poll_server=True) 51 cls._server_address = "localhost:%d" % cls._server_port 52 (cls._server_port_2, cls._debug_server_url_2, cls._server_dump_dir_2, 53 cls._server_thread_2, 54 cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread() 55 cls._server_address_2 = "localhost:%d" % cls._server_port_2 56 cls._curr_file_path = os.path.normpath(os.path.abspath(__file__)) 57 58 @classmethod 59 def tearDownClass(cls): 60 # Stop the test server and join the thread. 61 cls._server.stop_server().wait() 62 cls._server_thread.join() 63 cls._server_2.stop_server().wait() 64 cls._server_thread_2.join() 65 test_util.TensorFlowTestCase.tearDownClass() 66 67 def tearDown(self): 68 ops.reset_default_graph() 69 self._server.clear_data() 70 self._server_2.clear_data() 71 super(SendTracebacksTest, self).tearDown() 72 73 def _findFirstTraceInsideTensorFlowPyLibrary(self, op): 74 """Find the first trace of an op that belongs to the TF Python library.""" 75 for trace in op.traceback: 76 if source_utils.guess_is_tensorflow_py_library(trace.filename): 77 return trace 78 79 def testSendGraphTracebacksToSingleDebugServer(self): 80 this_func_name = "testSendGraphTracebacksToSingleDebugServer" 81 with session.Session() as sess: 82 a = variables.Variable(21.0, name="a") 83 a_lineno = line_number_above() 84 b = variables.Variable(2.0, name="b") 85 b_lineno = line_number_above() 86 math_ops.add(a, b, name="x") 87 x_lineno = line_number_above() 88 89 send_stack = traceback.extract_stack() 90 send_lineno = line_number_above() 91 source_remote.send_graph_tracebacks( 92 self._server_address, "dummy_run_key", send_stack, sess.graph) 93 94 tb = self._server.query_op_traceback("a") 95 self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) 96 tb = self._server.query_op_traceback("b") 97 self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) 98 tb = self._server.query_op_traceback("x") 99 self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) 100 101 self.assertIn( 102 (self._curr_file_path, send_lineno, this_func_name), 103 self._server.query_origin_stack()[-1]) 104 105 self.assertEqual( 106 " a = variables.Variable(21.0, name=\"a\")", 107 self._server.query_source_file_line(__file__, a_lineno)) 108 # Files in the TensorFlow code base shouldn not have been sent. 109 tf_trace = self._findFirstTraceInsideTensorFlowPyLibrary(a.op) 110 tf_trace_file_path = tf_trace.filename 111 with self.assertRaises(ValueError): 112 self._server.query_source_file_line(tf_trace_file_path, 0) 113 self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION], 114 self._server.query_call_types()) 115 self.assertEqual(["dummy_run_key"], self._server.query_call_keys()) 116 self.assertEqual( 117 [sess.graph.version], self._server.query_graph_versions()) 118 119 def testSendGraphTracebacksToTwoDebugServers(self): 120 this_func_name = "testSendGraphTracebacksToTwoDebugServers" 121 with session.Session() as sess: 122 a = variables.Variable(21.0, name="two/a") 123 a_lineno = line_number_above() 124 b = variables.Variable(2.0, name="two/b") 125 b_lineno = line_number_above() 126 x = math_ops.add(a, b, name="two/x") 127 x_lineno = line_number_above() 128 129 send_traceback = traceback.extract_stack() 130 send_lineno = line_number_above() 131 132 with test.mock.patch.object( 133 grpc, "insecure_channel", 134 wraps=grpc.insecure_channel) as mock_grpc_channel: 135 source_remote.send_graph_tracebacks( 136 [self._server_address, self._server_address_2], 137 "dummy_run_key", send_traceback, sess.graph) 138 mock_grpc_channel.assert_called_with( 139 test.mock.ANY, 140 options=[("grpc.max_receive_message_length", -1), 141 ("grpc.max_send_message_length", -1)]) 142 143 servers = [self._server, self._server_2] 144 for server in servers: 145 tb = server.query_op_traceback("two/a") 146 self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) 147 tb = server.query_op_traceback("two/b") 148 self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) 149 tb = server.query_op_traceback("two/x") 150 self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) 151 152 self.assertIn( 153 (self._curr_file_path, send_lineno, this_func_name), 154 server.query_origin_stack()[-1]) 155 156 self.assertEqual( 157 " x = math_ops.add(a, b, name=\"two/x\")", 158 server.query_source_file_line(__file__, x_lineno)) 159 tf_trace = self._findFirstTraceInsideTensorFlowPyLibrary(a.op) 160 tf_trace_file_path = tf_trace.filename 161 with self.assertRaises(ValueError): 162 server.query_source_file_line(tf_trace_file_path, 0) 163 self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION], 164 server.query_call_types()) 165 self.assertEqual(["dummy_run_key"], server.query_call_keys()) 166 self.assertEqual([sess.graph.version], server.query_graph_versions()) 167 168 def testSendEagerTracebacksToSingleDebugServer(self): 169 this_func_name = "testSendEagerTracebacksToSingleDebugServer" 170 send_traceback = traceback.extract_stack() 171 send_lineno = line_number_above() 172 source_remote.send_eager_tracebacks(self._server_address, send_traceback) 173 174 self.assertEqual([debug_service_pb2.CallTraceback.EAGER_EXECUTION], 175 self._server.query_call_types()) 176 self.assertIn((self._curr_file_path, send_lineno, this_func_name), 177 self._server.query_origin_stack()[-1]) 178 179 def testGRPCServerMessageSizeLimit(self): 180 """Assert gRPC debug server is started with unlimited message size.""" 181 with test.mock.patch.object( 182 grpc, "server", wraps=grpc.server) as mock_grpc_server: 183 (_, _, _, server_thread, 184 server) = grpc_debug_test_server.start_server_on_separate_thread( 185 poll_server=True) 186 mock_grpc_server.assert_called_with( 187 test.mock.ANY, 188 options=[("grpc.max_receive_message_length", -1), 189 ("grpc.max_send_message_length", -1)]) 190 server.stop_server().wait() 191 server_thread.join() 192 193 194if __name__ == "__main__": 195 googletest.main() 196