xref: /aosp_15_r20/external/executorch/exir/dialects/edge/op/api.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""
8APIs to help lowering edge dialect ops to other dialects.
9"""
10import dataclasses
11import logging
12from typing import List, Optional
13
14import torch
15
16from executorch.exir.operator.convert import _pybind_schema_to_native_schema
17from torch._ops import OpOverload, OpOverloadPacket
18from torchgen.model import FunctionSchema, SchemaKind
19
20
21def get_torch_op_overload(
22    namespace: str, opname: str, overload: Optional[str]
23) -> torch._ops.OpOverload:
24    packet: OpOverloadPacket = getattr(getattr(torch.ops, namespace), opname)
25    if overload:
26        return getattr(packet, overload)
27    else:
28        return packet.default
29
30
31def get_callable(name) -> torch._ops.OpOverload:
32    main, suffix = name.split(".")
33    return get_torch_op_overload("aten", main, suffix)
34
35
36def to_variant(op: OpOverload, variant: SchemaKind) -> OpOverload:
37    """Given an operator overload, return its corresponding variant. Currently
38    only supports functional variant and out variant.
39    Argument:
40        op (OpOverload): operator overload instance.
41        variant (SchemaKind): the variant we are looking for.
42    Returns:
43        OpOverload: The matched variant operator.
44    Example:
45        torch.ops.aten.add.Tensor, SchemaKind.out -> torch.ops.aten.add.out
46        torch.ops.aten.add.out, SchemaKind.functional -> torch.ops.aten.add.Tensor
47    """
48    assert (
49        variant == SchemaKind.functional or variant == SchemaKind.out
50    ), f"Only support out variant and functional variant, got {variant}"
51    # first check if the current operator is the target variant
52    native_schema: Optional[FunctionSchema] = _pybind_schema_to_native_schema(
53        op._schema
54    )
55    assert (
56        native_schema is not None
57    ), f"Schema: {op._schema} cannot be converted to torch.FunctionSchema"
58
59    # get all overloads
60    torch_packet = getattr(
61        getattr(torch.ops, op.namespace), op._schema.name.split("::")[1]
62    )
63    schemas: List[torch._C.FunctionSchema] = [
64        getattr(torch_packet, o)._schema
65        for o in torch._C._jit_get_operation(op._schema.name)[1]
66    ]
67    # compare the signature of out variant overload with the signature of the original overload
68    signature = dataclasses.replace(native_schema.signature(), returns=())
69    for schema in schemas:
70        native_s: Optional[FunctionSchema] = _pybind_schema_to_native_schema(schema)
71        if native_s is None:
72            logging.warning(
73                f"Schema: {schema} cannot be converted to torch.FunctionSchema"
74            )
75            continue
76        if (
77            native_s.kind() == variant
78            and dataclasses.replace(native_s.signature(), returns=()) == signature
79        ):
80            op_variant = get_torch_op_overload(
81                op.namespace, schema.name.split("::")[1], schema.overload_name
82            )
83            return op_variant
84    raise RuntimeError(
85        f"{variant} variant of operator {op.name()} can't be found. We've found the schemas of all the overloads: {[str(s) for s in schemas]}"
86    )
87