xref: /aosp_15_r20/external/pytorch/torch/export/graph_signature.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3from enum import auto, Enum
4from typing import Collection, Dict, List, Mapping, Optional, Set, TYPE_CHECKING, Union
5
6from torch._library.fake_class_registry import FakeScriptObject
7
8
9if TYPE_CHECKING:
10    import torch
11    from torch._functorch._aot_autograd.schemas import GraphSignature
12
13__all__ = [
14    "ConstantArgument",
15    "CustomObjArgument",
16    "ExportBackwardSignature",
17    "ExportGraphSignature",
18    "InputKind",
19    "InputSpec",
20    "OutputKind",
21    "OutputSpec",
22    "SymIntArgument",
23    "TensorArgument",
24]
25
26
27@dataclasses.dataclass
28class TensorArgument:
29    name: str
30
31
32@dataclasses.dataclass
33class TokenArgument:
34    name: str
35
36
37@dataclasses.dataclass
38class SymIntArgument:
39    name: str
40
41
42@dataclasses.dataclass
43class CustomObjArgument:
44    name: str
45    class_fqn: str
46    fake_val: Optional[FakeScriptObject] = None
47
48
49@dataclasses.dataclass
50class ConstantArgument:
51    name: str
52    value: Union[int, float, bool, str, None]
53
54
55ArgumentSpec = Union[
56    TensorArgument,
57    SymIntArgument,
58    ConstantArgument,
59    CustomObjArgument,
60    TokenArgument,
61]
62
63
64class InputKind(Enum):
65    USER_INPUT = auto()
66    PARAMETER = auto()
67    BUFFER = auto()
68    CONSTANT_TENSOR = auto()
69    CUSTOM_OBJ = auto()
70    TOKEN = auto()
71
72
73@dataclasses.dataclass
74class InputSpec:
75    kind: InputKind
76    arg: ArgumentSpec
77    target: Optional[str]
78    persistent: Optional[bool] = None
79
80    def __post_init__(self):
81        if self.kind == InputKind.BUFFER:
82            assert (
83                self.persistent is not None
84            ), "Failed to specify persistent flag on BUFFER."
85        assert isinstance(
86            self.arg,
87            (
88                TensorArgument,
89                SymIntArgument,
90                ConstantArgument,
91                CustomObjArgument,
92                TokenArgument,
93            ),
94        ), f"got {type(self.arg)}"
95
96
97class OutputKind(Enum):
98    USER_OUTPUT = auto()
99    LOSS_OUTPUT = auto()
100    BUFFER_MUTATION = auto()
101    GRADIENT_TO_PARAMETER = auto()
102    GRADIENT_TO_USER_INPUT = auto()
103    USER_INPUT_MUTATION = auto()
104    TOKEN = auto()
105
106
107@dataclasses.dataclass
108class OutputSpec:
109    kind: OutputKind
110    arg: ArgumentSpec
111    target: Optional[str]
112
113    def __post_init__(self):
114        assert isinstance(
115            self.arg,
116            (
117                TensorArgument,
118                SymIntArgument,
119                ConstantArgument,
120                TokenArgument,
121                CustomObjArgument,
122            ),
123        ), self.arg
124
125
126@dataclasses.dataclass
127class ExportBackwardSignature:
128    gradients_to_parameters: Dict[str, str]
129    gradients_to_user_inputs: Dict[str, str]
130    loss_output: str
131
132
133@dataclasses.dataclass
134class ExportGraphSignature:
135    """
136    :class:`ExportGraphSignature` models the input/output signature of Export Graph,
137    which is a fx.Graph with stronger invariants gurantees.
138
139    Export Graph is functional and does not access "states" like parameters
140    or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
141    gurantees that parameters, buffers, and constant tensors are lifted out of
142    the graph as inputs.  Similarly, any mutations to buffers are not included
143    in the graph either, instead the updated values of mutated buffers are
144    modeled as additional outputs of Export Graph.
145
146    The ordering of all inputs and outputs are::
147
148        Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
149        Outputs = [*mutated_inputs, *flattened_user_outputs]
150
151    e.g. If following module is exported::
152
153        class CustomModule(nn.Module):
154            def __init__(self) -> None:
155                super(CustomModule, self).__init__()
156
157                # Define a parameter
158                self.my_parameter = nn.Parameter(torch.tensor(2.0))
159
160                # Define two buffers
161                self.register_buffer('my_buffer1', torch.tensor(3.0))
162                self.register_buffer('my_buffer2', torch.tensor(4.0))
163
164            def forward(self, x1, x2):
165                # Use the parameter, buffers, and both inputs in the forward method
166                output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
167
168                # Mutate one of the buffers (e.g., increment it by 1)
169                self.my_buffer2.add_(1.0) # In-place addition
170
171                return output
172
173    Resulting Graph would be::
174
175        graph():
176            %arg0_1 := placeholder[target=arg0_1]
177            %arg1_1 := placeholder[target=arg1_1]
178            %arg2_1 := placeholder[target=arg2_1]
179            %arg3_1 := placeholder[target=arg3_1]
180            %arg4_1 := placeholder[target=arg4_1]
181            %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
182            %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
183            %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
184            %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
185            %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
186            return (add_tensor_2, add_tensor_1)
187
188    Resulting ExportGraphSignature would be::
189
190        ExportGraphSignature(
191            input_specs=[
192                InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
193                InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
194                InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
195                InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
196                InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
197            ],
198            output_specs=[
199                OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
200                OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
201            ]
202        )
203    """
204
205    input_specs: List[InputSpec]
206    output_specs: List[OutputSpec]
207
208    # A list of parameters uniquely identified by mangled fully qualified name
209    @property
210    def parameters(self) -> Collection[str]:
211        return tuple(
212            s.target
213            for s in self.input_specs
214            if s.kind == InputKind.PARAMETER
215            if isinstance(s.target, str)
216        )
217
218    # A list of buffers uniquely identified by mangled fully qualified name
219    @property
220    def buffers(self) -> Collection[str]:
221        return tuple(
222            s.target
223            for s in self.input_specs
224            if s.kind == InputKind.BUFFER
225            if isinstance(s.target, str)
226        )
227
228    @property
229    def non_persistent_buffers(self) -> Collection[str]:
230        return tuple(
231            s.target
232            for s in self.input_specs
233            if s.kind == InputKind.BUFFER
234            if s.persistent is False
235            if isinstance(s.target, str)
236        )
237
238    # A list of lifted constant tensors
239    @property
240    def lifted_tensor_constants(self) -> Collection[str]:
241        return tuple(
242            s.target
243            for s in self.input_specs
244            if s.kind == InputKind.CONSTANT_TENSOR
245            if isinstance(s.target, str)
246        )
247
248    @property
249    def lifted_custom_objs(self) -> Collection[str]:
250        return tuple(
251            s.target
252            for s in self.input_specs
253            if s.kind == InputKind.CUSTOM_OBJ
254            if isinstance(s.target, str)
255        )
256
257    # Graph node names of pytree-flattened inputs of original program
258    @property
259    def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]:
260        user_inputs: List[Union[int, float, bool, None, str]] = []
261        for s in self.input_specs:
262            if s.kind != InputKind.USER_INPUT:
263                continue
264
265            if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)):
266                user_inputs.append(s.arg.name)
267            elif isinstance(s.arg, ConstantArgument):
268                user_inputs.append(s.arg.value)
269            else:
270                raise RuntimeError(f"{s.arg} is not a valid user inputs")
271        return tuple(user_inputs)
272
273    # Graph node names of pytree-flattened outputs of original program
274    @property
275    def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]:
276        user_outputs: List[Union[int, float, bool, None, str]] = []
277        for s in self.output_specs:
278            if s.kind != OutputKind.USER_OUTPUT:
279                continue
280
281            if isinstance(s.arg, (TensorArgument, SymIntArgument)):
282                user_outputs.append(s.arg.name)
283            elif isinstance(s.arg, ConstantArgument):
284                user_outputs.append(s.arg.value)
285            elif isinstance(s.arg, CustomObjArgument):
286                user_outputs.append(s.arg.name)
287            else:
288                raise RuntimeError(f"{s.arg} is not a valid user output")
289        return tuple(user_outputs)
290
291    # A dictionary mapping graph input node names to parameters. If a graph input
292    # name is found in this dictionary, it is guranteed to be a lifted parameter.
293    @property
294    def inputs_to_parameters(self) -> Mapping[str, str]:
295        return _immutable_dict(
296            (s.arg.name, s.target)
297            for s in self.input_specs
298            if s.kind == InputKind.PARAMETER
299            and isinstance(s.arg, TensorArgument)
300            and isinstance(s.target, str)
301        )
302
303    # A dictionary mapping graph input node names to buffers. If a graph input
304    # name is found in this dictionary, it is guranteed to be a lifted buffer.
305    @property
306    def inputs_to_buffers(self) -> Mapping[str, str]:
307        return _immutable_dict(
308            (s.arg.name, s.target)  # type: ignore[union-attr, misc]
309            for s in self.input_specs
310            if s.kind == InputKind.BUFFER
311            and isinstance(s.arg, TensorArgument)
312            and isinstance(s.target, str)
313        )
314
315    # A dictionary mapping graph output node names to buffers that are mutated in the
316    # original program. Buffers that are not mutated will not be found in this dictionary.
317    @property
318    def buffers_to_mutate(self) -> Mapping[str, str]:
319        return _immutable_dict(
320            (s.arg.name, s.target)
321            for s in self.output_specs
322            if s.kind == OutputKind.BUFFER_MUTATION
323            and isinstance(s.arg, TensorArgument)
324            and isinstance(s.target, str)
325        )
326
327    @property
328    def user_inputs_to_mutate(self) -> Mapping[str, str]:
329        return _immutable_dict(
330            (s.arg.name, s.target)
331            for s in self.output_specs
332            if s.kind == OutputKind.USER_INPUT_MUTATION
333            and isinstance(s.arg, TensorArgument)
334            and isinstance(s.target, str)
335        )
336
337    # A dictionary mapping graph input node names to lifted tensor constants.
338    @property
339    def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]:
340        return _immutable_dict(
341            (s.arg.name, s.target)
342            for s in self.input_specs
343            if s.kind == InputKind.CONSTANT_TENSOR
344            and isinstance(s.arg, TensorArgument)
345            and isinstance(s.target, str)
346        )
347
348    @property
349    def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]:
350        return _immutable_dict(
351            (s.arg.name, s.target)
352            for s in self.input_specs
353            if s.kind == InputKind.CUSTOM_OBJ
354            and isinstance(s.arg, CustomObjArgument)
355            and isinstance(s.target, str)
356        )
357
358    @property
359    def backward_signature(self) -> Optional[ExportBackwardSignature]:
360        loss_output = None
361        gradients_to_parameters: Dict[str, str] = {}
362        gradients_to_user_inputs: Dict[str, str] = {}
363        for spec in self.output_specs:
364            if spec.kind == OutputKind.LOSS_OUTPUT:
365                assert loss_output is None
366                assert isinstance(spec.arg, TensorArgument)
367                loss_output = spec.arg.name
368            elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
369                assert isinstance(spec.target, str)
370                assert isinstance(spec.arg, TensorArgument)
371                gradients_to_parameters[spec.arg.name] = spec.target
372            elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT:
373                assert isinstance(spec.target, str)
374                assert isinstance(spec.arg, TensorArgument)
375                gradients_to_user_inputs[spec.arg.name] = spec.target
376
377        if loss_output is None:
378            return None
379
380        return ExportBackwardSignature(
381            loss_output=loss_output,
382            gradients_to_parameters=gradients_to_parameters,
383            gradients_to_user_inputs=gradients_to_user_inputs,
384        )
385
386    # Map from assertion dependency token index to assertion dep token output
387    # name in output. The shape of output after aot_autograd will be like:
388    # (updated_inputs, user_outputs, dep_token).
389    @property
390    def assertion_dep_token(self) -> Optional[Mapping[int, str]]:
391        return None
392
393    @property
394    def input_tokens(self) -> Collection[str]:
395        input_tokens = []
396        for s in self.input_specs:
397            if s.kind == InputKind.TOKEN:
398                assert isinstance(s.arg, TokenArgument)
399                input_tokens.append(s.arg.name)
400        return tuple(input_tokens)
401
402    @property
403    def output_tokens(self) -> Collection[str]:
404        output_tokens = []
405        for s in self.output_specs:
406            if s.kind == OutputKind.TOKEN:
407                assert isinstance(s.arg, TokenArgument)
408                output_tokens.append(s.arg.name)
409        return tuple(output_tokens)
410
411    def __post_init__(self) -> None:
412        assertion_dep_token = self.assertion_dep_token
413        if assertion_dep_token is None:
414            return
415        assert len(assertion_dep_token) == 1
416        assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
417        assert (
418            len(self.user_outputs) + len(self.buffers_to_mutate)
419            == assertion_dep_token_index
420        )
421
422    def replace_all_uses(self, old: str, new: str):
423        """
424        Replace all uses of the old name with new name in the signature.
425        """
426        assert isinstance(old, str)
427        assert isinstance(new, str)
428        arg_types = (TensorArgument, SymIntArgument, CustomObjArgument, TokenArgument)
429        for o in self.output_specs:
430            if isinstance(o.arg, arg_types):
431                if o.arg.name == old:
432                    o.arg.name = new
433        for i in self.input_specs:
434            if isinstance(i.arg, arg_types):
435                if i.arg.name == old:
436                    i.arg.name = new
437
438    def get_replace_hook(self):
439        def _(old, new, user):
440            if user.op in ("output", "input"):
441                self.replace_all_uses(old.name, new)
442
443        return _
444
445
446def _immutable_dict(items):
447    """
448    Creates a mapping where items cannot be added, deleted, or updated.
449    NOTE: The immutability is shallow (like tuple is an immutable collection).
450    """
451    from types import MappingProxyType
452
453    return MappingProxyType(dict(items))
454
455
456def _make_argument_spec(node, token_names) -> ArgumentSpec:
457    from torch import ScriptObject, SymInt
458    from torch._library.fake_class_registry import FakeScriptObject
459    from torch._subclasses.fake_tensor import FakeTensor
460
461    if isinstance(node, (int, bool, float, type(None), str)):
462        # For const outputs we just directly return this
463        return ConstantArgument(name="", value=node)
464
465    assert (
466        "val" in node.meta
467    ), f"{node} is not a constant or a node with a 'val' metadata field"
468    val = node.meta["val"]
469    if node.name in token_names:
470        return TokenArgument(name=node.name)
471    elif isinstance(val, FakeTensor):
472        return TensorArgument(name=node.name)
473    elif isinstance(val, SymInt):
474        return SymIntArgument(name=node.name)
475    elif isinstance(val, ScriptObject):
476        return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name())  # type: ignore[attr-defined]
477    elif isinstance(val, FakeScriptObject):
478        return CustomObjArgument(
479            name=node.name, class_fqn=val.script_class_name, fake_val=val
480        )
481    elif isinstance(val, (int, bool, str, float, type(None))):
482        return ConstantArgument(name=node.name, value=val)
483    else:
484        raise AssertionError(
485            f"Encountered an unsupported object of type {type(val)} "
486            f"while writing the metadata for exported program"
487        )
488
489
490def _convert_to_export_graph_signature(
491    graph_signature: "GraphSignature",
492    gm: "torch.fx.GraphModule",
493    non_persistent_buffers: Set[str],
494) -> "ExportGraphSignature":
495    from torch.utils import _pytree as pytree
496
497    is_joint = graph_signature.backward_signature is not None
498
499    # unpack objects
500    user_inputs = set(graph_signature.user_inputs)
501    inputs_to_parameters = graph_signature.inputs_to_parameters
502    inputs_to_buffers = graph_signature.inputs_to_buffers
503    user_outputs = set(graph_signature.user_outputs)
504    buffer_mutations = graph_signature.buffers_to_mutate
505    user_input_mutations = graph_signature.user_inputs_to_mutate
506    grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {}  # type: ignore[union-attr]
507    grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}  # type: ignore[union-attr]
508    loss_output = graph_signature.backward_signature.loss_output if is_joint else None  # type: ignore[union-attr]
509    input_tokens = graph_signature.input_tokens
510    output_tokens = graph_signature.output_tokens
511
512    inputs = [
513        _make_argument_spec(node, input_tokens)
514        for node in gm.graph.nodes
515        if node.op == "placeholder"
516    ]
517    outputs = [
518        _make_argument_spec(node, output_tokens)
519        for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
520    ]
521
522    def to_input_spec(inp: ArgumentSpec) -> InputSpec:
523        if isinstance(inp, TokenArgument):
524            return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None)
525
526        if not isinstance(inp, TensorArgument):
527            return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
528        name = inp.name
529        if name in user_inputs:
530            return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
531        elif name in inputs_to_parameters:
532            return InputSpec(
533                kind=InputKind.PARAMETER,
534                arg=inp,
535                target=inputs_to_parameters[name],  # type: ignore[index]
536            )
537        elif name in inputs_to_buffers:
538            return InputSpec(
539                kind=InputKind.BUFFER,
540                arg=inp,
541                target=inputs_to_buffers[name],  # type: ignore[index]
542                persistent=(inputs_to_buffers[name] not in non_persistent_buffers),  # type: ignore[index]
543            )
544        else:
545            raise AssertionError(f"Unknown tensor input kind: {name}")
546
547    def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
548        if isinstance(o, TokenArgument):
549            return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None)
550
551        if not isinstance(o, TensorArgument):
552            return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
553        name = o.name
554        if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
555            if name in buffer_mutations:
556                return OutputSpec(
557                    kind=OutputKind.BUFFER_MUTATION,
558                    arg=o,
559                    target=buffer_mutations[name],  # type: ignore[index]
560                )
561            elif name in user_input_mutations:
562                return OutputSpec(
563                    kind=OutputKind.USER_INPUT_MUTATION,
564                    arg=o,
565                    target=user_input_mutations[name],  # type: ignore[index]
566                )
567            else:
568                raise AssertionError(f"Unknown tensor mutation kind: {name}")
569        else:
570            if name in user_outputs:
571                return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
572
573            elif name in grad_params:
574                return OutputSpec(
575                    kind=OutputKind.GRADIENT_TO_PARAMETER,
576                    arg=o,
577                    target=grad_params[name],
578                )
579            elif name in grad_user_inputs:
580                return OutputSpec(
581                    kind=OutputKind.GRADIENT_TO_USER_INPUT,
582                    arg=o,
583                    target=grad_user_inputs[name],
584                )
585            elif name == loss_output:
586                return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
587
588            else:
589                raise AssertionError(f"Unknown tensor output kind: {name}")
590
591    input_specs = [to_input_spec(inp) for inp in inputs]
592    output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
593    return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs)
594