xref: /aosp_15_r20/external/executorch/exir/passes/replace_aten_with_edge_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch
10from executorch.exir.dialects._ops import ops
11from executorch.exir.dialects.edge._ops import EdgeOpOverload
12from executorch.exir.pass_base import ExportPass
13from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
14from torch.fx.node import Target
15
16
17DISALLOW_LIST = [
18    torch.ops.aten._assert_scalar.default,
19    torch.ops.aten._assert_async.msg,
20    torch.ops.aten.scalar_tensor.default,
21]
22
23
24def aten_to_edge(aten_op: torch._ops.OpOverload) -> EdgeOpOverload:
25    # Assume qualified op name: aten::add.Tensor
26    op_namespace, op_name, op_overload_name = (
27        aten_op.namespace,
28        aten_op._schema.name.split("::")[1],
29        aten_op._overloadname,
30    )
31    edge_op = getattr(
32        getattr(getattr(ops.edge, op_namespace), op_name), op_overload_name
33    )
34    return edge_op
35
36
37def should_lower_to_edge(op: Target) -> bool:
38    """Returns true if the given operator should be lowered to edge op."""
39    return (
40        isinstance(op, torch._ops.OpOverload)
41        and op not in _EXECUTORCH_SYM_OPS
42        and op not in DISALLOW_LIST
43    )
44
45
46class OpReplacePass(ExportPass):
47    """
48    Goes through all ops and replaces torch (aten + custom) ops with edge ops.
49    Exclude those ops that don't care about input dtypes and out variants.
50    """
51
52    def __init__(self) -> None:
53        super().__init__()
54
55    def call_operator(self, op, args, kwargs, meta):
56        if should_lower_to_edge(op):
57            return super().call_operator(aten_to_edge(op), args, kwargs, meta)
58        else:
59            return super().call_operator(op, args, kwargs, meta)
60