xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/source_remote.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Communicating tracebacks and source code with debug server."""
16
17import socket
18
19import grpc
20
21from tensorflow.core.debug import debug_service_pb2
22from tensorflow.core.protobuf import debug_pb2
23from tensorflow.python.debug.lib import common
24from tensorflow.python.debug.lib import debug_service_pb2_grpc
25from tensorflow.python.debug.lib import source_utils
26from tensorflow.python.platform import gfile
27from tensorflow.python.profiler import tfprof_logger
28
29
30def _load_debugged_source_file(file_path, source_file_proto):
31  file_stat = gfile.Stat(file_path)
32  source_file_proto.host = socket.gethostname()
33  source_file_proto.file_path = file_path
34  source_file_proto.last_modified = file_stat.mtime_nsec
35  source_file_proto.bytes = file_stat.length
36  try:
37    with gfile.Open(file_path, "r") as f:
38      source_file_proto.lines.extend(f.read().splitlines())
39  except IOError:
40    pass
41
42
43def _string_to_id(string, string_to_id):
44  if string not in string_to_id:
45    string_to_id[string] = len(string_to_id)
46  return string_to_id[string]
47
48
49def _format_origin_stack(origin_stack, call_traceback_proto):
50  """Format a traceback stack for a `CallTraceback` proto.
51
52  Args:
53    origin_stack: The stack list as returned by `traceback.extract_stack()`.
54    call_traceback_proto: A `CallTraceback` proto whose fields are to be
55      populated.
56  """
57  string_to_id = {}
58  string_to_id[None] = 0
59  for frame in origin_stack:
60    file_path, lineno, func_name, line_text = frame
61    call_traceback_proto.origin_stack.traces.add(
62        file_id=_string_to_id(file_path, string_to_id),
63        lineno=lineno,
64        function_id=_string_to_id(func_name, string_to_id),
65        line_id=_string_to_id(line_text, string_to_id))
66
67  id_to_string = call_traceback_proto.origin_id_to_string
68  for key, value in string_to_id.items():
69    id_to_string[value] = key if key is not None else ""
70
71
72def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
73  """Extract source file paths outside TensorFlow Python library.
74
75  Args:
76    code_defs: An iterable of `CodeDef` protos, i.e., an iterable of stack
77      traces.
78    id_to_string: A proto map from integer ids to strings.
79
80  Returns:
81    An iterable of source file paths outside the TensorFlow Python library.
82  """
83  file_ids = set()
84  for code_def in code_defs:
85    for trace in code_def.traces:
86      file_ids.add(trace.file_id)
87  non_tf_files = (id_to_string[file_id] for file_id in file_ids)
88  non_tf_files = (
89      f for f in non_tf_files
90      if not source_utils.guess_is_tensorflow_py_library(f) and gfile.Exists(f))
91  return non_tf_files
92
93
94def _send_call_tracebacks(destinations,
95                          origin_stack,
96                          is_eager_execution=False,
97                          call_key=None,
98                          graph=None,
99                          send_source=True):
100  """Send the tracebacks of a TensorFlow execution call.
101
102  To gRPC debug server(s). This applies to graph execution (`tf.Session.run()`)
103  calls and eager execution calls.
104
105  If `send_source`, also sends the underlying source files outside the
106  TensorFlow library.
107
108  Args:
109    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
110      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
111      `CallTraceback` proto payload will be sent to all the destinations.
112    origin_stack: The traceback stack for the origin of the execution call. For
113      graph execution, this is the traceback of the `tf.Session.run()`
114      invocation. For eager execution, this is the traceback of the Python
115      line that executes the eager operation.
116    is_eager_execution: (`bool`) whether an eager execution call (i.e., not a
117      `tf.Session.run` or derived methods) is being sent.
118    call_key: The key of the execution call, as a string. For graph execution,
119      this is a string describing the feeds, fetches (and targets) names of the
120      `tf.Session.run` call. For eager execution, this is ignored.
121    graph: A Python `tf.Graph` object (i.e., *not* a `tf.compat.v1.GraphDef`),
122      which contains op tracebacks, if applicable.
123    send_source: Whether the source files involved in the op tracebacks but
124      outside the TensorFlow library are to be sent.
125  """
126  if not isinstance(destinations, list):
127    destinations = [destinations]
128  # Strip grpc:// prefix, if any is present.
129  destinations = [
130      dest[len(common.GRPC_URL_PREFIX):]
131      if dest.startswith(common.GRPC_URL_PREFIX) else dest
132      for dest in destinations]
133
134  call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
135               if is_eager_execution
136               else debug_service_pb2.CallTraceback.GRAPH_EXECUTION)
137  graph_traceback = tfprof_logger.merge_default_with_oplog(
138      graph, add_trainable_var=False) if graph else None
139  call_traceback = debug_service_pb2.CallTraceback(
140      call_type=call_type, call_key=call_key, graph_traceback=graph_traceback,
141      graph_version=graph.version if graph else None)
142
143  _format_origin_stack(origin_stack, call_traceback)
144
145  if send_source:
146    source_file_paths = set()
147    source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
148        (log_entry.code_def for log_entry
149         in call_traceback.graph_traceback.log_entries),
150        call_traceback.graph_traceback.id_to_string))
151    source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
152        [call_traceback.origin_stack], call_traceback.origin_id_to_string))
153
154    debugged_source_files = []
155    for file_path in source_file_paths:
156      source_files = debug_pb2.DebuggedSourceFiles()
157      _load_debugged_source_file(
158          file_path, source_files.source_files.add())
159      debugged_source_files.append(source_files)
160
161  for destination in destinations:
162    no_max_message_sizes = [("grpc.max_receive_message_length", -1),
163                            ("grpc.max_send_message_length", -1)]
164    channel = grpc.insecure_channel(destination, options=no_max_message_sizes)
165    stub = debug_service_pb2_grpc.EventListenerStub(channel)
166    stub.SendTracebacks(call_traceback)
167    if send_source:
168      for source_files in debugged_source_files:
169        stub.SendSourceFiles(source_files)
170
171
172def send_graph_tracebacks(destinations,
173                          run_key,
174                          origin_stack,
175                          graph,
176                          send_source=True):
177  """Send the tracebacks of a graph execution call to debug server(s).
178
179  Args:
180    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
181      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
182      `CallTraceback` proto payload will be sent to all the destinations.
183    run_key: A string describing the feeds, fetches (and targets) names of the
184      `tf.Session.run` call.
185    origin_stack: The traceback of the `tf.Session.run()` invocation.
186    graph: A Python `tf.Graph` object (i.e., *not* a `tf.compat.v1.GraphDef`),
187      which contains op tracebacks.
188    send_source: Whether the source files involved in the op tracebacks but
189      outside the TensorFlow library are to be sent.
190  """
191  _send_call_tracebacks(
192      destinations, origin_stack, is_eager_execution=False, call_key=run_key,
193      graph=graph, send_source=send_source)
194
195
196def send_eager_tracebacks(destinations,
197                          origin_stack,
198                          send_source=True):
199  """Send the tracebacks of an eager execution call to debug server(s).
200
201  Args:
202    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
203      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
204    origin_stack: The traceback of the eager operation invocation.
205    send_source: Whether the source files involved in the op tracebacks but
206      outside the TensorFlow library are to be sent.
207  """
208  _send_call_tracebacks(
209      destinations, origin_stack, is_eager_execution=True,
210      send_source=send_source)
211