1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for profile_analyzer_cli.""" 16 17import re 18 19from tensorflow.core.framework import step_stats_pb2 20from tensorflow.core.protobuf import config_pb2 21from tensorflow.core.protobuf import rewriter_config_pb2 22from tensorflow.python.client import session 23from tensorflow.python.debug.cli import debugger_cli_common 24from tensorflow.python.debug.cli import profile_analyzer_cli 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import googletest 31from tensorflow.python.platform import test 32from tensorflow.python.util import tf_inspect 33 34 35def no_rewrite_session_config(): 36 rewriter_config = rewriter_config_pb2.RewriterConfig( 37 disable_model_pruning=True, 38 constant_folding=rewriter_config_pb2.RewriterConfig.OFF) 39 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 40 return config_pb2.ConfigProto(graph_options=graph_options) 41 42 43def _line_number_above(): 44 return tf_inspect.stack()[1][2] - 1 45 46 47def _at_least_one_line_matches(pattern, lines): 48 pattern_re = re.compile(pattern) 49 for i, line in enumerate(lines): 50 if pattern_re.search(line): 51 return True, i 52 return False, None 53 54 55def _assert_at_least_one_line_matches(pattern, lines): 56 any_match, _ = _at_least_one_line_matches(pattern, lines) 57 if not any_match: 58 raise AssertionError( 59 "%s does not match any line in %s." % (pattern, str(lines))) 60 61 62def _assert_no_lines_match(pattern, lines): 63 any_match, _ = _at_least_one_line_matches(pattern, lines) 64 if any_match: 65 raise AssertionError( 66 "%s matched at least one line in %s." % (pattern, str(lines))) 67 68 69@test_util.run_v1_only("Requires tf.Session") 70class ProfileAnalyzerListProfileTest(test_util.TensorFlowTestCase): 71 72 def testNodeInfoEmpty(self): 73 graph = ops.Graph() 74 run_metadata = config_pb2.RunMetadata() 75 76 prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) 77 prof_output = prof_analyzer.list_profile([]).lines 78 self.assertEqual([""], prof_output) 79 80 def testSingleDevice(self): 81 node1 = step_stats_pb2.NodeExecStats( 82 node_name="Add/123", 83 op_start_rel_micros=3, 84 op_end_rel_micros=5, 85 all_end_rel_micros=4) 86 87 node2 = step_stats_pb2.NodeExecStats( 88 node_name="Mul/456", 89 op_start_rel_micros=1, 90 op_end_rel_micros=2, 91 all_end_rel_micros=3) 92 93 run_metadata = config_pb2.RunMetadata() 94 device1 = run_metadata.step_stats.dev_stats.add() 95 device1.device = "deviceA" 96 device1.node_stats.extend([node1, node2]) 97 98 graph = test.mock.MagicMock() 99 op1 = test.mock.MagicMock() 100 op1.name = "Add/123" 101 op1.traceback = [("a/b/file1", 10, "some_var")] 102 op1.type = "add" 103 op2 = test.mock.MagicMock() 104 op2.name = "Mul/456" 105 op2.traceback = [("a/b/file1", 11, "some_var")] 106 op2.type = "mul" 107 graph.get_operations.return_value = [op1, op2] 108 109 prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) 110 prof_output = prof_analyzer.list_profile([]).lines 111 112 _assert_at_least_one_line_matches(r"Device 1 of 1: deviceA", prof_output) 113 _assert_at_least_one_line_matches(r"^Add/123.*add.*2us.*4us", prof_output) 114 _assert_at_least_one_line_matches(r"^Mul/456.*mul.*1us.*3us", prof_output) 115 116 def testMultipleDevices(self): 117 node1 = step_stats_pb2.NodeExecStats( 118 node_name="Add/123", 119 op_start_rel_micros=3, 120 op_end_rel_micros=5, 121 all_end_rel_micros=3) 122 123 run_metadata = config_pb2.RunMetadata() 124 device1 = run_metadata.step_stats.dev_stats.add() 125 device1.device = "deviceA" 126 device1.node_stats.extend([node1]) 127 128 device2 = run_metadata.step_stats.dev_stats.add() 129 device2.device = "deviceB" 130 device2.node_stats.extend([node1]) 131 132 graph = test.mock.MagicMock() 133 op = test.mock.MagicMock() 134 op.name = "Add/123" 135 op.traceback = [("a/b/file1", 10, "some_var")] 136 op.type = "abc" 137 graph.get_operations.return_value = [op] 138 139 prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) 140 prof_output = prof_analyzer.list_profile([]).lines 141 142 _assert_at_least_one_line_matches(r"Device 1 of 2: deviceA", prof_output) 143 _assert_at_least_one_line_matches(r"Device 2 of 2: deviceB", prof_output) 144 145 # Try filtering by device. 146 prof_output = prof_analyzer.list_profile(["-d", "deviceB"]).lines 147 _assert_at_least_one_line_matches(r"Device 2 of 2: deviceB", prof_output) 148 _assert_no_lines_match(r"Device 1 of 2: deviceA", prof_output) 149 150 def testWithSession(self): 151 options = config_pb2.RunOptions() 152 options.trace_level = config_pb2.RunOptions.FULL_TRACE 153 run_metadata = config_pb2.RunMetadata() 154 155 with session.Session(config=no_rewrite_session_config()) as sess: 156 a = constant_op.constant([1, 2, 3]) 157 b = constant_op.constant([2, 2, 1]) 158 result = math_ops.add(a, b) 159 160 sess.run(result, options=options, run_metadata=run_metadata) 161 162 prof_analyzer = profile_analyzer_cli.ProfileAnalyzer( 163 sess.graph, run_metadata) 164 prof_output = prof_analyzer.list_profile([]).lines 165 166 _assert_at_least_one_line_matches("Device 1 of", prof_output) 167 expected_headers = [ 168 "Node", r"Start Time \(us\)", r"Op Time \(.*\)", r"Exec Time \(.*\)", 169 r"Filename:Lineno\(function\)"] 170 _assert_at_least_one_line_matches( 171 ".*".join(expected_headers), prof_output) 172 _assert_at_least_one_line_matches(r"^Add/", prof_output) 173 _assert_at_least_one_line_matches(r"Device Total", prof_output) 174 175 def testSorting(self): 176 node1 = step_stats_pb2.NodeExecStats( 177 node_name="Add/123", 178 all_start_micros=123, 179 op_start_rel_micros=3, 180 op_end_rel_micros=5, 181 all_end_rel_micros=4) 182 183 node2 = step_stats_pb2.NodeExecStats( 184 node_name="Mul/456", 185 all_start_micros=122, 186 op_start_rel_micros=1, 187 op_end_rel_micros=2, 188 all_end_rel_micros=5) 189 190 run_metadata = config_pb2.RunMetadata() 191 device1 = run_metadata.step_stats.dev_stats.add() 192 device1.device = "deviceA" 193 device1.node_stats.extend([node1, node2]) 194 195 graph = test.mock.MagicMock() 196 op1 = test.mock.MagicMock() 197 op1.name = "Add/123" 198 op1.traceback = [("a/b/file2", 10, "some_var")] 199 op1.type = "add" 200 op2 = test.mock.MagicMock() 201 op2.name = "Mul/456" 202 op2.traceback = [("a/b/file1", 11, "some_var")] 203 op2.type = "mul" 204 graph.get_operations.return_value = [op1, op2] 205 206 prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) 207 208 # Default sort by start time (i.e. all_start_micros). 209 prof_output = prof_analyzer.list_profile([]).lines 210 self.assertRegex("".join(prof_output), r"Mul/456.*Add/123") 211 # Default sort in reverse. 212 prof_output = prof_analyzer.list_profile(["-r"]).lines 213 self.assertRegex("".join(prof_output), r"Add/123.*Mul/456") 214 # Sort by name. 215 prof_output = prof_analyzer.list_profile(["-s", "node"]).lines 216 self.assertRegex("".join(prof_output), r"Add/123.*Mul/456") 217 # Sort by op time (i.e. op_end_rel_micros - op_start_rel_micros). 218 prof_output = prof_analyzer.list_profile(["-s", "op_time"]).lines 219 self.assertRegex("".join(prof_output), r"Mul/456.*Add/123") 220 # Sort by exec time (i.e. all_end_rel_micros). 221 prof_output = prof_analyzer.list_profile(["-s", "exec_time"]).lines 222 self.assertRegex("".join(prof_output), r"Add/123.*Mul/456") 223 # Sort by line number. 224 prof_output = prof_analyzer.list_profile(["-s", "line"]).lines 225 self.assertRegex("".join(prof_output), r"Mul/456.*Add/123") 226 227 def testFiltering(self): 228 node1 = step_stats_pb2.NodeExecStats( 229 node_name="Add/123", 230 all_start_micros=123, 231 op_start_rel_micros=3, 232 op_end_rel_micros=5, 233 all_end_rel_micros=4) 234 235 node2 = step_stats_pb2.NodeExecStats( 236 node_name="Mul/456", 237 all_start_micros=122, 238 op_start_rel_micros=1, 239 op_end_rel_micros=2, 240 all_end_rel_micros=5) 241 242 run_metadata = config_pb2.RunMetadata() 243 device1 = run_metadata.step_stats.dev_stats.add() 244 device1.device = "deviceA" 245 device1.node_stats.extend([node1, node2]) 246 247 graph = test.mock.MagicMock() 248 op1 = test.mock.MagicMock() 249 op1.name = "Add/123" 250 op1.traceback = [("a/b/file2", 10, "some_var")] 251 op1.type = "add" 252 op2 = test.mock.MagicMock() 253 op2.name = "Mul/456" 254 op2.traceback = [("a/b/file1", 11, "some_var")] 255 op2.type = "mul" 256 graph.get_operations.return_value = [op1, op2] 257 258 prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) 259 260 # Filter by name 261 prof_output = prof_analyzer.list_profile(["-n", "Add"]).lines 262 _assert_at_least_one_line_matches(r"Add/123", prof_output) 263 _assert_no_lines_match(r"Mul/456", prof_output) 264 # Filter by op_type 265 prof_output = prof_analyzer.list_profile(["-t", "mul"]).lines 266 _assert_at_least_one_line_matches(r"Mul/456", prof_output) 267 _assert_no_lines_match(r"Add/123", prof_output) 268 # Filter by file name. 269 prof_output = prof_analyzer.list_profile(["-f", ".*file2"]).lines 270 _assert_at_least_one_line_matches(r"Add/123", prof_output) 271 _assert_no_lines_match(r"Mul/456", prof_output) 272 # Filter by execution time. 273 prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines 274 _assert_at_least_one_line_matches(r"Mul/456", prof_output) 275 _assert_no_lines_match(r"Add/123", prof_output) 276 # Filter by op time. 277 prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines 278 _assert_at_least_one_line_matches(r"Add/123", prof_output) 279 _assert_no_lines_match(r"Mul/456", prof_output) 280 281 def testSpecifyingTimeUnit(self): 282 node1 = step_stats_pb2.NodeExecStats( 283 node_name="Add/123", 284 all_start_micros=123, 285 op_start_rel_micros=3, 286 op_end_rel_micros=5, 287 all_end_rel_micros=4) 288 289 node2 = step_stats_pb2.NodeExecStats( 290 node_name="Mul/456", 291 all_start_micros=122, 292 op_start_rel_micros=1, 293 op_end_rel_micros=2, 294 all_end_rel_micros=5) 295 296 run_metadata = config_pb2.RunMetadata() 297 device1 = run_metadata.step_stats.dev_stats.add() 298 device1.device = "deviceA" 299 device1.node_stats.extend([node1, node2]) 300 301 graph = test.mock.MagicMock() 302 op1 = test.mock.MagicMock() 303 op1.name = "Add/123" 304 op1.traceback = [("a/b/file2", 10, "some_var")] 305 op1.type = "add" 306 op2 = test.mock.MagicMock() 307 op2.name = "Mul/456" 308 op2.traceback = [("a/b/file1", 11, "some_var")] 309 op2.type = "mul" 310 graph.get_operations.return_value = [op1, op2] 311 312 prof_analyzer = profile_analyzer_cli.ProfileAnalyzer(graph, run_metadata) 313 314 # Force time unit. 315 prof_output = prof_analyzer.list_profile(["--time_unit", "ms"]).lines 316 _assert_at_least_one_line_matches(r"Add/123.*add.*0\.002ms", prof_output) 317 _assert_at_least_one_line_matches(r"Mul/456.*mul.*0\.005ms", prof_output) 318 _assert_at_least_one_line_matches(r"Device Total.*0\.009ms", prof_output) 319 320 321@test_util.run_v1_only("Requires tf.Session") 322class ProfileAnalyzerPrintSourceTest(test_util.TensorFlowTestCase): 323 324 def setUp(self): 325 super(ProfileAnalyzerPrintSourceTest, self).setUp() 326 327 options = config_pb2.RunOptions() 328 options.trace_level = config_pb2.RunOptions.FULL_TRACE 329 run_metadata = config_pb2.RunMetadata() 330 with session.Session() as sess: 331 loop_cond = lambda x: math_ops.less(x, 10) 332 self.loop_cond_lineno = _line_number_above() 333 loop_body = lambda x: math_ops.add(x, 1) 334 self.loop_body_lineno = _line_number_above() 335 x = constant_op.constant(0, name="x") 336 self.x_lineno = _line_number_above() 337 loop = control_flow_ops.while_loop(loop_cond, loop_body, [x]) 338 self.loop_lineno = _line_number_above() 339 self.assertEqual( 340 10, sess.run(loop, options=options, run_metadata=run_metadata)) 341 342 self.prof_analyzer = profile_analyzer_cli.ProfileAnalyzer( 343 sess.graph, run_metadata) 344 345 def tearDown(self): 346 ops.reset_default_graph() 347 super(ProfileAnalyzerPrintSourceTest, self).tearDown() 348 349 def testPrintSourceForWhileLoop(self): 350 prof_output = self.prof_analyzer.print_source([__file__]) 351 352 _assert_at_least_one_line_matches( 353 r"\[(\|)+(\s)*\] .*us .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno, 354 prof_output.lines) 355 _assert_at_least_one_line_matches( 356 r"\[(\|)+(\s)*\] .*us .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno, 357 prof_output.lines) 358 _assert_at_least_one_line_matches( 359 r"\[(\|)+(\s)*\] .*us .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno, 360 prof_output.lines) 361 362 def testPrintSourceOutputContainsClickableLinks(self): 363 prof_output = self.prof_analyzer.print_source([__file__]) 364 any_match, line_index = _at_least_one_line_matches( 365 r"\[(\|)+(\s)*\] .*us .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno, 366 prof_output.lines) 367 self.assertTrue(any_match) 368 any_menu_item_match = False 369 for seg in prof_output.font_attr_segs[line_index]: 370 if (isinstance(seg[2][1], debugger_cli_common.MenuItem) and 371 seg[2][1].content.startswith("lp --file_path_filter ") and 372 "--min_lineno %d" % self.loop_cond_lineno in seg[2][1].content and 373 "--max_lineno %d" % (self.loop_cond_lineno + 1) in seg[2][1].content): 374 any_menu_item_match = True 375 break 376 self.assertTrue(any_menu_item_match) 377 378 def testPrintSourceWithNonDefaultTimeUnit(self): 379 prof_output = self.prof_analyzer.print_source([ 380 __file__, "--time_unit", "ms"]) 381 382 _assert_at_least_one_line_matches( 383 r"\[(\|)+(\s)*\] .*ms .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno, 384 prof_output.lines) 385 _assert_at_least_one_line_matches( 386 r"\[(\|)+(\s)*\] .*ms .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno, 387 prof_output.lines) 388 _assert_at_least_one_line_matches( 389 r"\[(\|)+(\s)*\] .*ms .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno, 390 prof_output.lines) 391 392 def testPrintSourceWithNodeNameFilter(self): 393 prof_output = self.prof_analyzer.print_source([ 394 __file__, "--node_name_filter", "x$"]) 395 396 _assert_at_least_one_line_matches( 397 r"\[(\|)+(\s)*\] .*us .*1\(1\) .*L%d.*(\S)+" % self.x_lineno, 398 prof_output.lines) 399 _assert_no_lines_match( 400 r"\[(\|)+(\s)*\] .*us .*2\(22\) .*L%d.*(\S)+" % self.loop_cond_lineno, 401 prof_output.lines) 402 _assert_no_lines_match( 403 r"\[(\|)+(\s)*\] .*us .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno, 404 prof_output.lines) 405 _assert_no_lines_match( 406 r"\[(\|)+(\s)*\] .*ms .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno, 407 prof_output.lines) 408 409 # Check clickable link. 410 _, line_index = _at_least_one_line_matches( 411 r"\[(\|)+(\s)*\] .*us .*1\(1\) .*L%d.*(\S)+" % self.x_lineno, 412 prof_output.lines) 413 any_menu_item_match = False 414 for seg in prof_output.font_attr_segs[line_index]: 415 if (isinstance(seg[2][1], debugger_cli_common.MenuItem) and 416 seg[2][1].content.startswith("lp --file_path_filter ") and 417 "--node_name_filter x$" in seg[2][1].content and 418 "--min_lineno %d" % self.x_lineno in seg[2][1].content and 419 "--max_lineno %d" % (self.x_lineno + 1) in seg[2][1].content): 420 any_menu_item_match = True 421 break 422 self.assertTrue(any_menu_item_match) 423 424 def testPrintSourceWithOpTypeFilter(self): 425 prof_output = self.prof_analyzer.print_source([ 426 __file__, "--op_type_filter", "Less"]) 427 428 _assert_at_least_one_line_matches( 429 r"\[(\|)+(\s)*\] .*us .*1\(11\) .*L%d.*(\S)+" % self.loop_cond_lineno, 430 prof_output.lines) 431 _assert_no_lines_match( 432 r"\[(\|)+(\s)*\] .*us .*2\(20\) .*L%d.*(\S)+" % self.loop_body_lineno, 433 prof_output.lines) 434 _assert_no_lines_match( 435 r"\[(\|)+(\s)*\] .*us .*7\(55\) .*L%d.*(\S)+" % self.loop_lineno, 436 prof_output.lines) 437 438 def testPrintSourceWithNonexistentDeviceGivesCorrectErrorMessage(self): 439 prof_output = self.prof_analyzer.print_source([ 440 __file__, "--device_name_filter", "foo_device"]) 441 442 _assert_at_least_one_line_matches( 443 r"The source file .* does not contain any profile information for the " 444 "previous Session run", prof_output.lines) 445 _assert_at_least_one_line_matches( 446 r".*--device_name_filter: foo_device", prof_output.lines) 447 448 def testPrintSourceWithUnrelatedFileShowsCorrectErrorMessage(self): 449 prof_output = self.prof_analyzer.print_source([tf_inspect.__file__]) 450 _assert_at_least_one_line_matches( 451 r"The source file .* does not contain any profile information for the " 452 "previous Session run", prof_output.lines) 453 454 def testPrintSourceOutputContainsInitScrollPosAnnotation(self): 455 prof_output = self.prof_analyzer.print_source([ 456 __file__, "--init_line", str(self.loop_cond_lineno)]) 457 self.assertEqual( 458 self.loop_cond_lineno + 1, # The extra line is due to the head lines. 459 prof_output.annotations[debugger_cli_common.INIT_SCROLL_POS_KEY]) 460 461 462if __name__ == "__main__": 463 googletest.main() 464