xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/io_adapter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4from typing import (
5    Any,
6    Callable,
7    Mapping,
8    Protocol,
9    runtime_checkable,
10    Sequence,
11    TYPE_CHECKING,
12)
13
14import torch
15import torch.export as torch_export
16from torch.utils import _pytree as pytree
17
18
19if TYPE_CHECKING:
20    import inspect
21
22# TODO(bowbao): Add diagnostics for IO adapters.
23
24
25@runtime_checkable
26class InputAdaptStep(Protocol):
27    """A protocol that defines a step in the input adapting process.
28
29    The input adapting process is a sequence of steps that are applied to the
30    PyTorch model inputs to transform them into the inputs format expected by the
31    exported ONNX model. Each step takes the PyTorch model inputs as arguments and
32    returns the transformed inputs.
33
34    This serves as a base formalized construct for the transformation done to model
35    input signature by any individual component in the exporter.
36    """
37
38    def apply(
39        self,
40        model_args: Sequence[Any],
41        model_kwargs: Mapping[str, Any],
42        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
43    ) -> tuple[Sequence[Any], Mapping[str, Any]]: ...
44
45
46class InputAdapter:
47    """A class that adapts the PyTorch model inputs to exported ONNX model inputs format."""
48
49    def __init__(self, steps: list[InputAdaptStep] | None = None):
50        self._steps = steps or []
51
52    def append_step(self, step: InputAdaptStep) -> None:
53        """Appends a step to the input adapt steps.
54
55        Args:
56            step: The step to append.
57        """
58        self._steps.append(step)
59
60    def apply(
61        self,
62        *model_args,
63        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
64        **model_kwargs,
65    ) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]:
66        """Converts the PyTorch model inputs to exported ONNX model inputs format.
67
68        Args:
69            model_args: The PyTorch model inputs.
70            model: The PyTorch model.
71            model_kwargs: The PyTorch model keyword inputs.
72        Returns:
73            A sequence of tensors converted from PyTorch model inputs.
74        """
75        args: Sequence[Any] = model_args
76        kwargs: Mapping[str, Any] = model_kwargs
77        for step in self._steps:
78            args, kwargs = step.apply(args, kwargs, model=model)
79        assert not kwargs
80        return args
81
82
83@runtime_checkable
84class OutputAdaptStep(Protocol):
85    """A protocol that defines a step in the output adapting process.
86
87    The output adapting process is a sequence of steps that are applied to the
88    PyTorch model outputs to transform them into the outputs format produced by the
89    exported ONNX model. Each step takes the PyTorch model outputs as arguments and
90    returns the transformed outputs.
91
92    This serves as a base formalized construct for the transformation done to model
93    output signature by any individual component in the exporter.
94    """
95
96    def apply(
97        self,
98        model_outputs: Any,
99        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
100    ) -> Any: ...
101
102
103class OutputAdapter:
104    """A class that adapts the PyTorch model outputs to exported ONNX model outputs format."""
105
106    def __init__(self, steps: list[OutputAdaptStep] | None = None):
107        self._steps = steps or []
108
109    def append_step(self, step: OutputAdaptStep) -> None:
110        """Appends a step to the output format steps.
111
112        Args:
113            step: The step to append.
114        """
115        self._steps.append(step)
116
117    def apply(
118        self,
119        model_outputs: Any,
120        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
121    ) -> Sequence[torch.Tensor | int | float | bool | str]:
122        """Converts the PyTorch model outputs to exported ONNX model outputs format.
123
124        Args:
125            model_outputs: The PyTorch model outputs.
126            model: The PyTorch model.
127
128        Returns:
129            PyTorch model outputs in exported ONNX model outputs format.
130        """
131        for step in self._steps:
132            model_outputs = step.apply(model_outputs, model=model)
133        return model_outputs
134
135
136# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
137
138
139def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec:
140    _type = list if spec.type == tuple else spec.type
141    return pytree.TreeSpec(
142        _type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs))
143    )
144
145
146def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec:
147    if spec.type == list and spec.num_children == 1:
148        return spec.children_specs[0]
149    return spec
150
151
152def _assert_identical_pytree_spec(
153    spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str
154) -> None:
155    """Assert the two `TreeSpec` objects are identical.
156
157    Args:
158        spec1: The first `TreeSpec` object.
159        spec2: The second `TreeSpec` object.
160        error_message: The error message to raise if the two `TreeSpec` objects are not
161            identical.
162
163    Raises:
164        ValueError: If the two `TreeSpec` objects are not identical.
165    """
166    # TODO(bowbao): Turn this check into diagnostic. Consider warning instead of error.
167    pass_if_any_checks: Sequence[Callable[[], bool]] = [
168        lambda: spec1 == spec2,
169        # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
170        lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2),
171        # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
172        lambda: _open_top_level_list_if_single_element(spec1) == spec2,
173        lambda: spec1 == _open_top_level_list_if_single_element(spec2),
174    ]
175
176    if not any(check() for check in pass_if_any_checks):
177        raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.")
178
179
180class BindInputStep(InputAdaptStep):
181    """Bind the input arguments to the model signature."""
182
183    def __init__(self, model_signature: inspect.Signature):
184        self._model_signature = model_signature
185
186    def apply(
187        self,
188        model_args: Sequence[Any],
189        model_kwargs: Mapping[str, Any],
190        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
191    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
192        """Bind the input arguments to the model signature.
193
194        We hope the input kwargs will be mapped to bound.args after binding.
195        If not, we will raise an error.
196
197        Args:
198            model_args: The model args.
199            model_kwargs: The model kwargs.
200            model: The PyTorch model.
201
202        Returns:
203            A tuple of the model args and kwargs. args is always empty.
204
205        Raises:
206            ValueError: If there are keyword-only arguments left after binding args and
207                kwargs to model signature.
208        """
209        bound = self._model_signature.bind(*model_args, **model_kwargs)
210        bound.apply_defaults()
211
212        # keyword-only arguments are not handled.
213        # bound.kwargs only contains keyword-only arguments after calling
214        # bind & apply_defaults, so we raise if it's not empty.
215        if bound.kwargs:
216            raise ValueError("Keyword-only arguments are not supported.")
217        return (), bound.arguments
218
219
220class MergeKwargsIntoArgsInputStep(InputAdaptStep):
221    """Merge the input kwargs into the input args."""
222
223    def apply(
224        self,
225        model_args: Sequence[Any],
226        model_kwargs: Mapping[str, Any],
227        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
228    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
229        """Merge the input kwargs into the input args.
230
231        Args:
232            model_args: The model args.
233            model_kwargs: The model kwargs.
234            model: The PyTorch model.
235
236        Returns:
237            A tuple of the model args and kwargs. kwargs is always empty.
238        """
239        return tuple(model_args) + tuple(model_kwargs.values()), {}
240
241
242class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep):
243    """Append parameters and buffers to model's positional argument list."""
244
245    def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None:
246        self.inputs = inputs
247
248    def apply(
249        self,
250        model_args: Sequence[Any],
251        model_kwargs: Mapping[str, Any],
252        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
253    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
254        """Append model's parameters and buffers into its input.
255
256        Args:
257            model_args: The model args.
258            model_kwargs: The model kwargs.
259            model: The PyTorch model.
260
261        Returns:
262            A tuple of the model args + appended inputs and kwargs.
263        """
264        return (*model_args, *self.inputs), model_kwargs
265
266
267class ConvertComplexToRealRepresentationInputStep(InputAdaptStep):
268    """Convert complex dtype tensors to real representation tensors.
269
270    ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
271    to real representation tensors (i.e., float dtype tensors with an extra dimension
272    representing the real and imaginary parts of the complex number).
273
274    """
275
276    def apply(
277        self,
278        model_args: Sequence[Any],
279        model_kwargs: Mapping[str, Any],
280        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
281    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
282        """Convert complex tensors to float tensors.
283
284        Args:
285            model_args: The model args.
286            model_kwargs: The model kwargs.
287            model: The PyTorch model.
288
289        Returns:
290            A tuple of the model args and kwargs.
291        """
292        return (
293            tuple(
294                torch.view_as_real(arg.resolve_conj())
295                if isinstance(arg, torch.Tensor) and arg.is_complex()
296                else arg
297                for arg in model_args
298            ),
299            model_kwargs,
300        )
301
302
303class RemoveNoneInputStep(InputAdaptStep):
304    """Remove `None` from arguments.
305
306    This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args``
307    is flattened, i.e. it does not check `None` inside nested collections.
308    """
309
310    def apply(
311        self,
312        model_args: Sequence[Any],
313        model_kwargs: Mapping[str, Any],
314        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
315    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
316        """Remove `None` from arguments.
317
318        Args:
319            model_args: The model args.
320            model_kwargs: The model kwargs.
321            model: The PyTorch model.
322
323        Returns:
324            A tuple of the model args and kwargs.
325
326        Raises:
327            ValueError: If `model_kwargs` is not empty.
328        """
329        assert not model_kwargs
330        return tuple(arg for arg in model_args if arg is not None), {}
331
332
333class RemoveNonTensorInputStep(InputAdaptStep):
334    """Remove the non-tensor input arguments.
335
336    Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
337
338    Specifically, it does put the input into graph with an empty node, but consumed by no ones.
339    The concrete value is embedded into the graph as a constant arg of a target node. Meta
340    suggests in this case that one should rewrite the model code to make it tensor if the
341    input value is supposed to change at runtime. We might need to further investigate
342    the feasibility of that suggestion.
343
344    For example,
345
346        def func(x, b=1.0):
347            y = x + b
348            z = y.relu()
349            return (y, z)
350
351        x = torch.randn(1, 1, 2, dtype=torch.float32)
352        gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
353
354        # class GraphModule(torch.nn.Module):
355        #     def forward(self, x, b):
356        #         arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
357        #         # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
358        #         add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0);  arg0 = None
359
360        #         # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
361        #         relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
362        #         return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
363
364    Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
365    it's ignored in ONNX graph. Thus, we delete the useless input here.
366
367    """
368
369    def apply(
370        self,
371        model_args: Sequence[Any],
372        model_kwargs: Mapping[str, Any],
373        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
374    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
375        """Remove Constant from arguments.
376
377        Args:
378            model_args: The model args.
379            model_kwargs: The model kwargs.
380            model: The PyTorch model.
381
382        Returns:
383            A tuple of the model args and kwargs.
384
385        Raises:
386            ValueError: If `model_kwargs` is not empty.
387        """
388        assert not model_kwargs
389        return (
390            tuple(
391                arg
392                for arg in model_args
393                if not isinstance(arg, (int, float, bool, str))
394            ),
395            {},
396        )
397
398
399class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep):
400    """Flatten nested collection types and return a flat list of elements.
401
402    ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
403    etc).
404
405    This class stores the `SpecTree` output produced when `adapt` was called the first
406    time. It then validates the `SpecTree` output produced from later `adapt` calls.
407    """
408
409    _spec: pytree.TreeSpec | None = None
410
411    def apply(
412        self,
413        model_args: Sequence[Any],
414        model_kwargs: Mapping[str, Any],
415        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
416    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
417        """Flatten the model args and kwargs and validate the `SpecTree` output.
418
419        Args:
420            model_args: The model args.
421            model_kwargs: The model kwargs.
422            model: The PyTorch model.
423
424        Returns:
425            A tuple of the flattened model args and kwargs. The kwargs is empty, because
426            they are flattened and merged into the args.
427
428        Raises:
429            ValueError: If the `SpecTree` output produced from the current `model_outputs`
430                is not identical to the `SpecTree` output produced from the first
431                `model_outputs` that was passed to this method.
432        """
433        flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs))
434        if self._spec is None:
435            self._spec = spec
436        else:
437            _assert_identical_pytree_spec(
438                self._spec,
439                spec,
440                error_message="Model inputs incompatible with the format that was exported. ",
441            )
442        return flattened_args, {}
443
444
445class FlattenOutputStep(OutputAdaptStep):
446    """Flatten nested collection types and return a flat list of elements.
447
448    ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
449    etc).
450
451    NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such
452    that `SpecTree` can be validate for new model outputs. However, this is not possible
453    currently because we never have access to real PyTorch model outputs during export.
454    Only traced outputs may be available, but they are not an accurate reflection of the
455    original PyTorch model outputs format as they are typically in their own unique format,
456    depending on the tracing strategy.
457    """
458
459    def apply(
460        self,
461        model_outputs: Any,
462        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
463    ) -> Sequence[Any]:
464        """Flatten the model outputs.
465
466        Args:
467            model_outputs: The model outputs to flatten.
468            model: The PyTorch model.
469
470        Returns:
471            A tuple of the flattened model outputs.
472        """
473        return pytree.tree_leaves(model_outputs)
474
475
476class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep):
477    """Convert complex dtype tensors to real representation tensors.
478
479    ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
480    to real representation tensors (i.e., float dtype tensors with an extra dimension
481    representing the real and imaginary parts of the complex number).
482
483    """
484
485    def apply(
486        self,
487        model_outputs: Any,
488        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
489    ) -> Any:
490        """Convert float tensors to complex tensors.
491
492        Args:
493            model_output: The model output.
494            model: The PyTorch model.
495
496        Returns:
497            A tuple of the model output.
498        """
499        return [
500            torch.view_as_real(output.resolve_conj())
501            if isinstance(output, torch.Tensor) and torch.is_complex(output)
502            else output
503            for output in model_outputs
504        ]
505
506
507class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep):
508    """Same as ``FlattenOutputStep``, with additional `TreeSpec` validation.
509
510    This class stores the `SpecTree` output produced when `adapt` was called the first
511    time. It then validates the `SpecTree` output produced from later `adapt` calls.
512    """
513
514    _spec: pytree.TreeSpec | None = None
515
516    def apply(
517        self,
518        model_outputs: Any,
519        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
520    ) -> Sequence[Any]:
521        """Flatten the model outputs and validate the `SpecTree` output.
522
523        Args:
524            model_outputs: The model outputs to flatten.
525            model: The PyTorch model.
526
527        Returns:
528            flattened_outputs: The flattened model outputs.
529
530        Raises:
531            ValueError: If the `SpecTree` output produced from the current `model_outputs`
532                is not identical to the `SpecTree` output produced from the first
533                `model_outputs` that was passed to this method.
534        """
535        flattened_outputs, spec = pytree.tree_flatten(model_outputs)
536        if self._spec is None:
537            self._spec = spec
538        else:
539            _assert_identical_pytree_spec(
540                self._spec,
541                spec,
542                error_message="Model outputs incompatible with the format that was exported. ",
543            )
544        return flattened_outputs
545
546
547class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
548    """Prepend model parameters, buffers and constants to the user input.
549
550    :func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they
551    must be added to the user input before the model is executed.
552
553    Args:
554        model: The PyTorch model with embedded parameters and buffers.
555    """
556
557    def apply(
558        self,
559        model_args: Sequence[Any],
560        model_kwargs: Mapping[str, Any],
561        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
562    ) -> tuple[Sequence[Any], Mapping[str, Any]]:
563        """Convert complex tensors to float tensors.
564
565        Args:
566            model_args: The model args.
567            model_kwargs: The model kwargs.
568            model: The PyTorch model.
569
570        Returns:
571            A tuple of the model args and kwargs.
572        """
573        ordered_params = tuple(
574            model.state_dict[name]  # type: ignore[union-attr,index]
575            for name in model.graph_signature.parameters  # type: ignore[union-attr]
576        )
577        non_persistent_buffers = set(model.graph_signature.non_persistent_buffers)  # type: ignore[union-attr]
578        ordered_buffers = []
579        for name in model.graph_signature.buffers:  # type: ignore[union-attr]
580            if name in non_persistent_buffers:
581                ordered_buffers.append(model.constants[name])  # type: ignore[union-attr]
582            else:
583                ordered_buffers.append(model.state_dict[name])  # type: ignore[union-attr,index]
584        ordered_constant_tensors = tuple(
585            model.constants[fqn]  # type: ignore[union-attr,index]
586            for fqn in model.graph_signature.lifted_tensor_constants  # type: ignore[union-attr]
587        )
588
589        # NOTE: calling convention is first params, then buffers, then args as user supplied them.
590        # See: torch/_functorch/aot_autograd.py#L1034
591        updated_args = (
592            *ordered_params,
593            *ordered_buffers,
594            *ordered_constant_tensors,
595            *model_args,
596        )
597        if model_kwargs:
598            return MergeKwargsIntoArgsInputStep().apply(
599                updated_args, model_kwargs, model=model
600            )
601        return updated_args, {}
602
603
604class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep):
605    """Prepend model's mutated buffers to the user output.
606
607    :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they
608    must be added to the user output after the model is executed.
609
610    Args:
611        model: The PyTorch model with mutated buffers.
612    """
613
614    def apply(
615        self,
616        model_outputs: Any,
617        model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
618    ) -> Sequence[Any]:
619        """Flatten the model outputs and validate the `SpecTree` output.
620
621        Args:
622            model_outputs: The model outputs to flatten.
623            model: The PyTorch model.
624
625        Returns:
626            flattened_outputs: The flattened model outputs.
627        """
628
629        assert isinstance(
630            model, torch_export.ExportedProgram
631        ), "'model' must be torch_export.ExportedProgram"
632        ordered_buffers = tuple(
633            model.state_dict[name]
634            if name in model.state_dict
635            else model.constants[name]
636            for name in model.graph_signature.buffers_to_mutate.values()
637        )
638
639        # NOTE: calling convention is first mutated buffers, then outputs args as model returned them.
640        updated_outputs = (*ordered_buffers, *model_outputs)
641        return updated_outputs
642