xref: /aosp_15_r20/external/pytorch/torch/profiler/_pattern_matcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import json
3import math
4import os
5import re
6from typing import Dict, List, Optional, Set
7
8import torch
9import torch.utils.benchmark as benchmark
10from torch._C._profiler import (
11    _EventType,
12    _ExtraFields_PyCall,
13    _ExtraFields_PyCCall,
14    _ExtraFields_TorchOp,
15    _ProfilerEvent,
16)
17from torch.profiler import profile
18from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs
19
20
21class Pattern:
22    """
23    Base class for all patterns, subclass this class and implement match()
24    to define custom patterns.
25
26    In subclass, define description and skip property.
27    """
28
29    def __init__(self, prof: profile, should_benchmark: bool = False):
30        self.prof = prof
31        self.should_benchmark = should_benchmark
32        self.name = "Please specify a name for pattern"
33        self.description = "Please specify a description for pattern"
34        self.url = ""
35        assert prof.profiler is not None and prof.profiler.kineto_results is not None
36        self.event_tree = prof.profiler.kineto_results.experimental_event_tree()
37        self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
38        for event in self.event_tree:
39            self.tid_root.setdefault(event.start_tid, []).append(event)
40
41    @property
42    def skip(self):
43        return False
44
45    def report(self, event: _ProfilerEvent):
46        msg = (
47            f"{self.description}\n[Source Code Location] {source_code_location(event)}"
48        )
49        return msg
50
51    def eventTreeTraversal(self):
52        """
53        Traverse the event tree and yield all events.
54        Override this method in subclass to customize the traversal.
55        """
56        yield from traverse_dfs(self.event_tree)
57
58    def summary(self, events: List[_ProfilerEvent]):
59        default_summary = f"{self.name}: {len(events)} events matched."
60        if self.should_benchmark:
61            # If benchmark summary is not empty, use it.
62            return (
63                self.benchmark_summary(events)
64                if hasattr(self, "benchmark")  # type: ignore[attr-defined]
65                else default_summary
66            )
67        return default_summary
68
69    def benchmark_summary(self, events: List[_ProfilerEvent]):
70        def format_time(time_ns: int):
71            unit_lst = ["ns", "us", "ms"]
72            for unit in unit_lst:
73                if time_ns < 1000:
74                    return f"{time_ns:.2f} {unit}"
75                time_ns //= 1000
76            return f"{time_ns:.2f} s"
77
78        assert hasattr(self, "benchmark"), "Please implement benchmark()"
79        shapes_factor_map = self.benchmark(events)  # type: ignore[attr-defined]
80        original_time = sum(event.duration_time_ns for event in events)
81        new_time = sum(
82            shapes_factor_map[input_shapes(event)] * event.duration_time_ns
83            for event in events
84        )
85        return (
86            f"{self.name}: {len(events)} events matched. "
87            f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)"
88        )
89
90    def match(self, event: _ProfilerEvent):
91        """
92        Return True if the event matches the pattern.
93        This method should be overriden in subclass.
94        """
95        raise NotImplementedError
96
97    def matched_events(self):
98        if self.skip:
99            return []
100        matched_events = []
101        for event in self.eventTreeTraversal():
102            if self.match(event):
103                matched_events.append(event)
104        return matched_events
105
106    def root_of(self, event: _ProfilerEvent):
107        while event.parent:
108            event = event.parent
109        return event
110
111    def siblings_of(self, event: _ProfilerEvent):
112        if event.parent:
113            children = event.parent.children
114        else:
115            children = self.tid_root[event.start_tid]
116        index = children.index(event)
117        return children[:index], children[index + 1 :]
118
119    def next_of(self, event: _ProfilerEvent):
120        _, next_events = self.siblings_of(event)
121        return next_events[0] if next_events else None
122
123    def prev_of(self, event: _ProfilerEvent):
124        prev_events, _ = self.siblings_of(event)
125        return prev_events[-1] if prev_events else None
126
127    def go_up_until(self, event: _ProfilerEvent, predicate):
128        if not event:
129            return None
130        while event.parent and not predicate(event):
131            event = event.parent
132        return event
133
134
135# Patterns
136
137
138class NamePattern(Pattern):
139    def __init__(self, prof: profile, name: str, should_benchmark: bool = False):
140        super().__init__(prof, should_benchmark)
141        self.description = f"Matched Name Event: {name}"
142        self.name = name
143
144    def match(self, event: _ProfilerEvent):
145        return re.search(self.name, event.name) is not None
146
147
148class ExtraCUDACopyPattern(Pattern):
149    """
150    This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
151    example: torch.zeros((100, 100)).to("cuda")
152
153    Pattern:
154    build-in method                 |build-in method
155        ...                         |    aten::to
156            aten::fill_/aten::zero_ |        aten::_to_copy
157
158    Algorithm:
159    We start at node aten::to, go parent events' previous events,
160    and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
161    We always select the last child in the children list when we go down the tree.
162    If at any step we failed, it is not a match.
163    """
164
165    def __init__(self, prof: profile, should_benchmark: bool = False):
166        super().__init__(prof, should_benchmark)
167        self.name = "Extra CUDA Copy Pattern"
168        self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
169        self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
170        self.init_ops = {
171            "aten::fill_",
172            "aten::zero_",
173            "aten::normal_",
174            "aten::uniform_",
175        }
176
177    @property
178    def skip(self):
179        return not self.prof.with_stack or not self.prof.record_shapes
180
181    def match(self, event):
182        # TODO: We should also check tensor identities
183        if event.name != "aten::to":
184            return False
185        to_event = event
186        if not event.children:
187            return False
188        event = event.children[-1]
189        if event.name != "aten::_to_copy":
190            return False
191        if not event.children:
192            return False
193        event = event.children[-1]
194        if event.name != "aten::copy_":
195            return False
196        # aten::copy_ should have the first 2 args dtype the same
197        dtypes = input_dtypes(event)
198        if len(dtypes) < 2:
199            return False
200        if dtypes[0] is None or dtypes[0] != dtypes[1]:
201            return False
202        event = to_event
203        # Up one level
204        event = event.parent
205        if event is None:
206            return False
207        # Check if we have a aten::fill_ in previous leaf
208        event = self.prev_of(event)
209        if event is None:
210            return False
211        while event.children:
212            event = event.children[-1]
213            # aten::zero_ is a special optimzation case where fill_ is not called
214            if event.name in self.init_ops:
215                return True
216        return event.name in self.init_ops
217        # TODO: Check if tensor is reused
218
219    def benchmark(self, events: List[_ProfilerEvent]):
220        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
221        for shape in shapes_factor_map:
222            size = shape[0]
223            to_timer = benchmark.Timer(
224                stmt='torch.ones(size).to("cuda")', globals={"size": size}
225            )
226            de_timer = benchmark.Timer(
227                stmt='torch.ones(size, device="cuda")', globals={"size": size}
228            )
229            to_time = to_timer.timeit(10).mean
230            de_time = de_timer.timeit(10).mean
231            shapes_factor_map[shape] = de_time / to_time
232        return shapes_factor_map
233
234
235class ForLoopIndexingPattern(Pattern):
236    """
237    This pattern identifies if we use a for loop to index a tensor that
238    can be vectorized.
239    example:
240    tensor = torch.empty((100, 100))
241    for i in range(100):
242        tensor[i] = i
243
244    Pattern:
245    aten::select | ... | aten::select | ... (Repeat)
246
247    Algorithm:
248    We start at node aten::select, and we check if we can find this alternating patterns.
249    We also keep a dictionary to avoid duplicate match in the for loop.
250    """
251
252    def __init__(self, prof: profile, should_benchmark: bool = False):
253        super().__init__(prof, should_benchmark)
254        self.name = "For Loop Indexing Pattern"
255        self.description = "For loop indexing detected. Vectorization recommended."
256        self.visited: Set[int] = set()
257
258    def eventTreeTraversal(self):
259        """
260        We need to use BFS traversal order to avoid duplicate match.
261        """
262        yield from traverse_bfs(self.event_tree)
263
264    def match(self, event: _ProfilerEvent):
265        if event.name != "aten::select":
266            return False
267        if event.id in self.visited:
268            return False
269        repeat_count = 1
270        _, next = self.siblings_of(event)
271        if len(next) <= 1:
272            return False
273
274        # Custom event list matching
275        def same_ops(list1, list2):
276            if len(list1) != len(list2):
277                return False
278            for op1, op2 in zip(list1, list2):
279                if op1.name != op2.name:
280                    return False
281            return True
282
283        # Record the ops between two aten::select
284        next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select")
285        if next_select_idx is None:
286            return False
287        indexing_ops = [event] + next[:next_select_idx]
288        next = next[len(indexing_ops) - 1 :]
289        for i in range(0, len(next), len(indexing_ops)):
290            if same_ops(indexing_ops, next[i : i + len(indexing_ops)]):
291                repeat_count += 1
292                self.visited.add(next[i].id)
293            else:
294                break
295        return repeat_count >= 10
296
297
298class FP32MatMulPattern(Pattern):
299    def __init__(self, prof: profile, should_benchmark: bool = False):
300        super().__init__(prof, should_benchmark)
301        self.name = "FP32 MatMul Pattern"
302        self.description = (
303            "You are currently using GPU that supports TF32. "
304            "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
305        )
306        self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
307
308    @property
309    def skip(self):
310        if torch.version.hip is not None:
311            has_tf32 = False
312        else:
313            # Anything less than sm_80 is not Ampere which doesn't support TF32
314            has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
315        return has_tf32 is False or super().skip or not self.prof.record_shapes
316
317    def match(self, event: _ProfilerEvent):
318        # If we saw this pattern once, we don't need to match it again
319        if event.tag != _EventType.TorchOp:
320            return False
321        assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
322        if event.name == "aten::mm":
323            if event.extra_fields.allow_tf32_cublas is False:
324                return True
325        return False
326
327    def report(self, event: _ProfilerEvent):
328        return self.description
329
330    def benchmark(self, events: List[_ProfilerEvent]):
331        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
332        for shape in shapes_factor_map:
333            matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
334            matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
335            fp32_timer = benchmark.Timer(
336                stmt="torch.mm(matrixA, matrixB)",
337                globals={"matrixA": matrixA, "matrixB": matrixB},
338            )
339            tf32_timer = benchmark.Timer(
340                stmt="torch.mm(matrixA, matrixB)",
341                setup="torch.backends.cuda.matmul.allow_tf32 = True",
342                globals={"matrixA": matrixA, "matrixB": matrixB},
343            )
344            torch.backends.cuda.matmul.allow_tf32 = False
345            fp32_time = fp32_timer.timeit(10).mean
346            tf32_time = tf32_timer.timeit(10).mean
347            shapes_factor_map[shape] = tf32_time / fp32_time
348        return shapes_factor_map
349
350
351class OptimizerSingleTensorPattern(Pattern):
352    """
353    This pattern identifies if we are using the single-tensor version of an optimizer.
354    example:
355    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
356    By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when
357    the kernels are relatively small.
358
359    Pattern:
360    XXXXX: _single_tenser_<OPTIMIZER_NAME>
361
362    Algorithm:
363    String match
364    """
365
366    def __init__(self, prof: profile, should_benchmark: bool = False):
367        super().__init__(prof, should_benchmark)
368        self.name = "Optimizer Single Tensor Pattern"
369        self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
370        self.description = (
371            "Deteced optimizer running with single tensor implementation. "
372            "Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
373        )
374        self.url = ""
375
376    def match(self, event: _ProfilerEvent):
377        for optimizer in self.optimizers_with_foreach:
378            if event.name.endswith(f"_single_tensor_{optimizer}"):
379                return True
380        return False
381
382
383class SynchronizedDataLoaderPattern(Pattern):
384    """
385    This pattern identifies if we are using num_workers=0 in DataLoader.
386    example:
387    torch.utils.data.DataLoader(dataset, batch_size=batch_size)
388    Add num_workers=N to the arguments. N depends on system configuration.
389
390    Pattern:
391    dataloader.py(...): __iter__
392        dataloader.py(...): _get_iterator
393            NOT dataloader.py(...): check_worker_number_rationality
394
395    Algorithm:
396    If we don't see check_worker_number_rationality call in the dataloader __iter__,
397    It is not an asynchronous dataloader.
398
399    """
400
401    def __init__(self, prof: profile, should_benchmark: bool = False):
402        super().__init__(prof, should_benchmark)
403        self.name = "Synchronized DataLoader Pattern"
404        self.description = (
405            "Detected DataLoader running with synchronized implementation. "
406            "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
407        )
408        self.url = (
409            "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
410            "#enable-async-data-loading-and-augmentation"
411        )
412
413    def match(self, event: _ProfilerEvent):
414        def is_dataloader_function(name: str, function_name: str):
415            return name.startswith(
416                os.path.join("torch", "utils", "data", "dataloader.py")
417            ) and name.endswith(function_name)
418
419        # TODO: fixme! Due to lifetime issues of the function name, this field might
420        # actually point to an already freed string when the even is a PyCall.
421        # Just silently skip this to unblock testing.
422        try:
423            event.name
424        except UnicodeDecodeError:
425            return False
426
427        if not is_dataloader_function(event.name, "__iter__"):
428            return False
429        if not event.children:
430            return False
431        event = event.children[0]
432        if not is_dataloader_function(event.name, "_get_iterator"):
433            return False
434        if not event.children:
435            return False
436        event = event.children[0]
437        return not is_dataloader_function(event.name, "check_worker_number_rationality")
438        # TODO: We should also check if the loader is bottleneck.
439
440
441class GradNotSetToNonePattern(Pattern):
442    """
443    This pattern identifies if we are not setting grad to None in zero_grad.
444    example:
445    optimizer.zero_grad()
446    By setting set_to_none=True, we can gain speedup
447
448    Pattern:
449    XXXXX: _zero_grad
450        NOT aten::zeros
451            aten::zero_
452
453    aten::zero_ is called on each parameter in the model.
454    We also want to make sure it is not called by aten::zeros.
455
456    Algorithm:
457    String match
458    """
459
460    def __init__(self, prof: profile, should_benchmark: bool = False):
461        super().__init__(prof, should_benchmark)
462        self.name = "Gradient Set To Zero Instead of None Pattern"
463        self.description = (
464            "Detected gradient set to zero instead of None. "
465            "Please add 'set_to_none=True' when calling zero_grad()."
466        )
467        self.url = (
468            "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
469            "#disable-gradient-calculation-for-validation-or-inference"
470        )
471
472    def match(self, event: _ProfilerEvent):
473        if not event.name.endswith(": zero_grad"):
474            return False
475        if not event.children:
476            return False
477
478        for sub_event in traverse_dfs(event.children):
479            if (
480                sub_event.name == "aten::zero_"
481                and sub_event.parent.name != "aten::zeros"
482            ):
483                return True
484        # TODO: We should also check if the optimizer's numerical behavior will change.
485        return False
486
487
488class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
489    """
490    This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d.
491    Bias doesn't do anything when followed by batchnorm.
492    Pattern:
493    nn.Module: Conv2d            | nn.Module: BatchNorm2d
494        ...
495            aten::conv2d AND dtype of third argument is not null
496    The third argument is the bias
497    Algorithm:
498    String match
499    """
500
501    def __init__(self, prof: profile, should_benchmark: bool = False):
502        super().__init__(prof, should_benchmark)
503        self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
504        self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
505        self.url = (
506            "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
507            "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm"
508        )
509
510    @property
511    def skip(self):
512        return self.prof.record_shapes is False or super().skip
513
514    def match(self, event: _ProfilerEvent):
515        if event.name != "aten::conv2d":
516            return False
517        if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None:
518            return False
519        # This means bias=True
520        event = self.go_up_until(
521            event, lambda e: e.name.startswith("nn.Module: Conv2d")
522        )
523        if not event:
524            return False
525        event = self.next_of(event)
526        if not event:
527            return False
528        return event.name.startswith("nn.Module: BatchNorm2d")
529
530
531class MatMulDimInFP16Pattern(Pattern):
532    def __init__(self, prof: profile, should_benchmark: bool = False):
533        super().__init__(prof, should_benchmark)
534        self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
535        self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
536        self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
537
538    @property
539    def skip(self):
540        return not self.prof.with_stack or not self.prof.record_shapes
541
542    def match(self, event: _ProfilerEvent):
543        def mutiple_of(shapes, multiple):
544            return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:])
545
546        if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"):
547            return False
548        if not input_dtypes(event):
549            return False
550        arg_dtype = input_dtypes(event)[0]
551        if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of(
552            input_shapes(event), 8
553        ):
554            return True
555        return False
556
557    def benchmark(self, events: List[_ProfilerEvent]):
558        def closest_multiple(shapes, multiple):
559            return [multiple * math.ceil(shape / multiple) for shape in shapes]
560
561        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
562        for shape in shapes_factor_map:
563            matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16)
564            matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16)
565            not_aligned_dim_timer = benchmark.Timer(
566                stmt="torch.mm(matrixA, matrixB)",
567                globals={"matrixA": matrixA, "matrixB": matrixB},
568            )
569            matrixA = torch.randn(
570                closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16
571            )
572            matrixB = torch.randn(
573                closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16
574            )
575            aligned_dim_timer = benchmark.Timer(
576                stmt="torch.mm(matrixA, matrixB)",
577                globals={"matrixA": matrixA, "matrixB": matrixB},
578            )
579            not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean
580            aligned_dim_time = aligned_dim_timer.timeit(10).mean
581            shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time
582        return shapes_factor_map
583
584
585def source_code_location(event: Optional[_ProfilerEvent]):
586    while event:
587        if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
588            assert isinstance(
589                event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)
590            )
591            if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
592                return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
593        event = event.parent
594    return "No source code location found"
595
596
597def input_shapes(event: _ProfilerEvent):
598    assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
599    return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs)
600
601
602def input_dtypes(event: _ProfilerEvent):
603    assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
604    return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs)
605
606
607def report_all_anti_patterns(
608    prof,
609    should_benchmark: bool = False,
610    print_enable: bool = True,
611    json_report_dir: Optional[str] = None,
612):
613    report_dict: Dict = {}
614    anti_patterns = [
615        ExtraCUDACopyPattern(prof, should_benchmark),
616        # ForLoopIndexingPattern(prof, should_benchmark),
617        FP32MatMulPattern(prof, should_benchmark),
618        OptimizerSingleTensorPattern(prof, should_benchmark),
619        SynchronizedDataLoaderPattern(prof, should_benchmark),
620        GradNotSetToNonePattern(prof, should_benchmark),
621        Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark),
622        MatMulDimInFP16Pattern(prof, should_benchmark),
623    ]
624    reported = set()
625    summaries = []
626    message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"]
627    message_list.append("Matched Events:")
628
629    for anti_pattern in anti_patterns:
630        matched_events = anti_pattern.matched_events()
631        if not matched_events:
632            continue
633        summaries.append(anti_pattern.summary(matched_events))
634        for event in matched_events:
635            report_msg = anti_pattern.report(event)
636            if report_msg not in reported:
637                message_list.append(report_msg)
638                reported.add(report_msg)
639                src_location, line_no = source_code_location(event).split(":")
640                report_dict.setdefault(src_location, []).append(
641                    {
642                        "line_number": int(line_no),
643                        "name": anti_pattern.name,
644                        "url": anti_pattern.url,
645                        "message": anti_pattern.description,
646                    }
647                )
648
649    if json_report_dir is not None:
650        json_report_path = os.path.join(json_report_dir, "torchtidy_report.json")
651        if os.path.exists(json_report_path):
652            with open(json_report_path) as f:
653                exisiting_report = json.load(f)
654                exisiting_report.update(report_dict)
655                report_dict = exisiting_report
656        with open(json_report_path, "w") as f:
657            json.dump(report_dict, f, indent=4)
658
659    message_list.append("Summary:")
660    message_list += summaries
661    message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}")
662    if print_enable:
663        print("\n".join(message_list))
664