xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/wrappers/framework.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Framework of debug wrapper sessions.
16
17A debug wrapper session is a wrapper around a TensorFlow Python Session.
18The wrapper preserves the Session interface, most importantly the run() method,
19while providing abilities to:
20a) Intercept a run() call to a wrapped session and insert debug tensor watches
21   according to externally-specified debug URLs.
22
23b) Release control to an external (i.e., non-Session) object before and after
24   the run() call, so that the external object can perform actions such as
25   launching a UI to let users inspect the intermediate tensors and partition
26   graphs from the run() call.
27
28c) (To be implemented in a future CL) Enter an instruction loop to let an
29   external object (e.g., remote client) launch run() and cont() calls
30   remotely.
31
32*** The lifetime of a debug wrapper session: ***
33
341) The wrapper session is created by calling the constructor with a
35   wrapped (normal) session as the argument:
36     wrapper = FooDebugWrapperSession(sess)
37   wherein FooDebugWrapperSession is a concrete subclass implementing the
38   abstract BaseDebugWrapperSession class below.
39
402) Near the end of the constructor call, the on_session_init() callback is
41   invoked, with a OnSessionInitRequest object as the argument. The object
42   carries the wrapped (normal) session object.
43
443) The callback handles the request and returns a OnSessionInitResponse
45   object with an action field, directing the wrapper session what to do next.
46
47If the action field in the OnSessionInitResponse is PROCEED, the constructor
48returns. Control is released back to the caller of the constructor, which can
49invoke run() method of wrapper session with the same syntax as a non-wrapped
50session, e.g.,:
51  wrapper.run(fetches, feed_dict=feeds, options=run_options)
52
53Below, A1 - A2 is the lifetime of a wrapper run() call if the action is
54PROCEED:
55
56A1) Right at the start of each run() call, the on_run_start() callback is
57    invoked, with an OnRunStartRequest object carrying information such as
58    the fetches, the feed dict, the run options and run metadata used in
59    this run call, along with a count of how many run calls has occurred
60    on this wrapper session. The callback then returns an OnRunStartResponse
61    object, of which the action field directs what the wrapper session
62    actually will do of the run() call.
63
64    If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue,
65    with the debug URLs supplied in the debug_urls field of the response.
66    These can be file:// or grpc:// URLs, for example.
67
68    If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue.
69
70A2) Right before the run() returns, the on_run_end() callback is invoked,
71    with an OnRunEndRequest object as the argument, which carries information
72    including the actual action performed in the wrapper run() call and the
73    run_metadata from the run() call.
74
75However, if the action field in OnSessionInitResponse is
76REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop
77that gives the control to a remote caller.
78
79In the remote instruction loop, the following steps will happen:
80
81B1) Callback on_instr_start() is invoked. The callback will return an
82    OnInstrStartResponse object with an action field which can order one of
83    the following actions:
84        i) a run() call with fetches, feeds and debug_urls specified.
85       ii) exit the instruction loop.
86
87B2) The wrapper session carries out the action specified above.
88
89B3) If still in the instruction loop, the wrapper session invokes the
90    on_instr_end() callback. After the on_instr_end() callback returns, jump
91    back to B1.
92
93TODO(cais): Implemented the instruction loop in B1 - B3.
94
95"""
96
97import abc
98import re
99import threading
100
101from tensorflow.core.protobuf import config_pb2
102from tensorflow.python.client import session
103from tensorflow.python.debug.lib import debug_utils
104from tensorflow.python.framework import errors
105from tensorflow.python.framework import ops
106from tensorflow.python.platform import tf_logging
107from tensorflow.python.training import monitored_session
108from tensorflow.python.util import nest
109from tensorflow.python.util.compat import collections_abc
110
111
112# Helper function.
113def _check_type(obj, expected_types):
114  """Check if an object is of the expected type.
115
116  Args:
117    obj: The object being checked.
118    expected_types: (`type` or an iterable of `type`s) The expected `type`(s)
119      of obj.
120
121  Raises:
122      TypeError: If obj is not an instance of expected_type.
123  """
124  if not isinstance(obj, expected_types):
125    raise TypeError("Expected type %s; got type %s" %
126                    (expected_types, type(obj)))
127
128
129class OnSessionInitRequest:
130  """Request to an on-session-init callback.
131
132  This callback is invoked during the __init__ call to a debug-wrapper session.
133  """
134
135  def __init__(self, sess):
136    """Constructor.
137
138    Args:
139      sess: A tensorflow Session object.
140    """
141
142    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
143    self.session = sess
144
145
146class OnSessionInitAction:
147  """Enum-like values for possible action to take on session init."""
148
149  # Proceed, without special actions, in the wrapper session initialization.
150  # What action the wrapper session performs next is determined by the caller
151  # of the wrapper session. E.g., it can call run().
152  PROCEED = "proceed"
153
154  # Instead of letting the caller of the wrapper session determine what actions
155  # the wrapper session will perform next, enter a loop to receive instructions
156  # from a remote client.
157  # For example, TensorBoard visual debugger can use this action so that it can
158  # launch session.run() calls remotely.
159  REMOTE_INSTR_LOOP = "remote_instr_loop"
160
161
162class OnSessionInitResponse:
163  """Response from an on-session-init callback."""
164
165  def __init__(self, action):
166    """Constructor.
167
168    Args:
169      action: (`OnSessionInitAction`) Debugger action to take on session init.
170    """
171    _check_type(action, str)
172    self.action = action
173
174
175class OnRunStartRequest:
176  """Request to an on-run-start callback.
177
178  This callback is invoked during a run() call of the debug-wrapper
179  session, immediately after the run() call counter is incremented.
180  """
181
182  def __init__(self, fetches, feed_dict, run_options, run_metadata,
183               run_call_count, is_callable_runner=False):
184    """Constructor of `OnRunStartRequest`.
185
186    Args:
187      fetches: Fetch targets of the run() call.
188      feed_dict: The feed dictionary to the run() call.
189      run_options: RunOptions input to the run() call.
190      run_metadata: RunMetadata input to the run() call.
191        The above four arguments are identical to the input arguments to the
192        run() method of a non-wrapped TensorFlow session.
193      run_call_count: 1-based count of how many run calls (including this one)
194        has been invoked.
195      is_callable_runner: (bool) whether a runner returned by
196        Session.make_callable is being run.
197    """
198    self.fetches = fetches
199    self.feed_dict = feed_dict
200    self.run_options = run_options
201    self.run_metadata = run_metadata
202    self.run_call_count = run_call_count
203    self.is_callable_runner = is_callable_runner
204
205
206class OnRunStartAction:
207  """Enum-like values for possible action to take on start of a run() call."""
208
209  # Run once with debug tensor-watching.
210  DEBUG_RUN = "debug_run"
211
212  # Run once with profiler.
213  PROFILE_RUN = "profile_run"
214
215  # Run without debug tensor-watching.
216  NON_DEBUG_RUN = "non_debug_run"
217
218
219
220class OnRunStartResponse:
221  """Request from an on-run-start callback.
222
223  The caller of the callback can use this response object to specify what
224  action the debug-wrapper session actually takes on the run() call.
225  """
226
227  def __init__(self,
228               action,
229               debug_urls,
230               debug_ops="DebugIdentity",
231               node_name_regex_allowlist=None,
232               op_type_regex_allowlist=None,
233               tensor_dtype_regex_allowlist=None,
234               tolerate_debug_op_creation_failures=False):
235    """Constructor of `OnRunStartResponse`.
236
237    Args:
238      action: (`OnRunStartAction`) the action actually taken by the wrapped
239        session for the run() call.
240      debug_urls: (`list` of `str`) debug_urls used in watching the tensors
241        during the run() call.
242      debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
243        debugger.
244      node_name_regex_allowlist: Regular-expression allowlist for node
245        name.
246      op_type_regex_allowlist: Regular-expression allowlist for op type.
247      tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
248        dtype.
249      tolerate_debug_op_creation_failures: Whether debug op creation failures
250        are to be tolerated.
251    """
252
253    _check_type(action, str)
254    self.action = action
255
256    _check_type(debug_urls, list)
257    self.debug_urls = debug_urls
258
259    self.debug_ops = debug_ops
260
261    self.node_name_regex_allowlist = node_name_regex_allowlist
262    self.op_type_regex_allowlist = op_type_regex_allowlist
263    self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
264    self.tolerate_debug_op_creation_failures = (
265        tolerate_debug_op_creation_failures)
266
267
268class OnRunEndRequest:
269  """Request to an on-run-end callback.
270
271  The callback is invoked immediately before the wrapped run() call ends.
272  """
273
274  def __init__(self,
275               performed_action,
276               run_metadata=None,
277               client_graph_def=None,
278               tf_error=None):
279    """Constructor for `OnRunEndRequest`.
280
281    Args:
282      performed_action: (`OnRunStartAction`) Actually-performed action by the
283        debug-wrapper session.
284      run_metadata: run_metadata output from the run() call (if any).
285      client_graph_def: (GraphDef) GraphDef from the client side, i.e., from
286        the python front end of TensorFlow. Can be obtained with
287        session.graph.as_graph_def().
288      tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred
289        during the run (if any).
290    """
291
292    _check_type(performed_action, str)
293    self.performed_action = performed_action
294
295    if run_metadata is not None:
296      _check_type(run_metadata, config_pb2.RunMetadata)
297    self.run_metadata = run_metadata
298    self.client_graph_def = client_graph_def
299    self.tf_error = tf_error
300
301
302class OnRunEndResponse:
303  """Response from an on-run-end callback."""
304
305  def __init__(self):
306
307    # Currently only a placeholder.
308    pass
309
310
311class BaseDebugWrapperSession(session.SessionInterface, metaclass=abc.ABCMeta):
312  """Base class of debug-wrapper session classes.
313
314  Concrete classes that inherit from this class need to implement the abstract
315  methods such as on_session_init, on_run_start and on_run_end.
316  """
317
318  def __init__(self, sess, thread_name_filter=None,
319               pass_through_operrors=False):
320    """Constructor of `BaseDebugWrapperSession`.
321
322    Args:
323      sess: An (unwrapped) TensorFlow session instance. It should be a subtype
324        of `BaseSession` or `tf.MonitoredSession`.
325      thread_name_filter: Regular-expression filter (allowlist) for name(s) of
326        thread(s) on which the wrapper session will be active. This regular
327        expression is used in a start-anchored fashion on the thread name, i.e.,
328        by applying the `match` method of the compiled pattern. The default
329        `None` means that the wrapper session will be active on all threads.
330        E.g., r"MainThread$", r"QueueRunnerThread.*".
331      pass_through_operrors: If True, all captured OpErrors will be
332        propagated.  By default this captures all OpErrors.
333
334    Raises:
335      ValueError: On invalid `OnSessionInitAction` value.
336      NotImplementedError: If a non-DirectSession sess object is received.
337    """
338
339    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
340
341    # The session being wrapped.
342    self._sess = sess
343    self._thread_name_filter_pattern = (re.compile(thread_name_filter)
344                                        if thread_name_filter else None)
345    # TODO(cais/kstevens): Unittest this pass through feature.
346    self._pass_through_operrors = pass_through_operrors
347
348    # Keeps track of number of run calls that have been performed on this
349    # debug-wrapper session. The count can be used for purposes such as
350    # displaying the state of the Session in a UI and determining a run
351    # number-dependent debug URL.
352    self._run_call_count = 0
353
354    # Invoke on-session-init callback.
355    response = self.on_session_init(OnSessionInitRequest(self._sess))
356    _check_type(response, OnSessionInitResponse)
357
358    if response.action == OnSessionInitAction.PROCEED:
359      pass
360    elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP:
361      # TODO(cais): Implement REMOTE_INSTR_LOOP
362      raise NotImplementedError(
363          "OnSessionInitAction REMOTE_INSTR_LOOP has not been "
364          "implemented.")
365    else:
366      raise ValueError(
367          "Invalid OnSessionInitAction value: %s" % response.action)
368
369    self._default_session_context_manager = None
370
371    # A cache for callables created from CallableOptions.
372    self._cached_callables_from_options = {}
373
374  @property
375  def graph(self):
376    return self._sess.graph
377
378  @property
379  def graph_def(self):
380    return self._sess.graph_def
381
382  @property
383  def sess_str(self):
384    return self._sess.sess_str
385
386  @property
387  def session(self):
388    return self._sess
389
390  def run(self,
391          fetches,
392          feed_dict=None,
393          options=None,
394          run_metadata=None,
395          callable_runner=None,
396          callable_runner_args=None,
397          callable_options=None):
398    """Wrapper around Session.run() that inserts tensor watch options.
399
400    Args:
401      fetches: Same as the `fetches` arg to regular `Session.run()`.
402      feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
403      options: Same as the `options` arg to regular `Session.run()`.
404      run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
405      callable_runner: A `callable` returned by `Session.make_callable()`.
406        If not `None`, `fetches` and `feed_dict` must both be `None`.
407        Mutually exclusive with `callable_options`.
408      callable_runner_args: An optional list of arguments to `callable_runner`
409        or for `callable_options`.
410      callable_options: An instance of `config_pb2.CallableOptions`, to be
411        used with `Session._make_callable_from_options()`. Mutually exclusive
412        with `callable_runner`.
413
414    Returns:
415      Simply forwards the output of the wrapped `Session.run()` call.
416
417    Raises:
418      ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
419        is not `None` and either or both of `fetches` and `feed_dict` is `None`.
420    """
421    if callable_runner and callable_options:
422      raise ValueError(
423          "callable_runner and callable_options are mutually exclusive, but "
424          "are both specified in this call to BaseDebugWrapperSession.run().")
425
426    if callable_runner and (fetches or feed_dict):
427      raise ValueError(
428          "callable_runner and fetches/feed_dict are mutually exclusive, "
429          "but are used simultaneously.")
430    elif callable_options and (fetches or feed_dict):
431      raise ValueError(
432          "callable_options and fetches/feed_dict are mutually exclusive, "
433          "but are used simultaneously.")
434
435    self.increment_run_call_count()
436
437    def is_empty(x):
438      """Check whether a possibly nested structure is empty."""
439      if not nest.is_nested(x):
440        return False
441      if isinstance(x, collections_abc.Mapping):
442        return is_empty(list(x.values()))
443      for item in x:
444        if not is_empty(item):
445          return False
446      return True
447
448    empty_fetches = is_empty(fetches)
449    if empty_fetches:
450      tf_logging.info(
451          "Due to empty fetches, tfdbg Session wrapper is letting a "
452          "Session.run pass through without any debugging actions.")
453    if self._is_disabled_thread() or empty_fetches:
454      if callable_runner:
455        return callable_runner(*callable_runner_args)
456      elif callable_options:
457        # pylint:disable=protected-access
458        return self._sess._make_callable_from_options(
459            callable_options)(*callable_runner_args)
460        # pylint:enable=protected-access
461      else:
462        return self._sess.run(fetches,
463                              feed_dict=feed_dict,
464                              options=options,
465                              run_metadata=run_metadata)
466
467    # Invoke on-run-start callback and obtain response.
468    run_start_resp = self.on_run_start(
469        OnRunStartRequest(fetches, feed_dict, options, run_metadata,
470                          self._run_call_count,
471                          is_callable_runner=bool(callable_runner)))
472    _check_type(run_start_resp, OnRunStartResponse)
473
474    if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
475      retvals, run_end_req = self._run_with_debugging(
476          run_start_resp, fetches, feed_dict, options, run_metadata,
477          callable_runner, callable_runner_args, callable_options)
478    elif run_start_resp.action == OnRunStartAction.PROFILE_RUN:
479      retvals, run_end_req = self._run_with_profiling(
480          run_start_resp, fetches, feed_dict, options, run_metadata,
481          callable_runner, callable_runner_args, callable_options)
482    elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN:
483      # Invoke run() method of the wrapped session.
484      if callable_runner:
485        retvals = callable_runner(*callable_runner_args)
486      elif callable_options:
487        # pylint:disable=protected-access
488        callable_object = self._sess._make_callable_from_options(
489            callable_options)
490        # pylint:enable=protected-access
491        retvals = callable_object(*callable_runner_args)
492      else:
493        retvals = self._sess.run(
494            fetches,
495            feed_dict=feed_dict,
496            options=options,
497            run_metadata=run_metadata)
498
499      # Prepare arg for the on-run-end callback.
500      run_end_req = OnRunEndRequest(run_start_resp.action)
501    else:
502      raise ValueError(
503          "Invalid OnRunStartAction value: %s" % run_start_resp.action)
504
505    # Invoke on-run-end callback and obtain response.
506    run_end_resp = self.on_run_end(run_end_req)
507    _check_type(run_end_resp, OnRunEndResponse)
508    # Currently run_end_resp is only a placeholder. No action is taken on it.
509
510    return retvals
511
512  def _run_with_debugging(self,
513                          run_start_resp,
514                          fetches,
515                          feed_dict,
516                          options,
517                          run_metadata,
518                          callable_runner,
519                          callable_runner_args,
520                          callable_options):
521    """Perform a session.run() or callable with debugging."""
522    # Decorate RunOption to fill in debugger tensor watch specifications.
523    decorated_run_options = None
524    if callable_options:
525      callable_options_id = id(callable_options)
526      if callable_options_id not in self._cached_callables_from_options:
527        # Make a copy of callable_options to avoid mutating it.
528        new_callable_options = config_pb2.CallableOptions()
529        new_callable_options.CopyFrom(callable_options)
530        decorated_run_options = new_callable_options.run_options
531    else:
532      decorated_run_options = options or config_pb2.RunOptions()
533
534    run_metadata = run_metadata or config_pb2.RunMetadata()
535
536    if decorated_run_options:
537      self._decorate_run_options_for_debug(
538          decorated_run_options,
539          run_start_resp.debug_urls,
540          debug_ops=run_start_resp.debug_ops,
541          node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist),
542          op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist,
543          tensor_dtype_regex_allowlist=(
544              run_start_resp.tensor_dtype_regex_allowlist),
545          tolerate_debug_op_creation_failures=(
546              run_start_resp.tolerate_debug_op_creation_failures))
547
548    # Invoke the run() method of the wrapped Session. Catch any TensorFlow
549    # runtime errors.
550    tf_error = None
551    try:
552      if callable_runner:
553        retvals = callable_runner(*callable_runner_args,
554                                  options=decorated_run_options,
555                                  run_metadata=run_metadata)
556      elif callable_options:
557        # pylint:disable=protected-access
558        if callable_options_id in self._cached_callables_from_options:
559          callable_object = self._cached_callables_from_options[
560              callable_options_id]
561        else:
562          callable_object = self._sess._make_callable_from_options(
563              new_callable_options)
564          self._cached_callables_from_options[
565              callable_options_id] = callable_object
566        # pylint:enable=protected-access
567        retvals = callable_object(
568            *callable_runner_args, run_metadata=run_metadata)
569      else:
570        retvals = self._sess.run(fetches,
571                                 feed_dict=feed_dict,
572                                 options=decorated_run_options,
573                                 run_metadata=run_metadata)
574    except errors.OpError as op_error:
575      if self._pass_through_operrors:
576        raise op_error
577      tf_error = op_error
578      retvals = op_error
579
580    return retvals, OnRunEndRequest(
581        run_start_resp.action,
582        run_metadata=run_metadata,
583        client_graph_def=self._sess.graph.as_graph_def(),
584        tf_error=tf_error)
585
586  def _run_with_profiling(self,
587                          run_start_resp,
588                          fetches,
589                          feed_dict,
590                          options,
591                          run_metadata,
592                          callable_runner,
593                          callable_runner_args,
594                          callable_options):
595    """Perform a session.run() or callable with profiling."""
596    # Decorate RunOption to fill in debugger tensor watch specifications.
597    decorated_run_options = None
598    if callable_options:
599      callable_options_id = id(callable_options)
600      if callable_options_id not in self._cached_callables_from_options:
601        # Make a copy of callable_options to avoid mutating it.
602        new_callable_options = config_pb2.CallableOptions()
603        new_callable_options.CopyFrom(callable_options)
604        decorated_run_options = new_callable_options.run_options
605    else:
606      decorated_run_options = options or config_pb2.RunOptions()
607    self._decorate_run_options_for_profile(decorated_run_options)
608
609    run_metadata = run_metadata or config_pb2.RunMetadata()
610    if callable_runner:
611      retvals = callable_runner(*callable_runner_args,
612                                options=decorated_run_options,
613                                run_metadata=run_metadata)
614    elif callable_options:
615      # pylint:disable=protected-access
616      callable_object = self._sess._make_callable_from_options(
617          new_callable_options)
618      # pylint:enable=protected-access
619      retvals = callable_object(
620          *callable_runner_args, run_metadata=run_metadata)
621    else:
622      retvals = self._sess.run(fetches,
623                               feed_dict=feed_dict,
624                               options=decorated_run_options,
625                               run_metadata=run_metadata)
626    return retvals, OnRunEndRequest(
627        run_start_resp.action,
628        run_metadata=run_metadata,
629        client_graph_def=self._sess.graph.as_graph_def())
630
631  def _is_disabled_thread(self):
632    thread_name = threading.current_thread().name or ""
633    return (self._thread_name_filter_pattern and
634            not self._thread_name_filter_pattern.match(thread_name))
635
636  def run_step_fn(self, step_fn):
637    return step_fn(
638        monitored_session.MonitoredSession.StepContext(self._sess, self.run))
639
640  def partial_run_setup(self, fetches, feeds=None):
641    """Sets up the feeds and fetches for partial runs in the session."""
642    raise NotImplementedError(
643        "partial_run_setup is not implemented for debug-wrapper sessions.")
644
645  def partial_run(self, handle, fetches, feed_dict=None):
646    raise NotImplementedError(
647        "partial_run is not implemented for debug-wrapper sessions.")
648
649  def list_devices(self, *args, **kwargs):
650    return self._sess.list_devices(*args, **kwargs)
651
652  def reset(self, *args, **kwargs):
653    return self._sess.reset(*args, **kwargs)
654
655  def make_callable(self,
656                    fetches,
657                    feed_list=None,
658                    accept_options=False):
659    runner = self._sess.make_callable(
660        fetches, feed_list=feed_list, accept_options=True)
661    def wrapped_runner(*runner_args, **kwargs):
662      return self.run(None,
663                      feed_dict=None,
664                      options=kwargs.get("options", None),
665                      run_metadata=kwargs.get("run_metadata", None),
666                      callable_runner=runner,
667                      callable_runner_args=runner_args)
668    return wrapped_runner
669
670  def _make_callable_from_options(self, callable_options):
671    def wrapped_runner(*feed_values, **kwargs):
672      return self.run(None,
673                      run_metadata=kwargs.get("run_metadata", None),
674                      callable_options=callable_options,
675                      callable_runner_args=feed_values)
676    return wrapped_runner
677
678  @property
679  def run_call_count(self):
680    return self._run_call_count
681
682  def increment_run_call_count(self):
683    self._run_call_count += 1
684
685  def _is_disk_usage_reset_each_run(self):
686    """Indicates whether disk usage is reset after each Session.run.
687
688    Subclasses that clean up the disk usage after every run should
689    override this protected method.
690
691    Returns:
692      (`bool`) Whether the disk usage amount is reset to zero after
693        each Session.run.
694    """
695    return False
696
697  def _decorate_run_options_for_debug(
698      self,
699      run_options,
700      debug_urls,
701      debug_ops="DebugIdentity",
702      node_name_regex_allowlist=None,
703      op_type_regex_allowlist=None,
704      tensor_dtype_regex_allowlist=None,
705      tolerate_debug_op_creation_failures=False):
706    """Modify a RunOptions object for debug tensor watching.
707
708    Specifies request for outputting partition graphs. Adds
709    debug_tensor_watch_opts with proper debug URLs.
710
711    Args:
712      run_options: (RunOptions) the modified RunOptions object.
713      debug_urls: (list of str) debug URLs to be entered in run_options.
714        debug_tensor_watch_opts.
715      debug_ops: (str or list of str) debug op(s) to be used by the debugger.
716      node_name_regex_allowlist: Regular-expression allowlist for node
717        name.
718      op_type_regex_allowlist: Regular-expression allowlist for op type.
719      tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
720        dtype.
721      tolerate_debug_op_creation_failures: Whether debug op creation failures
722        are to be tolerated.
723    """
724
725    run_options.output_partition_graphs = True
726    debug_utils.watch_graph(
727        run_options,
728        self._sess.graph,
729        debug_urls=debug_urls,
730        debug_ops=debug_ops,
731        node_name_regex_allowlist=node_name_regex_allowlist,
732        op_type_regex_allowlist=op_type_regex_allowlist,
733        tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist,
734        tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
735        reset_disk_byte_usage=(self._run_call_count == 1 or
736                               self._is_disk_usage_reset_each_run()))
737
738  def _decorate_run_options_for_profile(self, run_options):
739    """Modify a RunOptions object for profiling TensorFlow graph execution.
740
741    Args:
742      run_options: (RunOptions) the modified RunOptions object.
743    """
744
745    run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
746
747  @abc.abstractmethod
748  def on_session_init(self, request):
749    """Callback invoked during construction of the debug-wrapper session.
750
751    This is a blocking callback.
752    The invocation happens right before the constructor ends.
753
754    Args:
755      request: (`OnSessionInitRequest`) callback request carrying information
756        such as the session being wrapped.
757
758    Returns:
759      An instance of `OnSessionInitResponse`.
760    """
761
762  @abc.abstractmethod
763  def on_run_start(self, request):
764    """Callback invoked on run() calls to the debug-wrapper session.
765
766    This is a blocking callback.
767    The invocation happens after the wrapper's run() call is entered,
768    after an increment of run call counter.
769
770    Args:
771      request: (`OnRunStartRequest`) callback request object carrying
772        information about the run call such as the fetches, feed dict, run
773        options, run metadata, and how many `run()` calls to this wrapper
774        session have occurred.
775
776    Returns:
777      An instance of `OnRunStartResponse`, carrying information to
778        debug URLs used to watch the tensors.
779    """
780
781  @abc.abstractmethod
782  def on_run_end(self, request):
783    """Callback invoked on run() calls to the debug-wrapper session.
784
785    This is a blocking callback.
786    The invocation happens right before the wrapper exits its run() call.
787
788    Args:
789      request: (`OnRunEndRequest`) callback request object carrying information
790        such as the actual action performed by the session wrapper for the
791        run() call.
792
793    Returns:
794      An instance of `OnRunStartResponse`.
795    """
796
797  def as_default(self):
798    return ops.default_session(self)
799
800  def __enter__(self):
801    if self._default_session_context_manager is None:
802      self._default_session_context_manager = self.as_default()
803    return self._default_session_context_manager.__enter__()
804
805  def __exit__(self, exec_type, exec_value, exec_tb):
806    self._default_session_context_manager.__exit__(
807        exec_type, exec_value, exec_tb)
808
809  def __del__(self):
810    if hasattr(self._sess, "__del__"):
811      self._sess.__del__()
812
813  def close(self):
814    self._sess.close()
815
816  # TODO(cais): Add _node_name_regex_allowlist and
817  #   _node_op_type_regex_allowlist.
818
819  def should_stop(self):
820    if hasattr(self._sess, "should_stop"):
821      return self._sess.should_stop()
822    else:
823      raise ValueError(
824          "The wrapped session %r does not have a method called 'should_stop'. "
825          "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess)
826
827
828class WatchOptions:
829  """Type for return values of watch_fn."""
830
831  def __init__(self,
832               debug_ops=None,
833               node_name_regex_allowlist=None,
834               op_type_regex_allowlist=None,
835               tensor_dtype_regex_allowlist=None,
836               tolerate_debug_op_creation_failures=False):
837    """Constructor of WatchOptions: Debug watch options.
838
839    Used as return values of `watch_fn`s.
840
841    Args:
842      debug_ops: (`str` or `list of str`) Debug ops to be used.
843      node_name_regex_allowlist: Regular-expression allowlist for node_name,
844        e.g., `"(weight_[0-9]+|bias_.*)"`
845      op_type_regex_allowlist: Regular-expression allowlist for the op type of
846        nodes, e.g., `"(Variable|Add)"`.
847        If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
848        are set, the two filtering operations will occur in a logical `AND`
849        relation. In other words, a node will be included if and only if it
850        hits both allowlists.
851      tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
852        data type, e.g., `"^int.*"`.
853        This allowlist operates in logical `AND` relations to the two allowlists
854        above.
855      tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
856        failures (e.g., due to dtype incompatibility) are to be tolerated by not
857        throwing exceptions.
858    """
859    if debug_ops:
860      self.debug_ops = debug_ops
861    else:
862      self.debug_ops = ["DebugIdentity"]
863    self.node_name_regex_allowlist = node_name_regex_allowlist
864    self.op_type_regex_allowlist = op_type_regex_allowlist
865    self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
866    self.tolerate_debug_op_creation_failures = (
867        tolerate_debug_op_creation_failures)
868
869  def __repr__(self):
870    return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, "
871            "op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, "
872            "tolerate_debug_op_creation_failures=%r)" %
873            (self.debug_ops, self.node_name_regex_allowlist,
874             self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist,
875             self.tolerate_debug_op_creation_failures))
876
877
878class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
879  """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions."""
880
881  def __init__(self, sess, watch_fn=None, thread_name_filter=None,
882               pass_through_operrors=False):
883    """Constructor of NonInteractiveDebugWrapperSession.
884
885    Args:
886      sess: The TensorFlow `Session` object being wrapped.
887      watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a
888        debugged `Session.run()` call to `WatchOptions.`
889        * Args:
890          * `fetches`: the fetches to the `Session.run()` call.
891          * `feeds`: the feeds to the `Session.run()` call.
892
893        * Returns:
894         (`tf_debug.WatchOptions`) An object containing debug options including
895           the debug ops to use, the node names, op types and/or tensor data
896           types to watch, etc. See the documentation of `tf_debug.WatchOptions`
897           for more details.
898      thread_name_filter: Regular-expression white list for threads on which the
899        wrapper session will be active. See doc of `BaseDebugWrapperSession` for
900        more details.
901      pass_through_operrors: If true, all captured OpErrors will be
902        propagated.  By default this captures all OpErrors.
903    Raises:
904       TypeError: If a non-None `watch_fn` is specified and it is not callable.
905    """
906
907    BaseDebugWrapperSession.__init__(
908        self, sess, thread_name_filter=thread_name_filter,
909        pass_through_operrors=pass_through_operrors)
910
911    self._watch_fn = None
912    if watch_fn is not None:
913      if not callable(watch_fn):
914        raise TypeError("watch_fn is not callable")
915      self._watch_fn = watch_fn
916
917  def on_session_init(self, request):
918    """See doc of BaseDebugWrapperSession.on_run_start."""
919
920    return OnSessionInitResponse(OnSessionInitAction.PROCEED)
921
922  @abc.abstractmethod
923  def prepare_run_debug_urls(self, fetches, feed_dict):
924    """Abstract method to be implemented by concrete subclasses.
925
926    This method prepares the run-specific debug URL(s).
927
928    Args:
929      fetches: Same as the `fetches` argument to `Session.run()`
930      feed_dict: Same as the `feed_dict` argument to `Session.run()`
931
932    Returns:
933      debug_urls: (`str` or `list` of `str`) Debug URLs to be used in
934        this `Session.run()` call.
935    """
936
937  def on_run_start(self, request):
938    """See doc of BaseDebugWrapperSession.on_run_start."""
939
940    debug_urls, watch_opts = self._prepare_run_watch_config(
941        request.fetches, request.feed_dict)
942
943    return OnRunStartResponse(
944        OnRunStartAction.DEBUG_RUN,
945        debug_urls,
946        debug_ops=watch_opts.debug_ops,
947        node_name_regex_allowlist=watch_opts.node_name_regex_allowlist,
948        op_type_regex_allowlist=watch_opts.op_type_regex_allowlist,
949        tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist,
950        tolerate_debug_op_creation_failures=(
951            watch_opts.tolerate_debug_op_creation_failures))
952
953  def _prepare_run_watch_config(self, fetches, feed_dict):
954    """Get the debug_urls, and node/op allowlists for the current run() call.
955
956    Args:
957      fetches: Same as the `fetches` argument to `Session.run()`.
958      feed_dict: Same as the `feed_dict argument` to `Session.run()`.
959
960    Returns:
961      debug_urls: (str or list of str) Debug URLs for the current run() call.
962        Currently, the list consists of only one URL that is a file:// URL.
963      watch_options: (WatchOptions) The return value of a watch_fn, containing
964        options including debug_ops, and allowlists.
965    """
966
967    debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
968    if self._watch_fn is None:
969      watch_options = WatchOptions()
970    else:
971      watch_options = self._watch_fn(fetches, feed_dict)
972      if isinstance(watch_options, tuple):
973        # For legacy return type (tuples).
974        watch_options = WatchOptions(*watch_options)
975
976    return debug_urls, watch_options
977
978  def on_run_end(self, request):
979    """See doc of BaseDebugWrapperSession.on_run_end."""
980
981    return OnRunEndResponse()
982