1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for the basic data structures and algorithms for profiling.""" 16 17from tensorflow.core.framework import step_stats_pb2 18from tensorflow.python.debug.lib import profiling 19from tensorflow.python.framework import test_util 20from tensorflow.python.platform import googletest 21 22 23class AggregateProfile(test_util.TensorFlowTestCase): 24 25 def setUp(self): 26 node_1 = step_stats_pb2.NodeExecStats( 27 node_name="Add/123", 28 op_start_rel_micros=3, 29 op_end_rel_micros=5, 30 all_end_rel_micros=4) 31 self.profile_datum_1 = profiling.ProfileDatum( 32 "cpu:0", node_1, "/foo/bar.py", 10, "func1", "Add") 33 34 node_2 = step_stats_pb2.NodeExecStats( 35 node_name="Mul/456", 36 op_start_rel_micros=13, 37 op_end_rel_micros=16, 38 all_end_rel_micros=17) 39 self.profile_datum_2 = profiling.ProfileDatum( 40 "cpu:0", node_2, "/foo/bar.py", 11, "func1", "Mul") 41 42 node_3 = step_stats_pb2.NodeExecStats( 43 node_name="Add/123", 44 op_start_rel_micros=103, 45 op_end_rel_micros=105, 46 all_end_rel_micros=4) 47 self.profile_datum_3 = profiling.ProfileDatum( 48 "cpu:0", node_3, "/foo/bar.py", 12, "func1", "Add") 49 50 node_4 = step_stats_pb2.NodeExecStats( 51 node_name="Add/123", 52 op_start_rel_micros=203, 53 op_end_rel_micros=205, 54 all_end_rel_micros=4) 55 self.profile_datum_4 = profiling.ProfileDatum( 56 "gpu:0", node_4, "/foo/bar.py", 13, "func1", "Add") 57 58 def testAggregateProfileConstructorWorks(self): 59 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 60 61 self.assertEqual(2, aggregate_data.total_op_time) 62 self.assertEqual(4, aggregate_data.total_exec_time) 63 self.assertEqual(1, aggregate_data.node_count) 64 self.assertEqual(1, aggregate_data.node_exec_count) 65 66 def testAddToAggregateProfileWithDifferentNodeWorks(self): 67 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 68 aggregate_data.add(self.profile_datum_2) 69 70 self.assertEqual(5, aggregate_data.total_op_time) 71 self.assertEqual(21, aggregate_data.total_exec_time) 72 self.assertEqual(2, aggregate_data.node_count) 73 self.assertEqual(2, aggregate_data.node_exec_count) 74 75 def testAddToAggregateProfileWithSameNodeWorks(self): 76 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 77 aggregate_data.add(self.profile_datum_2) 78 aggregate_data.add(self.profile_datum_3) 79 80 self.assertEqual(7, aggregate_data.total_op_time) 81 self.assertEqual(25, aggregate_data.total_exec_time) 82 self.assertEqual(2, aggregate_data.node_count) 83 self.assertEqual(3, aggregate_data.node_exec_count) 84 85 def testAddToAggregateProfileWithDifferentDeviceSameNodeWorks(self): 86 aggregate_data = profiling.AggregateProfile(self.profile_datum_1) 87 aggregate_data.add(self.profile_datum_4) 88 89 self.assertEqual(4, aggregate_data.total_op_time) 90 self.assertEqual(8, aggregate_data.total_exec_time) 91 self.assertEqual(2, aggregate_data.node_count) 92 self.assertEqual(2, aggregate_data.node_exec_count) 93 94 95if __name__ == "__main__": 96 googletest.main() 97