xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/source_remote_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for source_remote."""
16
17import os
18import traceback
19
20import grpc
21
22from tensorflow.core.debug import debug_service_pb2
23from tensorflow.python.client import session
24from tensorflow.python.debug.lib import grpc_debug_test_server
25from tensorflow.python.debug.lib import source_remote
26from tensorflow.python.debug.lib import source_utils
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import math_ops
30# Import resource_variable_ops for the variables-to-tensor implicit conversion.
31from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import googletest
34from tensorflow.python.platform import test
35from tensorflow.python.util import tf_inspect
36
37
38def line_number_above():
39  return tf_inspect.stack()[1][2] - 1
40
41
42class SendTracebacksTest(test_util.TensorFlowTestCase):
43
44  @classmethod
45  def setUpClass(cls):
46    test_util.TensorFlowTestCase.setUpClass()
47    (cls._server_port, cls._debug_server_url, cls._server_dump_dir,
48     cls._server_thread,
49     cls._server) = grpc_debug_test_server.start_server_on_separate_thread(
50         poll_server=True)
51    cls._server_address = "localhost:%d" % cls._server_port
52    (cls._server_port_2, cls._debug_server_url_2, cls._server_dump_dir_2,
53     cls._server_thread_2,
54     cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread()
55    cls._server_address_2 = "localhost:%d" % cls._server_port_2
56    cls._curr_file_path = os.path.normpath(os.path.abspath(__file__))
57
58  @classmethod
59  def tearDownClass(cls):
60    # Stop the test server and join the thread.
61    cls._server.stop_server().wait()
62    cls._server_thread.join()
63    cls._server_2.stop_server().wait()
64    cls._server_thread_2.join()
65    test_util.TensorFlowTestCase.tearDownClass()
66
67  def tearDown(self):
68    ops.reset_default_graph()
69    self._server.clear_data()
70    self._server_2.clear_data()
71    super(SendTracebacksTest, self).tearDown()
72
73  def _findFirstTraceInsideTensorFlowPyLibrary(self, op):
74    """Find the first trace of an op that belongs to the TF Python library."""
75    for trace in op.traceback:
76      if source_utils.guess_is_tensorflow_py_library(trace.filename):
77        return trace
78
79  def testSendGraphTracebacksToSingleDebugServer(self):
80    this_func_name = "testSendGraphTracebacksToSingleDebugServer"
81    with session.Session() as sess:
82      a = variables.Variable(21.0, name="a")
83      a_lineno = line_number_above()
84      b = variables.Variable(2.0, name="b")
85      b_lineno = line_number_above()
86      math_ops.add(a, b, name="x")
87      x_lineno = line_number_above()
88
89      send_stack = traceback.extract_stack()
90      send_lineno = line_number_above()
91      source_remote.send_graph_tracebacks(
92          self._server_address, "dummy_run_key", send_stack, sess.graph)
93
94      tb = self._server.query_op_traceback("a")
95      self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
96      tb = self._server.query_op_traceback("b")
97      self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
98      tb = self._server.query_op_traceback("x")
99      self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
100
101      self.assertIn(
102          (self._curr_file_path, send_lineno, this_func_name),
103          self._server.query_origin_stack()[-1])
104
105      self.assertEqual(
106          "      a = variables.Variable(21.0, name=\"a\")",
107          self._server.query_source_file_line(__file__, a_lineno))
108      # Files in the TensorFlow code base shouldn not have been sent.
109      tf_trace = self._findFirstTraceInsideTensorFlowPyLibrary(a.op)
110      tf_trace_file_path = tf_trace.filename
111      with self.assertRaises(ValueError):
112        self._server.query_source_file_line(tf_trace_file_path, 0)
113      self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
114                       self._server.query_call_types())
115      self.assertEqual(["dummy_run_key"], self._server.query_call_keys())
116      self.assertEqual(
117          [sess.graph.version], self._server.query_graph_versions())
118
119  def testSendGraphTracebacksToTwoDebugServers(self):
120    this_func_name = "testSendGraphTracebacksToTwoDebugServers"
121    with session.Session() as sess:
122      a = variables.Variable(21.0, name="two/a")
123      a_lineno = line_number_above()
124      b = variables.Variable(2.0, name="two/b")
125      b_lineno = line_number_above()
126      x = math_ops.add(a, b, name="two/x")
127      x_lineno = line_number_above()
128
129      send_traceback = traceback.extract_stack()
130      send_lineno = line_number_above()
131
132      with test.mock.patch.object(
133          grpc, "insecure_channel",
134          wraps=grpc.insecure_channel) as mock_grpc_channel:
135        source_remote.send_graph_tracebacks(
136            [self._server_address, self._server_address_2],
137            "dummy_run_key", send_traceback, sess.graph)
138        mock_grpc_channel.assert_called_with(
139            test.mock.ANY,
140            options=[("grpc.max_receive_message_length", -1),
141                     ("grpc.max_send_message_length", -1)])
142
143      servers = [self._server, self._server_2]
144      for server in servers:
145        tb = server.query_op_traceback("two/a")
146        self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
147        tb = server.query_op_traceback("two/b")
148        self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
149        tb = server.query_op_traceback("two/x")
150        self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
151
152        self.assertIn(
153            (self._curr_file_path, send_lineno, this_func_name),
154            server.query_origin_stack()[-1])
155
156        self.assertEqual(
157            "      x = math_ops.add(a, b, name=\"two/x\")",
158            server.query_source_file_line(__file__, x_lineno))
159        tf_trace = self._findFirstTraceInsideTensorFlowPyLibrary(a.op)
160        tf_trace_file_path = tf_trace.filename
161        with self.assertRaises(ValueError):
162          server.query_source_file_line(tf_trace_file_path, 0)
163        self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
164                         server.query_call_types())
165        self.assertEqual(["dummy_run_key"], server.query_call_keys())
166        self.assertEqual([sess.graph.version], server.query_graph_versions())
167
168  def testSendEagerTracebacksToSingleDebugServer(self):
169    this_func_name = "testSendEagerTracebacksToSingleDebugServer"
170    send_traceback = traceback.extract_stack()
171    send_lineno = line_number_above()
172    source_remote.send_eager_tracebacks(self._server_address, send_traceback)
173
174    self.assertEqual([debug_service_pb2.CallTraceback.EAGER_EXECUTION],
175                     self._server.query_call_types())
176    self.assertIn((self._curr_file_path, send_lineno, this_func_name),
177                  self._server.query_origin_stack()[-1])
178
179  def testGRPCServerMessageSizeLimit(self):
180    """Assert gRPC debug server is started with unlimited message size."""
181    with test.mock.patch.object(
182        grpc, "server", wraps=grpc.server) as mock_grpc_server:
183      (_, _, _, server_thread,
184       server) = grpc_debug_test_server.start_server_on_separate_thread(
185           poll_server=True)
186      mock_grpc_server.assert_called_with(
187          test.mock.ANY,
188          options=[("grpc.max_receive_message_length", -1),
189                   ("grpc.max_send_message_length", -1)])
190    server.stop_server().wait()
191    server_thread.join()
192
193
194if __name__ == "__main__":
195  googletest.main()
196