xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/grpc_debug_test_server.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""GRPC debug server for testing."""
16import collections
17import errno
18import functools
19import hashlib
20import json
21import os
22import re
23import tempfile
24import threading
25import time
26
27import portpicker
28
29from tensorflow.core.debug import debug_service_pb2
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.util import event_pb2
32from tensorflow.python.client import session
33from tensorflow.python.debug.lib import debug_data
34from tensorflow.python.debug.lib import debug_utils
35from tensorflow.python.debug.lib import grpc_debug_server
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import errors
38from tensorflow.python.lib.io import file_io
39from tensorflow.python.ops import variables
40from tensorflow.python.util import compat
41
42
43def _get_dump_file_path(dump_root, device_name, debug_node_name):
44  """Get the file path of the dump file for a debug node.
45
46  Args:
47    dump_root: (str) Root dump directory.
48    device_name: (str) Name of the device that the debug node resides on.
49    debug_node_name: (str) Name of the debug node, e.g.,
50      cross_entropy/Log:0:DebugIdentity.
51
52  Returns:
53    (str) Full path of the dump file.
54  """
55
56  dump_root = os.path.join(
57      dump_root, debug_data.device_name_to_device_path(device_name))
58  if "/" in debug_node_name:
59    dump_dir = os.path.join(dump_root, os.path.dirname(debug_node_name))
60    dump_file_name = re.sub(":", "_", os.path.basename(debug_node_name))
61  else:
62    dump_dir = dump_root
63    dump_file_name = re.sub(":", "_", debug_node_name)
64
65  now_microsec = int(round(time.time() * 1000 * 1000))
66  dump_file_name += "_%d" % now_microsec
67
68  return os.path.join(dump_dir, dump_file_name)
69
70
71class EventListenerTestStreamHandler(
72    grpc_debug_server.EventListenerBaseStreamHandler):
73  """Implementation of EventListenerBaseStreamHandler that dumps to file."""
74
75  def __init__(self, dump_dir, event_listener_servicer):
76    super(EventListenerTestStreamHandler, self).__init__()
77    self._dump_dir = dump_dir
78    self._event_listener_servicer = event_listener_servicer
79    if self._dump_dir:
80      self._try_makedirs(self._dump_dir)
81
82    self._grpc_path = None
83    self._cached_graph_defs = []
84    self._cached_graph_def_device_names = []
85    self._cached_graph_def_wall_times = []
86
87  def on_core_metadata_event(self, event):
88    self._event_listener_servicer.toggle_watch()
89
90    core_metadata = json.loads(event.log_message.message)
91
92    if not self._grpc_path:
93      grpc_path = core_metadata["grpc_path"]
94      if grpc_path:
95        if grpc_path.startswith("/"):
96          grpc_path = grpc_path[1:]
97      if self._dump_dir:
98        self._dump_dir = os.path.join(self._dump_dir, grpc_path)
99
100        # Write cached graph defs to filesystem.
101        for graph_def, device_name, wall_time in zip(
102            self._cached_graph_defs,
103            self._cached_graph_def_device_names,
104            self._cached_graph_def_wall_times):
105          self._write_graph_def(graph_def, device_name, wall_time)
106
107    if self._dump_dir:
108      self._write_core_metadata_event(event)
109    else:
110      self._event_listener_servicer.core_metadata_json_strings.append(
111          event.log_message.message)
112
113  def on_graph_def(self, graph_def, device_name, wall_time):
114    """Implementation of the tensor value-carrying Event proto callback.
115
116    Args:
117      graph_def: A GraphDef object.
118      device_name: Name of the device on which the graph was created.
119      wall_time: An epoch timestamp (in microseconds) for the graph.
120    """
121    if self._dump_dir:
122      if self._grpc_path:
123        self._write_graph_def(graph_def, device_name, wall_time)
124      else:
125        self._cached_graph_defs.append(graph_def)
126        self._cached_graph_def_device_names.append(device_name)
127        self._cached_graph_def_wall_times.append(wall_time)
128    else:
129      self._event_listener_servicer.partition_graph_defs.append(graph_def)
130
131  def on_value_event(self, event):
132    """Implementation of the tensor value-carrying Event proto callback.
133
134    Writes the Event proto to the file system for testing. The path written to
135    follows the same pattern as the file:// debug URLs of tfdbg, i.e., the
136    name scope of the op becomes the directory structure under the dump root
137    directory.
138
139    Args:
140      event: The Event proto carrying a tensor value.
141
142    Returns:
143      If the debug node belongs to the set of currently activated breakpoints,
144      a `EventReply` proto will be returned.
145    """
146    if self._dump_dir:
147      self._write_value_event(event)
148    else:
149      value = event.summary.value[0]
150      tensor_value = debug_data.load_tensor_from_event(event)
151      self._event_listener_servicer.debug_tensor_values[value.node_name].append(
152          tensor_value)
153
154      items = event.summary.value[0].node_name.split(":")
155      node_name = items[0]
156      output_slot = int(items[1])
157      debug_op = items[2]
158      if ((node_name, output_slot, debug_op) in
159          self._event_listener_servicer.breakpoints):
160        return debug_service_pb2.EventReply()
161
162  def _try_makedirs(self, dir_path):
163    if not os.path.isdir(dir_path):
164      try:
165        os.makedirs(dir_path)
166      except OSError as error:
167        if error.errno != errno.EEXIST:
168          raise
169
170  def _write_core_metadata_event(self, event):
171    core_metadata_path = os.path.join(
172        self._dump_dir,
173        debug_data.METADATA_FILE_PREFIX + debug_data.CORE_METADATA_TAG +
174        "_%d" % event.wall_time)
175    self._try_makedirs(self._dump_dir)
176    with open(core_metadata_path, "wb") as f:
177      f.write(event.SerializeToString())
178
179  def _write_graph_def(self, graph_def, device_name, wall_time):
180    encoded_graph_def = graph_def.SerializeToString()
181    graph_hash = int(hashlib.sha1(encoded_graph_def).hexdigest(), 16)
182    event = event_pb2.Event(graph_def=encoded_graph_def, wall_time=wall_time)
183    graph_file_path = os.path.join(
184        self._dump_dir,
185        debug_data.device_name_to_device_path(device_name),
186        debug_data.METADATA_FILE_PREFIX + debug_data.GRAPH_FILE_TAG +
187        debug_data.HASH_TAG + "%d_%d" % (graph_hash, wall_time))
188    self._try_makedirs(os.path.dirname(graph_file_path))
189    with open(graph_file_path, "wb") as f:
190      f.write(event.SerializeToString())
191
192  def _write_value_event(self, event):
193    value = event.summary.value[0]
194
195    # Obtain the device name from the metadata.
196    summary_metadata = event.summary.value[0].metadata
197    if not summary_metadata.plugin_data:
198      raise ValueError("The value lacks plugin data.")
199    try:
200      content = json.loads(compat.as_text(summary_metadata.plugin_data.content))
201    except ValueError as err:
202      raise ValueError("Could not parse content into JSON: %r, %r" % (content,
203                                                                      err))
204    device_name = content["device"]
205
206    dump_full_path = _get_dump_file_path(
207        self._dump_dir, device_name, value.node_name)
208    self._try_makedirs(os.path.dirname(dump_full_path))
209    with open(dump_full_path, "wb") as f:
210      f.write(event.SerializeToString())
211
212
213class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
214  """An implementation of EventListenerBaseServicer for testing."""
215
216  def __init__(self, server_port, dump_dir, toggle_watch_on_core_metadata=None):
217    """Constructor of EventListenerTestServicer.
218
219    Args:
220      server_port: (int) The server port number.
221      dump_dir: (str) The root directory to which the data files will be
222        dumped. If empty or None, the received debug data will not be dumped
223        to the file system: they will be stored in memory instead.
224      toggle_watch_on_core_metadata: A list of
225        (node_name, output_slot, debug_op) tuples to toggle the
226        watchpoint status during the on_core_metadata calls (optional).
227    """
228    self.core_metadata_json_strings = []
229    self.partition_graph_defs = []
230    self.debug_tensor_values = collections.defaultdict(list)
231    self._initialize_toggle_watch_state(toggle_watch_on_core_metadata)
232
233    grpc_debug_server.EventListenerBaseServicer.__init__(
234        self, server_port,
235        functools.partial(EventListenerTestStreamHandler, dump_dir, self))
236
237    # Members for storing the graph ops traceback and source files.
238    self._call_types = []
239    self._call_keys = []
240    self._origin_stacks = []
241    self._origin_id_to_strings = []
242    self._graph_tracebacks = []
243    self._graph_versions = []
244    self._source_files = []
245
246  def _initialize_toggle_watch_state(self, toggle_watches):
247    self._toggle_watches = toggle_watches
248    self._toggle_watch_state = {}
249    if self._toggle_watches:
250      for watch_key in self._toggle_watches:
251        self._toggle_watch_state[watch_key] = False
252
253  def toggle_watch(self):
254    for watch_key in self._toggle_watch_state:
255      node_name, output_slot, debug_op = watch_key
256      if self._toggle_watch_state[watch_key]:
257        self.request_unwatch(node_name, output_slot, debug_op)
258      else:
259        self.request_watch(node_name, output_slot, debug_op)
260      self._toggle_watch_state[watch_key] = (
261          not self._toggle_watch_state[watch_key])
262
263  def clear_data(self):
264    self.core_metadata_json_strings = []
265    self.partition_graph_defs = []
266    self.debug_tensor_values = collections.defaultdict(list)
267    self._call_types = []
268    self._call_keys = []
269    self._origin_stacks = []
270    self._origin_id_to_strings = []
271    self._graph_tracebacks = []
272    self._graph_versions = []
273    self._source_files = []
274
275  def SendTracebacks(self, request, context):
276    self._call_types.append(request.call_type)
277    self._call_keys.append(request.call_key)
278    self._origin_stacks.append(request.origin_stack)
279    self._origin_id_to_strings.append(request.origin_id_to_string)
280    self._graph_tracebacks.append(request.graph_traceback)
281    self._graph_versions.append(request.graph_version)
282    return debug_service_pb2.EventReply()
283
284  def SendSourceFiles(self, request, context):
285    self._source_files.append(request)
286    return debug_service_pb2.EventReply()
287
288  def query_op_traceback(self, op_name):
289    """Query the traceback of an op.
290
291    Args:
292      op_name: Name of the op to query.
293
294    Returns:
295      The traceback of the op, as a list of 3-tuples:
296        (filename, lineno, function_name)
297
298    Raises:
299      ValueError: If the op cannot be found in the tracebacks received by the
300        server so far.
301    """
302    for op_log_proto in self._graph_tracebacks:
303      for log_entry in op_log_proto.log_entries:
304        if log_entry.name == op_name:
305          return self._code_def_to_traceback(log_entry.code_def,
306                                             op_log_proto.id_to_string)
307    raise ValueError(
308        "Op '%s' does not exist in the tracebacks received by the debug "
309        "server." % op_name)
310
311  def query_origin_stack(self):
312    """Query the stack of the origin of the execution call.
313
314    Returns:
315      A `list` of all tracebacks. Each item corresponds to an execution call,
316        i.e., a `SendTracebacks` request. Each item is a `list` of 3-tuples:
317        (filename, lineno, function_name).
318    """
319    ret = []
320    for stack, id_to_string in zip(
321        self._origin_stacks, self._origin_id_to_strings):
322      ret.append(self._code_def_to_traceback(stack, id_to_string))
323    return ret
324
325  def query_call_types(self):
326    return self._call_types
327
328  def query_call_keys(self):
329    return self._call_keys
330
331  def query_graph_versions(self):
332    return self._graph_versions
333
334  def query_source_file_line(self, file_path, lineno):
335    """Query the content of a given line in a source file.
336
337    Args:
338      file_path: Path to the source file.
339      lineno: Line number as an `int`.
340
341    Returns:
342      Content of the line as a string.
343
344    Raises:
345      ValueError: If no source file is found at the given file_path.
346    """
347    if not self._source_files:
348      raise ValueError(
349          "This debug server has not received any source file contents yet.")
350    for source_files in self._source_files:
351      for source_file_proto in source_files.source_files:
352        if source_file_proto.file_path == file_path:
353          return source_file_proto.lines[lineno - 1]
354    raise ValueError(
355        "Source file at path %s has not been received by the debug server",
356        file_path)
357
358  def _code_def_to_traceback(self, code_def, id_to_string):
359    return [(id_to_string[trace.file_id],
360             trace.lineno,
361             id_to_string[trace.function_id]) for trace in code_def.traces]
362
363
364def start_server_on_separate_thread(dump_to_filesystem=True,
365                                    server_start_delay_sec=0.0,
366                                    poll_server=False,
367                                    blocking=True,
368                                    toggle_watch_on_core_metadata=None):
369  """Create a test gRPC debug server and run on a separate thread.
370
371  Args:
372    dump_to_filesystem: (bool) whether the debug server will dump debug data
373      to the filesystem.
374    server_start_delay_sec: (float) amount of time (in sec) to delay the server
375      start up for.
376    poll_server: (bool) whether the server will be polled till success on
377      startup.
378    blocking: (bool) whether the server should be started in a blocking mode.
379    toggle_watch_on_core_metadata: A list of
380        (node_name, output_slot, debug_op) tuples to toggle the
381        watchpoint status during the on_core_metadata calls (optional).
382
383  Returns:
384    server_port: (int) Port on which the server runs.
385    debug_server_url: (str) grpc:// URL to the server.
386    server_dump_dir: (str) The debug server's dump directory.
387    server_thread: The server Thread object.
388    server: The `EventListenerTestServicer` object.
389
390  Raises:
391    ValueError: If polling the server process for ready state is not successful
392      within maximum polling count.
393  """
394  server_port = portpicker.pick_unused_port()
395  debug_server_url = "grpc://localhost:%d" % server_port
396
397  server_dump_dir = tempfile.mkdtemp() if dump_to_filesystem else None
398  server = EventListenerTestServicer(
399      server_port=server_port,
400      dump_dir=server_dump_dir,
401      toggle_watch_on_core_metadata=toggle_watch_on_core_metadata)
402
403  def delay_then_run_server():
404    time.sleep(server_start_delay_sec)
405    server.run_server(blocking=blocking)
406
407  server_thread = threading.Thread(target=delay_then_run_server)
408  server_thread.start()
409
410  if poll_server:
411    if not _poll_server_till_success(
412        50,
413        0.2,
414        debug_server_url,
415        server_dump_dir,
416        server,
417        gpu_memory_fraction=0.1):
418      raise ValueError(
419          "Failed to start test gRPC debug server at port %d" % server_port)
420    server.clear_data()
421  return server_port, debug_server_url, server_dump_dir, server_thread, server
422
423
424def _poll_server_till_success(max_attempts,
425                              sleep_per_poll_sec,
426                              debug_server_url,
427                              dump_dir,
428                              server,
429                              gpu_memory_fraction=1.0):
430  """Poll server until success or exceeding max polling count.
431
432  Args:
433    max_attempts: (int) How many times to poll at maximum
434    sleep_per_poll_sec: (float) How many seconds to sleep for after each
435      unsuccessful poll.
436    debug_server_url: (str) gRPC URL to the debug server.
437    dump_dir: (str) Dump directory to look for files in. If None, will directly
438      check data from the server object.
439    server: The server object.
440    gpu_memory_fraction: (float) Fraction of GPU memory to be
441      allocated for the Session used in server polling.
442
443  Returns:
444    (bool) Whether the polling succeeded within max_polls attempts.
445  """
446  poll_count = 0
447
448  config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
449      per_process_gpu_memory_fraction=gpu_memory_fraction))
450  with session.Session(config=config) as sess:
451    for poll_count in range(max_attempts):
452      server.clear_data()
453      print("Polling: poll_count = %d" % poll_count)
454
455      x_init_name = "x_init_%d" % poll_count
456      x_init = constant_op.constant([42.0], shape=[1], name=x_init_name)
457      x = variables.Variable(x_init, name=x_init_name)
458
459      run_options = config_pb2.RunOptions()
460      debug_utils.add_debug_tensor_watch(
461          run_options, x_init_name, 0, debug_urls=[debug_server_url])
462      try:
463        sess.run(x.initializer, options=run_options)
464      except errors.FailedPreconditionError:
465        pass
466
467      if dump_dir:
468        if os.path.isdir(
469            dump_dir) and debug_data.DebugDumpDir(dump_dir).size > 0:
470          file_io.delete_recursively(dump_dir)
471          print("Poll succeeded.")
472          return True
473        else:
474          print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
475          time.sleep(sleep_per_poll_sec)
476      else:
477        if server.debug_tensor_values:
478          print("Poll succeeded.")
479          return True
480        else:
481          print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
482          time.sleep(sleep_per_poll_sec)
483
484    return False
485