xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/cli/profile_analyzer_cli_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for profile_analyzer_cli."""
16
17import re
18
19from tensorflow.core.framework import step_stats_pb2
20from tensorflow.core.protobuf import config_pb2
21from tensorflow.core.protobuf import rewriter_config_pb2
22from tensorflow.python.client import session
23from tensorflow.python.debug.cli import debugger_cli_common
24from tensorflow.python.debug.cli import profile_analyzer_cli
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import googletest
31from tensorflow.python.platform import test
32from tensorflow.python.util import tf_inspect
33
34
35def no_rewrite_session_config():
36  rewriter_config = rewriter_config_pb2.RewriterConfig(
37      disable_model_pruning=True,
38      constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
39  graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
40  return config_pb2.ConfigProto(graph_options=graph_options)
41
42
43def _line_number_above():
44  return tf_inspect.stack()[1][2] - 1
45
46
47def _at_least_one_line_matches(pattern, lines):
48  pattern_re = re.compile(pattern)
49  for i, line in enumerate(lines):
50    if pattern_re.search(line):
51      return True, i
52  return False, None
53
54
55def _assert_at_least_one_line_matches(pattern, lines):
56  any_match, _ = _at_least_one_line_matches(pattern, lines)
57  if not any_match:
58    raise AssertionError(
59        "%s does not match any line in %s." % (pattern, str(lines)))
60
61
62def _assert_no_lines_match(pattern, lines):
63  any_match, _ = _at_least_one_line_matches(pattern, lines)
64  if any_match:
65    raise AssertionError(
66        "%s matched at least one line in %s." % (pattern, str(lines)))
67
68
69@test_util.run_v1_only("Requires tf.Session")
70class ProfileAnalyzerListProfileTest(test_util.TensorFlowTestCase):
71
72  def testNodeInfoEmpty(self):
73    graph = ops.Graph()
74    run_metadata = config_pb2.RunMetadata()
75
76    prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
77    prof_output = prof_analyzer.list_profile([]).lines
78    self.assertEqual([""], prof_output)
79
80  def testSingleDevice(self):
81    node1 = step_stats_pb2.NodeExecStats(
82        node_name="Add/123",
83        op_start_rel_micros=3,
84        op_end_rel_micros=5,
85        all_end_rel_micros=4)
86
87    node2 = step_stats_pb2.NodeExecStats(
88        node_name="Mul/456",
89        op_start_rel_micros=1,
90        op_end_rel_micros=2,
91        all_end_rel_micros=3)
92
93    run_metadata = config_pb2.RunMetadata()
94    device1 = run_metadata.step_stats.dev_stats.add()
95    device1.device = "deviceA"
96    device1.node_stats.extend([node1, node2])
97
98    graph = test.mock.MagicMock()
99    op1 = test.mock.MagicMock()
100    op1.name = "Add/123"
101    op1.traceback = [("a/b/file1", 10, "some_var")]
102    op1.type = "add"
103    op2 = test.mock.MagicMock()
104    op2.name = "Mul/456"
105    op2.traceback = [("a/b/file1", 11, "some_var")]
106    op2.type = "mul"
107    graph.get_operations.return_value = [op1, op2]
108
109    prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
110    prof_output = prof_analyzer.list_profile([]).lines
111
112    _assert_at_least_one_line_matches(r"Device 1 of 1: deviceA", prof_output)
113    _assert_at_least_one_line_matches(r"^Add/123.*add.*2us.*4us", prof_output)
114    _assert_at_least_one_line_matches(r"^Mul/456.*mul.*1us.*3us", prof_output)
115
116  def testMultipleDevices(self):
117    node1 = step_stats_pb2.NodeExecStats(
118        node_name="Add/123",
119        op_start_rel_micros=3,
120        op_end_rel_micros=5,
121        all_end_rel_micros=3)
122
123    run_metadata = config_pb2.RunMetadata()
124    device1 = run_metadata.step_stats.dev_stats.add()
125    device1.device = "deviceA"
126    device1.node_stats.extend([node1])
127
128    device2 = run_metadata.step_stats.dev_stats.add()
129    device2.device = "deviceB"
130    device2.node_stats.extend([node1])
131
132    graph = test.mock.MagicMock()
133    op = test.mock.MagicMock()
134    op.name = "Add/123"
135    op.traceback = [("a/b/file1", 10, "some_var")]
136    op.type = "abc"
137    graph.get_operations.return_value = [op]
138
139    prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
140    prof_output = prof_analyzer.list_profile([]).lines
141
142    _assert_at_least_one_line_matches(r"Device 1 of 2: deviceA", prof_output)
143    _assert_at_least_one_line_matches(r"Device 2 of 2: deviceB", prof_output)
144
145    # Try filtering by device.
146    prof_output = prof_analyzer.list_profile(["-d", "deviceB"]).lines
147    _assert_at_least_one_line_matches(r"Device 2 of 2: deviceB", prof_output)
148    _assert_no_lines_match(r"Device 1 of 2: deviceA", prof_output)
149
150  def testWithSession(self):
151    options = config_pb2.RunOptions()
152    options.trace_level = config_pb2.RunOptions.FULL_TRACE
153    run_metadata = config_pb2.RunMetadata()
154
155    with session.Session(config=no_rewrite_session_config()) as sess:
156      a = constant_op.constant([1, 2, 3])
157      b = constant_op.constant([2, 2, 1])
158      result = math_ops.add(a, b)
159
160      sess.run(result, options=options, run_metadata=run_metadata)
161
162      prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(
163          sess.graph, run_metadata)
164      prof_output = prof_analyzer.list_profile([]).lines
165
166      _assert_at_least_one_line_matches("Device 1 of", prof_output)
167      expected_headers = [
168          "Node", r"Start Time \(us\)", r"Op Time \(.*\)", r"Exec Time \(.*\)",
169          r"Filename:Lineno\(function\)"]
170      _assert_at_least_one_line_matches(
171          ".*".join(expected_headers), prof_output)
172      _assert_at_least_one_line_matches(r"^Add/", prof_output)
173      _assert_at_least_one_line_matches(r"Device Total", prof_output)
174
175  def testSorting(self):
176    node1 = step_stats_pb2.NodeExecStats(
177        node_name="Add/123",
178        all_start_micros=123,
179        op_start_rel_micros=3,
180        op_end_rel_micros=5,
181        all_end_rel_micros=4)
182
183    node2 = step_stats_pb2.NodeExecStats(
184        node_name="Mul/456",
185        all_start_micros=122,
186        op_start_rel_micros=1,
187        op_end_rel_micros=2,
188        all_end_rel_micros=5)
189
190    run_metadata = config_pb2.RunMetadata()
191    device1 = run_metadata.step_stats.dev_stats.add()
192    device1.device = "deviceA"
193    device1.node_stats.extend([node1, node2])
194
195    graph = test.mock.MagicMock()
196    op1 = test.mock.MagicMock()
197    op1.name = "Add/123"
198    op1.traceback = [("a/b/file2", 10, "some_var")]
199    op1.type = "add"
200    op2 = test.mock.MagicMock()
201    op2.name = "Mul/456"
202    op2.traceback = [("a/b/file1", 11, "some_var")]
203    op2.type = "mul"
204    graph.get_operations.return_value = [op1, op2]
205
206    prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
207
208    # Default sort by start time (i.e. all_start_micros).
209    prof_output = prof_analyzer.list_profile([]).lines
210    self.assertRegex("".join(prof_output), r"Mul/456.*Add/123")
211    # Default sort in reverse.
212    prof_output = prof_analyzer.list_profile(["-r"]).lines
213    self.assertRegex("".join(prof_output), r"Add/123.*Mul/456")
214    # Sort by name.
215    prof_output = prof_analyzer.list_profile(["-s", "node"]).lines
216    self.assertRegex("".join(prof_output), r"Add/123.*Mul/456")
217    # Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros).
218    prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines
219    self.assertRegex("".join(prof_output), r"Mul/456.*Add/123")
220    # Sort by exec time (i.e. all_end_rel_micros).
221    prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines
222    self.assertRegex("".join(prof_output), r"Add/123.*Mul/456")
223    # Sort by line number.
224    prof_output = prof_analyzer.list_profile(["-s", "line"]).lines
225    self.assertRegex("".join(prof_output), r"Mul/456.*Add/123")
226
227  def testFiltering(self):
228    node1 = step_stats_pb2.NodeExecStats(
229        node_name="Add/123",
230        all_start_micros=123,
231        op_start_rel_micros=3,
232        op_end_rel_micros=5,
233        all_end_rel_micros=4)
234
235    node2 = step_stats_pb2.NodeExecStats(
236        node_name="Mul/456",
237        all_start_micros=122,
238        op_start_rel_micros=1,
239        op_end_rel_micros=2,
240        all_end_rel_micros=5)
241
242    run_metadata = config_pb2.RunMetadata()
243    device1 = run_metadata.step_stats.dev_stats.add()
244    device1.device = "deviceA"
245    device1.node_stats.extend([node1, node2])
246
247    graph = test.mock.MagicMock()
248    op1 = test.mock.MagicMock()
249    op1.name = "Add/123"
250    op1.traceback = [("a/b/file2", 10, "some_var")]
251    op1.type = "add"
252    op2 = test.mock.MagicMock()
253    op2.name = "Mul/456"
254    op2.traceback = [("a/b/file1", 11, "some_var")]
255    op2.type = "mul"
256    graph.get_operations.return_value = [op1, op2]
257
258    prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
259
260    # Filter by name
261    prof_output = prof_analyzer.list_profile(["-n", "Add"]).lines
262    _assert_at_least_one_line_matches(r"Add/123", prof_output)
263    _assert_no_lines_match(r"Mul/456", prof_output)
264    # Filter by op_type
265    prof_output = prof_analyzer.list_profile(["-t", "mul"]).lines
266    _assert_at_least_one_line_matches(r"Mul/456", prof_output)
267    _assert_no_lines_match(r"Add/123", prof_output)
268    # Filter by file name.
269    prof_output = prof_analyzer.list_profile(["-f", ".*file2"]).lines
270    _assert_at_least_one_line_matches(r"Add/123", prof_output)
271    _assert_no_lines_match(r"Mul/456", prof_output)
272    # Filter by execution time.
273    prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines
274    _assert_at_least_one_line_matches(r"Mul/456", prof_output)
275    _assert_no_lines_match(r"Add/123", prof_output)
276    # Filter by op time.
277    prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines
278    _assert_at_least_one_line_matches(r"Add/123", prof_output)
279    _assert_no_lines_match(r"Mul/456", prof_output)
280
281  def testSpecifyingTimeUnit(self):
282    node1 = step_stats_pb2.NodeExecStats(
283        node_name="Add/123",
284        all_start_micros=123,
285        op_start_rel_micros=3,
286        op_end_rel_micros=5,
287        all_end_rel_micros=4)
288
289    node2 = step_stats_pb2.NodeExecStats(
290        node_name="Mul/456",
291        all_start_micros=122,
292        op_start_rel_micros=1,
293        op_end_rel_micros=2,
294        all_end_rel_micros=5)
295
296    run_metadata = config_pb2.RunMetadata()
297    device1 = run_metadata.step_stats.dev_stats.add()
298    device1.device = "deviceA"
299    device1.node_stats.extend([node1, node2])
300
301    graph = test.mock.MagicMock()
302    op1 = test.mock.MagicMock()
303    op1.name = "Add/123"
304    op1.traceback = [("a/b/file2", 10, "some_var")]
305    op1.type = "add"
306    op2 = test.mock.MagicMock()
307    op2.name = "Mul/456"
308    op2.traceback = [("a/b/file1", 11, "some_var")]
309    op2.type = "mul"
310    graph.get_operations.return_value = [op1, op2]
311
312    prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata)
313
314    # Force time unit.
315    prof_output = prof_analyzer.list_profile(["--time_unit", "ms"]).lines
316    _assert_at_least_one_line_matches(r"Add/123.*add.*0\.002ms", prof_output)
317    _assert_at_least_one_line_matches(r"Mul/456.*mul.*0\.005ms", prof_output)
318    _assert_at_least_one_line_matches(r"Device Total.*0\.009ms", prof_output)
319
320
321@test_util.run_v1_only("Requires tf.Session")
322class ProfileAnalyzerPrintSourceTest(test_util.TensorFlowTestCase):
323
324  def setUp(self):
325    super(ProfileAnalyzerPrintSourceTest, self).setUp()
326
327    options = config_pb2.RunOptions()
328    options.trace_level = config_pb2.RunOptions.FULL_TRACE
329    run_metadata = config_pb2.RunMetadata()
330    with session.Session() as sess:
331      loop_cond = lambda x: math_ops.less(x, 10)
332      self.loop_cond_lineno = _line_number_above()
333      loop_body = lambda x: math_ops.add(x, 1)
334      self.loop_body_lineno = _line_number_above()
335      x = constant_op.constant(0, name="x")
336      self.x_lineno = _line_number_above()
337      loop = control_flow_ops.while_loop(loop_cond, loop_body, [x])
338      self.loop_lineno = _line_number_above()
339      self.assertEqual(
340          10, sess.run(loop, options=options, run_metadata=run_metadata))
341
342      self.prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(
343          sess.graph, run_metadata)
344
345  def tearDown(self):
346    ops.reset_default_graph()
347    super(ProfileAnalyzerPrintSourceTest, self).tearDown()
348
349  def testPrintSourceForWhileLoop(self):
350    prof_output = self.prof_analyzer.print_source([__file__])
351
352    _assert_at_least_one_line_matches(
353        r"\[(\|)+(\s)*\] .*us .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno,
354        prof_output.lines)
355    _assert_at_least_one_line_matches(
356        r"\[(\|)+(\s)*\] .*us .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno,
357        prof_output.lines)
358    _assert_at_least_one_line_matches(
359        r"\[(\|)+(\s)*\] .*us .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
360        prof_output.lines)
361
362  def testPrintSourceOutputContainsClickableLinks(self):
363    prof_output = self.prof_analyzer.print_source([__file__])
364    any_match, line_index = _at_least_one_line_matches(
365        r"\[(\|)+(\s)*\] .*us .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno,
366        prof_output.lines)
367    self.assertTrue(any_match)
368    any_menu_item_match = False
369    for seg in prof_output.font_attr_segs[line_index]:
370      if (isinstance(seg[2][1], debugger_cli_common.MenuItem) and
371          seg[2][1].content.startswith("lp --file_path_filter ") and
372          "--min_lineno %d" % self.loop_cond_lineno in seg[2][1].content and
373          "--max_lineno %d" % (self.loop_cond_lineno + 1) in seg[2][1].content):
374        any_menu_item_match = True
375        break
376    self.assertTrue(any_menu_item_match)
377
378  def testPrintSourceWithNonDefaultTimeUnit(self):
379    prof_output = self.prof_analyzer.print_source([
380        __file__, "--time_unit", "ms"])
381
382    _assert_at_least_one_line_matches(
383        r"\[(\|)+(\s)*\] .*ms .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno,
384        prof_output.lines)
385    _assert_at_least_one_line_matches(
386        r"\[(\|)+(\s)*\] .*ms .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno,
387        prof_output.lines)
388    _assert_at_least_one_line_matches(
389        r"\[(\|)+(\s)*\] .*ms .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
390        prof_output.lines)
391
392  def testPrintSourceWithNodeNameFilter(self):
393    prof_output = self.prof_analyzer.print_source([
394        __file__, "--node_name_filter", "x$"])
395
396    _assert_at_least_one_line_matches(
397        r"\[(\|)+(\s)*\] .*us .*1\(1\) .*L%d.*(\S)+" % self.x_lineno,
398        prof_output.lines)
399    _assert_no_lines_match(
400        r"\[(\|)+(\s)*\] .*us .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno,
401        prof_output.lines)
402    _assert_no_lines_match(
403        r"\[(\|)+(\s)*\] .*us .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno,
404        prof_output.lines)
405    _assert_no_lines_match(
406        r"\[(\|)+(\s)*\] .*ms .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
407        prof_output.lines)
408
409    # Check clickable link.
410    _, line_index = _at_least_one_line_matches(
411        r"\[(\|)+(\s)*\] .*us .*1\(1\) .*L%d.*(\S)+" % self.x_lineno,
412        prof_output.lines)
413    any_menu_item_match = False
414    for seg in prof_output.font_attr_segs[line_index]:
415      if (isinstance(seg[2][1], debugger_cli_common.MenuItem) and
416          seg[2][1].content.startswith("lp --file_path_filter ") and
417          "--node_name_filter x$" in seg[2][1].content and
418          "--min_lineno %d" % self.x_lineno in seg[2][1].content and
419          "--max_lineno %d" % (self.x_lineno + 1) in seg[2][1].content):
420        any_menu_item_match = True
421        break
422    self.assertTrue(any_menu_item_match)
423
424  def testPrintSourceWithOpTypeFilter(self):
425    prof_output = self.prof_analyzer.print_source([
426        __file__, "--op_type_filter", "Less"])
427
428    _assert_at_least_one_line_matches(
429        r"\[(\|)+(\s)*\] .*us .*1\(11\) .*L%d.*(\S)+" % self.loop_cond_lineno,
430        prof_output.lines)
431    _assert_no_lines_match(
432        r"\[(\|)+(\s)*\] .*us .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno,
433        prof_output.lines)
434    _assert_no_lines_match(
435        r"\[(\|)+(\s)*\] .*us .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno,
436        prof_output.lines)
437
438  def testPrintSourceWithNonexistentDeviceGivesCorrectErrorMessage(self):
439    prof_output = self.prof_analyzer.print_source([
440        __file__, "--device_name_filter", "foo_device"])
441
442    _assert_at_least_one_line_matches(
443        r"The source file .* does not contain any profile information for the "
444        "previous Session run", prof_output.lines)
445    _assert_at_least_one_line_matches(
446        r".*--device_name_filter: foo_device", prof_output.lines)
447
448  def testPrintSourceWithUnrelatedFileShowsCorrectErrorMessage(self):
449    prof_output = self.prof_analyzer.print_source([tf_inspect.__file__])
450    _assert_at_least_one_line_matches(
451        r"The source file .* does not contain any profile information for the "
452        "previous Session run", prof_output.lines)
453
454  def testPrintSourceOutputContainsInitScrollPosAnnotation(self):
455    prof_output = self.prof_analyzer.print_source([
456        __file__, "--init_line", str(self.loop_cond_lineno)])
457    self.assertEqual(
458        self.loop_cond_lineno + 1,  # The extra line is due to the head lines.
459        prof_output.annotations[debugger_cli_common.INIT_SCROLL_POS_KEY])
460
461
462if __name__ == "__main__":
463  googletest.main()
464