1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch 10from executorch.exir.dialects._ops import ops 11from executorch.exir.dialects.edge._ops import EdgeOpOverload 12from executorch.exir.pass_base import ExportPass 13from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS 14from torch.fx.node import Target 15 16 17DISALLOW_LIST = [ 18 torch.ops.aten._assert_scalar.default, 19 torch.ops.aten._assert_async.msg, 20 torch.ops.aten.scalar_tensor.default, 21] 22 23 24def aten_to_edge(aten_op: torch._ops.OpOverload) -> EdgeOpOverload: 25 # Assume qualified op name: aten::add.Tensor 26 op_namespace, op_name, op_overload_name = ( 27 aten_op.namespace, 28 aten_op._schema.name.split("::")[1], 29 aten_op._overloadname, 30 ) 31 edge_op = getattr( 32 getattr(getattr(ops.edge, op_namespace), op_name), op_overload_name 33 ) 34 return edge_op 35 36 37def should_lower_to_edge(op: Target) -> bool: 38 """Returns true if the given operator should be lowered to edge op.""" 39 return ( 40 isinstance(op, torch._ops.OpOverload) 41 and op not in _EXECUTORCH_SYM_OPS 42 and op not in DISALLOW_LIST 43 ) 44 45 46class OpReplacePass(ExportPass): 47 """ 48 Goes through all ops and replaces torch (aten + custom) ops with edge ops. 49 Exclude those ops that don't care about input dtypes and out variants. 50 """ 51 52 def __init__(self) -> None: 53 super().__init__() 54 55 def call_operator(self, op, args, kwargs, meta): 56 if should_lower_to_edge(op): 57 return super().call_operator(aten_to_edge(op), args, kwargs, meta) 58 else: 59 return super().call_operator(op, args, kwargs, meta) 60