xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/debug_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Debugger (tfdbg) Utilities."""
16
17import re
18
19
20
21def add_debug_tensor_watch(run_options,
22                           node_name,
23                           output_slot=0,
24                           debug_ops="DebugIdentity",
25                           debug_urls=None,
26                           tolerate_debug_op_creation_failures=False,
27                           global_step=-1):
28  """Add watch on a `Tensor` to `RunOptions`.
29
30  N.B.:
31    1. Under certain circumstances, the `Tensor` may not get actually watched
32      (e.g., if the node of the `Tensor` is constant-folded during runtime).
33    2. For debugging purposes, the `parallel_iteration` attribute of all
34      `tf.while_loop`s in the graph are set to 1 to prevent any node from
35      being executed multiple times concurrently. This change does not affect
36      subsequent non-debugged runs of the same `tf.while_loop`s.
37
38  Args:
39    run_options: An instance of `config_pb2.RunOptions` to be modified.
40    node_name: (`str`) name of the node to watch.
41    output_slot: (`int`) output slot index of the tensor from the watched node.
42    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s). Can be a
43      `list` of `str` or a single `str`. The latter case is equivalent to a
44      `list` of `str` with only one element.
45      For debug op types with customizable attributes, each debug op string can
46      optionally contain a list of attribute names, in the syntax of:
47        debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
48    debug_urls: (`str` or `list` of `str`) URL(s) to send debug values to,
49      e.g., `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
50    tolerate_debug_op_creation_failures: (`bool`) Whether to tolerate debug op
51      creation failures by not throwing exceptions.
52    global_step: (`int`) Optional global_step count for this debug tensor
53      watch.
54  """
55
56  watch_opts = run_options.debug_options.debug_tensor_watch_opts
57  run_options.debug_options.global_step = global_step
58
59  watch = watch_opts.add()
60  watch.tolerate_debug_op_creation_failures = (
61      tolerate_debug_op_creation_failures)
62  watch.node_name = node_name
63  watch.output_slot = output_slot
64
65  if isinstance(debug_ops, str):
66    debug_ops = [debug_ops]
67
68  watch.debug_ops.extend(debug_ops)
69
70  if debug_urls:
71    if isinstance(debug_urls, str):
72      debug_urls = [debug_urls]
73
74    watch.debug_urls.extend(debug_urls)
75
76
77def watch_graph(run_options,
78                graph,
79                debug_ops="DebugIdentity",
80                debug_urls=None,
81                node_name_regex_allowlist=None,
82                op_type_regex_allowlist=None,
83                tensor_dtype_regex_allowlist=None,
84                tolerate_debug_op_creation_failures=False,
85                global_step=-1,
86                reset_disk_byte_usage=False):
87  """Add debug watches to `RunOptions` for a TensorFlow graph.
88
89  To watch all `Tensor`s on the graph, let both `node_name_regex_allowlist`
90  and `op_type_regex_allowlist` be the default (`None`).
91
92  N.B.:
93    1. Under certain circumstances, the `Tensor` may not get actually watched
94      (e.g., if the node of the `Tensor` is constant-folded during runtime).
95    2. For debugging purposes, the `parallel_iteration` attribute of all
96      `tf.while_loop`s in the graph are set to 1 to prevent any node from
97      being executed multiple times concurrently. This change does not affect
98      subsequent non-debugged runs of the same `tf.while_loop`s.
99
100
101  Args:
102    run_options: An instance of `config_pb2.RunOptions` to be modified.
103    graph: An instance of `ops.Graph`.
104    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
105    debug_urls: URLs to send debug values to. Can be a list of strings,
106      a single string, or None. The case of a single string is equivalent to
107      a list consisting of a single string, e.g., `file:///tmp/tfdbg_dump_1`,
108      `grpc://localhost:12345`.
109      For debug op types with customizable attributes, each debug op name string
110      can optionally contain a list of attribute names, in the syntax of:
111        debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
112    node_name_regex_allowlist: Regular-expression allowlist for node_name,
113      e.g., `"(weight_[0-9]+|bias_.*)"`
114    op_type_regex_allowlist: Regular-expression allowlist for the op type of
115      nodes, e.g., `"(Variable|Add)"`.
116      If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
117      are set, the two filtering operations will occur in a logical `AND`
118      relation. In other words, a node will be included if and only if it
119      hits both allowlists.
120    tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
121      data type, e.g., `"^int.*"`.
122      This allowlist operates in logical `AND` relations to the two allowlists
123      above.
124    tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
125      failures (e.g., due to dtype incompatibility) are to be tolerated by not
126      throwing exceptions.
127    global_step: (`int`) Optional global_step count for this debug tensor
128      watch.
129    reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
130      usage to zero (default: `False`).
131  """
132  if not debug_ops:
133    raise ValueError("debug_ops must not be empty or None.")
134  if not debug_urls:
135    raise ValueError("debug_urls must not be empty or None.")
136
137  if isinstance(debug_ops, str):
138    debug_ops = [debug_ops]
139
140  node_name_pattern = (
141      re.compile(node_name_regex_allowlist)
142      if node_name_regex_allowlist else None)
143  op_type_pattern = (
144      re.compile(op_type_regex_allowlist) if op_type_regex_allowlist else None)
145  tensor_dtype_pattern = (
146      re.compile(tensor_dtype_regex_allowlist)
147      if tensor_dtype_regex_allowlist else None)
148
149  ops = graph.get_operations()
150  for op in ops:
151    # Skip nodes without any output tensors.
152    if not op.outputs:
153      continue
154
155    node_name = op.name
156    op_type = op.type
157
158    if node_name_pattern and not node_name_pattern.match(node_name):
159      continue
160    if op_type_pattern and not op_type_pattern.match(op_type):
161      continue
162
163    for slot in range(len(op.outputs)):
164      if (tensor_dtype_pattern and
165          not tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
166        continue
167
168      add_debug_tensor_watch(
169          run_options,
170          node_name,
171          output_slot=slot,
172          debug_ops=debug_ops,
173          debug_urls=debug_urls,
174          tolerate_debug_op_creation_failures=(
175              tolerate_debug_op_creation_failures),
176          global_step=global_step)
177
178  # If no filter for node or tensor is used, will add a wildcard node name, so
179  # that all nodes, including the ones created internally by TensorFlow itself
180  # (e.g., by Grappler), can be watched during debugging.
181  use_node_name_wildcard = (not node_name_pattern and
182                            not op_type_pattern and
183                            not tensor_dtype_pattern)
184  if use_node_name_wildcard:
185    add_debug_tensor_watch(
186        run_options,
187        "*",
188        output_slot=-1,
189        debug_ops=debug_ops,
190        debug_urls=debug_urls,
191        tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
192        global_step=global_step)
193
194  run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
195
196
197def watch_graph_with_denylists(run_options,
198                               graph,
199                               debug_ops="DebugIdentity",
200                               debug_urls=None,
201                               node_name_regex_denylist=None,
202                               op_type_regex_denylist=None,
203                               tensor_dtype_regex_denylist=None,
204                               tolerate_debug_op_creation_failures=False,
205                               global_step=-1,
206                               reset_disk_byte_usage=False):
207  """Add debug tensor watches, denylisting nodes and op types.
208
209  This is similar to `watch_graph()`, but the node names and op types are
210  denylisted, instead of allowlisted.
211
212  N.B.:
213    1. Under certain circumstances, the `Tensor` may not get actually watched
214      (e.g., if the node of the `Tensor` is constant-folded during runtime).
215    2. For debugging purposes, the `parallel_iteration` attribute of all
216      `tf.while_loop`s in the graph are set to 1 to prevent any node from
217      being executed multiple times concurrently. This change does not affect
218      subsequent non-debugged runs of the same `tf.while_loop`s.
219
220  Args:
221    run_options: An instance of `config_pb2.RunOptions` to be modified.
222    graph: An instance of `ops.Graph`.
223    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use. See
224      the documentation of `watch_graph` for more details.
225    debug_urls: URL(s) to send debug values to, e.g.,
226      `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
227    node_name_regex_denylist: Regular-expression denylist for node_name. This
228      should be a string, e.g., `"(weight_[0-9]+|bias_.*)"`.
229    op_type_regex_denylist: Regular-expression denylist for the op type of
230      nodes, e.g., `"(Variable|Add)"`. If both node_name_regex_denylist and
231      op_type_regex_denylist are set, the two filtering operations will occur in
232      a logical `OR` relation. In other words, a node will be excluded if it
233      hits either of the two denylists; a node will be included if and only if
234      it hits neither of the denylists.
235    tensor_dtype_regex_denylist: Regular-expression denylist for Tensor data
236      type, e.g., `"^int.*"`. This denylist operates in logical `OR` relations
237      to the two allowlists above.
238    tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
239      failures (e.g., due to dtype incompatibility) are to be tolerated by not
240      throwing exceptions.
241    global_step: (`int`) Optional global_step count for this debug tensor watch.
242    reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
243      usage to zero (default: `False`).
244  """
245
246  if isinstance(debug_ops, str):
247    debug_ops = [debug_ops]
248
249  node_name_pattern = (
250      re.compile(node_name_regex_denylist)
251      if node_name_regex_denylist else None)
252  op_type_pattern = (
253      re.compile(op_type_regex_denylist) if op_type_regex_denylist else None)
254  tensor_dtype_pattern = (
255      re.compile(tensor_dtype_regex_denylist)
256      if tensor_dtype_regex_denylist else None)
257
258  ops = graph.get_operations()
259  for op in ops:
260    # Skip nodes without any output tensors.
261    if not op.outputs:
262      continue
263
264    node_name = op.name
265    op_type = op.type
266
267    if node_name_pattern and node_name_pattern.match(node_name):
268      continue
269    if op_type_pattern and op_type_pattern.match(op_type):
270      continue
271
272    for slot in range(len(op.outputs)):
273      if (tensor_dtype_pattern and
274          tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
275        continue
276
277      add_debug_tensor_watch(
278          run_options,
279          node_name,
280          output_slot=slot,
281          debug_ops=debug_ops,
282          debug_urls=debug_urls,
283          tolerate_debug_op_creation_failures=(
284              tolerate_debug_op_creation_failures),
285          global_step=global_step)
286    run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
287