xref: /aosp_15_r20/external/executorch/exir/operator/util.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom torchgen.model import FunctionSchema, SchemaKind
10*523fa7a6SAndroid Build Coastguard Workerfrom torchgen.native_function_generation import (
11*523fa7a6SAndroid Build Coastguard Worker    functional_to_out_signature,
12*523fa7a6SAndroid Build Coastguard Worker    mutable_to_out_signature,
13*523fa7a6SAndroid Build Coastguard Worker    self_to_out_signature,
14*523fa7a6SAndroid Build Coastguard Worker)
15*523fa7a6SAndroid Build Coastguard Workerfrom torchgen.utils import NamespaceHelper
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerdef gen_out_variant_schema(func_op_schema: str) -> str:
19*523fa7a6SAndroid Build Coastguard Worker    """
20*523fa7a6SAndroid Build Coastguard Worker    Generate schema for the out= variant of a given functional operator schema.
21*523fa7a6SAndroid Build Coastguard Worker    """
22*523fa7a6SAndroid Build Coastguard Worker    # Parse the operator schema
23*523fa7a6SAndroid Build Coastguard Worker    namespace_helper = NamespaceHelper.from_namespaced_entity(
24*523fa7a6SAndroid Build Coastguard Worker        namespaced_entity=func_op_schema, max_level=1
25*523fa7a6SAndroid Build Coastguard Worker    )
26*523fa7a6SAndroid Build Coastguard Worker    func = FunctionSchema.parse(namespace_helper.entity_name)
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker    namespace = namespace_helper.get_cpp_namespace(default="")
29*523fa7a6SAndroid Build Coastguard Worker    # Convert it to out variant schema
30*523fa7a6SAndroid Build Coastguard Worker    if func.kind() == SchemaKind.inplace:
31*523fa7a6SAndroid Build Coastguard Worker        schema = str(self_to_out_signature(func))
32*523fa7a6SAndroid Build Coastguard Worker    elif func.kind() == SchemaKind.functional:
33*523fa7a6SAndroid Build Coastguard Worker        schema = str(functional_to_out_signature(func))
34*523fa7a6SAndroid Build Coastguard Worker    elif func.kind() == SchemaKind.mutable:
35*523fa7a6SAndroid Build Coastguard Worker        schema = str(mutable_to_out_signature(func))
36*523fa7a6SAndroid Build Coastguard Worker    elif func.kind() == SchemaKind.out:
37*523fa7a6SAndroid Build Coastguard Worker        schema = str(func)
38*523fa7a6SAndroid Build Coastguard Worker    else:
39*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"SchemaKind: {func.kind()} is not supported")
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker    return f"{namespace}::{schema}" if namespace else schema
42