1# mypy: allow-untyped-defs 2import abc 3import typing as t 4 5import torch 6import torch.fx 7from torch.fx._compatibility import compatibility 8from .shape_prop import TensorMetadata 9from .tools_common import get_node_target, CALLABLE_NODE_OPS 10 11 12__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain'] 13 14# fx.Node.target typename, as returned by `get_node_target()` 15TargetTypeName = str 16 17# Arguments' dtypes for a given node, see `OperatorSupport` 18SupportedArgumentDTypes = t.Optional[ 19 t.Tuple[ 20 t.Sequence[t.Sequence[torch.dtype]], 21 t.Dict[str, t.Sequence[torch.dtype]], 22 ] 23] 24 25SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] 26 27 28@compatibility(is_backward_compatible=False) 29class OperatorSupportBase(abc.ABC): 30 """Interface for determining if a fx.Node is supported by a backend""" 31 @abc.abstractmethod 32 def is_node_supported( 33 self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node 34 ) -> bool: 35 raise NotImplementedError 36 37 38@compatibility(is_backward_compatible=False) 39class OperatorSupport(OperatorSupportBase): 40 """ 41 `_support_dict` maps node.target typename to supported inputs dtypes. 42 43 node.target typename is retrieved using helper function `get_node_target()` 44 45 If supported inputs dtypes is None, it means any dtype is supported, else 46 we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). 47 48 The first tuple ([dtypes], ...) indicates what dtypes are supported for 49 inputs in node.args and the second dict {"name": [dtypes], ...} indicates 50 what dtypes are supported for inputs in node.kwargs. 51 52 For inputs in args, if we don't want to check it, we can put None there, 53 e.g. (None, [torch.float]) indicates that we don't care about the type of 54 the first input in args. And for inputs in kwargs, if not listed, will not 55 be checked. 56 """ 57 58 _support_dict: SupportDict 59 60 def __init__( 61 self, 62 support_dict: t.Optional[SupportDict] = None 63 ): 64 self._support_dict = support_dict or {} 65 66 def is_node_supported( 67 self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node 68 ) -> bool: 69 """ 70 Args: 71 `submodules`: mapping from module name to the module. This can be 72 retrieved by calling model.named_modules(). 73 74 `node`: a Fx node that we want to determine whether it's supported. 75 76 Returns: 77 `is_supported`: whether the arg `node` is supported. 78 """ 79 if node.op not in CALLABLE_NODE_OPS: 80 return True 81 82 target = get_node_target(submodules, node) 83 84 # Target not found in _support_dict meaning that we don't support this op at all 85 if target not in self._support_dict: 86 return False 87 88 # The rule for target is None meaning that we accept any dtype 89 if self._support_dict[target] is None: 90 return True 91 92 args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc] 93 94 # Check args dtypes 95 for i, dtypes in enumerate(args_dtypes): 96 if len(node.args) <= i: 97 break 98 99 # None indicates we don't care about the dtype of args[i] 100 if dtypes is None: 101 continue 102 103 # If arg is not a node then we don't check it 104 if not isinstance(node.args[i], torch.fx.Node): 105 continue 106 107 arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type] 108 if arg_dtype not in dtypes: 109 return False 110 111 # Check kwargs dtypes 112 for k, dtypes in kwargs_dtypes.items(): 113 if k not in node.kwargs: 114 continue 115 116 # If arg is not a node then we don't check it 117 if not isinstance(node.kwargs[k], torch.fx.Node): 118 continue 119 120 kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type] 121 if kwarg_dtype not in dtypes: 122 return False 123 124 return True 125 126 127# ====================================================================== 128# Functional interfaces and utils for defining basic operator support logic 129# and composing them into more complex ones 130# ====================================================================== 131 132IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool] 133 134 135@compatibility(is_backward_compatible=False) 136def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: 137 """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance 138 139 `IsNodeSupported` has the same call signature as 140 `OperatorSupportBase.is_node_supported` 141 """ 142 class FunctionalOperatorSupport(OperatorSupportBase): 143 def is_node_supported( 144 self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node 145 ) -> bool: 146 return is_node_supported(submodules, node) 147 return FunctionalOperatorSupport() 148 149 150@compatibility(is_backward_compatible=False) 151def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: 152 """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` 153 instance by evaluating each input `OperatorSupportBase` instance, and returns False if 154 any of it reports False. 155 """ 156 def _chain(submods, node) -> bool: 157 return all( 158 x.is_node_supported(submods, node) 159 for x in op_support 160 ) 161 return create_op_support(_chain) 162 163 164@compatibility(is_backward_compatible=False) 165def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: 166 """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` 167 instance by evaluating each input `OperatorSupportBase` instance, and returns True if 168 any of it reports True. 169 """ 170 def _any_chain(submods, node) -> bool: 171 return any( 172 x.is_node_supported(submods, node) 173 for x in op_support 174 ) 175 return create_op_support(_any_chain) 176 177 178@compatibility(is_backward_compatible=False) 179class OpSupports: 180 """A set of atomic `OperatorSupportBase` instances that can be combined together 181 to form more complex operator support logic. 182 """ 183 @classmethod 184 def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: 185 """Report a node as non-supported, if any of its arguments is of dtype""" 186 187 def _decline_if_input_dtype( 188 submodules: t.Mapping[str, torch.nn.Module], 189 node: torch.fx.Node, 190 ) -> bool: 191 for arg in node.all_input_nodes: 192 arg_dtype = _get_arg_dtype(arg) 193 if arg_dtype == dtype: 194 return False 195 return True 196 return create_op_support(_decline_if_input_dtype) 197 198 @classmethod 199 def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: 200 """ 201 If a node has a name that is in the disallow set, reported it as non-supported. 202 """ 203 def _decline_if_node_in_names( 204 submodules: t.Mapping[str, torch.nn.Module], 205 node: torch.fx.Node, 206 ) -> bool: 207 return node.name not in disallow_set 208 return create_op_support(_decline_if_node_in_names) 209 210 211def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: 212 assert isinstance(arg, torch.fx.Node) 213 tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] 214 dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] 215 return dtype 216