xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/decomposition_table.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Dispatcher for AtenLib functions from onnx-script."""
3
4from __future__ import annotations
5
6from typing import Callable
7
8import torch
9import torch._ops
10import torch.fx
11from torch.onnx._internal.fx import registration
12
13
14def _create_onnx_supports_op_overload_table(
15    registry,
16) -> set[torch._ops.OperatorBase | Callable]:
17    """
18    Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations.
19
20    Args:
21        registry (OnnxRegistry): The ONNX registry for PyTorch.
22
23    Returns:
24        A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations.
25    """
26    table: set[torch._ops.OperatorBase | Callable] = set()
27
28    # Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`,
29    # but retrievable via explicit lookup.
30    # https://github.com/pytorch/pytorch/issues/99681
31    # This is a workaround to make sure we register ONNX symbolic functions for these.
32    onnx_supported_aten_lookup_table = [
33        k.split("::")[1].split(".")[0]
34        for k in registry._all_registered_ops()
35        if k.startswith("aten::")
36    ]
37
38    for op_namespace in (torch.ops.aten, torch.ops.prims):
39        attr_names = dir(op_namespace)
40        if op_namespace is torch.ops.aten:
41            attr_names += onnx_supported_aten_lookup_table
42        for attr_name in attr_names:
43            if not hasattr(op_namespace, attr_name):
44                # torchlib owns some attributes that are not aten ops.
45                continue
46            op_overload_packet = getattr(op_namespace, attr_name)
47            if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket):
48                continue
49
50            for overload_name in op_overload_packet.overloads():
51                op_overload = getattr(op_overload_packet, overload_name)
52                internal_op_name = registration.OpName.from_qualified_name(
53                    qualified_name=op_overload.name()
54                )
55                # NOTE: If the overload is supported in registry or it's default overload is supported in registry,
56                # we add it to the table.
57                if registry.is_registered_op(
58                    namespace=internal_op_name.namespace,
59                    op_name=internal_op_name.op_name,
60                    overload=internal_op_name.overload,
61                ) or registry.is_registered_op(
62                    namespace=internal_op_name.namespace,
63                    op_name=internal_op_name.op_name,
64                    overload=None,
65                ):
66                    # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc
67                    # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add".
68                    # This is applied to all ops under torch.ops.aten.
69                    table.add(op_overload)
70    return table
71
72
73def create_onnx_friendly_decomposition_table(
74    registry,
75) -> dict[torch._ops.OperatorBase, Callable]:
76    """
77    This function creates a dictionary of op overloads and their decomposition functions
78    for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function,
79    its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's
80    built-in aten-to-aten decomposition.
81
82    Args:
83        registry (torch.onnx.OnnxRegistry): The ONNX registry for PyTorch.
84
85    Returns:
86        Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
87        decomposition functions.
88    """
89    decomposition_table: dict[torch._ops.OperatorBase, Callable] = {}
90    # Dictionary that maps torch.ops.aten.* to exporter look up key; e.g.,
91    # _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add".
92    _ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry)
93
94    # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single
95    # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your
96    # definitions in a single TORCH_LIBRARY block.
97    for op_overload, decomp_fn in torch._decomp.decomposition_table.items():  # type: ignore[attr-defined]
98        # Skip decomposition into "prim::*" ops (defined in 'torch._refs'), because they
99        # are not generally supported by ONNX.
100        # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX
101        # symbolic function.
102        if (
103            "torch._refs" in decomp_fn.__module__
104            or op_overload in _ONNX_SUPPORT_OP_OVERLOADS
105        ):
106            continue
107        decomposition_table[op_overload] = decomp_fn
108
109    # NOTE: There are ops in core ATen and under torch._refs,
110    # that are not decomposed to prim::ops. We need to pick them
111    # back
112    for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items():
113        if op_overload in _ONNX_SUPPORT_OP_OVERLOADS:
114            continue
115        decomposition_table[op_overload] = decomp_fn
116    return decomposition_table
117