xref: /aosp_15_r20/external/pytorch/torch/fx/passes/operator_support.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import typing as t
4
5import torch
6import torch.fx
7from torch.fx._compatibility import compatibility
8from .shape_prop import TensorMetadata
9from .tools_common import get_node_target, CALLABLE_NODE_OPS
10
11
12__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
13
14# fx.Node.target typename, as returned by `get_node_target()`
15TargetTypeName = str
16
17# Arguments' dtypes for a given node, see `OperatorSupport`
18SupportedArgumentDTypes = t.Optional[
19    t.Tuple[
20        t.Sequence[t.Sequence[torch.dtype]],
21        t.Dict[str, t.Sequence[torch.dtype]],
22    ]
23]
24
25SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
26
27
28@compatibility(is_backward_compatible=False)
29class OperatorSupportBase(abc.ABC):
30    """Interface for determining if a fx.Node is supported by a backend"""
31    @abc.abstractmethod
32    def is_node_supported(
33        self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
34    ) -> bool:
35        raise NotImplementedError
36
37
38@compatibility(is_backward_compatible=False)
39class OperatorSupport(OperatorSupportBase):
40    """
41    `_support_dict` maps node.target typename to supported inputs dtypes.
42
43    node.target typename is retrieved using helper function `get_node_target()`
44
45    If supported inputs dtypes is None, it means any dtype is supported, else
46    we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
47
48    The first tuple ([dtypes], ...) indicates what dtypes are supported for
49    inputs in node.args and the second dict {"name": [dtypes], ...} indicates
50    what dtypes are supported for inputs in node.kwargs.
51
52    For inputs in args, if we don't want to check it, we can put None there,
53    e.g. (None, [torch.float]) indicates that we don't care about the type of
54    the first input in args. And for inputs in kwargs, if not listed, will not
55    be checked.
56    """
57
58    _support_dict: SupportDict
59
60    def __init__(
61        self,
62        support_dict: t.Optional[SupportDict] = None
63    ):
64        self._support_dict = support_dict or {}
65
66    def is_node_supported(
67        self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
68    ) -> bool:
69        """
70        Args:
71            `submodules`: mapping from module name to the module. This can be
72                          retrieved by calling model.named_modules().
73
74            `node`: a Fx node that we want to determine whether it's supported.
75
76        Returns:
77            `is_supported`: whether the arg `node` is supported.
78        """
79        if node.op not in CALLABLE_NODE_OPS:
80            return True
81
82        target = get_node_target(submodules, node)
83
84        # Target not found in _support_dict meaning that we don't support this op at all
85        if target not in self._support_dict:
86            return False
87
88        # The rule for target is None meaning that we accept any dtype
89        if self._support_dict[target] is None:
90            return True
91
92        args_dtypes, kwargs_dtypes = self._support_dict[target]  # type: ignore[misc]
93
94        # Check args dtypes
95        for i, dtypes in enumerate(args_dtypes):
96            if len(node.args) <= i:
97                break
98
99            # None indicates we don't care about the dtype of args[i]
100            if dtypes is None:
101                continue
102
103            # If arg is not a node then we don't check it
104            if not isinstance(node.args[i], torch.fx.Node):
105                continue
106
107            arg_dtype = _get_arg_dtype(node.args[i])  # type: ignore[arg-type]
108            if arg_dtype not in dtypes:
109                return False
110
111        # Check kwargs dtypes
112        for k, dtypes in kwargs_dtypes.items():
113            if k not in node.kwargs:
114                continue
115
116            # If arg is not a node then we don't check it
117            if not isinstance(node.kwargs[k], torch.fx.Node):
118                continue
119
120            kwarg_dtype = _get_arg_dtype(node.kwargs[k])  # type: ignore[arg-type]
121            if kwarg_dtype not in dtypes:
122                return False
123
124        return True
125
126
127# ======================================================================
128# Functional interfaces and utils for defining basic operator support logic
129# and composing them into more complex ones
130# ======================================================================
131
132IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
133
134
135@compatibility(is_backward_compatible=False)
136def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
137    """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
138
139    `IsNodeSupported` has the same call signature as
140    `OperatorSupportBase.is_node_supported`
141    """
142    class FunctionalOperatorSupport(OperatorSupportBase):
143        def is_node_supported(
144                self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
145        ) -> bool:
146            return is_node_supported(submodules, node)
147    return FunctionalOperatorSupport()
148
149
150@compatibility(is_backward_compatible=False)
151def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
152    """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
153    instance by evaluating each input `OperatorSupportBase` instance, and returns False if
154    any of it reports False.
155    """
156    def _chain(submods, node) -> bool:
157        return all(
158            x.is_node_supported(submods, node)
159            for x in op_support
160        )
161    return create_op_support(_chain)
162
163
164@compatibility(is_backward_compatible=False)
165def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
166    """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
167    instance by evaluating each input `OperatorSupportBase` instance, and returns True if
168    any of it reports True.
169    """
170    def _any_chain(submods, node) -> bool:
171        return any(
172            x.is_node_supported(submods, node)
173            for x in op_support
174        )
175    return create_op_support(_any_chain)
176
177
178@compatibility(is_backward_compatible=False)
179class OpSupports:
180    """A set of atomic `OperatorSupportBase` instances that can be combined together
181    to form more complex operator support logic.
182    """
183    @classmethod
184    def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
185        """Report a node as non-supported, if any of its arguments is of dtype"""
186
187        def _decline_if_input_dtype(
188            submodules: t.Mapping[str, torch.nn.Module],
189            node: torch.fx.Node,
190        ) -> bool:
191            for arg in node.all_input_nodes:
192                arg_dtype = _get_arg_dtype(arg)
193                if arg_dtype == dtype:
194                    return False
195            return True
196        return create_op_support(_decline_if_input_dtype)
197
198    @classmethod
199    def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
200        """
201        If a node has a name that is in the disallow set, reported it as non-supported.
202        """
203        def _decline_if_node_in_names(
204            submodules: t.Mapping[str, torch.nn.Module],
205            node: torch.fx.Node,
206        ) -> bool:
207            return node.name not in disallow_set
208        return create_op_support(_decline_if_node_in_names)
209
210
211def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
212    assert isinstance(arg, torch.fx.Node)
213    tensor_meta = arg.meta.get("tensor_meta")  # type: ignore[union-attr]
214    dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
215    return dtype
216