xref: /aosp_15_r20/external/pytorch/torch/_export/passes/lift_constants_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import warnings
4from typing import Any, Dict, List, Union
5
6import torch
7from torch._export.verifier import SpecViolationError
8from torch._guards import detect_fake_mode
9from torch._library.fake_class_registry import FakeScriptObject
10from torch._subclasses.fake_tensor import unset_fake_temporarily
11from torch.export.exported_program import (
12    ArgumentSpec,
13    CustomObjArgument,
14    ExportGraphSignature,
15    InputKind,
16    InputSpec,
17    TensorArgument,
18)
19
20
21class ConstantAttrMap(collections.abc.MutableMapping):
22    """A mapping class that understands how to use module constants (tensors,
23    ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally,
24    but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to
25    the same underlying value (but we guarantee that they will `hash()` to the same value
26    if that's the case).
27    """
28
29    def __init__(self) -> None:
30        # Underlying dict that we use to implement this mapping.
31        self._constant_attrs: Dict[
32            Union[int, torch.Tensor, FakeScriptObject], List[Any]
33        ] = {}
34        # Map from the hash(ScriptObject) to the ScriptObject itself. Used for
35        # APIs like `__iter__` that should look like they're returning the
36        # original ScriptObjects.
37        self._script_object_map: Dict[int, torch.ScriptObject] = {}
38
39    def __getitem__(
40        self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
41    ) -> Any:
42        real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
43        assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject))
44        return self._constant_attrs[real_key]
45
46    def __setitem__(self, key: Union[torch.Tensor, torch.ScriptObject], value):
47        # we shouldn't actually call this, should go to add() instead to handle aliasing
48        raise NotImplementedError(
49            """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead.
50The same key can be mapped to multiple values, for handling constant aliasing."""
51        )
52
53    def add(
54        self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], value: Any
55    ) -> None:
56        if isinstance(key, torch.ScriptObject):
57            if hash(key) not in self._constant_attrs:
58                self._constant_attrs[hash(key)] = []
59            self._constant_attrs[hash(key)].append(value)
60            self._script_object_map[hash(key)] = key
61        elif isinstance(key, (torch.Tensor, FakeScriptObject)):
62            if key not in self._constant_attrs:
63                self._constant_attrs[key] = []
64            self._constant_attrs[key].append(value)
65        else:
66            raise TypeError(
67                f"Expected key to be a tensor or ScriptObject, got {type(key)}"
68            )
69
70    def __delitem__(self, key):
71        real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
72
73        del self._constant_attrs[real_key]
74
75    def __iter__(self):
76        for key in self._constant_attrs:
77            if isinstance(key, int):
78                yield self._script_object_map[key]
79            else:
80                yield key
81
82    def __len__(self):
83        return len(self._constant_attrs)
84
85    def __contains__(self, key: object) -> bool:
86        real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
87        return real_key in self._constant_attrs
88
89
90def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str:
91    # The FQN of the constant tensor in the state dict should
92    # correspond to the module where the constant tensor was
93    # originally used.
94    if len(node.meta["nn_module_stack"]) == 0:
95        return constant_name
96    parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0]
97    if len(parent_fqn) > 0:
98        return f"{parent_fqn}.{constant_name}"
99    else:
100        return constant_name
101
102
103def _get_first_fqn(
104    const_attrs: ConstantAttrMap,
105    key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],
106) -> Any:
107    fqns = const_attrs.get(key)
108    return fqns[0] if fqns else None
109
110
111def lift_constants_pass(
112    gm: torch.fx.GraphModule,
113    graph_signature: ExportGraphSignature,
114    constant_attrs: ConstantAttrMap,
115) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
116    """
117    Takes a graph module, graph signature, and modifies them implace to lift any
118    constants (tensors or custom classes) as inputs to the graph. Returns a
119    dictionary of names to constants.
120
121    Arguments:
122        gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift.
123        graph_signature (ExportGraphSignature): This graph signature will be
124            mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs.
125        constant_attrs (ConstantAttr): A mapping from a constant value to its
126            fully-qualified path in `gm`. This is used to maintain consistent
127            location of constants between the original module and the exported
128            version.
129
130    Returns:
131        A dictionary of fqn => constant value.
132    """
133    all_constants: Dict[
134        str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
135    ] = {}
136
137    inputs = graph_signature.input_specs
138    num_custom_obj = sum(
139        input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs
140    )
141    num_tensor_constants = sum(
142        input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs
143    )
144
145    fake_mode = detect_fake_mode(
146        tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
147    )
148
149    first_user_input_loc, first_user_input = 0, None
150    for node in gm.graph.nodes:
151        if node.op == "placeholder" and node.name in graph_signature.user_inputs:
152            first_user_input = node
153            break
154        first_user_input_loc += 1
155
156    lifted_objs = ConstantAttrMap()
157    for node in gm.graph.nodes:
158        if node.op == "get_attr":
159            constant_val = getattr(gm, node.target)
160            if constant_val in lifted_objs:
161                # We already lifted this constant elsewhere. Just rewrite uses
162                # of this get_attr to point to the already-existing placeholder
163                # node.
164                const_placeholder_node = _get_first_fqn(lifted_objs, constant_val)
165                node.replace_all_uses_with(const_placeholder_node)
166                gm.graph.erase_node(node)
167                continue
168
169            # For ScriptObject, Tensor and FakeScriptObject constants:
170            # First check if the constant was an attribute on some module by
171            # consulting `constant_attrs` map. If it is, use the fqn that keeps
172            # its location consistent with the eager module.
173            #
174            # If it's not in the `constant_attrs` map, that means it's an inline
175            # constant (e.g. x + torch.tensor(0)), and thus did not have a
176            # specific location in the eager module. In that case, just generate
177            # some name and attach it to the module in which it was used.
178            if isinstance(constant_val, (torch.ScriptObject, FakeScriptObject)):
179                constant_kind = InputKind.CUSTOM_OBJ
180                constant_fqn = _get_first_fqn(constant_attrs, constant_val)
181                if constant_fqn is not None:
182                    constant_name = constant_fqn.replace(".", "_")
183                else:
184                    constant_name = f"lifted_custom_{num_custom_obj}"
185                    constant_fqn = get_constant_fqn(node, constant_name)
186                    num_custom_obj += 1
187            elif isinstance(constant_val, torch.Tensor):
188                # Remove the parameterness of constant_val
189                if isinstance(constant_val, torch.nn.Parameter):
190                    warnings.warn(
191                        f"{node.target} created when tracing {node.meta['stack_trace']} is a parameter. But"
192                        f"it's not registered with register_parameter(). export will treat it as a constant tensor"
193                    )
194                    # We get the real data out of the parameter by disabling the surrounding fake mode.
195                    with unset_fake_temporarily():
196                        constant_val = constant_val.data
197                constant_kind = InputKind.CONSTANT_TENSOR
198                constant_fqn = _get_first_fqn(constant_attrs, constant_val)
199                if constant_fqn is not None:
200                    constant_name = constant_fqn.replace(".", "_")
201                else:
202                    constant_name = f"lifted_tensor_{num_tensor_constants}"
203                    constant_fqn = get_constant_fqn(node, constant_name)
204                    num_tensor_constants += 1
205            elif isinstance(constant_val, torch.fx.GraphModule):
206                continue
207            elif "LoweredBackendModule" in type(constant_val).__name__:
208                continue
209            else:
210                raise SpecViolationError(
211                    f"getattr node {node} referencing unsupported type {type(constant_val)}"
212                )
213
214            with gm.graph.inserting_before(first_user_input):
215                # Insert the constant node before the first user input
216                const_placeholder_node = gm.graph.placeholder(constant_name)
217                # match target name with its node name in case there is name collision
218                # and suffix is added to node name in fx
219                const_placeholder_node.target = const_placeholder_node.name
220
221                for k, v in node.meta.items():
222                    const_placeholder_node.meta[k] = v
223
224                # Once the FQN has been used, remove nn_module_stack, stack_trace
225                const_placeholder_node.meta.pop("nn_module_stack")
226                const_placeholder_node.meta.pop("stack_trace", None)
227
228                input_spec_arg: ArgumentSpec
229                if isinstance(constant_val, torch.Tensor):
230                    if fake_mode is not None:
231                        const_placeholder_node.meta["val"] = fake_mode.from_tensor(
232                            constant_val, static_shapes=True
233                        )
234                        const_placeholder_node.meta["val"].constant = constant_val
235                    else:
236                        const_placeholder_node.meta["val"] = constant_val
237                    input_spec_arg = TensorArgument(name=const_placeholder_node.name)
238                elif isinstance(constant_val, torch._C.ScriptObject):
239                    class_fqn = constant_val._type().qualified_name()  # type: ignore[attr-defined]
240                    const_placeholder_node.meta["val"] = CustomObjArgument(
241                        constant_fqn, class_fqn
242                    )
243                    input_spec_arg = CustomObjArgument(
244                        name=const_placeholder_node.name, class_fqn=class_fqn
245                    )
246                elif isinstance(constant_val, FakeScriptObject):
247                    class_fqn = constant_val.script_class_name
248                    const_placeholder_node.meta["val"] = CustomObjArgument(
249                        constant_fqn, class_fqn, constant_val
250                    )
251                    input_spec_arg = CustomObjArgument(
252                        name=const_placeholder_node.name,
253                        class_fqn=class_fqn,
254                        fake_val=constant_val,
255                    )
256                else:
257                    raise SpecViolationError(
258                        f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}"
259                    )
260
261                lifted_objs.add(constant_val, const_placeholder_node)
262                node.replace_all_uses_with(const_placeholder_node)
263                gm.graph.erase_node(node)
264
265                # Add the constant as a buffer to the graph signature
266                graph_signature.input_specs.insert(
267                    first_user_input_loc,
268                    InputSpec(
269                        kind=constant_kind,
270                        arg=input_spec_arg,
271                        target=constant_fqn,
272                    ),
273                )
274                if constant_val in constant_attrs:
275                    for fqn in constant_attrs[constant_val]:
276                        all_constants[fqn] = constant_val
277                else:
278                    all_constants[constant_fqn] = constant_val
279                first_user_input_loc += 1
280
281    return all_constants
282
283
284def rewrite_script_object_meta(
285    gm: torch.fx.GraphModule,
286) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
287    """When tracing, we produce a graph with FakeScriptObject in the
288    meta["val"].
289
290    For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
291    """
292    constants: Dict[
293        str,
294        Union[
295            torch.Tensor,
296            torch.ScriptObject,
297            FakeScriptObject,
298        ],
299    ] = {}
300    for node in gm.graph.nodes:
301        if "val" not in node.meta:
302            continue
303
304        old_meta = node.meta["val"]
305
306        if isinstance(old_meta, torch.ScriptObject):
307            class_fqn = old_meta._type().qualified_name()  # type: ignore[attr-defined]
308            new_meta = CustomObjArgument(node.name, class_fqn)
309            constants[node.name] = old_meta
310            node.meta["val"] = new_meta
311
312        elif isinstance(old_meta, FakeScriptObject):
313            class_fqn = old_meta.script_class_name  # type: ignore[attr-defined]
314            new_meta = CustomObjArgument(node.name, class_fqn, old_meta)
315            constants[node.name] = old_meta
316            node.meta["val"] = new_meta
317
318    return constants
319