xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/profiling_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for the basic data structures and algorithms for profiling."""
16
17from tensorflow.core.framework import step_stats_pb2
18from tensorflow.python.debug.lib import profiling
19from tensorflow.python.framework import test_util
20from tensorflow.python.platform import googletest
21
22
23class AggregateProfile(test_util.TensorFlowTestCase):
24
25  def setUp(self):
26    node_1 = step_stats_pb2.NodeExecStats(
27        node_name="Add/123",
28        op_start_rel_micros=3,
29        op_end_rel_micros=5,
30        all_end_rel_micros=4)
31    self.profile_datum_1 = profiling.ProfileDatum(
32        "cpu:0", node_1, "/foo/bar.py", 10, "func1", "Add")
33
34    node_2 = step_stats_pb2.NodeExecStats(
35        node_name="Mul/456",
36        op_start_rel_micros=13,
37        op_end_rel_micros=16,
38        all_end_rel_micros=17)
39    self.profile_datum_2 = profiling.ProfileDatum(
40        "cpu:0", node_2, "/foo/bar.py", 11, "func1", "Mul")
41
42    node_3 = step_stats_pb2.NodeExecStats(
43        node_name="Add/123",
44        op_start_rel_micros=103,
45        op_end_rel_micros=105,
46        all_end_rel_micros=4)
47    self.profile_datum_3 = profiling.ProfileDatum(
48        "cpu:0", node_3, "/foo/bar.py", 12, "func1", "Add")
49
50    node_4 = step_stats_pb2.NodeExecStats(
51        node_name="Add/123",
52        op_start_rel_micros=203,
53        op_end_rel_micros=205,
54        all_end_rel_micros=4)
55    self.profile_datum_4 = profiling.ProfileDatum(
56        "gpu:0", node_4, "/foo/bar.py", 13, "func1", "Add")
57
58  def testAggregateProfileConstructorWorks(self):
59    aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
60
61    self.assertEqual(2, aggregate_data.total_op_time)
62    self.assertEqual(4, aggregate_data.total_exec_time)
63    self.assertEqual(1, aggregate_data.node_count)
64    self.assertEqual(1, aggregate_data.node_exec_count)
65
66  def testAddToAggregateProfileWithDifferentNodeWorks(self):
67    aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
68    aggregate_data.add(self.profile_datum_2)
69
70    self.assertEqual(5, aggregate_data.total_op_time)
71    self.assertEqual(21, aggregate_data.total_exec_time)
72    self.assertEqual(2, aggregate_data.node_count)
73    self.assertEqual(2, aggregate_data.node_exec_count)
74
75  def testAddToAggregateProfileWithSameNodeWorks(self):
76    aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
77    aggregate_data.add(self.profile_datum_2)
78    aggregate_data.add(self.profile_datum_3)
79
80    self.assertEqual(7, aggregate_data.total_op_time)
81    self.assertEqual(25, aggregate_data.total_exec_time)
82    self.assertEqual(2, aggregate_data.node_count)
83    self.assertEqual(3, aggregate_data.node_exec_count)
84
85  def testAddToAggregateProfileWithDifferentDeviceSameNodeWorks(self):
86    aggregate_data = profiling.AggregateProfile(self.profile_datum_1)
87    aggregate_data.add(self.profile_datum_4)
88
89    self.assertEqual(4, aggregate_data.total_op_time)
90    self.assertEqual(8, aggregate_data.total_exec_time)
91    self.assertEqual(2, aggregate_data.node_count)
92    self.assertEqual(2, aggregate_data.node_exec_count)
93
94
95if __name__ == "__main__":
96  googletest.main()
97