xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/misc_patterns.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from typing import Dict, Set, Tuple
4
5import torch
6from torch._dynamo.utils import counters
7from torch._ops import OpOverload, OpOverloadPacket
8
9from ..pattern_matcher import fwd_only, register_replacement
10
11
12aten = torch.ops.aten
13
14
15@functools.lru_cache(None)
16def _misc_patterns_init():
17    from .joint_graph import patterns as joint_graph_patterns
18    from .post_grad import pass_patterns as post_grad_patterns_all
19
20    post_grad_patterns = post_grad_patterns_all[1]  # medium priority
21
22    if torch.cuda.is_available():
23        # workaround https://github.com/pytorch/pytorch/issues/97894
24        device = "cuda"
25    else:
26        device = "cpu"
27
28    # These patterns do 2 things
29    # 1. Since we know that index is completely unique, we can codegen it using
30    # stores instead of atomic adds, which is quite a bit faster.
31    # 2. Also, since we are guaranteed that they are completely within bounds,
32    # we can use unsafe indexing and skip debug asserts
33    def randperm_index_add_pattern(x, y):
34        index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
35        return torch.index_add(x, dim=0, source=y, index=index), index
36
37    def randperm_index_add_replacement(x, y):
38        index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
39        return (
40            torch.ops.aten._unsafe_index_put(
41                x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False
42            ),
43            index,
44        )
45
46    register_replacement(
47        randperm_index_add_pattern,
48        randperm_index_add_replacement,
49        [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
50        fwd_only,
51        [post_grad_patterns, joint_graph_patterns],
52    )
53
54    def randperm_index_pattern(x, slice_shape):
55        index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
56        return torch.ops.aten.index(x, (index,)), index
57
58    def randperm_index_replacement(x, slice_shape):
59        index = torch.randperm(x.shape[0], device=x.device)[:slice_shape]
60        return torch.ops.aten._unsafe_index(x, (index,)), index
61
62    register_replacement(
63        randperm_index_pattern,
64        randperm_index_replacement,
65        [torch.empty(4, 8, device=device)],
66        fwd_only,
67        [post_grad_patterns, joint_graph_patterns],
68        scalar_workaround={"slice_shape": 42},
69    )
70
71
72class NumpyCompatNormalization:
73    numpy_compat: Dict[str, Tuple[str, ...]] = {
74        "dim": ("axis",),
75        "keepdim": ("keepdims",),
76        "input": ("x", "a", "x1"),
77        "other": ("x2",),
78    }
79    inverse_mapping: Dict[str, str]
80    cache: Dict["torch.fx.graph.Target", Set[str]]
81
82    def __init__(self) -> None:
83        self.cache = {}  # callable -> tuple of replaceable args e.g. ["axis"]
84        self.inverse_mapping = {}
85        for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
86            for numpy_kwarg in numpy_kwargs:
87                assert numpy_kwarg not in self.inverse_mapping
88                self.inverse_mapping[numpy_kwarg] = actual_kwarg
89
90    def __call__(self, graph: torch.fx.Graph):
91        for node in graph.nodes:
92            if node.op != "call_function":
93                continue
94            if isinstance(node.target, (OpOverload, OpOverloadPacket)):
95                # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't.
96                continue
97            kwargs = node.kwargs
98
99            if node.target in self.cache:
100                replaceable_kwargs = self.cache[node.target]
101            else:
102                signatures = torch.fx.operator_schemas.get_signature_for_torch_op(
103                    node.target
104                )
105                signatures = () if signatures is None else signatures
106                replaceable_kwargs = set()
107                for sig in signatures:
108                    for param_name in sig.parameters.keys():
109                        if param_name in self.numpy_compat:
110                            replaceable_kwargs.update(self.numpy_compat[param_name])
111
112                self.cache[node.target] = replaceable_kwargs
113
114            if not replaceable_kwargs:
115                continue
116
117            new_kwargs = {}
118            kwargs_changed = False
119            for k, v in kwargs.items():
120                if k in replaceable_kwargs:
121                    kwargs_changed = True
122                    new_kwargs[self.inverse_mapping[k]] = v
123                else:
124                    new_kwargs[k] = v
125
126            if kwargs_changed:
127                node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs)
128                counters["inductor"]["numpy_compat_normalization"] += 1
129
130
131numpy_compat_normalization = NumpyCompatNormalization()
132