xref: /aosp_15_r20/external/pytorch/torch/autograd/profiler_util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import bisect
3import itertools
4import math
5from collections import defaultdict, namedtuple
6from operator import attrgetter
7from typing import Any, Dict, List, Optional, Tuple
8from typing_extensions import deprecated
9
10import torch
11from torch.autograd import DeviceType
12
13
14__all__ = [
15    "EventList",
16    "FormattedTimesMixin",
17    "Interval",
18    "Kernel",
19    "FunctionEvent",
20    "FunctionEventAvg",
21    "StringTable",
22    "MemRecordsAcc",
23]
24
25
26class EventList(list):
27    """A list of Events (for pretty printing)."""
28
29    def __init__(self, *args, **kwargs):
30        use_device = kwargs.pop("use_device", None)
31        profile_memory = kwargs.pop("profile_memory", False)
32        with_flops = kwargs.pop("with_flops", False)
33        super().__init__(*args, **kwargs)
34        self._use_device = use_device
35        self._profile_memory = profile_memory
36        self._tree_built = False
37        self._with_flops = with_flops
38
39    def _build_tree(self):
40        self._populate_cpu_children()
41        self._remove_dup_nodes()
42        self._set_backward_stacktraces()
43        self._tree_built = True
44
45    def __str__(self):
46        return self.table()
47
48    def _remove_dup_nodes(self):
49        while True:
50            to_delete = set()
51            for idx in range(len(self)):
52                if (
53                    self[idx].cpu_parent is not None
54                    and self[idx].cpu_parent.name == self[idx].name
55                    and len(self[idx].cpu_parent.cpu_children) == 1
56                ):
57                    self[idx].cpu_parent.cpu_children = self[idx].cpu_children
58                    self[idx].cpu_parent.kernels = self[idx].kernels  # lift kernels up
59                    for ch in self[idx].cpu_children:
60                        ch.cpu_parent = self[idx].cpu_parent
61                    to_delete.add(idx)
62            if len(to_delete) == 0:
63                break
64            new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
65            self.clear()
66            self.extend(new_evts)
67
68    def _populate_cpu_children(self):
69        """Populate child events into each underlying FunctionEvent object.
70
71        One event is a child of another if [s1, e1) is inside [s2, e2). Where
72        s1 and e1 would be start and end of the child event's interval. And
73        s2 and e2 start and end of the parent event's interval
74
75        Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10]
76        be a parent of two other intervals.
77
78        If for any reason two intervals intersect only partially, this function
79        will not record a parent child relationship between then.
80        """
81        # Some events can be async (i.e. start and end on different threads),
82        # since it's generally undefined how to attribute children ranges to
83        # async ranges, we do not use them when calculating nested ranges and stats
84        sync_events = [
85            evt
86            for evt in self
87            if not evt.is_async and evt.device_type == DeviceType.CPU
88        ]
89        events = sorted(
90            sync_events,
91            key=attrgetter("thread"),
92        )
93        # Group by both thread and node_id, so that events that happen to have
94        # the same thread_id but are from different nodes aren't incorrectly
95        # grouped together.
96        threads = itertools.groupby(
97            events, key=lambda event: (event.thread, event.node_id)
98        )
99
100        # For each thread we keep a stack of current nested parents.
101        # We maintain the invariant that each interval is a subset of all other
102        # intervals lower in the stack.
103        #
104        # First we sort the intervals by their start time. Then we iterate over them.
105        # Every time we see a new interval we remove several parents from
106        # the top until we restore the invariant. Then parent child relationship
107        # if recorded if the stack is not empty.
108        # Finally we add new interval to the list
109        #
110        # Algorithm has O(N * log(N)) complexity where N is number of
111        # intervals
112        for thread_id, thread_events in threads:
113            thread_events_ = sorted(
114                thread_events,
115                key=lambda event: [event.time_range.start, -event.time_range.end],
116            )
117            current_events: List[FunctionEvent] = []
118            cur_end = 0
119            for event in thread_events_:
120                while len(current_events) > 0:
121                    parent = current_events[-1]
122                    if (
123                        event.time_range.start >= parent.time_range.end
124                        or event.time_range.end > parent.time_range.end
125                    ):
126                        # this can't be a parent
127                        current_events.pop()
128                    else:
129                        parent.append_cpu_child(event)
130                        assert (
131                            event.cpu_parent is None
132                        ), f"There is already a CPU parent event for {event.key}"
133                        event.set_cpu_parent(parent)
134                        break
135
136                current_events.append(event)
137
138    def _set_backward_stacktraces(self):
139        def bw_parent(evt):
140            if evt is None:
141                return None
142            elif evt.scope == 1:  # BACKWARD_FUNCTION
143                return evt
144            else:
145                return bw_parent(evt.cpu_parent)
146
147        fwd_stacks = {}
148        for evt in self:
149            if bw_parent(evt) is None and evt.stack is not None:
150                t = (evt.sequence_nr, evt.thread)
151                if t not in fwd_stacks:
152                    fwd_stacks[t] = evt.stack
153
154        for evt in self:
155            p = bw_parent(evt)
156            if p is not None:
157                assert p.fwd_thread is not None
158                t = (p.sequence_nr, p.fwd_thread)
159                if t in fwd_stacks:
160                    evt.stack = fwd_stacks[t]
161                else:
162                    evt.stack = []
163
164    @property
165    def self_cpu_time_total(self):
166        return sum(event.self_cpu_time_total for event in self)
167
168    def table(
169        self,
170        sort_by=None,
171        row_limit=100,
172        max_src_column_width=75,
173        max_name_column_width=55,
174        max_shapes_column_width=80,
175        header=None,
176        top_level_events_only=False,
177    ):
178        """Print an EventList as a nicely formatted table.
179
180        Args:
181            sort_by (str, optional): Attribute used to sort entries. By default
182                they are printed in the same order as they were registered.
183                Valid keys include: ``cpu_time``, ``cuda_time``, ``xpu_time``,
184                ``cpu_time_total``, ``cuda_time_total``, ``xpu_time_total``,
185                ``cpu_memory_usage``, ``cuda_memory_usage``, ``xpu_memory_usage``,
186                ``self_cpu_memory_usage``, ``self_cuda_memory_usage``,
187                ``self_xpu_memory_usage``, ``count``.
188            top_level_events_only(bool, optional): Boolean flag to determine the
189                selection of events to display. If true, the profiler will only
190                display events at top level like top-level invocation of python
191                `lstm`, python `add` or other functions, nested events like low-level
192                cpu/cuda/xpu ops events are omitted for profiler result readability.
193
194        Returns:
195            A string containing the table.
196        """
197        return _build_table(
198            self,
199            sort_by=sort_by,
200            row_limit=row_limit,
201            max_src_column_width=max_src_column_width,
202            max_name_column_width=max_name_column_width,
203            max_shapes_column_width=max_shapes_column_width,
204            header=header,
205            profile_memory=self._profile_memory,
206            with_flops=self._with_flops,
207            top_level_events_only=top_level_events_only,
208        )
209
210    def export_chrome_trace(self, path):
211        """Export an EventList as a Chrome tracing tools file.
212
213        The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.
214
215        Args:
216            path (str): Path where the trace will be written.
217        """
218        import os
219
220        device_name = "cuda" if not self._use_device else self._use_device
221        with open(path, "w") as f:
222            chrome_events = []
223            next_id = 0
224            # Use file IO over using json.dump since JSON dumping is very slow and
225            # this technique is proven to give a 4x speedup.
226            f.write("[")
227            for evt in self:
228                if evt.trace_name is None:
229                    continue
230                f.write(
231                    '{{"name": "{}", '
232                    '"ph": "X", '
233                    '"ts": {}, '
234                    '"dur": {}, '
235                    '"tid": {}, '
236                    '"pid": "CPU functions", '
237                    '"args": {{}}}}, '.format(
238                        evt.trace_name,
239                        evt.time_range.start,
240                        evt.time_range.elapsed_us(),
241                        evt.thread
242                        if not evt.is_remote
243                        else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "',
244                    )
245                )
246                for k in evt.kernels:
247                    # 's' and 'f' draw Flow arrows from
248                    # the CPU launch to the GPU kernel
249                    f.write(
250                        f'{{"name": "{evt.trace_name}", '
251                        '"ph": "s", '
252                        f'"ts": {evt.time_range.start}, '
253                        f'"tid": {evt.thread}, '
254                        '"pid": "CPU functions", '
255                        f'"id": {next_id}, '
256                        f'"cat": "cpu_to_{device_name}", '
257                        '"args": {}}, '
258                    )
259                    # Note: use torch.profiler to get device kernel trace
260                    next_id += 1
261            if len(self) > 0:
262                # remove trailing whitespace and comma
263                f.seek(f.tell() - 2, os.SEEK_SET)
264                f.truncate()
265            f.write("]")
266
267    def supported_export_stacks_metrics(self):
268        return [
269            "self_cpu_time_total",
270            "self_cuda_time_total",
271            "self_xpu_time_total",
272            "self_privateuse1_time_total",
273        ]
274
275    def export_stacks(self, path: str, metric: str):
276        if metric not in self.supported_export_stacks_metrics():
277            raise ValueError(
278                "metric should be one of: "
279                + str(self.supported_export_stacks_metrics())
280            )
281        translate_table = str.maketrans(" ;\t\n", "____")
282        with open(path, "w") as f:
283            for evt in self:
284                if evt.stack and len(evt.stack) > 0:
285                    metric_value = getattr(
286                        evt,
287                        metric.replace("cuda", "device")
288                        .replace("xpu", "device")
289                        .replace("privateuse1", "device"),
290                    )
291                    if int(metric_value) > 0:
292                        stack_str = ""
293                        for entry in reversed(evt.stack):
294                            stack_str += entry.translate(translate_table)
295                            stack_str += ";"
296                        stack_str = stack_str[:-1] + " " + str(int(metric_value))
297                        f.write(stack_str + "\n")
298
299    def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0):
300        """Averages all function events over their keys.
301
302        Args:
303            group_by_input_shapes: group entries by
304                (event name, input shapes) rather than just event name.
305                This is useful to see which input shapes contribute to the runtime
306                the most and may help with size-specific optimizations or
307                choosing the best candidates for quantization (aka fitting a roof line)
308
309            group_by_stack_n: group by top n stack trace entries
310
311        Returns:
312            An EventList containing FunctionEventAvg objects.
313        """
314        assert self._tree_built
315        stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
316
317        def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]:
318            key = [
319                str(event.key),
320                str(event.node_id),
321                str(event.device_type),
322                str(event.is_legacy),
323                str(event.is_user_annotation),
324            ]
325            if group_by_input_shapes:
326                key.append(str(event.input_shapes))
327            if group_by_stack_n > 0:
328                key += event.stack[:group_by_stack_n]
329            return tuple(key)
330
331        for evt in self:
332            stats[get_key(evt, group_by_input_shapes, group_by_stack_n)].add(evt)
333
334        avg_list = EventList(
335            stats.values(),
336            use_device=self._use_device,
337            profile_memory=self._profile_memory,
338            with_flops=self._with_flops,
339        )
340        for evt in avg_list:
341            evt.stack = evt.stack[:group_by_stack_n]
342            if not group_by_input_shapes:
343                evt.input_shapes = ""
344        return avg_list
345
346    def total_average(self):
347        """Averages all events.
348
349        Returns:
350            A FunctionEventAvg object.
351        """
352        total_stat = FunctionEventAvg()
353        for evt in self:
354            total_stat += evt
355            total_stat.key = None
356        total_stat.key = "Total"
357        return total_stat
358
359
360def _format_time(time_us):
361    """Define how to format time in FunctionEvent."""
362    US_IN_SECOND = 1000.0 * 1000.0
363    US_IN_MS = 1000.0
364    if time_us >= US_IN_SECOND:
365        return f"{time_us / US_IN_SECOND:.3f}s"
366    if time_us >= US_IN_MS:
367        return f"{time_us / US_IN_MS:.3f}ms"
368    return f"{time_us:.3f}us"
369
370
371def _format_time_share(time_us, total_time_us):
372    """Define how to format time in FunctionEvent."""
373    if total_time_us == 0:
374        assert time_us == 0, f"Expected time_us == 0 but got {time_us}"
375        return "NaN"
376    return f"{time_us * 100.0 / total_time_us:.2f}%"
377
378
379def _format_memory(nbytes):
380    """Return a formatted memory size string."""
381    KB = 1024
382    MB = 1024 * KB
383    GB = 1024 * MB
384    if abs(nbytes) >= GB:
385        return f"{nbytes * 1.0 / GB:.2f} Gb"
386    elif abs(nbytes) >= MB:
387        return f"{nbytes * 1.0 / MB:.2f} Mb"
388    elif abs(nbytes) >= KB:
389        return f"{nbytes * 1.0 / KB:.2f} Kb"
390    else:
391        return str(nbytes) + " b"
392
393
394def _attr_formatter(name):
395    return property(lambda self: _format_time(getattr(self, name)))
396
397
398class FormattedTimesMixin:
399    """Helpers for FunctionEvent and FunctionEventAvg.
400
401    The subclass should define `*_time_total` and `count` attributes.
402    """
403
404    cpu_time_str = _attr_formatter("cpu_time")
405    device_time_str = _attr_formatter("device_time")
406    cpu_time_total_str = _attr_formatter("cpu_time_total")
407    device_time_total_str = _attr_formatter("device_time_total")
408    self_cpu_time_total_str = _attr_formatter("self_cpu_time_total")
409    self_device_time_total_str = _attr_formatter("self_device_time_total")
410
411    @property
412    def cpu_time(self):
413        return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count  # type: ignore[attr-defined]
414
415    @property
416    def device_time(self):
417        return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count  # type: ignore[attr-defined]
418
419    @property
420    @deprecated(
421        "`cuda_time` is deprecated, please use `device_time` instead.",
422        category=FutureWarning,
423    )
424    def cuda_time(self):  # To be deprecated
425        return self.device_time
426
427
428class Interval:
429    def __init__(self, start, end):
430        self.start = start
431        self.end = end
432
433    def elapsed_us(self):
434        r"""
435        Returns the length of the interval
436        """
437        return self.end - self.start
438
439
440Kernel = namedtuple("Kernel", ["name", "device", "duration"])
441
442
443class FunctionEvent(FormattedTimesMixin):
444    """Profiling information about a single function."""
445
446    def __init__(
447        self,
448        id,
449        name,
450        thread,
451        start_us,
452        end_us,
453        fwd_thread=None,
454        input_shapes=None,
455        stack=None,
456        scope=0,
457        use_device=None,
458        cpu_memory_usage=0,
459        device_memory_usage=0,
460        is_async=False,
461        is_remote=False,
462        sequence_nr=-1,
463        node_id=-1,
464        device_type=DeviceType.CPU,
465        device_index=0,
466        device_resource_id=None,
467        is_legacy=False,
468        flops=None,
469        trace_name=None,
470        concrete_inputs=None,
471        kwinputs=None,
472        is_user_annotation=False,
473    ):
474        self.id: int = id
475        self.node_id: int = node_id
476        self.name: str = name
477        self.trace_name: str = trace_name
478        self.time_range: Interval = Interval(start_us, end_us)
479        self.thread: int = thread
480        self.fwd_thread: Optional[int] = fwd_thread
481        self.kernels: List[Kernel] = []
482        self.count: int = 1
483        self.cpu_children: List[FunctionEvent] = []
484        self.cpu_parent: Optional[FunctionEvent] = None
485        self.input_shapes: Tuple[int, ...] = input_shapes
486        self.concrete_inputs: List[Any] = concrete_inputs
487        self.kwinputs: Dict[str, Any] = kwinputs
488        self.stack: List = stack
489        self.scope: int = scope
490        self.use_device: Optional[str] = use_device
491        self.cpu_memory_usage: int = cpu_memory_usage
492        self.device_memory_usage: int = device_memory_usage
493        self.is_async: bool = is_async
494        self.is_remote: bool = is_remote
495        self.sequence_nr: int = sequence_nr
496        self.device_type: DeviceType = device_type
497        self.device_index: int = device_index
498        self.device_resource_id: int = (
499            thread if device_resource_id is None else device_resource_id
500        )
501        self.is_legacy: bool = is_legacy
502        self.flops: Optional[int] = flops
503        self.is_user_annotation: Optional[bool] = is_user_annotation
504        self.self_cpu_percent = -1
505        self.total_cpu_percent = -1
506        self.total_device_percent = -1
507
508    def append_kernel(self, name, device, duration):
509        assert self.device_type == DeviceType.CPU
510        self.kernels.append(Kernel(name, device, duration))
511
512    def append_cpu_child(self, child):
513        """Append a CPU child of type FunctionEvent.
514
515        One is supposed to append only direct children to the event to have
516        correct self cpu time being reported.
517        """
518        assert self.device_type == DeviceType.CPU
519        assert isinstance(child, FunctionEvent)
520        assert child.device_type == DeviceType.CPU
521        self.cpu_children.append(child)
522
523    def set_cpu_parent(self, parent):
524        """Set the immediate CPU parent of type FunctionEvent.
525
526        One profiling FunctionEvent should have only one CPU parent such that
527        the child's range interval is completely inside the parent's. We use
528        this connection to determine the event is from top-level op or not.
529        """
530        assert self.device_type == DeviceType.CPU
531        assert isinstance(parent, FunctionEvent)
532        assert parent.device_type == DeviceType.CPU
533        self.cpu_parent = parent
534
535    # Note: async events don't have children, are not used when computing 'self'
536    # metrics of other events, have only total cpu time
537    @property
538    def self_cpu_memory_usage(self):
539        if self.is_async or self.device_type != DeviceType.CPU:
540            return 0
541        return self.cpu_memory_usage - sum(
542            child.cpu_memory_usage for child in self.cpu_children
543        )
544
545    @property
546    def self_device_memory_usage(self):
547        if self.is_async or self.device_type != DeviceType.CPU:
548            return 0
549        return self.device_memory_usage - sum(
550            child.device_memory_usage for child in self.cpu_children
551        )
552
553    @property
554    @deprecated(
555        "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.",
556        category=FutureWarning,
557    )
558    def self_cuda_memory_usage(self):  # To be deprecated
559        return self.self_device_memory_usage
560
561    @property
562    def cpu_time_total(self):
563        if self.device_type == DeviceType.CPU:
564            return self.time_range.elapsed_us()
565        else:
566            return 0
567
568    @property
569    def self_cpu_time_total(self):
570        if self.is_async or self.device_type != DeviceType.CPU:
571            return 0
572        return self.cpu_time_total - sum(
573            child.cpu_time_total for child in self.cpu_children
574        )
575
576    @property
577    def device_time_total(self):
578        if self.is_async or not self.use_device:
579            return 0
580        if self.device_type == DeviceType.CPU:
581            if not self.is_legacy:
582                # account for the kernels in the children ops
583                return sum(kinfo.duration for kinfo in self.kernels) + sum(
584                    ch.device_time_total for ch in self.cpu_children
585                )
586            else:
587                # each legacy cpu events has a single (fake) kernel
588                return sum(kinfo.duration for kinfo in self.kernels)
589        else:
590            assert self.device_type in [
591                DeviceType.CUDA,
592                DeviceType.PrivateUse1,
593                DeviceType.MTIA,
594            ]
595            return self.time_range.elapsed_us()
596
597    @property
598    @deprecated(
599        "`cuda_time_total` is deprecated. Use `device_time_total` instead.",
600        category=FutureWarning,
601    )
602    def cuda_time_total(self):  # To be deprecated
603        return self.device_time_total
604
605    @property
606    def self_device_time_total(self):
607        if self.is_async or not self.use_device:
608            return 0
609        if self.device_type == DeviceType.CPU:
610            return self.device_time_total - sum(
611                child.device_time_total for child in self.cpu_children
612            )
613        else:
614            assert self.device_type in [
615                DeviceType.CUDA,
616                DeviceType.PrivateUse1,
617                DeviceType.MTIA,
618            ]
619            return self.device_time_total
620
621    @property
622    @deprecated(
623        "`self_cuda_time_total` is deprecated. Use `self_device_time_total` instead.",
624        category=FutureWarning,
625    )
626    def self_cuda_time_total(self):  # To be deprecated
627        return self.self_device_time_total
628
629    @property
630    def key(self):
631        return self.name
632
633    def __repr__(self):
634        device_name = self.use_device
635        device_time = self.device_time_str
636        device_memory_usage = self.device_memory_usage
637        return (
638            f"<FunctionEvent id={self.id} name={self.name} device_type={self.device_type} node_id={self.node_id} "
639            f"cpu_time={self.cpu_time_str} start_us={self.time_range.start} end_us={self.time_range.end} "
640            f"cpu_children={str([child.id for child in self.cpu_children])} {device_name}_time={device_time} "
641            f"name={self.name} thread={self.thread} input_shapes={str(self.input_shapes)} "
642            f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory_usage} "
643            f"is_async={self.is_async} is_remote={self.is_remote} seq_nr={self.sequence_nr} is_legacy={self.is_legacy}>"
644        )
645
646
647class FunctionEventAvg(FormattedTimesMixin):
648    """Used to average stats over multiple FunctionEvent objects."""
649
650    def __init__(self) -> None:
651        self.key: Optional[str] = None
652        self.count: int = 0
653        self.node_id: int = 0
654        self.is_async: bool = False
655        self.is_remote: bool = False
656        self.use_device: Optional[str] = None
657        self.cpu_time_total: int = 0
658        self.device_time_total: int = 0
659        self.self_cpu_time_total: int = 0
660        self.self_device_time_total: int = 0
661        self.input_shapes: Optional[List[List[int]]] = None
662        self.stack: Optional[List] = None
663        self.scope: Optional[int] = None
664        self.cpu_memory_usage: int = 0
665        self.device_memory_usage: int = 0
666        self.self_cpu_memory_usage: int = 0
667        self.self_device_memory_usage: int = 0
668        self.cpu_children: Optional[List[FunctionEvent]] = None
669        self.cpu_parent: Optional[FunctionEvent] = None
670        self.device_type: DeviceType = DeviceType.CPU
671        self.is_legacy: bool = False
672        self.flops: int = 0
673
674    def add(self, other):
675        if self.key is None:
676            # First function being recorded as part of FunctionEventAvg, propagate
677            # fields.
678            self.key = other.key
679            self.node_id = other.node_id
680            self.is_async = other.is_async
681            self.is_remote = other.is_remote
682            self.cpu_parent = other.cpu_parent
683            self.cpu_children = other.cpu_children
684
685            self.input_shapes = other.input_shapes
686            self.stack = other.stack
687            self.scope = other.scope
688            self.device_type = other.device_type
689            self.is_legacy = other.is_legacy
690            self.use_device = other.use_device
691            self.is_user_annotation = other.is_user_annotation
692
693        assert isinstance(other, (FunctionEvent, FunctionEventAvg))
694        assert other.key == self.key
695        self.cpu_time_total += other.cpu_time_total
696        self.device_time_total += other.device_time_total
697        self.self_cpu_time_total += other.self_cpu_time_total
698        self.self_device_time_total += other.self_device_time_total
699        self.cpu_memory_usage += other.cpu_memory_usage
700        self.device_memory_usage += other.device_memory_usage
701        self.self_cpu_memory_usage += other.self_cpu_memory_usage
702        self.self_device_memory_usage += other.self_device_memory_usage
703        self.count += other.count
704        if self.flops is None:
705            self.flops = other.flops
706        elif other.flops is not None:
707            self.flops += other.flops
708        return self
709
710    def __iadd__(self, other):
711        return self.add(other)
712
713    def __repr__(self):
714        device_name = "cuda" if not self.use_device else self.use_device
715        self_device_time = self.self_device_time_total_str
716        device_time = self.device_time_str
717        device_memory = self.device_memory_usage
718        return (
719            f"<FunctionEventAvg key={self.key} self_cpu_time={self.self_cpu_time_total_str} cpu_time={self.cpu_time_str} "
720            f" self_{device_name}_time={self_device_time} {device_name}_time={device_time} input_shapes={str(self.input_shapes)} "
721            f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory}>"
722        )
723
724
725class StringTable(defaultdict):
726    def __missing__(self, key):
727        # manage cases like 't' (demangled to 'unsigned short') separately,
728        # for now simply check the length to avoid unexpected results for
729        # the short sequences
730        self[key] = torch._C._demangle(key) if len(key) > 1 else key
731        return self[key]
732
733
734class MemRecordsAcc:
735    """Acceleration structure for accessing mem_records in interval."""
736
737    def __init__(self, mem_records):
738        self._mem_records = mem_records
739        self._start_nses: List[int] = []
740        self._indices: List[int] = []
741        if len(mem_records) > 0:
742            tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)])
743            self._start_nses, self._indices = zip(*tmp)  # type: ignore[assignment]
744
745    def in_interval(self, start_us, end_us):
746        r"""
747        Return all records in the given interval
748        To maintain backward compatibility, convert us to ns in function
749        """
750        start_idx = bisect.bisect_left(self._start_nses, start_us * 1000)
751        end_idx = bisect.bisect_right(self._start_nses, end_us * 1000)
752        for i in range(start_idx, end_idx):
753            yield self._mem_records[self._indices[i]]
754
755
756def _filter_stack_entry(entry):
757    filtered_entries = [
758        ("autograd/__init__", "_make_grads"),
759        ("autograd/__init__", "backward"),
760        ("torch/tensor", "backward"),
761        ("_internal/common_utils", "prof_callable"),
762        ("_internal/common_utils", "prof_func_call"),
763        ("_internal/common_utils", "prof_meth_call"),
764    ]
765    return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries)
766
767
768MEMORY_EVENT_NAME = "[memory]"
769OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]"
770
771
772def _filter_name(name):
773    # ignoring the following utility ops
774    filtered_out_names = [
775        MEMORY_EVENT_NAME,  # used only for the top-level memory events
776        OUT_OF_MEMORY_EVENT_NAME,
777        "profiler::_record_function_enter",
778        "profiler::_record_function_enter_new",
779        "profiler::_record_function_exit",
780        "aten::is_leaf",
781        "aten::output_nr",
782        "aten::_version",
783    ]
784    return name in filtered_out_names
785
786
787# Demangles and optionally rewrites the provided event name,
788# with_wildcard - whether to replace certain numbered event names
789# with a wildcard name to aggregate them together in the profiler table
790# output
791def _rewrite_name(name, with_wildcard=False):
792    string_table = StringTable()
793    name = string_table[name]
794    if with_wildcard:
795        if name.startswith("ProfilerStep#"):
796            name = "ProfilerStep*"
797    return name
798
799
800def _build_table(
801    events,
802    sort_by=None,
803    header=None,
804    row_limit=100,
805    max_src_column_width=75,
806    max_name_column_width=55,
807    max_shapes_column_width=80,
808    with_flops=False,
809    profile_memory=False,
810    top_level_events_only=False,
811):
812    """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""
813    if len(events) == 0:
814        return ""
815
816    has_device_time = any(event.self_device_time_total > 0 for event in events)
817    has_device_mem = any(event.self_device_memory_usage > 0 for event in events)
818    use_device = events[0].use_device
819    # Running on PrivateUse1 device with profiler but not enable
820    # ProfilerActivity.PrivateUse1 can also catch privateuse1 memory usage.
821    # Here only need to check has_privateuse1_time if not use_device.
822    if not use_device and has_device_time:
823        raise RuntimeError("use_device is None, but there is device performance data.")
824
825    has_input_shapes = any(
826        (event.input_shapes is not None and len(event.input_shapes) > 0)
827        for event in events
828    )
829
830    if sort_by is not None:
831        events = EventList(
832            sorted(
833                events,
834                key=lambda evt: getattr(
835                    evt,
836                    sort_by.replace("cuda", "device")
837                    .replace("xpu", "device")
838                    .replace("privateuse1", "device"),
839                ),
840                reverse=True,
841            ),
842            use_device=use_device,
843            profile_memory=profile_memory,
844            with_flops=with_flops,
845        )
846
847    name_column_width = max(len(evt.key) for evt in events) + 4
848    if max_name_column_width is not None:
849        name_column_width = min(name_column_width, max_name_column_width)
850
851    shapes_column_width = max(len(str(evt.input_shapes)) for evt in events) + 4
852    if max_shapes_column_width is not None:
853        shapes_column_width = min(shapes_column_width, max_shapes_column_width)
854
855    DEFAULT_COLUMN_WIDTH = 12
856    flops_column_width = DEFAULT_COLUMN_WIDTH
857
858    src_column_width = None
859    stacks = []
860    for evt in events:
861        if evt.stack is not None and len(evt.stack) > 0:
862            stacks.append(evt.stack)
863    has_stack = len(stacks) > 0
864    if has_stack:
865        src_column_width = (
866            max(max(len(entry) for entry in stack) for stack in stacks) + 4
867        )
868        if max_src_column_width is not None:
869            src_column_width = min(src_column_width, max_src_column_width)
870
871    headers = [
872        "Name",
873        "Self CPU %",
874        "Self CPU",
875        "CPU total %",
876        "CPU total",
877        "CPU time avg",
878    ]
879    device_name = use_device.upper() if use_device is not None else "None"
880    if has_device_time:
881        headers.extend(
882            [
883                f"Self {device_name}",
884                f"Self {device_name} %",
885                f"{device_name} total",
886                f"{device_name} time avg",
887            ]
888        )
889    if profile_memory:
890        headers.extend(
891            [
892                "CPU Mem",
893                "Self CPU Mem",
894            ]
895        )
896        if use_device and has_device_mem:
897            headers.extend(
898                [
899                    f"{device_name} Mem",
900                    f"Self {device_name} Mem",
901                ]
902            )
903    headers.append("# of Calls")
904    # Only append Node ID if any event has a valid (>= 0) Node ID
905    append_node_id = any(evt.node_id != -1 for evt in events)
906    if append_node_id:
907        headers.append("Node ID")
908
909    # Have to use a list because nonlocal is Py3 only...
910    SPACING_SIZE = 2
911    row_format_lst = [""]
912    header_sep_lst = [""]
913    line_length_lst = [-SPACING_SIZE]
914
915    def add_column(padding, text_dir=">"):
916        row_format_lst[0] += (
917            "{: " + text_dir + str(padding) + "}" + (" " * SPACING_SIZE)
918        )
919        header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE)
920        line_length_lst[0] += padding + SPACING_SIZE
921
922    def auto_scale_flops(flops):
923        flop_headers = [
924            "FLOPs",
925            "KFLOPs",
926            "MFLOPs",
927            "GFLOPs",
928            "TFLOPs",
929            "PFLOPs",
930        ]
931        assert flops > 0
932        log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
933        assert log_flops >= 0 and log_flops < len(flop_headers)
934        return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
935
936    add_column(name_column_width)
937    for _ in headers[1:]:
938        add_column(DEFAULT_COLUMN_WIDTH)
939
940    if has_input_shapes:
941        headers.append("Input Shapes")
942        add_column(shapes_column_width)
943
944    if has_stack:
945        headers.append("Source Location")
946        add_column(src_column_width, text_dir="<")
947
948    if with_flops:
949        # Auto-scaling of flops header
950        raw_flops = []
951        for evt in events:
952            if evt.flops > 0:
953                raw_flops.append(evt.flops)
954        if len(raw_flops) != 0:
955            (flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
956            headers.append(f"Total {flops_header}")
957            add_column(flops_column_width)
958        else:
959            with_flops = False  # can't find any valid flops
960
961    row_format = row_format_lst[0]
962    header_sep = header_sep_lst[0]
963    line_length = line_length_lst[0]
964    add_column = None  # type: ignore[assignment]
965
966    # Have to use a list because nonlocal is Py3 only...
967    result = []
968
969    def append(s):
970        result.append(s)
971        result.append("\n")  # Yes, newline after the end as well
972
973    sum_self_cpu_time_total = 0
974    sum_self_device_time_total = 0
975    for evt in events:
976        sum_self_cpu_time_total += evt.self_cpu_time_total
977        if evt.device_type == DeviceType.CPU and evt.is_legacy:
978            # in legacy profiler, kernel info is stored in cpu events
979            sum_self_device_time_total += evt.self_device_time_total
980        elif (
981            evt.device_type
982            in [
983                DeviceType.CUDA,
984                DeviceType.PrivateUse1,
985                DeviceType.MTIA,
986            ]
987            and not evt.is_user_annotation
988        ):
989            # in kineto profiler, there're events with the correct device type (e.g. CUDA)
990            sum_self_device_time_total += evt.self_device_time_total
991
992    # Actual printing
993    if header is not None:
994        append("=" * line_length)
995        append(header)
996    if top_level_events_only:
997        append("=" * line_length)
998        append("This report only display top-level ops statistics")
999    append(header_sep)
1000    append(row_format.format(*headers))
1001
1002    append(header_sep)
1003
1004    def trim_path(path, src_column_width):
1005        if len(path) > src_column_width:
1006            offset = len(path) - src_column_width
1007            path = path[offset:]
1008            if len(path) > 3:
1009                path = "..." + path[3:]
1010        return path
1011
1012    event_limit = 0
1013    for evt in events:
1014        if event_limit == row_limit:
1015            break
1016        if top_level_events_only and evt.cpu_parent is not None:
1017            continue
1018        else:
1019            event_limit += 1
1020        name = evt.key
1021        if max_name_column_width is not None and len(name) >= max_name_column_width - 3:
1022            name = name[: (max_name_column_width - 3)] + "..."
1023        evt.self_cpu_percent = _format_time_share(
1024            evt.self_cpu_time_total, sum_self_cpu_time_total
1025        )
1026        evt.total_cpu_percent = (
1027            _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total)
1028            if not evt.is_async
1029            else 0
1030        )
1031        row_values = [
1032            name,
1033            # Self CPU total %, 0 for async events.
1034            evt.self_cpu_percent,
1035            evt.self_cpu_time_total_str,  # Self CPU total
1036            # CPU total %, 0 for async events.
1037            evt.total_cpu_percent,
1038            evt.cpu_time_total_str,  # CPU total
1039            evt.cpu_time_str,  # CPU time avg
1040        ]
1041        if has_device_time:
1042            evt.total_device_percent = _format_time_share(
1043                evt.self_device_time_total, sum_self_device_time_total
1044            )
1045            row_values.extend(
1046                [
1047                    evt.self_device_time_total_str,
1048                    # device time total %
1049                    evt.total_device_percent,
1050                    evt.device_time_total_str,
1051                    evt.device_time_str,  # device time avg
1052                ]
1053            )
1054        if profile_memory:
1055            row_values.extend(
1056                [
1057                    # CPU Mem Total
1058                    _format_memory(evt.cpu_memory_usage),
1059                    # Self CPU Mem Total
1060                    _format_memory(evt.self_cpu_memory_usage),
1061                ]
1062            )
1063            if use_device and has_device_mem:
1064                row_values.extend(
1065                    [
1066                        # Device Mem Total
1067                        _format_memory(evt.device_memory_usage),
1068                        # Self Device Mem Total
1069                        _format_memory(evt.self_device_memory_usage),
1070                    ]
1071                )
1072        row_values.append(
1073            evt.count,  # Number of calls
1074        )
1075
1076        if append_node_id:
1077            row_values.append(evt.node_id)
1078        if has_input_shapes:
1079            row_values.append(str(evt.input_shapes)[:shapes_column_width])
1080        if with_flops:
1081            if evt.flops <= 0:
1082                row_values.append("--")
1083            else:
1084                row_values.append(f"{evt.flops * flops_scale:8.3f}")  # type: ignore[possibly-undefined]
1085        if has_stack:
1086            src_field = ""
1087            if len(evt.stack) > 0:
1088                src_field = trim_path(evt.stack[0], src_column_width)
1089            row_values.append(src_field)
1090        append(row_format.format(*row_values))
1091
1092        if has_stack:
1093            empty_headers = [""] * (len(headers) - 1)
1094            for entry in evt.stack[1:]:
1095                append(
1096                    row_format.format(
1097                        *(empty_headers + [trim_path(entry, src_column_width)])
1098                    )
1099                )
1100            empty_headers.append("")
1101            append(row_format.format(*empty_headers))
1102
1103    append(header_sep)
1104    append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}")
1105    if has_device_time:
1106        append(
1107            f"Self {use_device.upper() if use_device is not None else 'None'} "
1108            f"time total: {_format_time(sum_self_device_time_total)}"
1109        )
1110    return "".join(result)
1111