1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for debugger functionalities in tf.Session with file:// URLs.""" 16import os 17import tempfile 18 19from tensorflow.core.protobuf import config_pb2 20from tensorflow.python.client import session 21from tensorflow.python.debug.lib import debug_data 22from tensorflow.python.debug.lib import debug_utils 23from tensorflow.python.debug.lib import session_debug_testlib 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.lib.io import file_io 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import googletest 31 32 33@test_util.run_v1_only("b/120545219") 34class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase): 35 36 def _debug_urls(self, run_number=None): 37 return ["file://%s" % self._debug_dump_dir(run_number=run_number)] 38 39 def _debug_dump_dir(self, run_number=None): 40 if run_number is None: 41 return self._dump_root 42 else: 43 return os.path.join(self._dump_root, "run_%d" % run_number) 44 45 def testAllowsDifferentWatchesOnDifferentRuns(self): 46 """Test watching different tensors on different runs of the same graph.""" 47 48 with session.Session( 49 config=session_debug_testlib.no_rewrite_session_config()) as sess: 50 u_init_val = [[5.0, 3.0], [-1.0, 0.0]] 51 v_init_val = [[2.0], [-1.0]] 52 53 # Use node names with overlapping namespace (i.e., parent directory) to 54 # test concurrent, non-racing directory creation. 55 u_name = "diff_Watch/u" 56 v_name = "diff_Watch/v" 57 58 u_init = constant_op.constant(u_init_val, shape=[2, 2]) 59 u = variables.VariableV1(u_init, name=u_name) 60 v_init = constant_op.constant(v_init_val, shape=[2, 1]) 61 v = variables.VariableV1(v_init, name=v_name) 62 63 w = math_ops.matmul(u, v, name="diff_Watch/matmul") 64 65 u.initializer.run() 66 v.initializer.run() 67 68 for i in range(2): 69 run_options = config_pb2.RunOptions(output_partition_graphs=True) 70 71 run_dump_root = self._debug_dump_dir(run_number=i) 72 debug_urls = self._debug_urls(run_number=i) 73 74 if i == 0: 75 # First debug run: Add debug tensor watch for u. 76 debug_utils.add_debug_tensor_watch( 77 run_options, "%s/read" % u_name, 0, debug_urls=debug_urls) 78 else: 79 # Second debug run: Add debug tensor watch for v. 80 debug_utils.add_debug_tensor_watch( 81 run_options, "%s/read" % v_name, 0, debug_urls=debug_urls) 82 83 run_metadata = config_pb2.RunMetadata() 84 85 # Invoke Session.run(). 86 sess.run(w, options=run_options, run_metadata=run_metadata) 87 88 self.assertEqual(self._expected_partition_graph_count, 89 len(run_metadata.partition_graphs)) 90 91 dump = debug_data.DebugDumpDir( 92 run_dump_root, partition_graphs=run_metadata.partition_graphs) 93 self.assertTrue(dump.loaded_partition_graphs()) 94 95 # Each run should have generated only one dumped tensor, not two. 96 self.assertEqual(1, dump.size) 97 98 if i == 0: 99 self.assertAllClose([u_init_val], 100 dump.get_tensors("%s/read" % u_name, 0, 101 "DebugIdentity")) 102 self.assertGreaterEqual( 103 dump.get_rel_timestamps("%s/read" % u_name, 0, 104 "DebugIdentity")[0], 0) 105 else: 106 self.assertAllClose([v_init_val], 107 dump.get_tensors("%s/read" % v_name, 0, 108 "DebugIdentity")) 109 self.assertGreaterEqual( 110 dump.get_rel_timestamps("%s/read" % v_name, 0, 111 "DebugIdentity")[0], 0) 112 113 114class SessionDebugConcurrentTest( 115 session_debug_testlib.DebugConcurrentRunCallsTest): 116 117 def setUp(self): 118 self._num_concurrent_runs = 3 119 self._dump_roots = [] 120 for _ in range(self._num_concurrent_runs): 121 self._dump_roots.append(tempfile.mkdtemp()) 122 123 def tearDown(self): 124 ops.reset_default_graph() 125 for dump_root in self._dump_roots: 126 if os.path.isdir(dump_root): 127 file_io.delete_recursively(dump_root) 128 129 def _get_concurrent_debug_urls(self): 130 return [("file://%s" % dump_root) for dump_root in self._dump_roots] 131 132 133if __name__ == "__main__": 134 googletest.main() 135