xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/session_debug_file_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for debugger functionalities in tf.Session with file:// URLs."""
16import os
17import tempfile
18
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.python.client import session
21from tensorflow.python.debug.lib import debug_data
22from tensorflow.python.debug.lib import debug_utils
23from tensorflow.python.debug.lib import session_debug_testlib
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.lib.io import file_io
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import googletest
31
32
33@test_util.run_v1_only("b/120545219")
34class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
35
36  def _debug_urls(self, run_number=None):
37    return ["file://%s" % self._debug_dump_dir(run_number=run_number)]
38
39  def _debug_dump_dir(self, run_number=None):
40    if run_number is None:
41      return self._dump_root
42    else:
43      return os.path.join(self._dump_root, "run_%d" % run_number)
44
45  def testAllowsDifferentWatchesOnDifferentRuns(self):
46    """Test watching different tensors on different runs of the same graph."""
47
48    with session.Session(
49        config=session_debug_testlib.no_rewrite_session_config()) as sess:
50      u_init_val = [[5.0, 3.0], [-1.0, 0.0]]
51      v_init_val = [[2.0], [-1.0]]
52
53      # Use node names with overlapping namespace (i.e., parent directory) to
54      # test concurrent, non-racing directory creation.
55      u_name = "diff_Watch/u"
56      v_name = "diff_Watch/v"
57
58      u_init = constant_op.constant(u_init_val, shape=[2, 2])
59      u = variables.VariableV1(u_init, name=u_name)
60      v_init = constant_op.constant(v_init_val, shape=[2, 1])
61      v = variables.VariableV1(v_init, name=v_name)
62
63      w = math_ops.matmul(u, v, name="diff_Watch/matmul")
64
65      u.initializer.run()
66      v.initializer.run()
67
68      for i in range(2):
69        run_options = config_pb2.RunOptions(output_partition_graphs=True)
70
71        run_dump_root = self._debug_dump_dir(run_number=i)
72        debug_urls = self._debug_urls(run_number=i)
73
74        if i == 0:
75          # First debug run: Add debug tensor watch for u.
76          debug_utils.add_debug_tensor_watch(
77              run_options, "%s/read" % u_name, 0, debug_urls=debug_urls)
78        else:
79          # Second debug run: Add debug tensor watch for v.
80          debug_utils.add_debug_tensor_watch(
81              run_options, "%s/read" % v_name, 0, debug_urls=debug_urls)
82
83        run_metadata = config_pb2.RunMetadata()
84
85        # Invoke Session.run().
86        sess.run(w, options=run_options, run_metadata=run_metadata)
87
88        self.assertEqual(self._expected_partition_graph_count,
89                         len(run_metadata.partition_graphs))
90
91        dump = debug_data.DebugDumpDir(
92            run_dump_root, partition_graphs=run_metadata.partition_graphs)
93        self.assertTrue(dump.loaded_partition_graphs())
94
95        # Each run should have generated only one dumped tensor, not two.
96        self.assertEqual(1, dump.size)
97
98        if i == 0:
99          self.assertAllClose([u_init_val],
100                              dump.get_tensors("%s/read" % u_name, 0,
101                                               "DebugIdentity"))
102          self.assertGreaterEqual(
103              dump.get_rel_timestamps("%s/read" % u_name, 0,
104                                      "DebugIdentity")[0], 0)
105        else:
106          self.assertAllClose([v_init_val],
107                              dump.get_tensors("%s/read" % v_name, 0,
108                                               "DebugIdentity"))
109          self.assertGreaterEqual(
110              dump.get_rel_timestamps("%s/read" % v_name, 0,
111                                      "DebugIdentity")[0], 0)
112
113
114class SessionDebugConcurrentTest(
115    session_debug_testlib.DebugConcurrentRunCallsTest):
116
117  def setUp(self):
118    self._num_concurrent_runs = 3
119    self._dump_roots = []
120    for _ in range(self._num_concurrent_runs):
121      self._dump_roots.append(tempfile.mkdtemp())
122
123  def tearDown(self):
124    ops.reset_default_graph()
125    for dump_root in self._dump_roots:
126      if os.path.isdir(dump_root):
127        file_io.delete_recursively(dump_root)
128
129  def _get_concurrent_debug_urls(self):
130    return [("file://%s" % dump_root) for dump_root in self._dump_roots]
131
132
133if __name__ == "__main__":
134  googletest.main()
135