1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Debugger (tfdbg) Utilities.""" 16 17import re 18 19 20 21def add_debug_tensor_watch(run_options, 22 node_name, 23 output_slot=0, 24 debug_ops="DebugIdentity", 25 debug_urls=None, 26 tolerate_debug_op_creation_failures=False, 27 global_step=-1): 28 """Add watch on a `Tensor` to `RunOptions`. 29 30 N.B.: 31 1. Under certain circumstances, the `Tensor` may not get actually watched 32 (e.g., if the node of the `Tensor` is constant-folded during runtime). 33 2. For debugging purposes, the `parallel_iteration` attribute of all 34 `tf.while_loop`s in the graph are set to 1 to prevent any node from 35 being executed multiple times concurrently. This change does not affect 36 subsequent non-debugged runs of the same `tf.while_loop`s. 37 38 Args: 39 run_options: An instance of `config_pb2.RunOptions` to be modified. 40 node_name: (`str`) name of the node to watch. 41 output_slot: (`int`) output slot index of the tensor from the watched node. 42 debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s). Can be a 43 `list` of `str` or a single `str`. The latter case is equivalent to a 44 `list` of `str` with only one element. 45 For debug op types with customizable attributes, each debug op string can 46 optionally contain a list of attribute names, in the syntax of: 47 debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...) 48 debug_urls: (`str` or `list` of `str`) URL(s) to send debug values to, 49 e.g., `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`. 50 tolerate_debug_op_creation_failures: (`bool`) Whether to tolerate debug op 51 creation failures by not throwing exceptions. 52 global_step: (`int`) Optional global_step count for this debug tensor 53 watch. 54 """ 55 56 watch_opts = run_options.debug_options.debug_tensor_watch_opts 57 run_options.debug_options.global_step = global_step 58 59 watch = watch_opts.add() 60 watch.tolerate_debug_op_creation_failures = ( 61 tolerate_debug_op_creation_failures) 62 watch.node_name = node_name 63 watch.output_slot = output_slot 64 65 if isinstance(debug_ops, str): 66 debug_ops = [debug_ops] 67 68 watch.debug_ops.extend(debug_ops) 69 70 if debug_urls: 71 if isinstance(debug_urls, str): 72 debug_urls = [debug_urls] 73 74 watch.debug_urls.extend(debug_urls) 75 76 77def watch_graph(run_options, 78 graph, 79 debug_ops="DebugIdentity", 80 debug_urls=None, 81 node_name_regex_allowlist=None, 82 op_type_regex_allowlist=None, 83 tensor_dtype_regex_allowlist=None, 84 tolerate_debug_op_creation_failures=False, 85 global_step=-1, 86 reset_disk_byte_usage=False): 87 """Add debug watches to `RunOptions` for a TensorFlow graph. 88 89 To watch all `Tensor`s on the graph, let both `node_name_regex_allowlist` 90 and `op_type_regex_allowlist` be the default (`None`). 91 92 N.B.: 93 1. Under certain circumstances, the `Tensor` may not get actually watched 94 (e.g., if the node of the `Tensor` is constant-folded during runtime). 95 2. For debugging purposes, the `parallel_iteration` attribute of all 96 `tf.while_loop`s in the graph are set to 1 to prevent any node from 97 being executed multiple times concurrently. This change does not affect 98 subsequent non-debugged runs of the same `tf.while_loop`s. 99 100 101 Args: 102 run_options: An instance of `config_pb2.RunOptions` to be modified. 103 graph: An instance of `ops.Graph`. 104 debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use. 105 debug_urls: URLs to send debug values to. Can be a list of strings, 106 a single string, or None. The case of a single string is equivalent to 107 a list consisting of a single string, e.g., `file:///tmp/tfdbg_dump_1`, 108 `grpc://localhost:12345`. 109 For debug op types with customizable attributes, each debug op name string 110 can optionally contain a list of attribute names, in the syntax of: 111 debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...) 112 node_name_regex_allowlist: Regular-expression allowlist for node_name, 113 e.g., `"(weight_[0-9]+|bias_.*)"` 114 op_type_regex_allowlist: Regular-expression allowlist for the op type of 115 nodes, e.g., `"(Variable|Add)"`. 116 If both `node_name_regex_allowlist` and `op_type_regex_allowlist` 117 are set, the two filtering operations will occur in a logical `AND` 118 relation. In other words, a node will be included if and only if it 119 hits both allowlists. 120 tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor 121 data type, e.g., `"^int.*"`. 122 This allowlist operates in logical `AND` relations to the two allowlists 123 above. 124 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 125 failures (e.g., due to dtype incompatibility) are to be tolerated by not 126 throwing exceptions. 127 global_step: (`int`) Optional global_step count for this debug tensor 128 watch. 129 reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte 130 usage to zero (default: `False`). 131 """ 132 if not debug_ops: 133 raise ValueError("debug_ops must not be empty or None.") 134 if not debug_urls: 135 raise ValueError("debug_urls must not be empty or None.") 136 137 if isinstance(debug_ops, str): 138 debug_ops = [debug_ops] 139 140 node_name_pattern = ( 141 re.compile(node_name_regex_allowlist) 142 if node_name_regex_allowlist else None) 143 op_type_pattern = ( 144 re.compile(op_type_regex_allowlist) if op_type_regex_allowlist else None) 145 tensor_dtype_pattern = ( 146 re.compile(tensor_dtype_regex_allowlist) 147 if tensor_dtype_regex_allowlist else None) 148 149 ops = graph.get_operations() 150 for op in ops: 151 # Skip nodes without any output tensors. 152 if not op.outputs: 153 continue 154 155 node_name = op.name 156 op_type = op.type 157 158 if node_name_pattern and not node_name_pattern.match(node_name): 159 continue 160 if op_type_pattern and not op_type_pattern.match(op_type): 161 continue 162 163 for slot in range(len(op.outputs)): 164 if (tensor_dtype_pattern and 165 not tensor_dtype_pattern.match(op.outputs[slot].dtype.name)): 166 continue 167 168 add_debug_tensor_watch( 169 run_options, 170 node_name, 171 output_slot=slot, 172 debug_ops=debug_ops, 173 debug_urls=debug_urls, 174 tolerate_debug_op_creation_failures=( 175 tolerate_debug_op_creation_failures), 176 global_step=global_step) 177 178 # If no filter for node or tensor is used, will add a wildcard node name, so 179 # that all nodes, including the ones created internally by TensorFlow itself 180 # (e.g., by Grappler), can be watched during debugging. 181 use_node_name_wildcard = (not node_name_pattern and 182 not op_type_pattern and 183 not tensor_dtype_pattern) 184 if use_node_name_wildcard: 185 add_debug_tensor_watch( 186 run_options, 187 "*", 188 output_slot=-1, 189 debug_ops=debug_ops, 190 debug_urls=debug_urls, 191 tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures, 192 global_step=global_step) 193 194 run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage 195 196 197def watch_graph_with_denylists(run_options, 198 graph, 199 debug_ops="DebugIdentity", 200 debug_urls=None, 201 node_name_regex_denylist=None, 202 op_type_regex_denylist=None, 203 tensor_dtype_regex_denylist=None, 204 tolerate_debug_op_creation_failures=False, 205 global_step=-1, 206 reset_disk_byte_usage=False): 207 """Add debug tensor watches, denylisting nodes and op types. 208 209 This is similar to `watch_graph()`, but the node names and op types are 210 denylisted, instead of allowlisted. 211 212 N.B.: 213 1. Under certain circumstances, the `Tensor` may not get actually watched 214 (e.g., if the node of the `Tensor` is constant-folded during runtime). 215 2. For debugging purposes, the `parallel_iteration` attribute of all 216 `tf.while_loop`s in the graph are set to 1 to prevent any node from 217 being executed multiple times concurrently. This change does not affect 218 subsequent non-debugged runs of the same `tf.while_loop`s. 219 220 Args: 221 run_options: An instance of `config_pb2.RunOptions` to be modified. 222 graph: An instance of `ops.Graph`. 223 debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use. See 224 the documentation of `watch_graph` for more details. 225 debug_urls: URL(s) to send debug values to, e.g., 226 `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`. 227 node_name_regex_denylist: Regular-expression denylist for node_name. This 228 should be a string, e.g., `"(weight_[0-9]+|bias_.*)"`. 229 op_type_regex_denylist: Regular-expression denylist for the op type of 230 nodes, e.g., `"(Variable|Add)"`. If both node_name_regex_denylist and 231 op_type_regex_denylist are set, the two filtering operations will occur in 232 a logical `OR` relation. In other words, a node will be excluded if it 233 hits either of the two denylists; a node will be included if and only if 234 it hits neither of the denylists. 235 tensor_dtype_regex_denylist: Regular-expression denylist for Tensor data 236 type, e.g., `"^int.*"`. This denylist operates in logical `OR` relations 237 to the two allowlists above. 238 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 239 failures (e.g., due to dtype incompatibility) are to be tolerated by not 240 throwing exceptions. 241 global_step: (`int`) Optional global_step count for this debug tensor watch. 242 reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte 243 usage to zero (default: `False`). 244 """ 245 246 if isinstance(debug_ops, str): 247 debug_ops = [debug_ops] 248 249 node_name_pattern = ( 250 re.compile(node_name_regex_denylist) 251 if node_name_regex_denylist else None) 252 op_type_pattern = ( 253 re.compile(op_type_regex_denylist) if op_type_regex_denylist else None) 254 tensor_dtype_pattern = ( 255 re.compile(tensor_dtype_regex_denylist) 256 if tensor_dtype_regex_denylist else None) 257 258 ops = graph.get_operations() 259 for op in ops: 260 # Skip nodes without any output tensors. 261 if not op.outputs: 262 continue 263 264 node_name = op.name 265 op_type = op.type 266 267 if node_name_pattern and node_name_pattern.match(node_name): 268 continue 269 if op_type_pattern and op_type_pattern.match(op_type): 270 continue 271 272 for slot in range(len(op.outputs)): 273 if (tensor_dtype_pattern and 274 tensor_dtype_pattern.match(op.outputs[slot].dtype.name)): 275 continue 276 277 add_debug_tensor_watch( 278 run_options, 279 node_name, 280 output_slot=slot, 281 debug_ops=debug_ops, 282 debug_urls=debug_urls, 283 tolerate_debug_op_creation_failures=( 284 tolerate_debug_op_creation_failures), 285 global_step=global_step) 286 run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage 287