xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/debug_utils_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TensorFlow Debugger (tfdbg) Utilities."""
16import numpy as np
17
18from tensorflow.core.protobuf import config_pb2
19from tensorflow.python.client import session
20from tensorflow.python.debug.lib import debug_utils
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import math_ops
24# Import resource_variable_ops for the variables-to-tensor implicit conversion.
25from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
26from tensorflow.python.ops import variables
27from tensorflow.python.platform import googletest
28
29
30@test_util.run_v1_only("Requires tf.Session")
31class DebugUtilsTest(test_util.TensorFlowTestCase):
32
33  @classmethod
34  def setUpClass(cls):
35    cls._sess = session.Session()
36    with cls._sess:
37      cls._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
38      cls._b_init_val = np.array([[2.0], [-1.0]])
39      cls._c_val = np.array([[-4.0], [np.nan]])
40
41      cls._a_init = constant_op.constant(
42          cls._a_init_val, shape=[2, 2], name="a1_init")
43      cls._b_init = constant_op.constant(
44          cls._b_init_val, shape=[2, 1], name="b_init")
45
46      cls._a = variables.VariableV1(cls._a_init, name="a1")
47      cls._b = variables.VariableV1(cls._b_init, name="b")
48      cls._c = constant_op.constant(cls._c_val, shape=[2, 1], name="c")
49
50      # Matrix product of a and b.
51      cls._p = math_ops.matmul(cls._a, cls._b, name="p1")
52
53      # Sum of two vectors.
54      cls._s = math_ops.add(cls._p, cls._c, name="s")
55
56    cls._graph = cls._sess.graph
57
58    # These are all the expected nodes in the graph:
59    #   - Two variables (a, b), each with four nodes (Variable, init, Assign,
60    #     read).
61    #   - One constant (c).
62    #   - One add operation and one matmul operation.
63    #   - One wildcard node name ("*") that covers nodes created internally
64    #     by TensorFlow itself (e.g., Grappler).
65    cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1 + 1
66
67  def setUp(self):
68    self._run_options = config_pb2.RunOptions()
69
70  def _verify_watches(self, watch_opts, expected_output_slot,
71                      expected_debug_ops, expected_debug_urls):
72    """Verify a list of debug tensor watches.
73
74    This requires all watches in the watch list have exactly the same
75    output_slot, debug_ops and debug_urls.
76
77    Args:
78      watch_opts: Repeated protobuf field of DebugTensorWatch.
79      expected_output_slot: Expected output slot index, as an integer.
80      expected_debug_ops: Expected debug ops, as a list of strings.
81      expected_debug_urls: Expected debug URLs, as a list of strings.
82
83    Returns:
84      List of node names from the list of debug tensor watches.
85    """
86    node_names = []
87    for watch in watch_opts:
88      node_names.append(watch.node_name)
89
90      if watch.node_name == "*":
91        self.assertEqual(-1, watch.output_slot)
92        self.assertEqual(expected_debug_ops, watch.debug_ops)
93        self.assertEqual(expected_debug_urls, watch.debug_urls)
94      else:
95        self.assertEqual(expected_output_slot, watch.output_slot)
96        self.assertEqual(expected_debug_ops, watch.debug_ops)
97        self.assertEqual(expected_debug_urls, watch.debug_urls)
98
99    return node_names
100
101  def testAddDebugTensorWatches_defaultDebugOp(self):
102    debug_utils.add_debug_tensor_watch(
103        self._run_options, "foo/node_a", 1, debug_urls="file:///tmp/tfdbg_1")
104    debug_utils.add_debug_tensor_watch(
105        self._run_options, "foo/node_b", 0, debug_urls="file:///tmp/tfdbg_2")
106
107    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
108    self.assertEqual(2, len(debug_watch_opts))
109
110    watch_0 = debug_watch_opts[0]
111    watch_1 = debug_watch_opts[1]
112
113    self.assertEqual("foo/node_a", watch_0.node_name)
114    self.assertEqual(1, watch_0.output_slot)
115    self.assertEqual("foo/node_b", watch_1.node_name)
116    self.assertEqual(0, watch_1.output_slot)
117    # Verify default debug op name.
118    self.assertEqual(["DebugIdentity"], watch_0.debug_ops)
119    self.assertEqual(["DebugIdentity"], watch_1.debug_ops)
120
121    # Verify debug URLs.
122    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
123    self.assertEqual(["file:///tmp/tfdbg_2"], watch_1.debug_urls)
124
125  def testAddDebugTensorWatches_explicitDebugOp(self):
126    debug_utils.add_debug_tensor_watch(
127        self._run_options,
128        "foo/node_a",
129        0,
130        debug_ops="DebugNanCount",
131        debug_urls="file:///tmp/tfdbg_1")
132
133    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
134    self.assertEqual(1, len(debug_watch_opts))
135
136    watch_0 = debug_watch_opts[0]
137
138    self.assertEqual("foo/node_a", watch_0.node_name)
139    self.assertEqual(0, watch_0.output_slot)
140
141    # Verify default debug op name.
142    self.assertEqual(["DebugNanCount"], watch_0.debug_ops)
143
144    # Verify debug URLs.
145    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
146
147  def testAddDebugTensorWatches_multipleDebugOps(self):
148    debug_utils.add_debug_tensor_watch(
149        self._run_options,
150        "foo/node_a",
151        0,
152        debug_ops=["DebugNanCount", "DebugIdentity"],
153        debug_urls="file:///tmp/tfdbg_1")
154
155    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
156    self.assertEqual(1, len(debug_watch_opts))
157
158    watch_0 = debug_watch_opts[0]
159
160    self.assertEqual("foo/node_a", watch_0.node_name)
161    self.assertEqual(0, watch_0.output_slot)
162
163    # Verify default debug op name.
164    self.assertEqual(["DebugNanCount", "DebugIdentity"], watch_0.debug_ops)
165
166    # Verify debug URLs.
167    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
168
169  def testAddDebugTensorWatches_multipleURLs(self):
170    debug_utils.add_debug_tensor_watch(
171        self._run_options,
172        "foo/node_a",
173        0,
174        debug_ops="DebugNanCount",
175        debug_urls=["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"])
176
177    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
178    self.assertEqual(1, len(debug_watch_opts))
179
180    watch_0 = debug_watch_opts[0]
181
182    self.assertEqual("foo/node_a", watch_0.node_name)
183    self.assertEqual(0, watch_0.output_slot)
184
185    # Verify default debug op name.
186    self.assertEqual(["DebugNanCount"], watch_0.debug_ops)
187
188    # Verify debug URLs.
189    self.assertEqual(["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"],
190                     watch_0.debug_urls)
191
192  def testWatchGraph_allNodes(self):
193    debug_utils.watch_graph(
194        self._run_options,
195        self._graph,
196        debug_ops=["DebugIdentity", "DebugNanCount"],
197        debug_urls="file:///tmp/tfdbg_1")
198
199    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
200    self.assertEqual(self._expected_num_nodes, len(debug_watch_opts))
201
202    # Verify that each of the nodes in the graph with output tensors in the
203    # graph have debug tensor watch.
204    node_names = self._verify_watches(debug_watch_opts, 0,
205                                      ["DebugIdentity", "DebugNanCount"],
206                                      ["file:///tmp/tfdbg_1"])
207
208    # Verify the node names.
209    self.assertIn("a1_init", node_names)
210    self.assertIn("a1", node_names)
211    self.assertIn("a1/Assign", node_names)
212    self.assertIn("a1/read", node_names)
213
214    self.assertIn("b_init", node_names)
215    self.assertIn("b", node_names)
216    self.assertIn("b/Assign", node_names)
217    self.assertIn("b/read", node_names)
218
219    self.assertIn("c", node_names)
220    self.assertIn("p1", node_names)
221    self.assertIn("s", node_names)
222
223    # Assert that the wildcard node name has been created.
224    self.assertIn("*", node_names)
225
226  def testWatchGraph_nodeNameAllowlist(self):
227    debug_utils.watch_graph(
228        self._run_options,
229        self._graph,
230        debug_urls="file:///tmp/tfdbg_1",
231        node_name_regex_allowlist="(a1$|a1_init$|a1/.*|p1$)")
232
233    node_names = self._verify_watches(
234        self._run_options.debug_options.debug_tensor_watch_opts, 0,
235        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
236    self.assertEqual(
237        sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
238        sorted(node_names))
239
240  def testWatchGraph_opTypeAllowlist(self):
241    debug_utils.watch_graph(
242        self._run_options,
243        self._graph,
244        debug_urls="file:///tmp/tfdbg_1",
245        op_type_regex_allowlist="(Variable|MatMul)")
246
247    node_names = self._verify_watches(
248        self._run_options.debug_options.debug_tensor_watch_opts, 0,
249        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
250    self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
251
252  def testWatchGraph_nodeNameAndOpTypeAllowlists(self):
253    debug_utils.watch_graph(
254        self._run_options,
255        self._graph,
256        debug_urls="file:///tmp/tfdbg_1",
257        node_name_regex_allowlist="([a-z]+1$)",
258        op_type_regex_allowlist="(MatMul)")
259
260    node_names = self._verify_watches(
261        self._run_options.debug_options.debug_tensor_watch_opts, 0,
262        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
263    self.assertEqual(["p1"], node_names)
264
265  def testWatchGraph_tensorDTypeAllowlist(self):
266    debug_utils.watch_graph(
267        self._run_options,
268        self._graph,
269        debug_urls="file:///tmp/tfdbg_1",
270        tensor_dtype_regex_allowlist=".*_ref")
271
272    node_names = self._verify_watches(
273        self._run_options.debug_options.debug_tensor_watch_opts, 0,
274        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
275    self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
276
277  def testWatchGraph_nodeNameAndTensorDTypeAllowlists(self):
278    debug_utils.watch_graph(
279        self._run_options,
280        self._graph,
281        debug_urls="file:///tmp/tfdbg_1",
282        node_name_regex_allowlist="^a.*",
283        tensor_dtype_regex_allowlist=".*_ref")
284
285    node_names = self._verify_watches(
286        self._run_options.debug_options.debug_tensor_watch_opts, 0,
287        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
288    self.assertItemsEqual(["a1", "a1/Assign"], node_names)
289
290  def testWatchGraph_nodeNameDenylist(self):
291    debug_utils.watch_graph_with_denylists(
292        self._run_options,
293        self._graph,
294        debug_urls="file:///tmp/tfdbg_1",
295        node_name_regex_denylist="(a1$|a1_init$|a1/.*|p1$)")
296
297    node_names = self._verify_watches(
298        self._run_options.debug_options.debug_tensor_watch_opts, 0,
299        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
300    self.assertEqual(
301        sorted(["b_init", "b", "b/Assign", "b/read", "c", "s"]),
302        sorted(node_names))
303
304  def testWatchGraph_opTypeDenylist(self):
305    debug_utils.watch_graph_with_denylists(
306        self._run_options,
307        self._graph,
308        debug_urls="file:///tmp/tfdbg_1",
309        op_type_regex_denylist="(Variable|Identity|Assign|Const)")
310
311    node_names = self._verify_watches(
312        self._run_options.debug_options.debug_tensor_watch_opts, 0,
313        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
314    self.assertEqual(sorted(["p1", "s"]), sorted(node_names))
315
316  def testWatchGraph_nodeNameAndOpTypeDenylists(self):
317    debug_utils.watch_graph_with_denylists(
318        self._run_options,
319        self._graph,
320        debug_urls="file:///tmp/tfdbg_1",
321        node_name_regex_denylist="p1$",
322        op_type_regex_denylist="(Variable|Identity|Assign|Const)")
323
324    node_names = self._verify_watches(
325        self._run_options.debug_options.debug_tensor_watch_opts, 0,
326        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
327    self.assertEqual(["s"], node_names)
328
329  def testWatchGraph_tensorDTypeDenylists(self):
330    debug_utils.watch_graph_with_denylists(
331        self._run_options,
332        self._graph,
333        debug_urls="file:///tmp/tfdbg_1",
334        tensor_dtype_regex_denylist=".*_ref")
335
336    node_names = self._verify_watches(
337        self._run_options.debug_options.debug_tensor_watch_opts, 0,
338        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
339    self.assertNotIn("a1", node_names)
340    self.assertNotIn("a1/Assign", node_names)
341    self.assertNotIn("b", node_names)
342    self.assertNotIn("b/Assign", node_names)
343    self.assertIn("s", node_names)
344
345  def testWatchGraph_nodeNameAndTensorDTypeDenylists(self):
346    debug_utils.watch_graph_with_denylists(
347        self._run_options,
348        self._graph,
349        debug_urls="file:///tmp/tfdbg_1",
350        node_name_regex_denylist="^s$",
351        tensor_dtype_regex_denylist=".*_ref")
352
353    node_names = self._verify_watches(
354        self._run_options.debug_options.debug_tensor_watch_opts, 0,
355        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
356    self.assertNotIn("a1", node_names)
357    self.assertNotIn("a1/Assign", node_names)
358    self.assertNotIn("b", node_names)
359    self.assertNotIn("b/Assign", node_names)
360    self.assertNotIn("s", node_names)
361
362
363if __name__ == "__main__":
364  googletest.main()
365