1# mypy: allow-untyped-defs 2import bisect 3import itertools 4import math 5from collections import defaultdict, namedtuple 6from operator import attrgetter 7from typing import Any, Dict, List, Optional, Tuple 8from typing_extensions import deprecated 9 10import torch 11from torch.autograd import DeviceType 12 13 14__all__ = [ 15 "EventList", 16 "FormattedTimesMixin", 17 "Interval", 18 "Kernel", 19 "FunctionEvent", 20 "FunctionEventAvg", 21 "StringTable", 22 "MemRecordsAcc", 23] 24 25 26class EventList(list): 27 """A list of Events (for pretty printing).""" 28 29 def __init__(self, *args, **kwargs): 30 use_device = kwargs.pop("use_device", None) 31 profile_memory = kwargs.pop("profile_memory", False) 32 with_flops = kwargs.pop("with_flops", False) 33 super().__init__(*args, **kwargs) 34 self._use_device = use_device 35 self._profile_memory = profile_memory 36 self._tree_built = False 37 self._with_flops = with_flops 38 39 def _build_tree(self): 40 self._populate_cpu_children() 41 self._remove_dup_nodes() 42 self._set_backward_stacktraces() 43 self._tree_built = True 44 45 def __str__(self): 46 return self.table() 47 48 def _remove_dup_nodes(self): 49 while True: 50 to_delete = set() 51 for idx in range(len(self)): 52 if ( 53 self[idx].cpu_parent is not None 54 and self[idx].cpu_parent.name == self[idx].name 55 and len(self[idx].cpu_parent.cpu_children) == 1 56 ): 57 self[idx].cpu_parent.cpu_children = self[idx].cpu_children 58 self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up 59 for ch in self[idx].cpu_children: 60 ch.cpu_parent = self[idx].cpu_parent 61 to_delete.add(idx) 62 if len(to_delete) == 0: 63 break 64 new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete] 65 self.clear() 66 self.extend(new_evts) 67 68 def _populate_cpu_children(self): 69 """Populate child events into each underlying FunctionEvent object. 70 71 One event is a child of another if [s1, e1) is inside [s2, e2). Where 72 s1 and e1 would be start and end of the child event's interval. And 73 s2 and e2 start and end of the parent event's interval 74 75 Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10] 76 be a parent of two other intervals. 77 78 If for any reason two intervals intersect only partially, this function 79 will not record a parent child relationship between then. 80 """ 81 # Some events can be async (i.e. start and end on different threads), 82 # since it's generally undefined how to attribute children ranges to 83 # async ranges, we do not use them when calculating nested ranges and stats 84 sync_events = [ 85 evt 86 for evt in self 87 if not evt.is_async and evt.device_type == DeviceType.CPU 88 ] 89 events = sorted( 90 sync_events, 91 key=attrgetter("thread"), 92 ) 93 # Group by both thread and node_id, so that events that happen to have 94 # the same thread_id but are from different nodes aren't incorrectly 95 # grouped together. 96 threads = itertools.groupby( 97 events, key=lambda event: (event.thread, event.node_id) 98 ) 99 100 # For each thread we keep a stack of current nested parents. 101 # We maintain the invariant that each interval is a subset of all other 102 # intervals lower in the stack. 103 # 104 # First we sort the intervals by their start time. Then we iterate over them. 105 # Every time we see a new interval we remove several parents from 106 # the top until we restore the invariant. Then parent child relationship 107 # if recorded if the stack is not empty. 108 # Finally we add new interval to the list 109 # 110 # Algorithm has O(N * log(N)) complexity where N is number of 111 # intervals 112 for thread_id, thread_events in threads: 113 thread_events_ = sorted( 114 thread_events, 115 key=lambda event: [event.time_range.start, -event.time_range.end], 116 ) 117 current_events: List[FunctionEvent] = [] 118 cur_end = 0 119 for event in thread_events_: 120 while len(current_events) > 0: 121 parent = current_events[-1] 122 if ( 123 event.time_range.start >= parent.time_range.end 124 or event.time_range.end > parent.time_range.end 125 ): 126 # this can't be a parent 127 current_events.pop() 128 else: 129 parent.append_cpu_child(event) 130 assert ( 131 event.cpu_parent is None 132 ), f"There is already a CPU parent event for {event.key}" 133 event.set_cpu_parent(parent) 134 break 135 136 current_events.append(event) 137 138 def _set_backward_stacktraces(self): 139 def bw_parent(evt): 140 if evt is None: 141 return None 142 elif evt.scope == 1: # BACKWARD_FUNCTION 143 return evt 144 else: 145 return bw_parent(evt.cpu_parent) 146 147 fwd_stacks = {} 148 for evt in self: 149 if bw_parent(evt) is None and evt.stack is not None: 150 t = (evt.sequence_nr, evt.thread) 151 if t not in fwd_stacks: 152 fwd_stacks[t] = evt.stack 153 154 for evt in self: 155 p = bw_parent(evt) 156 if p is not None: 157 assert p.fwd_thread is not None 158 t = (p.sequence_nr, p.fwd_thread) 159 if t in fwd_stacks: 160 evt.stack = fwd_stacks[t] 161 else: 162 evt.stack = [] 163 164 @property 165 def self_cpu_time_total(self): 166 return sum(event.self_cpu_time_total for event in self) 167 168 def table( 169 self, 170 sort_by=None, 171 row_limit=100, 172 max_src_column_width=75, 173 max_name_column_width=55, 174 max_shapes_column_width=80, 175 header=None, 176 top_level_events_only=False, 177 ): 178 """Print an EventList as a nicely formatted table. 179 180 Args: 181 sort_by (str, optional): Attribute used to sort entries. By default 182 they are printed in the same order as they were registered. 183 Valid keys include: ``cpu_time``, ``cuda_time``, ``xpu_time``, 184 ``cpu_time_total``, ``cuda_time_total``, ``xpu_time_total``, 185 ``cpu_memory_usage``, ``cuda_memory_usage``, ``xpu_memory_usage``, 186 ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, 187 ``self_xpu_memory_usage``, ``count``. 188 top_level_events_only(bool, optional): Boolean flag to determine the 189 selection of events to display. If true, the profiler will only 190 display events at top level like top-level invocation of python 191 `lstm`, python `add` or other functions, nested events like low-level 192 cpu/cuda/xpu ops events are omitted for profiler result readability. 193 194 Returns: 195 A string containing the table. 196 """ 197 return _build_table( 198 self, 199 sort_by=sort_by, 200 row_limit=row_limit, 201 max_src_column_width=max_src_column_width, 202 max_name_column_width=max_name_column_width, 203 max_shapes_column_width=max_shapes_column_width, 204 header=header, 205 profile_memory=self._profile_memory, 206 with_flops=self._with_flops, 207 top_level_events_only=top_level_events_only, 208 ) 209 210 def export_chrome_trace(self, path): 211 """Export an EventList as a Chrome tracing tools file. 212 213 The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL. 214 215 Args: 216 path (str): Path where the trace will be written. 217 """ 218 import os 219 220 device_name = "cuda" if not self._use_device else self._use_device 221 with open(path, "w") as f: 222 chrome_events = [] 223 next_id = 0 224 # Use file IO over using json.dump since JSON dumping is very slow and 225 # this technique is proven to give a 4x speedup. 226 f.write("[") 227 for evt in self: 228 if evt.trace_name is None: 229 continue 230 f.write( 231 '{{"name": "{}", ' 232 '"ph": "X", ' 233 '"ts": {}, ' 234 '"dur": {}, ' 235 '"tid": {}, ' 236 '"pid": "CPU functions", ' 237 '"args": {{}}}}, '.format( 238 evt.trace_name, 239 evt.time_range.start, 240 evt.time_range.elapsed_us(), 241 evt.thread 242 if not evt.is_remote 243 else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "', 244 ) 245 ) 246 for k in evt.kernels: 247 # 's' and 'f' draw Flow arrows from 248 # the CPU launch to the GPU kernel 249 f.write( 250 f'{{"name": "{evt.trace_name}", ' 251 '"ph": "s", ' 252 f'"ts": {evt.time_range.start}, ' 253 f'"tid": {evt.thread}, ' 254 '"pid": "CPU functions", ' 255 f'"id": {next_id}, ' 256 f'"cat": "cpu_to_{device_name}", ' 257 '"args": {}}, ' 258 ) 259 # Note: use torch.profiler to get device kernel trace 260 next_id += 1 261 if len(self) > 0: 262 # remove trailing whitespace and comma 263 f.seek(f.tell() - 2, os.SEEK_SET) 264 f.truncate() 265 f.write("]") 266 267 def supported_export_stacks_metrics(self): 268 return [ 269 "self_cpu_time_total", 270 "self_cuda_time_total", 271 "self_xpu_time_total", 272 "self_privateuse1_time_total", 273 ] 274 275 def export_stacks(self, path: str, metric: str): 276 if metric not in self.supported_export_stacks_metrics(): 277 raise ValueError( 278 "metric should be one of: " 279 + str(self.supported_export_stacks_metrics()) 280 ) 281 translate_table = str.maketrans(" ;\t\n", "____") 282 with open(path, "w") as f: 283 for evt in self: 284 if evt.stack and len(evt.stack) > 0: 285 metric_value = getattr( 286 evt, 287 metric.replace("cuda", "device") 288 .replace("xpu", "device") 289 .replace("privateuse1", "device"), 290 ) 291 if int(metric_value) > 0: 292 stack_str = "" 293 for entry in reversed(evt.stack): 294 stack_str += entry.translate(translate_table) 295 stack_str += ";" 296 stack_str = stack_str[:-1] + " " + str(int(metric_value)) 297 f.write(stack_str + "\n") 298 299 def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0): 300 """Averages all function events over their keys. 301 302 Args: 303 group_by_input_shapes: group entries by 304 (event name, input shapes) rather than just event name. 305 This is useful to see which input shapes contribute to the runtime 306 the most and may help with size-specific optimizations or 307 choosing the best candidates for quantization (aka fitting a roof line) 308 309 group_by_stack_n: group by top n stack trace entries 310 311 Returns: 312 An EventList containing FunctionEventAvg objects. 313 """ 314 assert self._tree_built 315 stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) 316 317 def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]: 318 key = [ 319 str(event.key), 320 str(event.node_id), 321 str(event.device_type), 322 str(event.is_legacy), 323 str(event.is_user_annotation), 324 ] 325 if group_by_input_shapes: 326 key.append(str(event.input_shapes)) 327 if group_by_stack_n > 0: 328 key += event.stack[:group_by_stack_n] 329 return tuple(key) 330 331 for evt in self: 332 stats[get_key(evt, group_by_input_shapes, group_by_stack_n)].add(evt) 333 334 avg_list = EventList( 335 stats.values(), 336 use_device=self._use_device, 337 profile_memory=self._profile_memory, 338 with_flops=self._with_flops, 339 ) 340 for evt in avg_list: 341 evt.stack = evt.stack[:group_by_stack_n] 342 if not group_by_input_shapes: 343 evt.input_shapes = "" 344 return avg_list 345 346 def total_average(self): 347 """Averages all events. 348 349 Returns: 350 A FunctionEventAvg object. 351 """ 352 total_stat = FunctionEventAvg() 353 for evt in self: 354 total_stat += evt 355 total_stat.key = None 356 total_stat.key = "Total" 357 return total_stat 358 359 360def _format_time(time_us): 361 """Define how to format time in FunctionEvent.""" 362 US_IN_SECOND = 1000.0 * 1000.0 363 US_IN_MS = 1000.0 364 if time_us >= US_IN_SECOND: 365 return f"{time_us / US_IN_SECOND:.3f}s" 366 if time_us >= US_IN_MS: 367 return f"{time_us / US_IN_MS:.3f}ms" 368 return f"{time_us:.3f}us" 369 370 371def _format_time_share(time_us, total_time_us): 372 """Define how to format time in FunctionEvent.""" 373 if total_time_us == 0: 374 assert time_us == 0, f"Expected time_us == 0 but got {time_us}" 375 return "NaN" 376 return f"{time_us * 100.0 / total_time_us:.2f}%" 377 378 379def _format_memory(nbytes): 380 """Return a formatted memory size string.""" 381 KB = 1024 382 MB = 1024 * KB 383 GB = 1024 * MB 384 if abs(nbytes) >= GB: 385 return f"{nbytes * 1.0 / GB:.2f} Gb" 386 elif abs(nbytes) >= MB: 387 return f"{nbytes * 1.0 / MB:.2f} Mb" 388 elif abs(nbytes) >= KB: 389 return f"{nbytes * 1.0 / KB:.2f} Kb" 390 else: 391 return str(nbytes) + " b" 392 393 394def _attr_formatter(name): 395 return property(lambda self: _format_time(getattr(self, name))) 396 397 398class FormattedTimesMixin: 399 """Helpers for FunctionEvent and FunctionEventAvg. 400 401 The subclass should define `*_time_total` and `count` attributes. 402 """ 403 404 cpu_time_str = _attr_formatter("cpu_time") 405 device_time_str = _attr_formatter("device_time") 406 cpu_time_total_str = _attr_formatter("cpu_time_total") 407 device_time_total_str = _attr_formatter("device_time_total") 408 self_cpu_time_total_str = _attr_formatter("self_cpu_time_total") 409 self_device_time_total_str = _attr_formatter("self_device_time_total") 410 411 @property 412 def cpu_time(self): 413 return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore[attr-defined] 414 415 @property 416 def device_time(self): 417 return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count # type: ignore[attr-defined] 418 419 @property 420 @deprecated( 421 "`cuda_time` is deprecated, please use `device_time` instead.", 422 category=FutureWarning, 423 ) 424 def cuda_time(self): # To be deprecated 425 return self.device_time 426 427 428class Interval: 429 def __init__(self, start, end): 430 self.start = start 431 self.end = end 432 433 def elapsed_us(self): 434 r""" 435 Returns the length of the interval 436 """ 437 return self.end - self.start 438 439 440Kernel = namedtuple("Kernel", ["name", "device", "duration"]) 441 442 443class FunctionEvent(FormattedTimesMixin): 444 """Profiling information about a single function.""" 445 446 def __init__( 447 self, 448 id, 449 name, 450 thread, 451 start_us, 452 end_us, 453 fwd_thread=None, 454 input_shapes=None, 455 stack=None, 456 scope=0, 457 use_device=None, 458 cpu_memory_usage=0, 459 device_memory_usage=0, 460 is_async=False, 461 is_remote=False, 462 sequence_nr=-1, 463 node_id=-1, 464 device_type=DeviceType.CPU, 465 device_index=0, 466 device_resource_id=None, 467 is_legacy=False, 468 flops=None, 469 trace_name=None, 470 concrete_inputs=None, 471 kwinputs=None, 472 is_user_annotation=False, 473 ): 474 self.id: int = id 475 self.node_id: int = node_id 476 self.name: str = name 477 self.trace_name: str = trace_name 478 self.time_range: Interval = Interval(start_us, end_us) 479 self.thread: int = thread 480 self.fwd_thread: Optional[int] = fwd_thread 481 self.kernels: List[Kernel] = [] 482 self.count: int = 1 483 self.cpu_children: List[FunctionEvent] = [] 484 self.cpu_parent: Optional[FunctionEvent] = None 485 self.input_shapes: Tuple[int, ...] = input_shapes 486 self.concrete_inputs: List[Any] = concrete_inputs 487 self.kwinputs: Dict[str, Any] = kwinputs 488 self.stack: List = stack 489 self.scope: int = scope 490 self.use_device: Optional[str] = use_device 491 self.cpu_memory_usage: int = cpu_memory_usage 492 self.device_memory_usage: int = device_memory_usage 493 self.is_async: bool = is_async 494 self.is_remote: bool = is_remote 495 self.sequence_nr: int = sequence_nr 496 self.device_type: DeviceType = device_type 497 self.device_index: int = device_index 498 self.device_resource_id: int = ( 499 thread if device_resource_id is None else device_resource_id 500 ) 501 self.is_legacy: bool = is_legacy 502 self.flops: Optional[int] = flops 503 self.is_user_annotation: Optional[bool] = is_user_annotation 504 self.self_cpu_percent = -1 505 self.total_cpu_percent = -1 506 self.total_device_percent = -1 507 508 def append_kernel(self, name, device, duration): 509 assert self.device_type == DeviceType.CPU 510 self.kernels.append(Kernel(name, device, duration)) 511 512 def append_cpu_child(self, child): 513 """Append a CPU child of type FunctionEvent. 514 515 One is supposed to append only direct children to the event to have 516 correct self cpu time being reported. 517 """ 518 assert self.device_type == DeviceType.CPU 519 assert isinstance(child, FunctionEvent) 520 assert child.device_type == DeviceType.CPU 521 self.cpu_children.append(child) 522 523 def set_cpu_parent(self, parent): 524 """Set the immediate CPU parent of type FunctionEvent. 525 526 One profiling FunctionEvent should have only one CPU parent such that 527 the child's range interval is completely inside the parent's. We use 528 this connection to determine the event is from top-level op or not. 529 """ 530 assert self.device_type == DeviceType.CPU 531 assert isinstance(parent, FunctionEvent) 532 assert parent.device_type == DeviceType.CPU 533 self.cpu_parent = parent 534 535 # Note: async events don't have children, are not used when computing 'self' 536 # metrics of other events, have only total cpu time 537 @property 538 def self_cpu_memory_usage(self): 539 if self.is_async or self.device_type != DeviceType.CPU: 540 return 0 541 return self.cpu_memory_usage - sum( 542 child.cpu_memory_usage for child in self.cpu_children 543 ) 544 545 @property 546 def self_device_memory_usage(self): 547 if self.is_async or self.device_type != DeviceType.CPU: 548 return 0 549 return self.device_memory_usage - sum( 550 child.device_memory_usage for child in self.cpu_children 551 ) 552 553 @property 554 @deprecated( 555 "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.", 556 category=FutureWarning, 557 ) 558 def self_cuda_memory_usage(self): # To be deprecated 559 return self.self_device_memory_usage 560 561 @property 562 def cpu_time_total(self): 563 if self.device_type == DeviceType.CPU: 564 return self.time_range.elapsed_us() 565 else: 566 return 0 567 568 @property 569 def self_cpu_time_total(self): 570 if self.is_async or self.device_type != DeviceType.CPU: 571 return 0 572 return self.cpu_time_total - sum( 573 child.cpu_time_total for child in self.cpu_children 574 ) 575 576 @property 577 def device_time_total(self): 578 if self.is_async or not self.use_device: 579 return 0 580 if self.device_type == DeviceType.CPU: 581 if not self.is_legacy: 582 # account for the kernels in the children ops 583 return sum(kinfo.duration for kinfo in self.kernels) + sum( 584 ch.device_time_total for ch in self.cpu_children 585 ) 586 else: 587 # each legacy cpu events has a single (fake) kernel 588 return sum(kinfo.duration for kinfo in self.kernels) 589 else: 590 assert self.device_type in [ 591 DeviceType.CUDA, 592 DeviceType.PrivateUse1, 593 DeviceType.MTIA, 594 ] 595 return self.time_range.elapsed_us() 596 597 @property 598 @deprecated( 599 "`cuda_time_total` is deprecated. Use `device_time_total` instead.", 600 category=FutureWarning, 601 ) 602 def cuda_time_total(self): # To be deprecated 603 return self.device_time_total 604 605 @property 606 def self_device_time_total(self): 607 if self.is_async or not self.use_device: 608 return 0 609 if self.device_type == DeviceType.CPU: 610 return self.device_time_total - sum( 611 child.device_time_total for child in self.cpu_children 612 ) 613 else: 614 assert self.device_type in [ 615 DeviceType.CUDA, 616 DeviceType.PrivateUse1, 617 DeviceType.MTIA, 618 ] 619 return self.device_time_total 620 621 @property 622 @deprecated( 623 "`self_cuda_time_total` is deprecated. Use `self_device_time_total` instead.", 624 category=FutureWarning, 625 ) 626 def self_cuda_time_total(self): # To be deprecated 627 return self.self_device_time_total 628 629 @property 630 def key(self): 631 return self.name 632 633 def __repr__(self): 634 device_name = self.use_device 635 device_time = self.device_time_str 636 device_memory_usage = self.device_memory_usage 637 return ( 638 f"<FunctionEvent id={self.id} name={self.name} device_type={self.device_type} node_id={self.node_id} " 639 f"cpu_time={self.cpu_time_str} start_us={self.time_range.start} end_us={self.time_range.end} " 640 f"cpu_children={str([child.id for child in self.cpu_children])} {device_name}_time={device_time} " 641 f"name={self.name} thread={self.thread} input_shapes={str(self.input_shapes)} " 642 f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory_usage} " 643 f"is_async={self.is_async} is_remote={self.is_remote} seq_nr={self.sequence_nr} is_legacy={self.is_legacy}>" 644 ) 645 646 647class FunctionEventAvg(FormattedTimesMixin): 648 """Used to average stats over multiple FunctionEvent objects.""" 649 650 def __init__(self) -> None: 651 self.key: Optional[str] = None 652 self.count: int = 0 653 self.node_id: int = 0 654 self.is_async: bool = False 655 self.is_remote: bool = False 656 self.use_device: Optional[str] = None 657 self.cpu_time_total: int = 0 658 self.device_time_total: int = 0 659 self.self_cpu_time_total: int = 0 660 self.self_device_time_total: int = 0 661 self.input_shapes: Optional[List[List[int]]] = None 662 self.stack: Optional[List] = None 663 self.scope: Optional[int] = None 664 self.cpu_memory_usage: int = 0 665 self.device_memory_usage: int = 0 666 self.self_cpu_memory_usage: int = 0 667 self.self_device_memory_usage: int = 0 668 self.cpu_children: Optional[List[FunctionEvent]] = None 669 self.cpu_parent: Optional[FunctionEvent] = None 670 self.device_type: DeviceType = DeviceType.CPU 671 self.is_legacy: bool = False 672 self.flops: int = 0 673 674 def add(self, other): 675 if self.key is None: 676 # First function being recorded as part of FunctionEventAvg, propagate 677 # fields. 678 self.key = other.key 679 self.node_id = other.node_id 680 self.is_async = other.is_async 681 self.is_remote = other.is_remote 682 self.cpu_parent = other.cpu_parent 683 self.cpu_children = other.cpu_children 684 685 self.input_shapes = other.input_shapes 686 self.stack = other.stack 687 self.scope = other.scope 688 self.device_type = other.device_type 689 self.is_legacy = other.is_legacy 690 self.use_device = other.use_device 691 self.is_user_annotation = other.is_user_annotation 692 693 assert isinstance(other, (FunctionEvent, FunctionEventAvg)) 694 assert other.key == self.key 695 self.cpu_time_total += other.cpu_time_total 696 self.device_time_total += other.device_time_total 697 self.self_cpu_time_total += other.self_cpu_time_total 698 self.self_device_time_total += other.self_device_time_total 699 self.cpu_memory_usage += other.cpu_memory_usage 700 self.device_memory_usage += other.device_memory_usage 701 self.self_cpu_memory_usage += other.self_cpu_memory_usage 702 self.self_device_memory_usage += other.self_device_memory_usage 703 self.count += other.count 704 if self.flops is None: 705 self.flops = other.flops 706 elif other.flops is not None: 707 self.flops += other.flops 708 return self 709 710 def __iadd__(self, other): 711 return self.add(other) 712 713 def __repr__(self): 714 device_name = "cuda" if not self.use_device else self.use_device 715 self_device_time = self.self_device_time_total_str 716 device_time = self.device_time_str 717 device_memory = self.device_memory_usage 718 return ( 719 f"<FunctionEventAvg key={self.key} self_cpu_time={self.self_cpu_time_total_str} cpu_time={self.cpu_time_str} " 720 f" self_{device_name}_time={self_device_time} {device_name}_time={device_time} input_shapes={str(self.input_shapes)} " 721 f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory}>" 722 ) 723 724 725class StringTable(defaultdict): 726 def __missing__(self, key): 727 # manage cases like 't' (demangled to 'unsigned short') separately, 728 # for now simply check the length to avoid unexpected results for 729 # the short sequences 730 self[key] = torch._C._demangle(key) if len(key) > 1 else key 731 return self[key] 732 733 734class MemRecordsAcc: 735 """Acceleration structure for accessing mem_records in interval.""" 736 737 def __init__(self, mem_records): 738 self._mem_records = mem_records 739 self._start_nses: List[int] = [] 740 self._indices: List[int] = [] 741 if len(mem_records) > 0: 742 tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)]) 743 self._start_nses, self._indices = zip(*tmp) # type: ignore[assignment] 744 745 def in_interval(self, start_us, end_us): 746 r""" 747 Return all records in the given interval 748 To maintain backward compatibility, convert us to ns in function 749 """ 750 start_idx = bisect.bisect_left(self._start_nses, start_us * 1000) 751 end_idx = bisect.bisect_right(self._start_nses, end_us * 1000) 752 for i in range(start_idx, end_idx): 753 yield self._mem_records[self._indices[i]] 754 755 756def _filter_stack_entry(entry): 757 filtered_entries = [ 758 ("autograd/__init__", "_make_grads"), 759 ("autograd/__init__", "backward"), 760 ("torch/tensor", "backward"), 761 ("_internal/common_utils", "prof_callable"), 762 ("_internal/common_utils", "prof_func_call"), 763 ("_internal/common_utils", "prof_meth_call"), 764 ] 765 return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries) 766 767 768MEMORY_EVENT_NAME = "[memory]" 769OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]" 770 771 772def _filter_name(name): 773 # ignoring the following utility ops 774 filtered_out_names = [ 775 MEMORY_EVENT_NAME, # used only for the top-level memory events 776 OUT_OF_MEMORY_EVENT_NAME, 777 "profiler::_record_function_enter", 778 "profiler::_record_function_enter_new", 779 "profiler::_record_function_exit", 780 "aten::is_leaf", 781 "aten::output_nr", 782 "aten::_version", 783 ] 784 return name in filtered_out_names 785 786 787# Demangles and optionally rewrites the provided event name, 788# with_wildcard - whether to replace certain numbered event names 789# with a wildcard name to aggregate them together in the profiler table 790# output 791def _rewrite_name(name, with_wildcard=False): 792 string_table = StringTable() 793 name = string_table[name] 794 if with_wildcard: 795 if name.startswith("ProfilerStep#"): 796 name = "ProfilerStep*" 797 return name 798 799 800def _build_table( 801 events, 802 sort_by=None, 803 header=None, 804 row_limit=100, 805 max_src_column_width=75, 806 max_name_column_width=55, 807 max_shapes_column_width=80, 808 with_flops=False, 809 profile_memory=False, 810 top_level_events_only=False, 811): 812 """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg).""" 813 if len(events) == 0: 814 return "" 815 816 has_device_time = any(event.self_device_time_total > 0 for event in events) 817 has_device_mem = any(event.self_device_memory_usage > 0 for event in events) 818 use_device = events[0].use_device 819 # Running on PrivateUse1 device with profiler but not enable 820 # ProfilerActivity.PrivateUse1 can also catch privateuse1 memory usage. 821 # Here only need to check has_privateuse1_time if not use_device. 822 if not use_device and has_device_time: 823 raise RuntimeError("use_device is None, but there is device performance data.") 824 825 has_input_shapes = any( 826 (event.input_shapes is not None and len(event.input_shapes) > 0) 827 for event in events 828 ) 829 830 if sort_by is not None: 831 events = EventList( 832 sorted( 833 events, 834 key=lambda evt: getattr( 835 evt, 836 sort_by.replace("cuda", "device") 837 .replace("xpu", "device") 838 .replace("privateuse1", "device"), 839 ), 840 reverse=True, 841 ), 842 use_device=use_device, 843 profile_memory=profile_memory, 844 with_flops=with_flops, 845 ) 846 847 name_column_width = max(len(evt.key) for evt in events) + 4 848 if max_name_column_width is not None: 849 name_column_width = min(name_column_width, max_name_column_width) 850 851 shapes_column_width = max(len(str(evt.input_shapes)) for evt in events) + 4 852 if max_shapes_column_width is not None: 853 shapes_column_width = min(shapes_column_width, max_shapes_column_width) 854 855 DEFAULT_COLUMN_WIDTH = 12 856 flops_column_width = DEFAULT_COLUMN_WIDTH 857 858 src_column_width = None 859 stacks = [] 860 for evt in events: 861 if evt.stack is not None and len(evt.stack) > 0: 862 stacks.append(evt.stack) 863 has_stack = len(stacks) > 0 864 if has_stack: 865 src_column_width = ( 866 max(max(len(entry) for entry in stack) for stack in stacks) + 4 867 ) 868 if max_src_column_width is not None: 869 src_column_width = min(src_column_width, max_src_column_width) 870 871 headers = [ 872 "Name", 873 "Self CPU %", 874 "Self CPU", 875 "CPU total %", 876 "CPU total", 877 "CPU time avg", 878 ] 879 device_name = use_device.upper() if use_device is not None else "None" 880 if has_device_time: 881 headers.extend( 882 [ 883 f"Self {device_name}", 884 f"Self {device_name} %", 885 f"{device_name} total", 886 f"{device_name} time avg", 887 ] 888 ) 889 if profile_memory: 890 headers.extend( 891 [ 892 "CPU Mem", 893 "Self CPU Mem", 894 ] 895 ) 896 if use_device and has_device_mem: 897 headers.extend( 898 [ 899 f"{device_name} Mem", 900 f"Self {device_name} Mem", 901 ] 902 ) 903 headers.append("# of Calls") 904 # Only append Node ID if any event has a valid (>= 0) Node ID 905 append_node_id = any(evt.node_id != -1 for evt in events) 906 if append_node_id: 907 headers.append("Node ID") 908 909 # Have to use a list because nonlocal is Py3 only... 910 SPACING_SIZE = 2 911 row_format_lst = [""] 912 header_sep_lst = [""] 913 line_length_lst = [-SPACING_SIZE] 914 915 def add_column(padding, text_dir=">"): 916 row_format_lst[0] += ( 917 "{: " + text_dir + str(padding) + "}" + (" " * SPACING_SIZE) 918 ) 919 header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE) 920 line_length_lst[0] += padding + SPACING_SIZE 921 922 def auto_scale_flops(flops): 923 flop_headers = [ 924 "FLOPs", 925 "KFLOPs", 926 "MFLOPs", 927 "GFLOPs", 928 "TFLOPs", 929 "PFLOPs", 930 ] 931 assert flops > 0 932 log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1))) 933 assert log_flops >= 0 and log_flops < len(flop_headers) 934 return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)]) 935 936 add_column(name_column_width) 937 for _ in headers[1:]: 938 add_column(DEFAULT_COLUMN_WIDTH) 939 940 if has_input_shapes: 941 headers.append("Input Shapes") 942 add_column(shapes_column_width) 943 944 if has_stack: 945 headers.append("Source Location") 946 add_column(src_column_width, text_dir="<") 947 948 if with_flops: 949 # Auto-scaling of flops header 950 raw_flops = [] 951 for evt in events: 952 if evt.flops > 0: 953 raw_flops.append(evt.flops) 954 if len(raw_flops) != 0: 955 (flops_scale, flops_header) = auto_scale_flops(min(raw_flops)) 956 headers.append(f"Total {flops_header}") 957 add_column(flops_column_width) 958 else: 959 with_flops = False # can't find any valid flops 960 961 row_format = row_format_lst[0] 962 header_sep = header_sep_lst[0] 963 line_length = line_length_lst[0] 964 add_column = None # type: ignore[assignment] 965 966 # Have to use a list because nonlocal is Py3 only... 967 result = [] 968 969 def append(s): 970 result.append(s) 971 result.append("\n") # Yes, newline after the end as well 972 973 sum_self_cpu_time_total = 0 974 sum_self_device_time_total = 0 975 for evt in events: 976 sum_self_cpu_time_total += evt.self_cpu_time_total 977 if evt.device_type == DeviceType.CPU and evt.is_legacy: 978 # in legacy profiler, kernel info is stored in cpu events 979 sum_self_device_time_total += evt.self_device_time_total 980 elif ( 981 evt.device_type 982 in [ 983 DeviceType.CUDA, 984 DeviceType.PrivateUse1, 985 DeviceType.MTIA, 986 ] 987 and not evt.is_user_annotation 988 ): 989 # in kineto profiler, there're events with the correct device type (e.g. CUDA) 990 sum_self_device_time_total += evt.self_device_time_total 991 992 # Actual printing 993 if header is not None: 994 append("=" * line_length) 995 append(header) 996 if top_level_events_only: 997 append("=" * line_length) 998 append("This report only display top-level ops statistics") 999 append(header_sep) 1000 append(row_format.format(*headers)) 1001 1002 append(header_sep) 1003 1004 def trim_path(path, src_column_width): 1005 if len(path) > src_column_width: 1006 offset = len(path) - src_column_width 1007 path = path[offset:] 1008 if len(path) > 3: 1009 path = "..." + path[3:] 1010 return path 1011 1012 event_limit = 0 1013 for evt in events: 1014 if event_limit == row_limit: 1015 break 1016 if top_level_events_only and evt.cpu_parent is not None: 1017 continue 1018 else: 1019 event_limit += 1 1020 name = evt.key 1021 if max_name_column_width is not None and len(name) >= max_name_column_width - 3: 1022 name = name[: (max_name_column_width - 3)] + "..." 1023 evt.self_cpu_percent = _format_time_share( 1024 evt.self_cpu_time_total, sum_self_cpu_time_total 1025 ) 1026 evt.total_cpu_percent = ( 1027 _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) 1028 if not evt.is_async 1029 else 0 1030 ) 1031 row_values = [ 1032 name, 1033 # Self CPU total %, 0 for async events. 1034 evt.self_cpu_percent, 1035 evt.self_cpu_time_total_str, # Self CPU total 1036 # CPU total %, 0 for async events. 1037 evt.total_cpu_percent, 1038 evt.cpu_time_total_str, # CPU total 1039 evt.cpu_time_str, # CPU time avg 1040 ] 1041 if has_device_time: 1042 evt.total_device_percent = _format_time_share( 1043 evt.self_device_time_total, sum_self_device_time_total 1044 ) 1045 row_values.extend( 1046 [ 1047 evt.self_device_time_total_str, 1048 # device time total % 1049 evt.total_device_percent, 1050 evt.device_time_total_str, 1051 evt.device_time_str, # device time avg 1052 ] 1053 ) 1054 if profile_memory: 1055 row_values.extend( 1056 [ 1057 # CPU Mem Total 1058 _format_memory(evt.cpu_memory_usage), 1059 # Self CPU Mem Total 1060 _format_memory(evt.self_cpu_memory_usage), 1061 ] 1062 ) 1063 if use_device and has_device_mem: 1064 row_values.extend( 1065 [ 1066 # Device Mem Total 1067 _format_memory(evt.device_memory_usage), 1068 # Self Device Mem Total 1069 _format_memory(evt.self_device_memory_usage), 1070 ] 1071 ) 1072 row_values.append( 1073 evt.count, # Number of calls 1074 ) 1075 1076 if append_node_id: 1077 row_values.append(evt.node_id) 1078 if has_input_shapes: 1079 row_values.append(str(evt.input_shapes)[:shapes_column_width]) 1080 if with_flops: 1081 if evt.flops <= 0: 1082 row_values.append("--") 1083 else: 1084 row_values.append(f"{evt.flops * flops_scale:8.3f}") # type: ignore[possibly-undefined] 1085 if has_stack: 1086 src_field = "" 1087 if len(evt.stack) > 0: 1088 src_field = trim_path(evt.stack[0], src_column_width) 1089 row_values.append(src_field) 1090 append(row_format.format(*row_values)) 1091 1092 if has_stack: 1093 empty_headers = [""] * (len(headers) - 1) 1094 for entry in evt.stack[1:]: 1095 append( 1096 row_format.format( 1097 *(empty_headers + [trim_path(entry, src_column_width)]) 1098 ) 1099 ) 1100 empty_headers.append("") 1101 append(row_format.format(*empty_headers)) 1102 1103 append(header_sep) 1104 append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}") 1105 if has_device_time: 1106 append( 1107 f"Self {use_device.upper() if use_device is not None else 'None'} " 1108 f"time total: {_format_time(sum_self_device_time_total)}" 1109 ) 1110 return "".join(result) 1111