xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/decomposition_skip.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""A context manager that disables the decomposition of certain ops during dynamo tracing.
3
4The approach is to temporarily hijack the operator callable with PT2 custom operator.
5The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.
6
7For the time being the decomposition of these ops is otherwise unavoidable.
8
9https://github.com/pytorch/pytorch/issues/116684
10https://github.com/pytorch/pytorch/issues/115883
11
12This solution will no longer be required once the issue is resolved.
13"""
14
15from __future__ import annotations
16
17import abc
18import contextlib
19from typing import Callable, Sequence
20
21from onnxscript.function_libs.torch_lib.ops import (  # type: ignore[import-not-found]
22    core as torchlib_core,
23    nn as torchlib_nn,
24)
25
26import torch
27from torch._decomp import decompositions
28
29
30_NEW_OP_NAMESPACE: str = "onnx_export"
31"""The namespace for the custom operator."""
32
33
34class DecompSkip(abc.ABC):
35    op_callable: Callable
36    """The original operator callable to skip decomposition."""
37    onnxscript_function: Callable
38    """The ONNXScript function to be registered for exporting the custom operator."""
39
40    new_op_name: str
41    """The name for the custom operator."""
42    new_op_schema: str
43    """The schema for the custom operator. This should match with the signature of the original operator."""
44
45    @classmethod
46    @abc.abstractmethod
47    def register(cls, export_options: torch.onnx.ExportOptions):
48        """Registers the custom operator and overrides the original operator.
49
50        It should do the following steps in order:
51
52        1. Register the custom operator.
53        2. Override the original operator with the replacement callable.
54        3. Register the ONNXScript function for exporting the custom operator.
55        """
56        ...
57
58    @classmethod
59    @abc.abstractmethod
60    def unregister(cls):
61        """Restores the original operator callable."""
62        ...
63
64    @classmethod
65    @abc.abstractmethod
66    def abstract(cls, *args, **kwargs):
67        """An abstract impl (meta kernel) for the operator."""
68        ...
69
70    @classmethod
71    def register_custom_op(cls):
72        """Registers the custom operator."""
73        new_op_qualname = f"{_NEW_OP_NAMESPACE}::{cls.new_op_name}"
74        torch.library.define(new_op_qualname, cls.new_op_schema)
75        torch.library.impl(new_op_qualname, "default", cls.replacement)
76        torch.library.register_fake(new_op_qualname, cls.abstract)
77
78    @classmethod
79    def replacement(cls, *args, **kwargs):
80        """A replacement callable for the operator to be hijacked.
81
82        This has the same signature and eager behavior as the original operator.
83        """
84        return cls.op_callable(*args, **kwargs)
85
86
87class UpsampleBilinear2DDecompSkip(DecompSkip):
88    op_callable = torch._C._nn.upsample_bilinear2d  # type: ignore[attr-defined]
89    onnxscript_function = torchlib_nn.aten_upsample_bilinear2d_vec  # type: ignore[attr-defined]
90    new_op_name = "upsample_bilinear2d"
91    new_op_schema = "(Tensor self, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)"
92
93    @classmethod
94    def register(cls, export_options: torch.onnx.ExportOptions):
95        if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
96            torch.ops.onnx_export, cls.new_op_name
97        ):
98            cls.register_custom_op()
99        torch._C._nn.upsample_bilinear2d = torch.ops.onnx_export.upsample_bilinear2d  # type: ignore[attr-defined]
100        if export_options.onnx_registry is None:
101            export_options.onnx_registry = torch.onnx.OnnxRegistry()
102        registry = export_options.onnx_registry
103        registry.register_op(
104            function=cls.onnxscript_function,
105            namespace=_NEW_OP_NAMESPACE,
106            op_name=cls.new_op_name,
107        )
108
109    @classmethod
110    def unregister(cls):
111        torch._C._nn.upsample_bilinear2d = cls.op_callable  # type: ignore[attr-defined]
112
113    @classmethod
114    def abstract(cls, input, output_size, align_corners, scale_factors):
115        osize = decompositions.upsample_compute_output_size(
116            input.size(), output_size, scale_factors
117        )
118        return torch.empty(
119            (input.size(0), input.size(1), *osize),
120            dtype=input.dtype,
121            device=input.device,
122        )
123
124
125class UpsampleTrilinear3DDecompSkip(DecompSkip):
126    op_callable = torch._C._nn.upsample_trilinear3d  # type: ignore[attr-defined]
127    onnxscript_function = torchlib_nn.aten_upsample_trilinear3d_vec  # type: ignore[attr-defined]
128    new_op_name = "upsample_trilinear3d"
129    new_op_schema = "(Tensor self, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)"
130
131    @classmethod
132    def register(cls, export_options: torch.onnx.ExportOptions):
133        if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
134            torch.ops.onnx_export, cls.new_op_name
135        ):
136            cls.register_custom_op()
137        torch._C._nn.upsample_trilinear3d = torch.ops.onnx_export.upsample_trilinear3d  # type: ignore[attr-defined]
138        if export_options.onnx_registry is None:
139            export_options.onnx_registry = torch.onnx.OnnxRegistry()
140        registry = export_options.onnx_registry
141        registry.register_op(
142            function=cls.onnxscript_function,
143            namespace=_NEW_OP_NAMESPACE,
144            op_name=cls.new_op_name,
145        )
146
147    @classmethod
148    def unregister(cls):
149        torch._C._nn.upsample_trilinear3d = cls.op_callable  # type: ignore[attr-defined]
150
151    @classmethod
152    def abstract(cls, input, output_size, align_corners, scale_factors):
153        osize = decompositions.upsample_compute_output_size(
154            input.size(), output_size, scale_factors
155        )
156        return torch.empty(
157            (input.size(0), input.size(1), input.size(2), *osize),
158            dtype=input.dtype,
159            device=input.device,
160        )
161
162
163class InstanceNormDecompSkip(DecompSkip):
164    op_callable = torch.instance_norm  # type: ignore[attr-defined]
165    onnxscript_function = torchlib_core.aten_instance_norm  # type: ignore[attr-defined]
166    new_op_name = "instance_norm"
167    new_op_schema = (
168        "(Tensor input, Tensor? weight, Tensor? bias, "
169        "Tensor? running_mean, Tensor? running_var, "
170        "bool use_input_stats, float momentum, float eps, "
171        "bool cudnn_enabled) -> Tensor"
172    )
173
174    @classmethod
175    def register(cls, export_options: torch.onnx.ExportOptions):
176        if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
177            torch.ops.onnx_export, cls.new_op_name
178        ):
179            cls.register_custom_op()
180
181        torch.instance_norm = torch.ops.onnx_export.instance_norm  # type: ignore[attr-defined]
182        if export_options.onnx_registry is None:
183            export_options.onnx_registry = torch.onnx.OnnxRegistry()
184        registry = export_options.onnx_registry
185        registry.register_op(
186            function=cls.onnxscript_function,
187            namespace=_NEW_OP_NAMESPACE,
188            op_name=cls.new_op_name,
189        )
190
191    @classmethod
192    def unregister(cls):
193        torch.instance_norm = cls.op_callable  # type: ignore[attr-defined]
194
195    @classmethod
196    def abstract(
197        cls,
198        input,
199        weight,
200        bias,
201        running_mean,
202        running_var,
203        use_input_stats: bool,
204        momentum: float,
205        eps: float,
206        cudnn_enabled: bool,
207    ):
208        return torch.empty(
209            input.size(),
210            dtype=input.dtype,
211            device=input.device,
212        )
213
214
215_DEFAULT_SKIP_LIST = [
216    UpsampleBilinear2DDecompSkip,
217    InstanceNormDecompSkip,
218    UpsampleTrilinear3DDecompSkip,
219]
220
221
222@contextlib.contextmanager
223def enable_decomposition_skips(
224    export_options: torch.onnx.ExportOptions,
225    skips: Sequence[type[DecompSkip]] = _DEFAULT_SKIP_LIST,
226):
227    """A context manager that enables the decomposition skips.
228
229    The original operator callables that are otherwise decomposed are replaced with custom operators.
230    The ONNXScript functions for exporting the custom operators are added to the ONNX registry inside export_options.
231    """
232    try:
233        for skip in skips:
234            skip.register(export_options)
235        yield
236    finally:
237        for skip in skips:
238            skip.unregister()
239