xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from collections import defaultdict
4from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
5
6import sympy
7
8import torch
9import torch.fx
10from torch.fx.experimental.symbolic_shapes import (
11    compute_unbacked_bindings,
12    rebind_unbacked,
13    statically_known_true,
14    sym_eq,
15)
16from torch.utils import _pytree as pytree
17from torch.utils._pytree import tree_map
18
19from .virtualized import V
20
21
22# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
23# Works for length 2 patterns with 1 module and 1 function/method.
24def matches_module_function_pattern(
25    pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
26    node: torch.fx.node.Node,
27    modules: Dict[str, torch.nn.modules.Module],
28) -> bool:
29    if len(node.args) == 0:
30        return False
31    if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
32        node, torch.fx.Node
33    ):
34        return False
35    # the first node is call_module
36    if node.args[0].op != "call_module":
37        return False
38    if not isinstance(node.args[0].target, str):
39        return False
40    if node.args[0].target not in modules:
41        return False
42    if type(modules[node.args[0].target]) is not pattern[0]:
43        return False
44    # the second node is call_function or call_method
45    if node.op != "call_function" and node.op != "call_method":
46        return False
47    if node.target != pattern[1]:
48        return False
49    # make sure node.args[0] output is only used by current node.
50    if len(node.args[0].users) > 1:
51        return False
52    return True
53
54
55class FakeTensorUpdater:
56    """
57    The main idea here is that it's difficult to maintain accurate fake
58    tensors (our primary form of metadata) for each node in our graph as we
59    transform it.
60
61    The most reliable way to obtain this information is by rerunning
62    faketensor propagation. However, in general, faketensor propagation is
63    fairly expensive. So, instead we'd like to only rerun faketensor
64    propagation on nodes that have changed.
65
66    In order to detect which nodes have changed, we first hash its node,
67    target, and argument lists (which are immutable in FX).
68
69    Then, whenever we call incremental_update, we check which FX nodes have a
70    new hash, and recompute the faketensor metadata for that node. Then, we
71    continue to recursively compute the faketensors for all users until the
72    fake tensors stop changing.
73    """
74
75    def __init__(self, graph: torch.fx.Graph) -> None:
76        self.processed_hashes = set()
77        self.graph = graph
78
79        for node in self.graph.nodes:
80            self.processed_hashes.add(self.hash_node(node))
81
82    def hash_node(self, node: torch.fx.Node):
83        # todo(chilli): Not a great hash function
84        return (node, node.target, id(node.args), id(node.kwargs))
85
86    def incremental_update(self):
87        processed = set()
88        existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
89        for node in self.graph.nodes:
90            existing_storages[get_node_storage(node)] += 1
91
92        def is_intlist_same(new, old):
93            return statically_known_true(sym_eq(new, old))
94
95        def is_fake_tensor_same(new, old):
96            if type(new) != type(old):
97                return False
98            if isinstance(new, (list, tuple)):
99                if len(new) != len(old):
100                    return False
101                return all(
102                    is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
103                )
104            if new is None:
105                return old is None
106            if not isinstance(new, torch.Tensor):
107                assert isinstance(
108                    new, (torch.SymInt, torch.SymBool, torch.SymFloat)
109                ), f"Unknown type {type(new)} in {self.graph}"
110                return (
111                    new.node.shape_env._maybe_evaluate_static(
112                        sympy.Eq(new.node.expr, old.node.expr)
113                    )
114                    == sympy.true
115                )
116            if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
117                return False
118            if new.layout == torch.strided and (
119                not is_intlist_same(new.stride(), old.stride())
120                or not statically_known_true(
121                    new.storage_offset() == old.storage_offset()
122                )
123            ):
124                return False
125
126            if new.device != old.device:
127                return False
128
129            if get_storage(new) == get_storage(old):
130                return True
131
132            # This is the case where it returns a completely fresh storage that's used nowhere else.
133            if (
134                existing_storages[get_storage(old)] == 1
135                and get_storage(new) not in existing_storages
136            ):
137                return True
138            return False
139
140        def should_process_node(node):
141            # node.target for nodes returning true from this function
142            # are called under fake mode and does not work for inductor
143            # lowerings. We check if the node.target is an aten operator
144            # or operator.getitem which is used when returning multiple
145            # tensors from an op.
146            return node.op == "call_function" and (
147                isinstance(node.target, torch._ops.OpOverload)
148                or node.target == operator.getitem
149            )
150
151        to_process = set()
152        for node in self.graph.nodes:
153            if (
154                self.hash_node(node) in self.processed_hashes
155                and id(node) not in to_process
156            ):
157                continue
158
159            if not should_process_node(node):
160                continue
161
162            is_valid, args, kwargs = get_fake_args_kwargs(node)
163            if not is_valid:
164                continue
165            with V.fake_mode:
166                new_fake_tensor = node.target(*args, **kwargs)
167            if "val" in node.meta and is_fake_tensor_same(
168                new_fake_tensor, node.meta["val"]
169            ):
170                continue
171
172            rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
173
174            node.meta["val"] = new_fake_tensor
175            if (shape_env := V.fake_mode.shape_env) and (
176                symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
177            ):
178                # Refresh the bindings to the new symbols
179                node.meta["unbacked_bindings"] = symbol_to_path
180
181            existing_storages[get_node_storage(node)] += 1
182
183            to_process.update([id(user) for user in node.users])
184
185            self.processed_hashes.add(self.hash_node(node))
186
187
188def get_storage(t: torch.Tensor) -> int:
189    return t.untyped_storage()._cdata
190
191
192def get_node_storage(node: torch.fx.Node) -> Optional[int]:
193    if "val" not in node.meta:
194        return None
195    if not isinstance(node.meta["val"], torch.Tensor):
196        return None
197    if not torch._C._has_storage(node.meta["val"]):
198        return None
199    return get_storage(node.meta["val"])
200
201
202def get_fake(x):
203    if isinstance(x, torch.fx.Node):
204        if "val" not in x.meta:
205            return x
206        return x.meta["val"]
207    return x
208
209
210def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
211    """
212    First value returns a boolean if any of the input nodes don't have a faketensor.
213    """
214    args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
215    if any(
216        isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
217    ):
218        return False, args, kwargs
219    return True, args, kwargs
220
221
222def is_node_realized(node: torch.fx.Node) -> bool:
223    """Returns true if a node is always realized when lowered to inductor IR.
224
225    NOTE: This may return some false negatives. e.g. it doesn't
226    handle buffers realized heuristically during lowering, or
227    buffers realized indirectly through view ops.
228    """
229    from torch._inductor.lowering import fallbacks, needs_realized_inputs
230
231    def is_buffer(node: torch.fx.Node) -> bool:
232        if node.op == "call_function" and node.target is operator.getitem:
233            # For nodes with multiple outputs, we get the fx graph:
234            #     foo = torch.ops.aten.foo(...)
235            #     getitem = foo[0]
236            #     getitem_1 = foo[1]
237            # where we need to check if foo is a fallback kernel
238            return is_buffer(node.args[0])  # type: ignore[arg-type]
239        return node.op in ("placeholder", "output") or node.target in fallbacks
240
241    if is_buffer(node):
242        return True
243
244    def realizes_inputs(node: torch.fx.Node) -> bool:
245        return node.op == "output" or node.target in needs_realized_inputs
246
247    if any(realizes_inputs(user) for user in node.users):
248        return True
249
250    # Otherwise, assume node isn't realized
251    return False
252