1# mypy: allow-untyped-defs 2"""Dispatcher for AtenLib functions from onnx-script.""" 3 4from __future__ import annotations 5 6from typing import Callable 7 8import torch 9import torch._ops 10import torch.fx 11from torch.onnx._internal.fx import registration 12 13 14def _create_onnx_supports_op_overload_table( 15 registry, 16) -> set[torch._ops.OperatorBase | Callable]: 17 """ 18 Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. 19 20 Args: 21 registry (OnnxRegistry): The ONNX registry for PyTorch. 22 23 Returns: 24 A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. 25 """ 26 table: set[torch._ops.OperatorBase | Callable] = set() 27 28 # Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`, 29 # but retrievable via explicit lookup. 30 # https://github.com/pytorch/pytorch/issues/99681 31 # This is a workaround to make sure we register ONNX symbolic functions for these. 32 onnx_supported_aten_lookup_table = [ 33 k.split("::")[1].split(".")[0] 34 for k in registry._all_registered_ops() 35 if k.startswith("aten::") 36 ] 37 38 for op_namespace in (torch.ops.aten, torch.ops.prims): 39 attr_names = dir(op_namespace) 40 if op_namespace is torch.ops.aten: 41 attr_names += onnx_supported_aten_lookup_table 42 for attr_name in attr_names: 43 if not hasattr(op_namespace, attr_name): 44 # torchlib owns some attributes that are not aten ops. 45 continue 46 op_overload_packet = getattr(op_namespace, attr_name) 47 if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket): 48 continue 49 50 for overload_name in op_overload_packet.overloads(): 51 op_overload = getattr(op_overload_packet, overload_name) 52 internal_op_name = registration.OpName.from_qualified_name( 53 qualified_name=op_overload.name() 54 ) 55 # NOTE: If the overload is supported in registry or it's default overload is supported in registry, 56 # we add it to the table. 57 if registry.is_registered_op( 58 namespace=internal_op_name.namespace, 59 op_name=internal_op_name.op_name, 60 overload=internal_op_name.overload, 61 ) or registry.is_registered_op( 62 namespace=internal_op_name.namespace, 63 op_name=internal_op_name.op_name, 64 overload=None, 65 ): 66 # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc 67 # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add". 68 # This is applied to all ops under torch.ops.aten. 69 table.add(op_overload) 70 return table 71 72 73def create_onnx_friendly_decomposition_table( 74 registry, 75) -> dict[torch._ops.OperatorBase, Callable]: 76 """ 77 This function creates a dictionary of op overloads and their decomposition functions 78 for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, 79 its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's 80 built-in aten-to-aten decomposition. 81 82 Args: 83 registry (torch.onnx.OnnxRegistry): The ONNX registry for PyTorch. 84 85 Returns: 86 Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding 87 decomposition functions. 88 """ 89 decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} 90 # Dictionary that maps torch.ops.aten.* to exporter look up key; e.g., 91 # _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add". 92 _ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry) 93 94 # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single 95 # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your 96 # definitions in a single TORCH_LIBRARY block. 97 for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defined] 98 # Skip decomposition into "prim::*" ops (defined in 'torch._refs'), because they 99 # are not generally supported by ONNX. 100 # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX 101 # symbolic function. 102 if ( 103 "torch._refs" in decomp_fn.__module__ 104 or op_overload in _ONNX_SUPPORT_OP_OVERLOADS 105 ): 106 continue 107 decomposition_table[op_overload] = decomp_fn 108 109 # NOTE: There are ops in core ATen and under torch._refs, 110 # that are not decomposed to prim::ops. We need to pick them 111 # back 112 for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items(): 113 if op_overload in _ONNX_SUPPORT_OP_OVERLOADS: 114 continue 115 decomposition_table[op_overload] = decomp_fn 116 return decomposition_table 117