# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict from torchgen.model import FunctionSchema, SchemaKind from torchgen.native_function_generation import ( functional_to_out_signature, mutable_to_out_signature, self_to_out_signature, ) from torchgen.utils import NamespaceHelper def gen_out_variant_schema(func_op_schema: str) -> str: """ Generate schema for the out= variant of a given functional operator schema. """ # Parse the operator schema namespace_helper = NamespaceHelper.from_namespaced_entity( namespaced_entity=func_op_schema, max_level=1 ) func = FunctionSchema.parse(namespace_helper.entity_name) namespace = namespace_helper.get_cpp_namespace(default="") # Convert it to out variant schema if func.kind() == SchemaKind.inplace: schema = str(self_to_out_signature(func)) elif func.kind() == SchemaKind.functional: schema = str(functional_to_out_signature(func)) elif func.kind() == SchemaKind.mutable: schema = str(mutable_to_out_signature(func)) elif func.kind() == SchemaKind.out: schema = str(func) else: raise RuntimeError(f"SchemaKind: {func.kind()} is not supported") return f"{namespace}::{schema}" if namespace else schema