1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Framework of debug wrapper sessions. 16 17A debug wrapper session is a wrapper around a TensorFlow Python Session. 18The wrapper preserves the Session interface, most importantly the run() method, 19while providing abilities to: 20a) Intercept a run() call to a wrapped session and insert debug tensor watches 21 according to externally-specified debug URLs. 22 23b) Release control to an external (i.e., non-Session) object before and after 24 the run() call, so that the external object can perform actions such as 25 launching a UI to let users inspect the intermediate tensors and partition 26 graphs from the run() call. 27 28c) (To be implemented in a future CL) Enter an instruction loop to let an 29 external object (e.g., remote client) launch run() and cont() calls 30 remotely. 31 32*** The lifetime of a debug wrapper session: *** 33 341) The wrapper session is created by calling the constructor with a 35 wrapped (normal) session as the argument: 36 wrapper = FooDebugWrapperSession(sess) 37 wherein FooDebugWrapperSession is a concrete subclass implementing the 38 abstract BaseDebugWrapperSession class below. 39 402) Near the end of the constructor call, the on_session_init() callback is 41 invoked, with a OnSessionInitRequest object as the argument. The object 42 carries the wrapped (normal) session object. 43 443) The callback handles the request and returns a OnSessionInitResponse 45 object with an action field, directing the wrapper session what to do next. 46 47If the action field in the OnSessionInitResponse is PROCEED, the constructor 48returns. Control is released back to the caller of the constructor, which can 49invoke run() method of wrapper session with the same syntax as a non-wrapped 50session, e.g.,: 51 wrapper.run(fetches, feed_dict=feeds, options=run_options) 52 53Below, A1 - A2 is the lifetime of a wrapper run() call if the action is 54PROCEED: 55 56A1) Right at the start of each run() call, the on_run_start() callback is 57 invoked, with an OnRunStartRequest object carrying information such as 58 the fetches, the feed dict, the run options and run metadata used in 59 this run call, along with a count of how many run calls has occurred 60 on this wrapper session. The callback then returns an OnRunStartResponse 61 object, of which the action field directs what the wrapper session 62 actually will do of the run() call. 63 64 If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue, 65 with the debug URLs supplied in the debug_urls field of the response. 66 These can be file:// or grpc:// URLs, for example. 67 68 If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue. 69 70A2) Right before the run() returns, the on_run_end() callback is invoked, 71 with an OnRunEndRequest object as the argument, which carries information 72 including the actual action performed in the wrapper run() call and the 73 run_metadata from the run() call. 74 75However, if the action field in OnSessionInitResponse is 76REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop 77that gives the control to a remote caller. 78 79In the remote instruction loop, the following steps will happen: 80 81B1) Callback on_instr_start() is invoked. The callback will return an 82 OnInstrStartResponse object with an action field which can order one of 83 the following actions: 84 i) a run() call with fetches, feeds and debug_urls specified. 85 ii) exit the instruction loop. 86 87B2) The wrapper session carries out the action specified above. 88 89B3) If still in the instruction loop, the wrapper session invokes the 90 on_instr_end() callback. After the on_instr_end() callback returns, jump 91 back to B1. 92 93TODO(cais): Implemented the instruction loop in B1 - B3. 94 95""" 96 97import abc 98import re 99import threading 100 101from tensorflow.core.protobuf import config_pb2 102from tensorflow.python.client import session 103from tensorflow.python.debug.lib import debug_utils 104from tensorflow.python.framework import errors 105from tensorflow.python.framework import ops 106from tensorflow.python.platform import tf_logging 107from tensorflow.python.training import monitored_session 108from tensorflow.python.util import nest 109from tensorflow.python.util.compat import collections_abc 110 111 112# Helper function. 113def _check_type(obj, expected_types): 114 """Check if an object is of the expected type. 115 116 Args: 117 obj: The object being checked. 118 expected_types: (`type` or an iterable of `type`s) The expected `type`(s) 119 of obj. 120 121 Raises: 122 TypeError: If obj is not an instance of expected_type. 123 """ 124 if not isinstance(obj, expected_types): 125 raise TypeError("Expected type %s; got type %s" % 126 (expected_types, type(obj))) 127 128 129class OnSessionInitRequest: 130 """Request to an on-session-init callback. 131 132 This callback is invoked during the __init__ call to a debug-wrapper session. 133 """ 134 135 def __init__(self, sess): 136 """Constructor. 137 138 Args: 139 sess: A tensorflow Session object. 140 """ 141 142 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 143 self.session = sess 144 145 146class OnSessionInitAction: 147 """Enum-like values for possible action to take on session init.""" 148 149 # Proceed, without special actions, in the wrapper session initialization. 150 # What action the wrapper session performs next is determined by the caller 151 # of the wrapper session. E.g., it can call run(). 152 PROCEED = "proceed" 153 154 # Instead of letting the caller of the wrapper session determine what actions 155 # the wrapper session will perform next, enter a loop to receive instructions 156 # from a remote client. 157 # For example, TensorBoard visual debugger can use this action so that it can 158 # launch session.run() calls remotely. 159 REMOTE_INSTR_LOOP = "remote_instr_loop" 160 161 162class OnSessionInitResponse: 163 """Response from an on-session-init callback.""" 164 165 def __init__(self, action): 166 """Constructor. 167 168 Args: 169 action: (`OnSessionInitAction`) Debugger action to take on session init. 170 """ 171 _check_type(action, str) 172 self.action = action 173 174 175class OnRunStartRequest: 176 """Request to an on-run-start callback. 177 178 This callback is invoked during a run() call of the debug-wrapper 179 session, immediately after the run() call counter is incremented. 180 """ 181 182 def __init__(self, fetches, feed_dict, run_options, run_metadata, 183 run_call_count, is_callable_runner=False): 184 """Constructor of `OnRunStartRequest`. 185 186 Args: 187 fetches: Fetch targets of the run() call. 188 feed_dict: The feed dictionary to the run() call. 189 run_options: RunOptions input to the run() call. 190 run_metadata: RunMetadata input to the run() call. 191 The above four arguments are identical to the input arguments to the 192 run() method of a non-wrapped TensorFlow session. 193 run_call_count: 1-based count of how many run calls (including this one) 194 has been invoked. 195 is_callable_runner: (bool) whether a runner returned by 196 Session.make_callable is being run. 197 """ 198 self.fetches = fetches 199 self.feed_dict = feed_dict 200 self.run_options = run_options 201 self.run_metadata = run_metadata 202 self.run_call_count = run_call_count 203 self.is_callable_runner = is_callable_runner 204 205 206class OnRunStartAction: 207 """Enum-like values for possible action to take on start of a run() call.""" 208 209 # Run once with debug tensor-watching. 210 DEBUG_RUN = "debug_run" 211 212 # Run once with profiler. 213 PROFILE_RUN = "profile_run" 214 215 # Run without debug tensor-watching. 216 NON_DEBUG_RUN = "non_debug_run" 217 218 219 220class OnRunStartResponse: 221 """Request from an on-run-start callback. 222 223 The caller of the callback can use this response object to specify what 224 action the debug-wrapper session actually takes on the run() call. 225 """ 226 227 def __init__(self, 228 action, 229 debug_urls, 230 debug_ops="DebugIdentity", 231 node_name_regex_allowlist=None, 232 op_type_regex_allowlist=None, 233 tensor_dtype_regex_allowlist=None, 234 tolerate_debug_op_creation_failures=False): 235 """Constructor of `OnRunStartResponse`. 236 237 Args: 238 action: (`OnRunStartAction`) the action actually taken by the wrapped 239 session for the run() call. 240 debug_urls: (`list` of `str`) debug_urls used in watching the tensors 241 during the run() call. 242 debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the 243 debugger. 244 node_name_regex_allowlist: Regular-expression allowlist for node 245 name. 246 op_type_regex_allowlist: Regular-expression allowlist for op type. 247 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor 248 dtype. 249 tolerate_debug_op_creation_failures: Whether debug op creation failures 250 are to be tolerated. 251 """ 252 253 _check_type(action, str) 254 self.action = action 255 256 _check_type(debug_urls, list) 257 self.debug_urls = debug_urls 258 259 self.debug_ops = debug_ops 260 261 self.node_name_regex_allowlist = node_name_regex_allowlist 262 self.op_type_regex_allowlist = op_type_regex_allowlist 263 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist 264 self.tolerate_debug_op_creation_failures = ( 265 tolerate_debug_op_creation_failures) 266 267 268class OnRunEndRequest: 269 """Request to an on-run-end callback. 270 271 The callback is invoked immediately before the wrapped run() call ends. 272 """ 273 274 def __init__(self, 275 performed_action, 276 run_metadata=None, 277 client_graph_def=None, 278 tf_error=None): 279 """Constructor for `OnRunEndRequest`. 280 281 Args: 282 performed_action: (`OnRunStartAction`) Actually-performed action by the 283 debug-wrapper session. 284 run_metadata: run_metadata output from the run() call (if any). 285 client_graph_def: (GraphDef) GraphDef from the client side, i.e., from 286 the python front end of TensorFlow. Can be obtained with 287 session.graph.as_graph_def(). 288 tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred 289 during the run (if any). 290 """ 291 292 _check_type(performed_action, str) 293 self.performed_action = performed_action 294 295 if run_metadata is not None: 296 _check_type(run_metadata, config_pb2.RunMetadata) 297 self.run_metadata = run_metadata 298 self.client_graph_def = client_graph_def 299 self.tf_error = tf_error 300 301 302class OnRunEndResponse: 303 """Response from an on-run-end callback.""" 304 305 def __init__(self): 306 307 # Currently only a placeholder. 308 pass 309 310 311class BaseDebugWrapperSession(session.SessionInterface, metaclass=abc.ABCMeta): 312 """Base class of debug-wrapper session classes. 313 314 Concrete classes that inherit from this class need to implement the abstract 315 methods such as on_session_init, on_run_start and on_run_end. 316 """ 317 318 def __init__(self, sess, thread_name_filter=None, 319 pass_through_operrors=False): 320 """Constructor of `BaseDebugWrapperSession`. 321 322 Args: 323 sess: An (unwrapped) TensorFlow session instance. It should be a subtype 324 of `BaseSession` or `tf.MonitoredSession`. 325 thread_name_filter: Regular-expression filter (allowlist) for name(s) of 326 thread(s) on which the wrapper session will be active. This regular 327 expression is used in a start-anchored fashion on the thread name, i.e., 328 by applying the `match` method of the compiled pattern. The default 329 `None` means that the wrapper session will be active on all threads. 330 E.g., r"MainThread$", r"QueueRunnerThread.*". 331 pass_through_operrors: If True, all captured OpErrors will be 332 propagated. By default this captures all OpErrors. 333 334 Raises: 335 ValueError: On invalid `OnSessionInitAction` value. 336 NotImplementedError: If a non-DirectSession sess object is received. 337 """ 338 339 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 340 341 # The session being wrapped. 342 self._sess = sess 343 self._thread_name_filter_pattern = (re.compile(thread_name_filter) 344 if thread_name_filter else None) 345 # TODO(cais/kstevens): Unittest this pass through feature. 346 self._pass_through_operrors = pass_through_operrors 347 348 # Keeps track of number of run calls that have been performed on this 349 # debug-wrapper session. The count can be used for purposes such as 350 # displaying the state of the Session in a UI and determining a run 351 # number-dependent debug URL. 352 self._run_call_count = 0 353 354 # Invoke on-session-init callback. 355 response = self.on_session_init(OnSessionInitRequest(self._sess)) 356 _check_type(response, OnSessionInitResponse) 357 358 if response.action == OnSessionInitAction.PROCEED: 359 pass 360 elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP: 361 # TODO(cais): Implement REMOTE_INSTR_LOOP 362 raise NotImplementedError( 363 "OnSessionInitAction REMOTE_INSTR_LOOP has not been " 364 "implemented.") 365 else: 366 raise ValueError( 367 "Invalid OnSessionInitAction value: %s" % response.action) 368 369 self._default_session_context_manager = None 370 371 # A cache for callables created from CallableOptions. 372 self._cached_callables_from_options = {} 373 374 @property 375 def graph(self): 376 return self._sess.graph 377 378 @property 379 def graph_def(self): 380 return self._sess.graph_def 381 382 @property 383 def sess_str(self): 384 return self._sess.sess_str 385 386 @property 387 def session(self): 388 return self._sess 389 390 def run(self, 391 fetches, 392 feed_dict=None, 393 options=None, 394 run_metadata=None, 395 callable_runner=None, 396 callable_runner_args=None, 397 callable_options=None): 398 """Wrapper around Session.run() that inserts tensor watch options. 399 400 Args: 401 fetches: Same as the `fetches` arg to regular `Session.run()`. 402 feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. 403 options: Same as the `options` arg to regular `Session.run()`. 404 run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. 405 callable_runner: A `callable` returned by `Session.make_callable()`. 406 If not `None`, `fetches` and `feed_dict` must both be `None`. 407 Mutually exclusive with `callable_options`. 408 callable_runner_args: An optional list of arguments to `callable_runner` 409 or for `callable_options`. 410 callable_options: An instance of `config_pb2.CallableOptions`, to be 411 used with `Session._make_callable_from_options()`. Mutually exclusive 412 with `callable_runner`. 413 414 Returns: 415 Simply forwards the output of the wrapped `Session.run()` call. 416 417 Raises: 418 ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` 419 is not `None` and either or both of `fetches` and `feed_dict` is `None`. 420 """ 421 if callable_runner and callable_options: 422 raise ValueError( 423 "callable_runner and callable_options are mutually exclusive, but " 424 "are both specified in this call to BaseDebugWrapperSession.run().") 425 426 if callable_runner and (fetches or feed_dict): 427 raise ValueError( 428 "callable_runner and fetches/feed_dict are mutually exclusive, " 429 "but are used simultaneously.") 430 elif callable_options and (fetches or feed_dict): 431 raise ValueError( 432 "callable_options and fetches/feed_dict are mutually exclusive, " 433 "but are used simultaneously.") 434 435 self.increment_run_call_count() 436 437 def is_empty(x): 438 """Check whether a possibly nested structure is empty.""" 439 if not nest.is_nested(x): 440 return False 441 if isinstance(x, collections_abc.Mapping): 442 return is_empty(list(x.values())) 443 for item in x: 444 if not is_empty(item): 445 return False 446 return True 447 448 empty_fetches = is_empty(fetches) 449 if empty_fetches: 450 tf_logging.info( 451 "Due to empty fetches, tfdbg Session wrapper is letting a " 452 "Session.run pass through without any debugging actions.") 453 if self._is_disabled_thread() or empty_fetches: 454 if callable_runner: 455 return callable_runner(*callable_runner_args) 456 elif callable_options: 457 # pylint:disable=protected-access 458 return self._sess._make_callable_from_options( 459 callable_options)(*callable_runner_args) 460 # pylint:enable=protected-access 461 else: 462 return self._sess.run(fetches, 463 feed_dict=feed_dict, 464 options=options, 465 run_metadata=run_metadata) 466 467 # Invoke on-run-start callback and obtain response. 468 run_start_resp = self.on_run_start( 469 OnRunStartRequest(fetches, feed_dict, options, run_metadata, 470 self._run_call_count, 471 is_callable_runner=bool(callable_runner))) 472 _check_type(run_start_resp, OnRunStartResponse) 473 474 if run_start_resp.action == OnRunStartAction.DEBUG_RUN: 475 retvals, run_end_req = self._run_with_debugging( 476 run_start_resp, fetches, feed_dict, options, run_metadata, 477 callable_runner, callable_runner_args, callable_options) 478 elif run_start_resp.action == OnRunStartAction.PROFILE_RUN: 479 retvals, run_end_req = self._run_with_profiling( 480 run_start_resp, fetches, feed_dict, options, run_metadata, 481 callable_runner, callable_runner_args, callable_options) 482 elif run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN: 483 # Invoke run() method of the wrapped session. 484 if callable_runner: 485 retvals = callable_runner(*callable_runner_args) 486 elif callable_options: 487 # pylint:disable=protected-access 488 callable_object = self._sess._make_callable_from_options( 489 callable_options) 490 # pylint:enable=protected-access 491 retvals = callable_object(*callable_runner_args) 492 else: 493 retvals = self._sess.run( 494 fetches, 495 feed_dict=feed_dict, 496 options=options, 497 run_metadata=run_metadata) 498 499 # Prepare arg for the on-run-end callback. 500 run_end_req = OnRunEndRequest(run_start_resp.action) 501 else: 502 raise ValueError( 503 "Invalid OnRunStartAction value: %s" % run_start_resp.action) 504 505 # Invoke on-run-end callback and obtain response. 506 run_end_resp = self.on_run_end(run_end_req) 507 _check_type(run_end_resp, OnRunEndResponse) 508 # Currently run_end_resp is only a placeholder. No action is taken on it. 509 510 return retvals 511 512 def _run_with_debugging(self, 513 run_start_resp, 514 fetches, 515 feed_dict, 516 options, 517 run_metadata, 518 callable_runner, 519 callable_runner_args, 520 callable_options): 521 """Perform a session.run() or callable with debugging.""" 522 # Decorate RunOption to fill in debugger tensor watch specifications. 523 decorated_run_options = None 524 if callable_options: 525 callable_options_id = id(callable_options) 526 if callable_options_id not in self._cached_callables_from_options: 527 # Make a copy of callable_options to avoid mutating it. 528 new_callable_options = config_pb2.CallableOptions() 529 new_callable_options.CopyFrom(callable_options) 530 decorated_run_options = new_callable_options.run_options 531 else: 532 decorated_run_options = options or config_pb2.RunOptions() 533 534 run_metadata = run_metadata or config_pb2.RunMetadata() 535 536 if decorated_run_options: 537 self._decorate_run_options_for_debug( 538 decorated_run_options, 539 run_start_resp.debug_urls, 540 debug_ops=run_start_resp.debug_ops, 541 node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist), 542 op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist, 543 tensor_dtype_regex_allowlist=( 544 run_start_resp.tensor_dtype_regex_allowlist), 545 tolerate_debug_op_creation_failures=( 546 run_start_resp.tolerate_debug_op_creation_failures)) 547 548 # Invoke the run() method of the wrapped Session. Catch any TensorFlow 549 # runtime errors. 550 tf_error = None 551 try: 552 if callable_runner: 553 retvals = callable_runner(*callable_runner_args, 554 options=decorated_run_options, 555 run_metadata=run_metadata) 556 elif callable_options: 557 # pylint:disable=protected-access 558 if callable_options_id in self._cached_callables_from_options: 559 callable_object = self._cached_callables_from_options[ 560 callable_options_id] 561 else: 562 callable_object = self._sess._make_callable_from_options( 563 new_callable_options) 564 self._cached_callables_from_options[ 565 callable_options_id] = callable_object 566 # pylint:enable=protected-access 567 retvals = callable_object( 568 *callable_runner_args, run_metadata=run_metadata) 569 else: 570 retvals = self._sess.run(fetches, 571 feed_dict=feed_dict, 572 options=decorated_run_options, 573 run_metadata=run_metadata) 574 except errors.OpError as op_error: 575 if self._pass_through_operrors: 576 raise op_error 577 tf_error = op_error 578 retvals = op_error 579 580 return retvals, OnRunEndRequest( 581 run_start_resp.action, 582 run_metadata=run_metadata, 583 client_graph_def=self._sess.graph.as_graph_def(), 584 tf_error=tf_error) 585 586 def _run_with_profiling(self, 587 run_start_resp, 588 fetches, 589 feed_dict, 590 options, 591 run_metadata, 592 callable_runner, 593 callable_runner_args, 594 callable_options): 595 """Perform a session.run() or callable with profiling.""" 596 # Decorate RunOption to fill in debugger tensor watch specifications. 597 decorated_run_options = None 598 if callable_options: 599 callable_options_id = id(callable_options) 600 if callable_options_id not in self._cached_callables_from_options: 601 # Make a copy of callable_options to avoid mutating it. 602 new_callable_options = config_pb2.CallableOptions() 603 new_callable_options.CopyFrom(callable_options) 604 decorated_run_options = new_callable_options.run_options 605 else: 606 decorated_run_options = options or config_pb2.RunOptions() 607 self._decorate_run_options_for_profile(decorated_run_options) 608 609 run_metadata = run_metadata or config_pb2.RunMetadata() 610 if callable_runner: 611 retvals = callable_runner(*callable_runner_args, 612 options=decorated_run_options, 613 run_metadata=run_metadata) 614 elif callable_options: 615 # pylint:disable=protected-access 616 callable_object = self._sess._make_callable_from_options( 617 new_callable_options) 618 # pylint:enable=protected-access 619 retvals = callable_object( 620 *callable_runner_args, run_metadata=run_metadata) 621 else: 622 retvals = self._sess.run(fetches, 623 feed_dict=feed_dict, 624 options=decorated_run_options, 625 run_metadata=run_metadata) 626 return retvals, OnRunEndRequest( 627 run_start_resp.action, 628 run_metadata=run_metadata, 629 client_graph_def=self._sess.graph.as_graph_def()) 630 631 def _is_disabled_thread(self): 632 thread_name = threading.current_thread().name or "" 633 return (self._thread_name_filter_pattern and 634 not self._thread_name_filter_pattern.match(thread_name)) 635 636 def run_step_fn(self, step_fn): 637 return step_fn( 638 monitored_session.MonitoredSession.StepContext(self._sess, self.run)) 639 640 def partial_run_setup(self, fetches, feeds=None): 641 """Sets up the feeds and fetches for partial runs in the session.""" 642 raise NotImplementedError( 643 "partial_run_setup is not implemented for debug-wrapper sessions.") 644 645 def partial_run(self, handle, fetches, feed_dict=None): 646 raise NotImplementedError( 647 "partial_run is not implemented for debug-wrapper sessions.") 648 649 def list_devices(self, *args, **kwargs): 650 return self._sess.list_devices(*args, **kwargs) 651 652 def reset(self, *args, **kwargs): 653 return self._sess.reset(*args, **kwargs) 654 655 def make_callable(self, 656 fetches, 657 feed_list=None, 658 accept_options=False): 659 runner = self._sess.make_callable( 660 fetches, feed_list=feed_list, accept_options=True) 661 def wrapped_runner(*runner_args, **kwargs): 662 return self.run(None, 663 feed_dict=None, 664 options=kwargs.get("options", None), 665 run_metadata=kwargs.get("run_metadata", None), 666 callable_runner=runner, 667 callable_runner_args=runner_args) 668 return wrapped_runner 669 670 def _make_callable_from_options(self, callable_options): 671 def wrapped_runner(*feed_values, **kwargs): 672 return self.run(None, 673 run_metadata=kwargs.get("run_metadata", None), 674 callable_options=callable_options, 675 callable_runner_args=feed_values) 676 return wrapped_runner 677 678 @property 679 def run_call_count(self): 680 return self._run_call_count 681 682 def increment_run_call_count(self): 683 self._run_call_count += 1 684 685 def _is_disk_usage_reset_each_run(self): 686 """Indicates whether disk usage is reset after each Session.run. 687 688 Subclasses that clean up the disk usage after every run should 689 override this protected method. 690 691 Returns: 692 (`bool`) Whether the disk usage amount is reset to zero after 693 each Session.run. 694 """ 695 return False 696 697 def _decorate_run_options_for_debug( 698 self, 699 run_options, 700 debug_urls, 701 debug_ops="DebugIdentity", 702 node_name_regex_allowlist=None, 703 op_type_regex_allowlist=None, 704 tensor_dtype_regex_allowlist=None, 705 tolerate_debug_op_creation_failures=False): 706 """Modify a RunOptions object for debug tensor watching. 707 708 Specifies request for outputting partition graphs. Adds 709 debug_tensor_watch_opts with proper debug URLs. 710 711 Args: 712 run_options: (RunOptions) the modified RunOptions object. 713 debug_urls: (list of str) debug URLs to be entered in run_options. 714 debug_tensor_watch_opts. 715 debug_ops: (str or list of str) debug op(s) to be used by the debugger. 716 node_name_regex_allowlist: Regular-expression allowlist for node 717 name. 718 op_type_regex_allowlist: Regular-expression allowlist for op type. 719 tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor 720 dtype. 721 tolerate_debug_op_creation_failures: Whether debug op creation failures 722 are to be tolerated. 723 """ 724 725 run_options.output_partition_graphs = True 726 debug_utils.watch_graph( 727 run_options, 728 self._sess.graph, 729 debug_urls=debug_urls, 730 debug_ops=debug_ops, 731 node_name_regex_allowlist=node_name_regex_allowlist, 732 op_type_regex_allowlist=op_type_regex_allowlist, 733 tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist, 734 tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures, 735 reset_disk_byte_usage=(self._run_call_count == 1 or 736 self._is_disk_usage_reset_each_run())) 737 738 def _decorate_run_options_for_profile(self, run_options): 739 """Modify a RunOptions object for profiling TensorFlow graph execution. 740 741 Args: 742 run_options: (RunOptions) the modified RunOptions object. 743 """ 744 745 run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 746 747 @abc.abstractmethod 748 def on_session_init(self, request): 749 """Callback invoked during construction of the debug-wrapper session. 750 751 This is a blocking callback. 752 The invocation happens right before the constructor ends. 753 754 Args: 755 request: (`OnSessionInitRequest`) callback request carrying information 756 such as the session being wrapped. 757 758 Returns: 759 An instance of `OnSessionInitResponse`. 760 """ 761 762 @abc.abstractmethod 763 def on_run_start(self, request): 764 """Callback invoked on run() calls to the debug-wrapper session. 765 766 This is a blocking callback. 767 The invocation happens after the wrapper's run() call is entered, 768 after an increment of run call counter. 769 770 Args: 771 request: (`OnRunStartRequest`) callback request object carrying 772 information about the run call such as the fetches, feed dict, run 773 options, run metadata, and how many `run()` calls to this wrapper 774 session have occurred. 775 776 Returns: 777 An instance of `OnRunStartResponse`, carrying information to 778 debug URLs used to watch the tensors. 779 """ 780 781 @abc.abstractmethod 782 def on_run_end(self, request): 783 """Callback invoked on run() calls to the debug-wrapper session. 784 785 This is a blocking callback. 786 The invocation happens right before the wrapper exits its run() call. 787 788 Args: 789 request: (`OnRunEndRequest`) callback request object carrying information 790 such as the actual action performed by the session wrapper for the 791 run() call. 792 793 Returns: 794 An instance of `OnRunStartResponse`. 795 """ 796 797 def as_default(self): 798 return ops.default_session(self) 799 800 def __enter__(self): 801 if self._default_session_context_manager is None: 802 self._default_session_context_manager = self.as_default() 803 return self._default_session_context_manager.__enter__() 804 805 def __exit__(self, exec_type, exec_value, exec_tb): 806 self._default_session_context_manager.__exit__( 807 exec_type, exec_value, exec_tb) 808 809 def __del__(self): 810 if hasattr(self._sess, "__del__"): 811 self._sess.__del__() 812 813 def close(self): 814 self._sess.close() 815 816 # TODO(cais): Add _node_name_regex_allowlist and 817 # _node_op_type_regex_allowlist. 818 819 def should_stop(self): 820 if hasattr(self._sess, "should_stop"): 821 return self._sess.should_stop() 822 else: 823 raise ValueError( 824 "The wrapped session %r does not have a method called 'should_stop'. " 825 "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess) 826 827 828class WatchOptions: 829 """Type for return values of watch_fn.""" 830 831 def __init__(self, 832 debug_ops=None, 833 node_name_regex_allowlist=None, 834 op_type_regex_allowlist=None, 835 tensor_dtype_regex_allowlist=None, 836 tolerate_debug_op_creation_failures=False): 837 """Constructor of WatchOptions: Debug watch options. 838 839 Used as return values of `watch_fn`s. 840 841 Args: 842 debug_ops: (`str` or `list of str`) Debug ops to be used. 843 node_name_regex_allowlist: Regular-expression allowlist for node_name, 844 e.g., `"(weight_[0-9]+|bias_.*)"` 845 op_type_regex_allowlist: Regular-expression allowlist for the op type of 846 nodes, e.g., `"(Variable|Add)"`. 847 If both `node_name_regex_allowlist` and `op_type_regex_allowlist` 848 are set, the two filtering operations will occur in a logical `AND` 849 relation. In other words, a node will be included if and only if it 850 hits both allowlists. 851 tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor 852 data type, e.g., `"^int.*"`. 853 This allowlist operates in logical `AND` relations to the two allowlists 854 above. 855 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 856 failures (e.g., due to dtype incompatibility) are to be tolerated by not 857 throwing exceptions. 858 """ 859 if debug_ops: 860 self.debug_ops = debug_ops 861 else: 862 self.debug_ops = ["DebugIdentity"] 863 self.node_name_regex_allowlist = node_name_regex_allowlist 864 self.op_type_regex_allowlist = op_type_regex_allowlist 865 self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist 866 self.tolerate_debug_op_creation_failures = ( 867 tolerate_debug_op_creation_failures) 868 869 def __repr__(self): 870 return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, " 871 "op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, " 872 "tolerate_debug_op_creation_failures=%r)" % 873 (self.debug_ops, self.node_name_regex_allowlist, 874 self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist, 875 self.tolerate_debug_op_creation_failures)) 876 877 878class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): 879 """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions.""" 880 881 def __init__(self, sess, watch_fn=None, thread_name_filter=None, 882 pass_through_operrors=False): 883 """Constructor of NonInteractiveDebugWrapperSession. 884 885 Args: 886 sess: The TensorFlow `Session` object being wrapped. 887 watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a 888 debugged `Session.run()` call to `WatchOptions.` 889 * Args: 890 * `fetches`: the fetches to the `Session.run()` call. 891 * `feeds`: the feeds to the `Session.run()` call. 892 893 * Returns: 894 (`tf_debug.WatchOptions`) An object containing debug options including 895 the debug ops to use, the node names, op types and/or tensor data 896 types to watch, etc. See the documentation of `tf_debug.WatchOptions` 897 for more details. 898 thread_name_filter: Regular-expression white list for threads on which the 899 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 900 more details. 901 pass_through_operrors: If true, all captured OpErrors will be 902 propagated. By default this captures all OpErrors. 903 Raises: 904 TypeError: If a non-None `watch_fn` is specified and it is not callable. 905 """ 906 907 BaseDebugWrapperSession.__init__( 908 self, sess, thread_name_filter=thread_name_filter, 909 pass_through_operrors=pass_through_operrors) 910 911 self._watch_fn = None 912 if watch_fn is not None: 913 if not callable(watch_fn): 914 raise TypeError("watch_fn is not callable") 915 self._watch_fn = watch_fn 916 917 def on_session_init(self, request): 918 """See doc of BaseDebugWrapperSession.on_run_start.""" 919 920 return OnSessionInitResponse(OnSessionInitAction.PROCEED) 921 922 @abc.abstractmethod 923 def prepare_run_debug_urls(self, fetches, feed_dict): 924 """Abstract method to be implemented by concrete subclasses. 925 926 This method prepares the run-specific debug URL(s). 927 928 Args: 929 fetches: Same as the `fetches` argument to `Session.run()` 930 feed_dict: Same as the `feed_dict` argument to `Session.run()` 931 932 Returns: 933 debug_urls: (`str` or `list` of `str`) Debug URLs to be used in 934 this `Session.run()` call. 935 """ 936 937 def on_run_start(self, request): 938 """See doc of BaseDebugWrapperSession.on_run_start.""" 939 940 debug_urls, watch_opts = self._prepare_run_watch_config( 941 request.fetches, request.feed_dict) 942 943 return OnRunStartResponse( 944 OnRunStartAction.DEBUG_RUN, 945 debug_urls, 946 debug_ops=watch_opts.debug_ops, 947 node_name_regex_allowlist=watch_opts.node_name_regex_allowlist, 948 op_type_regex_allowlist=watch_opts.op_type_regex_allowlist, 949 tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist, 950 tolerate_debug_op_creation_failures=( 951 watch_opts.tolerate_debug_op_creation_failures)) 952 953 def _prepare_run_watch_config(self, fetches, feed_dict): 954 """Get the debug_urls, and node/op allowlists for the current run() call. 955 956 Args: 957 fetches: Same as the `fetches` argument to `Session.run()`. 958 feed_dict: Same as the `feed_dict argument` to `Session.run()`. 959 960 Returns: 961 debug_urls: (str or list of str) Debug URLs for the current run() call. 962 Currently, the list consists of only one URL that is a file:// URL. 963 watch_options: (WatchOptions) The return value of a watch_fn, containing 964 options including debug_ops, and allowlists. 965 """ 966 967 debug_urls = self.prepare_run_debug_urls(fetches, feed_dict) 968 if self._watch_fn is None: 969 watch_options = WatchOptions() 970 else: 971 watch_options = self._watch_fn(fetches, feed_dict) 972 if isinstance(watch_options, tuple): 973 # For legacy return type (tuples). 974 watch_options = WatchOptions(*watch_options) 975 976 return debug_urls, watch_options 977 978 def on_run_end(self, request): 979 """See doc of BaseDebugWrapperSession.on_run_end.""" 980 981 return OnRunEndResponse() 982