1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7""" 8APIs to help lowering edge dialect ops to other dialects. 9""" 10import dataclasses 11import logging 12from typing import List, Optional 13 14import torch 15 16from executorch.exir.operator.convert import _pybind_schema_to_native_schema 17from torch._ops import OpOverload, OpOverloadPacket 18from torchgen.model import FunctionSchema, SchemaKind 19 20 21def get_torch_op_overload( 22 namespace: str, opname: str, overload: Optional[str] 23) -> torch._ops.OpOverload: 24 packet: OpOverloadPacket = getattr(getattr(torch.ops, namespace), opname) 25 if overload: 26 return getattr(packet, overload) 27 else: 28 return packet.default 29 30 31def get_callable(name) -> torch._ops.OpOverload: 32 main, suffix = name.split(".") 33 return get_torch_op_overload("aten", main, suffix) 34 35 36def to_variant(op: OpOverload, variant: SchemaKind) -> OpOverload: 37 """Given an operator overload, return its corresponding variant. Currently 38 only supports functional variant and out variant. 39 Argument: 40 op (OpOverload): operator overload instance. 41 variant (SchemaKind): the variant we are looking for. 42 Returns: 43 OpOverload: The matched variant operator. 44 Example: 45 torch.ops.aten.add.Tensor, SchemaKind.out -> torch.ops.aten.add.out 46 torch.ops.aten.add.out, SchemaKind.functional -> torch.ops.aten.add.Tensor 47 """ 48 assert ( 49 variant == SchemaKind.functional or variant == SchemaKind.out 50 ), f"Only support out variant and functional variant, got {variant}" 51 # first check if the current operator is the target variant 52 native_schema: Optional[FunctionSchema] = _pybind_schema_to_native_schema( 53 op._schema 54 ) 55 assert ( 56 native_schema is not None 57 ), f"Schema: {op._schema} cannot be converted to torch.FunctionSchema" 58 59 # get all overloads 60 torch_packet = getattr( 61 getattr(torch.ops, op.namespace), op._schema.name.split("::")[1] 62 ) 63 schemas: List[torch._C.FunctionSchema] = [ 64 getattr(torch_packet, o)._schema 65 for o in torch._C._jit_get_operation(op._schema.name)[1] 66 ] 67 # compare the signature of out variant overload with the signature of the original overload 68 signature = dataclasses.replace(native_schema.signature(), returns=()) 69 for schema in schemas: 70 native_s: Optional[FunctionSchema] = _pybind_schema_to_native_schema(schema) 71 if native_s is None: 72 logging.warning( 73 f"Schema: {schema} cannot be converted to torch.FunctionSchema" 74 ) 75 continue 76 if ( 77 native_s.kind() == variant 78 and dataclasses.replace(native_s.signature(), returns=()) == signature 79 ): 80 op_variant = get_torch_op_overload( 81 op.namespace, schema.name.split("::")[1], schema.overload_name 82 ) 83 return op_variant 84 raise RuntimeError( 85 f"{variant} variant of operator {op.name()} can't be found. We've found the schemas of all the overloads: {[str(s) for s in schemas]}" 86 ) 87