xref: /aosp_15_r20/external/executorch/exir/operator/test/test_operator.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import unittest
10
11import torch
12from executorch.exir.operator.convert import _get_overload_schema, to_out_variant
13from executorch.exir.operator.util import gen_out_variant_schema
14from torch.library import _scoped_library, impl, impl_abstract
15
16
17class TestOperator(unittest.TestCase):
18    def setUp(self) -> None:
19        super().setUp()
20
21    def test_gen_out_variant_schema_from_functional(self) -> None:
22        func_schema = str(torch.ops.aten.mul.Scalar._schema)
23
24        out_schema = gen_out_variant_schema(func_schema)
25        self.assertEqual(out_schema, str(torch.ops.aten.mul.Scalar_out._schema))
26
27    def test_gen_out_variant_schema_from_inplace(self) -> None:
28        func_schema = str(torch.ops.aten.add_.Scalar._schema)
29
30        out_schema = gen_out_variant_schema(func_schema)
31        self.assertEqual(out_schema, str(torch.ops.aten.add.Scalar_out._schema))
32
33    def test_gen_out_variant_schema_for_custom_ops(self) -> None:
34        func_schema = "custom::foo(Tensor a, Tensor b) -> (Tensor c, Tensor d)"
35
36        out_schema = gen_out_variant_schema(func_schema)
37        self.assertEqual(
38            out_schema,
39            "custom::foo.out(Tensor a, Tensor b, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
40        )
41
42    def test_to_out_variant_mutable(self) -> None:
43
44        with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
45
46            lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor")
47            lib.define(
48                "custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)"
49            )
50
51            @impl(lib, "custom_mutator", "Meta")
52            def custom_mutator_meta(
53                x: torch.Tensor,
54                y: torch.Tensor,
55            ) -> torch.Tensor:
56                return torch.empty_like(x)
57
58            @impl(lib, "custom_mutator", "CompositeExplicitAutograd")
59            def custom_mutator(
60                x: torch.Tensor,
61                y: torch.Tensor,
62            ) -> torch.Tensor:
63                return x + y.add_(1)
64
65            @impl_abstract("DO_NOT_USE_TEST_ONLY::custom_mutator.out")
66            def custom_mutator_out(
67                x: torch.Tensor,
68                y: torch.Tensor,
69                out: torch.Tensor,
70            ) -> torch.Tensor:
71                out = custom_mutator_meta(
72                    x,
73                    y,
74                )
75                return out
76
77            out, _ = to_out_variant(
78                torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator.default
79            )
80            schema = _get_overload_schema(out)
81            self.assertEqual(
82                schema.__str__(),
83                "DO_NOT_USE_TEST_ONLY::custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)",
84            )
85