1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import itertools 8import operator 9import types 10from contextlib import nullcontext 11from typing import Any, List, Optional, Tuple, Type 12 13import torch 14from executorch.exir.capture._config import EdgeCompileConfig 15from executorch.exir.dialects.edge._ops import EdgeOpOverload 16from executorch.exir.error import ExportError, ExportErrorType 17from executorch.exir.lowered_backend_module import LoweredBackendModule 18from executorch.exir.verification.arg_validator import ( 19 EdgeOpArgValidator, 20 RunHigherOrderOperatorError, 21) 22from torch._dispatch.python import enable_python_dispatcher 23from torch._export.utils import _detect_fake_mode_from_gm 24 25from torch._export.verifier import SpecViolationError, Verifier 26from torch._ops import OpOverload 27from torch._subclasses import FakeTensor 28from torch.export.exported_program import ExportedProgram 29from torch.fx import GraphModule 30 31 32ALLOWED_META_KEYS = {"spec", "stack_trace"} 33 34 35def _check_tensors_are_contiguous(gm: GraphModule) -> None: 36 # Tensors be of contiguous format 37 for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): 38 if isinstance(param, torch.Tensor): 39 if not param.is_contiguous(): 40 raise SpecViolationError( 41 f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" 42 ) 43 44 45def _check_valid_dim_order_ops(op, use_dim_order) -> None: 46 if use_dim_order: 47 if op in (torch.ops.aten._to_copy.default,): 48 raise SpecViolationError(f"{op} should not be used in dim_order mode") 49 else: # not using dim_order 50 if op.namespace in ("dim_order_ops",): 51 raise SpecViolationError(f"{op} should not be used in non-dim_order mode") 52 53 54class EXIRATenDialectVerifierBase(Verifier): 55 dialect = "OLD_EXIR_ATEN_DISABLED" 56 57 def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: 58 return ( 59 torch.fx.GraphModule, 60 LoweredBackendModule, 61 torch.Tensor, 62 torch.ScriptObject, 63 ) 64 65 def allowed_op_types(self): 66 return super().allowed_op_types() + (torch._ops.OpOverloadPacket,) 67 68 def __call__(self, *args, **kwargs): 69 if hasattr(self, "_check_graph_module"): 70 return self._check_graph_module(*args, **kwargs) 71 elif hasattr(self, "check_valid"): 72 return self.check_valid(*args, **kwargs) 73 else: 74 raise RuntimeError("") 75 76 77def EXIRATenDialectVerifier( # noqa: C901 78 edge_compile_config: Optional[EdgeCompileConfig] = None, 79 class_only: bool = False, 80 exception_list: Optional[List[torch._ops.OpOverload]] = None, 81): 82 """ 83 Returns a verifier class that runs ATen dialect specific checks on the graph module. 84 """ 85 # merge the exception list from edge_compile_config and exception_list 86 if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: 87 exception_list = edge_compile_config._core_aten_ops_exception_list + ( 88 exception_list or [] 89 ) 90 91 class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase): 92 dialect = "OLD_EXIR_ATEN" 93 94 def __init__(self) -> None: 95 super().__init__() 96 # Note: here we are using the exception list passed from EXIRATenDialectVerifier function! 97 self._exception_list = exception_list if exception_list else [] 98 99 def _get_exception_list(self) -> List[torch._ops.OpOverload]: 100 exception_list = [ 101 torch.ops.aten.mkldnn_rnn_layer.default, 102 torch.ops.aten._upsample_bilinear2d_aa.default, 103 torch.ops.aten.quantize_per_tensor.default, 104 torch.ops.aten.dequantize.self, 105 torch.ops.aten.max.default, # TODO(T188268054) 106 torch.ops.aten.min.default, # TODO(T188268054) 107 torch.ops.aten.full_like.default, # TODO(T183507359) 108 ] 109 exception_list += self._exception_list 110 111 return exception_list 112 113 def check_valid_op(self, op): 114 if isinstance(op, OpOverload): 115 # TODO These special ops should be removable easily. 116 if op.namespace != "aten" or op in self._get_exception_list(): 117 return 118 if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags: 119 # NOTE(qihan): whether view_copy operators are marked as canonical is still under 120 # discussion. 121 raise SpecViolationError( 122 f"Operator {op.__module__}.{op.__name__} is not Aten Canonical." 123 ) 124 125 ret = _EXIRATenDialectVerifier 126 if not class_only: 127 ret = ret() 128 return ret 129 130 131def get_aten_verifier(config: EdgeCompileConfig): 132 return ( 133 EXIRATenDialectVerifier( 134 class_only=True, exception_list=config._core_aten_ops_exception_list 135 ) 136 if config._check_ir_validity 137 else EXIRATenDialectVerifierBase 138 ) 139 140 141def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]: 142 def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: 143 if "val" in node.meta: 144 return node.meta["val"] 145 146 if len(node.users) == 0: 147 return None 148 149 # TODO(ycao): `val` should always exist after we enable shape environment 150 # serialization and deserialization. 151 raise ExportError( 152 ExportErrorType.VIOLATION_OF_SPEC, 153 f"Cannot construct an input for graph module: {graph_module}.", 154 ) 155 156 return [ 157 extract_input(node) 158 for node in graph_module.graph.nodes 159 if node.op == "placeholder" 160 ] 161 162 163def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: 164 validator = EdgeOpArgValidator(gm) 165 inputs = _get_inputs(gm) 166 fake_mode = _detect_fake_mode_from_gm(gm) or nullcontext() 167 try: 168 with enable_python_dispatcher(), fake_mode: 169 validator.run(*inputs) 170 except RunHigherOrderOperatorError: 171 # NB: ignore higher order operator in the graph. 172 # If we lower a graph module to delegate and then compose it with some other graph module, retrace it, 173 # if we also turn on edge ops and validator (_check_ir_validity=True), we will run 174 # into RunHigherOrderOperatorError. The only thing we can do right now is to ignore this error, since 175 # by definition it's still a valid Edge dialect. This is not ideal because it ignores possible invalidity 176 # later in the graph. 177 return 178 179 if validator.violating_ops: 180 raise SpecViolationError( 181 f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}" 182 ) 183 184 185def EXIREdgeDialectVerifier( # noqa: C901 186 edge_compile_config: Optional[EdgeCompileConfig] = None, 187 class_only: bool = False, 188 exception_list: Optional[List[torch._ops.OpOverload]] = None, 189): 190 # merge the exception list from edge_compile_config and exception_list 191 if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: 192 exception_list = edge_compile_config._core_aten_ops_exception_list + ( 193 exception_list or [] 194 ) 195 196 class _EXIREdgeDialectVerifier(Verifier): 197 dialect = "EDGE" 198 199 def __init__(self) -> None: 200 _edge_compile_config = edge_compile_config or EdgeCompileConfig() 201 202 self.enable = _edge_compile_config._check_ir_validity 203 self.check_edge_ops = _edge_compile_config._use_edge_ops 204 self.use_dim_order = not _edge_compile_config._skip_dim_order 205 206 self.aten_op_verifier = EXIRATenDialectVerifier( 207 exception_list=exception_list 208 ) 209 self.check_valid_aten_op = self.aten_op_verifier.check_valid_op 210 211 if self.check_edge_ops: 212 self.check_valid_op = self.check_valid_edge_op 213 else: 214 self.check_valid_op = self.check_valid_aten_op 215 self._exception_list = exception_list if exception_list else [] 216 217 def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: 218 return ( 219 torch.fx.GraphModule, 220 LoweredBackendModule, 221 torch.Tensor, 222 torch.ScriptObject, 223 ) 224 225 def allowed_op_types(self): 226 return super().allowed_op_types() + (EdgeOpOverload, types.FunctionType) 227 228 def check_valid_edge_op(self, op): 229 if not self.enable: 230 return 231 if ( 232 op 233 in [ 234 operator.getitem, 235 torch.ops.aten.sym_size.int, 236 torch.ops.aten.scalar_tensor.default, 237 torch.ops.aten._assert_async.msg, 238 torch.ops.aten._assert_scalar.default, 239 ] 240 + self._exception_list 241 ): 242 return 243 244 if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload): 245 raise SpecViolationError( 246 "Operator {}.{} is not an Edge operator.".format( 247 op.__module__, op.__name__ 248 ) 249 ) 250 if isinstance(op, EdgeOpOverload): 251 _check_valid_dim_order_ops(op._op, self.use_dim_order) 252 self.check_valid_aten_op(op._op) 253 254 if isinstance(op, types.FunctionType): 255 assert op.__name__ in ("alloc",) 256 257 def check_additional(self, gm: GraphModule) -> None: 258 if not self.enable: 259 return 260 if self.check_edge_ops: 261 _check_tensors_are_contiguous(gm) 262 _check_tensor_args_matching_op_allowed_dtype(gm) 263 264 def check_valid_op(self, op): 265 if isinstance(op, OpOverload): 266 # TODO These special ops should be removable easily. 267 if op.namespace in ( 268 "quantized_decomposed", 269 "boltnn_nimble", 270 "nimble", 271 "quantized", 272 "dim_order_ops", 273 ) or op in ( 274 torch.ops.aten.mkldnn_rnn_layer.default, 275 torch.ops.aten._upsample_bilinear2d_aa.default, 276 torch.ops.aten.quantize_per_tensor.default, 277 torch.ops.aten.dequantize.self, 278 torch.ops.aten.max.default, 279 torch.ops.aten.full_like.default, # TODO(T183507359) 280 ): 281 return 282 if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags: 283 # NOTE(qihan): whether view_copy operators are marked as canonical is still under 284 # discussion. 285 raise SpecViolationError( 286 f"Operator {op.__module__}.{op.__name__} is not Aten Canonical." 287 ) 288 289 def is_valid(self, gm: GraphModule) -> bool: 290 try: 291 self(gm) 292 return True 293 except SpecViolationError: 294 return False 295 296 def __call__(self, ep_or_gm): 297 if not self.enable: 298 return 299 gm = ep_or_gm 300 if isinstance(gm, ExportedProgram): 301 gm = ep_or_gm.graph_module 302 if hasattr(self, "_check_graph_module"): 303 return self._check_graph_module(gm) 304 elif hasattr(self, "check_valid"): 305 return self.check_valid(gm) 306 else: 307 raise RuntimeError("") 308 309 ret = _EXIREdgeDialectVerifier 310 if not class_only: 311 ret = ret() 312 return ret 313 314 315EXIREdgeDialectVerifier() 316