1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tfdbg module debug_data.""" 16from tensorflow.python.debug.lib import debug_graphs 17from tensorflow.python.framework import test_util 18from tensorflow.python.platform import test 19 20 21class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase): 22 23 def testParseNodeName(self): 24 node_name, slot = debug_graphs.parse_node_or_tensor_name( 25 "namespace1/node_1") 26 27 self.assertEqual("namespace1/node_1", node_name) 28 self.assertIsNone(slot) 29 30 def testParseTensorName(self): 31 node_name, slot = debug_graphs.parse_node_or_tensor_name( 32 "namespace1/node_2:3") 33 34 self.assertEqual("namespace1/node_2", node_name) 35 self.assertEqual(3, slot) 36 37 38class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase): 39 40 def testParseTensorNameInputWorks(self): 41 self.assertEqual("a", debug_graphs.get_node_name("a:0")) 42 self.assertEqual(0, debug_graphs.get_output_slot("a:0")) 43 44 self.assertEqual("_b", debug_graphs.get_node_name("_b:1")) 45 self.assertEqual(1, debug_graphs.get_output_slot("_b:1")) 46 47 def testParseNodeNameInputWorks(self): 48 self.assertEqual("a", debug_graphs.get_node_name("a")) 49 self.assertEqual(0, debug_graphs.get_output_slot("a")) 50 51 52class NodeNameChecksTest(test_util.TensorFlowTestCase): 53 54 def testIsCopyNode(self): 55 self.assertTrue(debug_graphs.is_copy_node("__copy_ns1/ns2/node3_0")) 56 57 self.assertFalse(debug_graphs.is_copy_node("copy_ns1/ns2/node3_0")) 58 self.assertFalse(debug_graphs.is_copy_node("_copy_ns1/ns2/node3_0")) 59 self.assertFalse(debug_graphs.is_copy_node("_copyns1/ns2/node3_0")) 60 self.assertFalse(debug_graphs.is_copy_node("__dbg_ns1/ns2/node3_0")) 61 62 def testIsDebugNode(self): 63 self.assertTrue( 64 debug_graphs.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity")) 65 66 self.assertFalse( 67 debug_graphs.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity")) 68 self.assertFalse( 69 debug_graphs.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity")) 70 self.assertFalse( 71 debug_graphs.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity")) 72 self.assertFalse(debug_graphs.is_debug_node("__copy_ns1/ns2/node3_0")) 73 74 75class ParseDebugNodeNameTest(test_util.TensorFlowTestCase): 76 77 def testParseDebugNodeName_valid(self): 78 debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity" 79 (watched_node, watched_output_slot, debug_op_index, 80 debug_op) = debug_graphs.parse_debug_node_name(debug_node_name_1) 81 82 self.assertEqual("ns_a/ns_b/node_c", watched_node) 83 self.assertEqual(1, watched_output_slot) 84 self.assertEqual(0, debug_op_index) 85 self.assertEqual("DebugIdentity", debug_op) 86 87 def testParseDebugNodeName_invalidPrefix(self): 88 invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity" 89 90 with self.assertRaisesRegex(ValueError, "Invalid prefix"): 91 debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) 92 93 def testParseDebugNodeName_missingDebugOpIndex(self): 94 invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity" 95 96 with self.assertRaisesRegex(ValueError, "Invalid debug node name"): 97 debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) 98 99 def testParseDebugNodeName_invalidWatchedTensorName(self): 100 invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity" 101 102 with self.assertRaisesRegex(ValueError, 103 "Invalid tensor name in debug node name"): 104 debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) 105 106 107if __name__ == "__main__": 108 test.main() 109