1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from collections import defaultdict 8from typing import Any, Dict, Optional, Sequence, Tuple 9 10import torch 11from executorch.exir.dialects.edge._ops import EdgeDialectFunctionSchema, EdgeOpOverload 12from executorch.exir.emit._emitter import _Argument, _Target 13from executorch.exir.error import ExportError, InternalError 14from torch._ops import HigherOrderOperator 15 16 17class RunHigherOrderOperatorError(Exception): 18 """ 19 Raised when an we try to run delegate or other HigherOrderOperator in a graph module. 20 E.g., %executorch_call_delegate : [#users=1] = call_function[ 21 target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1), kwargs = {}) 22 """ 23 24 def __init__(self, message: str) -> None: 25 super().__init__(message) 26 27 28# pyre-ignore[13]: Attribute `node` is never initialized. 29class EdgeOpArgValidator(torch.fx.Interpreter): 30 """ 31 Validate whether all the Tensor arguments passed to an operator are valid in terms of allowed dtype. 32 Expecting all the operators are EdgeOpOverload which contains the allowed dtype information. 33 Violating operators are being kept in self.violating_ops 34 """ 35 36 node: torch.fx.Node 37 38 def __init__(self, graph_module: torch.fx.GraphModule) -> None: 39 super().__init__(graph_module) 40 self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = ( 41 defaultdict(dict) 42 ) 43 44 def run_node(self, n: torch.fx.Node) -> None: 45 self.node = n 46 try: 47 ret = super().run_node(n) 48 except Exception as e: 49 if isinstance(e, (InternalError, ExportError, RunHigherOrderOperatorError)): 50 raise e 51 else: 52 raise InternalError(str(e)) from e 53 return ret 54 55 def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs): 56 if schema_arg.name in kwargs: 57 kernel_arg = kwargs[schema_arg.name] 58 elif not schema_arg.kwarg_only and schema_arg_idx < len(args): 59 kernel_arg = args[schema_arg_idx] 60 else: 61 kernel_arg = schema_arg.default_value 62 63 return kernel_arg 64 65 def call_function( # noqa: C901 # pyre-fixme[14] 66 self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] 67 ) -> Any: 68 """ 69 Go through all the node.target and validate their Tensor arguments are having the allowed dtypes. 70 """ 71 if not isinstance(target, EdgeOpOverload) or not isinstance( 72 target._schema, EdgeDialectFunctionSchema 73 ): 74 if isinstance(target, HigherOrderOperator): 75 raise RunHigherOrderOperatorError("Can't run delegate") 76 return super().call_function(target, args, kwargs) # pyre-fixme[6] 77 78 # TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist. 79 tensor_arg_types: Dict[str, Optional[torch.dtype]] = {} 80 for i, schema_arg in enumerate(target._schema.arguments): 81 if ( 82 isinstance(schema_arg.type, torch.TensorType) 83 or schema_arg.type == torch.OptionalType.ofTensor() 84 ): 85 kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs) 86 if not isinstance(kernel_arg, torch.Tensor): 87 continue 88 tensor_arg_types[schema_arg.name] = kernel_arg.dtype 89 elif schema_arg.type == torch.ListType.ofTensors(): 90 kernel_arg = self._get_kernel_arg(schema_arg, i, args, kwargs) 91 if not isinstance(kernel_arg, list) or not all( 92 isinstance(kernel_arg[i], torch.Tensor) 93 for i in range(len(kernel_arg)) 94 ): 95 continue 96 if len(kernel_arg): 97 tensor_arg_types[schema_arg.name] = kernel_arg[0].dtype 98 else: 99 # If kernel_arg is an empty list, treat its type as None. 100 # FunctionDtypeConstraint.validate will take None as any legal dtype. 101 tensor_arg_types[schema_arg.name] = None 102 103 ret_index = 0 104 kernel_rets = self.node.meta["val"] 105 ret_iter = iter( 106 kernel_rets if isinstance(kernel_rets, Sequence) else [kernel_rets] 107 ) 108 for schema_ret in target._schema.returns: 109 name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}" 110 kernel_ret = next(ret_iter) 111 # Return value should not be in OptionalTensor type, so only check torch.TensorType here. 112 if isinstance(schema_ret.type, torch.TensorType) and isinstance( 113 kernel_ret, torch.Tensor 114 ): 115 tensor_arg_types[name] = kernel_ret.dtype 116 ret_index += 1 117 elif schema_ret.type == torch.ListType.ofTensors() and all( 118 isinstance(kernel_ret[i], torch.Tensor) for i in range(len(kernel_ret)) 119 ): 120 if len(kernel_ret): 121 tensor_arg_types[name] = kernel_ret[0].dtype 122 else: 123 tensor_arg_types[name] = None 124 ret_index += 1 125 126 valid = target._schema.dtype_constraint.validate(tensor_arg_types) 127 if not valid: 128 self.violating_ops[target] = tensor_arg_types 129 return super().call_function(target, args, kwargs) # pyre-fixme[6] 130