1# mypy: allow-untyped-defs 2import json 3import math 4import os 5import re 6from typing import Dict, List, Optional, Set 7 8import torch 9import torch.utils.benchmark as benchmark 10from torch._C._profiler import ( 11 _EventType, 12 _ExtraFields_PyCall, 13 _ExtraFields_PyCCall, 14 _ExtraFields_TorchOp, 15 _ProfilerEvent, 16) 17from torch.profiler import profile 18from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs 19 20 21class Pattern: 22 """ 23 Base class for all patterns, subclass this class and implement match() 24 to define custom patterns. 25 26 In subclass, define description and skip property. 27 """ 28 29 def __init__(self, prof: profile, should_benchmark: bool = False): 30 self.prof = prof 31 self.should_benchmark = should_benchmark 32 self.name = "Please specify a name for pattern" 33 self.description = "Please specify a description for pattern" 34 self.url = "" 35 assert prof.profiler is not None and prof.profiler.kineto_results is not None 36 self.event_tree = prof.profiler.kineto_results.experimental_event_tree() 37 self.tid_root: Dict[int, List[_ProfilerEvent]] = {} 38 for event in self.event_tree: 39 self.tid_root.setdefault(event.start_tid, []).append(event) 40 41 @property 42 def skip(self): 43 return False 44 45 def report(self, event: _ProfilerEvent): 46 msg = ( 47 f"{self.description}\n[Source Code Location] {source_code_location(event)}" 48 ) 49 return msg 50 51 def eventTreeTraversal(self): 52 """ 53 Traverse the event tree and yield all events. 54 Override this method in subclass to customize the traversal. 55 """ 56 yield from traverse_dfs(self.event_tree) 57 58 def summary(self, events: List[_ProfilerEvent]): 59 default_summary = f"{self.name}: {len(events)} events matched." 60 if self.should_benchmark: 61 # If benchmark summary is not empty, use it. 62 return ( 63 self.benchmark_summary(events) 64 if hasattr(self, "benchmark") # type: ignore[attr-defined] 65 else default_summary 66 ) 67 return default_summary 68 69 def benchmark_summary(self, events: List[_ProfilerEvent]): 70 def format_time(time_ns: int): 71 unit_lst = ["ns", "us", "ms"] 72 for unit in unit_lst: 73 if time_ns < 1000: 74 return f"{time_ns:.2f} {unit}" 75 time_ns //= 1000 76 return f"{time_ns:.2f} s" 77 78 assert hasattr(self, "benchmark"), "Please implement benchmark()" 79 shapes_factor_map = self.benchmark(events) # type: ignore[attr-defined] 80 original_time = sum(event.duration_time_ns for event in events) 81 new_time = sum( 82 shapes_factor_map[input_shapes(event)] * event.duration_time_ns 83 for event in events 84 ) 85 return ( 86 f"{self.name}: {len(events)} events matched. " 87 f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)" 88 ) 89 90 def match(self, event: _ProfilerEvent): 91 """ 92 Return True if the event matches the pattern. 93 This method should be overriden in subclass. 94 """ 95 raise NotImplementedError 96 97 def matched_events(self): 98 if self.skip: 99 return [] 100 matched_events = [] 101 for event in self.eventTreeTraversal(): 102 if self.match(event): 103 matched_events.append(event) 104 return matched_events 105 106 def root_of(self, event: _ProfilerEvent): 107 while event.parent: 108 event = event.parent 109 return event 110 111 def siblings_of(self, event: _ProfilerEvent): 112 if event.parent: 113 children = event.parent.children 114 else: 115 children = self.tid_root[event.start_tid] 116 index = children.index(event) 117 return children[:index], children[index + 1 :] 118 119 def next_of(self, event: _ProfilerEvent): 120 _, next_events = self.siblings_of(event) 121 return next_events[0] if next_events else None 122 123 def prev_of(self, event: _ProfilerEvent): 124 prev_events, _ = self.siblings_of(event) 125 return prev_events[-1] if prev_events else None 126 127 def go_up_until(self, event: _ProfilerEvent, predicate): 128 if not event: 129 return None 130 while event.parent and not predicate(event): 131 event = event.parent 132 return event 133 134 135# Patterns 136 137 138class NamePattern(Pattern): 139 def __init__(self, prof: profile, name: str, should_benchmark: bool = False): 140 super().__init__(prof, should_benchmark) 141 self.description = f"Matched Name Event: {name}" 142 self.name = name 143 144 def match(self, event: _ProfilerEvent): 145 return re.search(self.name, event.name) is not None 146 147 148class ExtraCUDACopyPattern(Pattern): 149 """ 150 This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU. 151 example: torch.zeros((100, 100)).to("cuda") 152 153 Pattern: 154 build-in method |build-in method 155 ... | aten::to 156 aten::fill_/aten::zero_ | aten::_to_copy 157 158 Algorithm: 159 We start at node aten::to, go parent events' previous events, 160 and check if we have a aten::fill_/aten::zero_ as we keep going down the tree. 161 We always select the last child in the children list when we go down the tree. 162 If at any step we failed, it is not a match. 163 """ 164 165 def __init__(self, prof: profile, should_benchmark: bool = False): 166 super().__init__(prof, should_benchmark) 167 self.name = "Extra CUDA Copy Pattern" 168 self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU." 169 self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device" 170 self.init_ops = { 171 "aten::fill_", 172 "aten::zero_", 173 "aten::normal_", 174 "aten::uniform_", 175 } 176 177 @property 178 def skip(self): 179 return not self.prof.with_stack or not self.prof.record_shapes 180 181 def match(self, event): 182 # TODO: We should also check tensor identities 183 if event.name != "aten::to": 184 return False 185 to_event = event 186 if not event.children: 187 return False 188 event = event.children[-1] 189 if event.name != "aten::_to_copy": 190 return False 191 if not event.children: 192 return False 193 event = event.children[-1] 194 if event.name != "aten::copy_": 195 return False 196 # aten::copy_ should have the first 2 args dtype the same 197 dtypes = input_dtypes(event) 198 if len(dtypes) < 2: 199 return False 200 if dtypes[0] is None or dtypes[0] != dtypes[1]: 201 return False 202 event = to_event 203 # Up one level 204 event = event.parent 205 if event is None: 206 return False 207 # Check if we have a aten::fill_ in previous leaf 208 event = self.prev_of(event) 209 if event is None: 210 return False 211 while event.children: 212 event = event.children[-1] 213 # aten::zero_ is a special optimzation case where fill_ is not called 214 if event.name in self.init_ops: 215 return True 216 return event.name in self.init_ops 217 # TODO: Check if tensor is reused 218 219 def benchmark(self, events: List[_ProfilerEvent]): 220 shapes_factor_map = {input_shapes(event): 0.0 for event in events} 221 for shape in shapes_factor_map: 222 size = shape[0] 223 to_timer = benchmark.Timer( 224 stmt='torch.ones(size).to("cuda")', globals={"size": size} 225 ) 226 de_timer = benchmark.Timer( 227 stmt='torch.ones(size, device="cuda")', globals={"size": size} 228 ) 229 to_time = to_timer.timeit(10).mean 230 de_time = de_timer.timeit(10).mean 231 shapes_factor_map[shape] = de_time / to_time 232 return shapes_factor_map 233 234 235class ForLoopIndexingPattern(Pattern): 236 """ 237 This pattern identifies if we use a for loop to index a tensor that 238 can be vectorized. 239 example: 240 tensor = torch.empty((100, 100)) 241 for i in range(100): 242 tensor[i] = i 243 244 Pattern: 245 aten::select | ... | aten::select | ... (Repeat) 246 247 Algorithm: 248 We start at node aten::select, and we check if we can find this alternating patterns. 249 We also keep a dictionary to avoid duplicate match in the for loop. 250 """ 251 252 def __init__(self, prof: profile, should_benchmark: bool = False): 253 super().__init__(prof, should_benchmark) 254 self.name = "For Loop Indexing Pattern" 255 self.description = "For loop indexing detected. Vectorization recommended." 256 self.visited: Set[int] = set() 257 258 def eventTreeTraversal(self): 259 """ 260 We need to use BFS traversal order to avoid duplicate match. 261 """ 262 yield from traverse_bfs(self.event_tree) 263 264 def match(self, event: _ProfilerEvent): 265 if event.name != "aten::select": 266 return False 267 if event.id in self.visited: 268 return False 269 repeat_count = 1 270 _, next = self.siblings_of(event) 271 if len(next) <= 1: 272 return False 273 274 # Custom event list matching 275 def same_ops(list1, list2): 276 if len(list1) != len(list2): 277 return False 278 for op1, op2 in zip(list1, list2): 279 if op1.name != op2.name: 280 return False 281 return True 282 283 # Record the ops between two aten::select 284 next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select") 285 if next_select_idx is None: 286 return False 287 indexing_ops = [event] + next[:next_select_idx] 288 next = next[len(indexing_ops) - 1 :] 289 for i in range(0, len(next), len(indexing_ops)): 290 if same_ops(indexing_ops, next[i : i + len(indexing_ops)]): 291 repeat_count += 1 292 self.visited.add(next[i].id) 293 else: 294 break 295 return repeat_count >= 10 296 297 298class FP32MatMulPattern(Pattern): 299 def __init__(self, prof: profile, should_benchmark: bool = False): 300 super().__init__(prof, should_benchmark) 301 self.name = "FP32 MatMul Pattern" 302 self.description = ( 303 "You are currently using GPU that supports TF32. " 304 "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'" 305 ) 306 self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 307 308 @property 309 def skip(self): 310 if torch.version.hip is not None: 311 has_tf32 = False 312 else: 313 # Anything less than sm_80 is not Ampere which doesn't support TF32 314 has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list()) 315 return has_tf32 is False or super().skip or not self.prof.record_shapes 316 317 def match(self, event: _ProfilerEvent): 318 # If we saw this pattern once, we don't need to match it again 319 if event.tag != _EventType.TorchOp: 320 return False 321 assert isinstance(event.extra_fields, _ExtraFields_TorchOp) 322 if event.name == "aten::mm": 323 if event.extra_fields.allow_tf32_cublas is False: 324 return True 325 return False 326 327 def report(self, event: _ProfilerEvent): 328 return self.description 329 330 def benchmark(self, events: List[_ProfilerEvent]): 331 shapes_factor_map = {input_shapes(event): 0.0 for event in events} 332 for shape in shapes_factor_map: 333 matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32) 334 matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32) 335 fp32_timer = benchmark.Timer( 336 stmt="torch.mm(matrixA, matrixB)", 337 globals={"matrixA": matrixA, "matrixB": matrixB}, 338 ) 339 tf32_timer = benchmark.Timer( 340 stmt="torch.mm(matrixA, matrixB)", 341 setup="torch.backends.cuda.matmul.allow_tf32 = True", 342 globals={"matrixA": matrixA, "matrixB": matrixB}, 343 ) 344 torch.backends.cuda.matmul.allow_tf32 = False 345 fp32_time = fp32_timer.timeit(10).mean 346 tf32_time = tf32_timer.timeit(10).mean 347 shapes_factor_map[shape] = tf32_time / fp32_time 348 return shapes_factor_map 349 350 351class OptimizerSingleTensorPattern(Pattern): 352 """ 353 This pattern identifies if we are using the single-tensor version of an optimizer. 354 example: 355 optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 356 By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when 357 the kernels are relatively small. 358 359 Pattern: 360 XXXXX: _single_tenser_<OPTIMIZER_NAME> 361 362 Algorithm: 363 String match 364 """ 365 366 def __init__(self, prof: profile, should_benchmark: bool = False): 367 super().__init__(prof, should_benchmark) 368 self.name = "Optimizer Single Tensor Pattern" 369 self.optimizers_with_foreach = ["adam", "sgd", "adamw"] 370 self.description = ( 371 "Deteced optimizer running with single tensor implementation. " 372 "Please enable multi tensor implementation by passing 'foreach=True' into optimizer." 373 ) 374 self.url = "" 375 376 def match(self, event: _ProfilerEvent): 377 for optimizer in self.optimizers_with_foreach: 378 if event.name.endswith(f"_single_tensor_{optimizer}"): 379 return True 380 return False 381 382 383class SynchronizedDataLoaderPattern(Pattern): 384 """ 385 This pattern identifies if we are using num_workers=0 in DataLoader. 386 example: 387 torch.utils.data.DataLoader(dataset, batch_size=batch_size) 388 Add num_workers=N to the arguments. N depends on system configuration. 389 390 Pattern: 391 dataloader.py(...): __iter__ 392 dataloader.py(...): _get_iterator 393 NOT dataloader.py(...): check_worker_number_rationality 394 395 Algorithm: 396 If we don't see check_worker_number_rationality call in the dataloader __iter__, 397 It is not an asynchronous dataloader. 398 399 """ 400 401 def __init__(self, prof: profile, should_benchmark: bool = False): 402 super().__init__(prof, should_benchmark) 403 self.name = "Synchronized DataLoader Pattern" 404 self.description = ( 405 "Detected DataLoader running with synchronized implementation. " 406 "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader." 407 ) 408 self.url = ( 409 "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" 410 "#enable-async-data-loading-and-augmentation" 411 ) 412 413 def match(self, event: _ProfilerEvent): 414 def is_dataloader_function(name: str, function_name: str): 415 return name.startswith( 416 os.path.join("torch", "utils", "data", "dataloader.py") 417 ) and name.endswith(function_name) 418 419 # TODO: fixme! Due to lifetime issues of the function name, this field might 420 # actually point to an already freed string when the even is a PyCall. 421 # Just silently skip this to unblock testing. 422 try: 423 event.name 424 except UnicodeDecodeError: 425 return False 426 427 if not is_dataloader_function(event.name, "__iter__"): 428 return False 429 if not event.children: 430 return False 431 event = event.children[0] 432 if not is_dataloader_function(event.name, "_get_iterator"): 433 return False 434 if not event.children: 435 return False 436 event = event.children[0] 437 return not is_dataloader_function(event.name, "check_worker_number_rationality") 438 # TODO: We should also check if the loader is bottleneck. 439 440 441class GradNotSetToNonePattern(Pattern): 442 """ 443 This pattern identifies if we are not setting grad to None in zero_grad. 444 example: 445 optimizer.zero_grad() 446 By setting set_to_none=True, we can gain speedup 447 448 Pattern: 449 XXXXX: _zero_grad 450 NOT aten::zeros 451 aten::zero_ 452 453 aten::zero_ is called on each parameter in the model. 454 We also want to make sure it is not called by aten::zeros. 455 456 Algorithm: 457 String match 458 """ 459 460 def __init__(self, prof: profile, should_benchmark: bool = False): 461 super().__init__(prof, should_benchmark) 462 self.name = "Gradient Set To Zero Instead of None Pattern" 463 self.description = ( 464 "Detected gradient set to zero instead of None. " 465 "Please add 'set_to_none=True' when calling zero_grad()." 466 ) 467 self.url = ( 468 "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" 469 "#disable-gradient-calculation-for-validation-or-inference" 470 ) 471 472 def match(self, event: _ProfilerEvent): 473 if not event.name.endswith(": zero_grad"): 474 return False 475 if not event.children: 476 return False 477 478 for sub_event in traverse_dfs(event.children): 479 if ( 480 sub_event.name == "aten::zero_" 481 and sub_event.parent.name != "aten::zeros" 482 ): 483 return True 484 # TODO: We should also check if the optimizer's numerical behavior will change. 485 return False 486 487 488class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): 489 """ 490 This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d. 491 Bias doesn't do anything when followed by batchnorm. 492 Pattern: 493 nn.Module: Conv2d | nn.Module: BatchNorm2d 494 ... 495 aten::conv2d AND dtype of third argument is not null 496 The third argument is the bias 497 Algorithm: 498 String match 499 """ 500 501 def __init__(self, prof: profile, should_benchmark: bool = False): 502 super().__init__(prof, should_benchmark) 503 self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern" 504 self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d." 505 self.url = ( 506 "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html" 507 "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm" 508 ) 509 510 @property 511 def skip(self): 512 return self.prof.record_shapes is False or super().skip 513 514 def match(self, event: _ProfilerEvent): 515 if event.name != "aten::conv2d": 516 return False 517 if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None: 518 return False 519 # This means bias=True 520 event = self.go_up_until( 521 event, lambda e: e.name.startswith("nn.Module: Conv2d") 522 ) 523 if not event: 524 return False 525 event = self.next_of(event) 526 if not event: 527 return False 528 return event.name.startswith("nn.Module: BatchNorm2d") 529 530 531class MatMulDimInFP16Pattern(Pattern): 532 def __init__(self, prof: profile, should_benchmark: bool = False): 533 super().__init__(prof, should_benchmark) 534 self.name = "Matrix Multiplication Dimension Not Aligned Pattern" 535 self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension." 536 self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp" 537 538 @property 539 def skip(self): 540 return not self.prof.with_stack or not self.prof.record_shapes 541 542 def match(self, event: _ProfilerEvent): 543 def mutiple_of(shapes, multiple): 544 return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:]) 545 546 if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"): 547 return False 548 if not input_dtypes(event): 549 return False 550 arg_dtype = input_dtypes(event)[0] 551 if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of( 552 input_shapes(event), 8 553 ): 554 return True 555 return False 556 557 def benchmark(self, events: List[_ProfilerEvent]): 558 def closest_multiple(shapes, multiple): 559 return [multiple * math.ceil(shape / multiple) for shape in shapes] 560 561 shapes_factor_map = {input_shapes(event): 0.0 for event in events} 562 for shape in shapes_factor_map: 563 matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16) 564 matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16) 565 not_aligned_dim_timer = benchmark.Timer( 566 stmt="torch.mm(matrixA, matrixB)", 567 globals={"matrixA": matrixA, "matrixB": matrixB}, 568 ) 569 matrixA = torch.randn( 570 closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16 571 ) 572 matrixB = torch.randn( 573 closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16 574 ) 575 aligned_dim_timer = benchmark.Timer( 576 stmt="torch.mm(matrixA, matrixB)", 577 globals={"matrixA": matrixA, "matrixB": matrixB}, 578 ) 579 not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean 580 aligned_dim_time = aligned_dim_timer.timeit(10).mean 581 shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time 582 return shapes_factor_map 583 584 585def source_code_location(event: Optional[_ProfilerEvent]): 586 while event: 587 if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall: 588 assert isinstance( 589 event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall) 590 ) 591 if not event.extra_fields.caller.file_name.startswith("torch" + os.sep): 592 return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}" 593 event = event.parent 594 return "No source code location found" 595 596 597def input_shapes(event: _ProfilerEvent): 598 assert isinstance(event.extra_fields, _ExtraFields_TorchOp) 599 return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs) 600 601 602def input_dtypes(event: _ProfilerEvent): 603 assert isinstance(event.extra_fields, _ExtraFields_TorchOp) 604 return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs) 605 606 607def report_all_anti_patterns( 608 prof, 609 should_benchmark: bool = False, 610 print_enable: bool = True, 611 json_report_dir: Optional[str] = None, 612): 613 report_dict: Dict = {} 614 anti_patterns = [ 615 ExtraCUDACopyPattern(prof, should_benchmark), 616 # ForLoopIndexingPattern(prof, should_benchmark), 617 FP32MatMulPattern(prof, should_benchmark), 618 OptimizerSingleTensorPattern(prof, should_benchmark), 619 SynchronizedDataLoaderPattern(prof, should_benchmark), 620 GradNotSetToNonePattern(prof, should_benchmark), 621 Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark), 622 MatMulDimInFP16Pattern(prof, should_benchmark), 623 ] 624 reported = set() 625 summaries = [] 626 message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"] 627 message_list.append("Matched Events:") 628 629 for anti_pattern in anti_patterns: 630 matched_events = anti_pattern.matched_events() 631 if not matched_events: 632 continue 633 summaries.append(anti_pattern.summary(matched_events)) 634 for event in matched_events: 635 report_msg = anti_pattern.report(event) 636 if report_msg not in reported: 637 message_list.append(report_msg) 638 reported.add(report_msg) 639 src_location, line_no = source_code_location(event).split(":") 640 report_dict.setdefault(src_location, []).append( 641 { 642 "line_number": int(line_no), 643 "name": anti_pattern.name, 644 "url": anti_pattern.url, 645 "message": anti_pattern.description, 646 } 647 ) 648 649 if json_report_dir is not None: 650 json_report_path = os.path.join(json_report_dir, "torchtidy_report.json") 651 if os.path.exists(json_report_path): 652 with open(json_report_path) as f: 653 exisiting_report = json.load(f) 654 exisiting_report.update(report_dict) 655 report_dict = exisiting_report 656 with open(json_report_path, "w") as f: 657 json.dump(report_dict, f, indent=4) 658 659 message_list.append("Summary:") 660 message_list += summaries 661 message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}") 662 if print_enable: 663 print("\n".join(message_list)) 664