1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10 11import torch 12from executorch.exir.operator.convert import _get_overload_schema, to_out_variant 13from executorch.exir.operator.util import gen_out_variant_schema 14from torch.library import _scoped_library, impl, impl_abstract 15 16 17class TestOperator(unittest.TestCase): 18 def setUp(self) -> None: 19 super().setUp() 20 21 def test_gen_out_variant_schema_from_functional(self) -> None: 22 func_schema = str(torch.ops.aten.mul.Scalar._schema) 23 24 out_schema = gen_out_variant_schema(func_schema) 25 self.assertEqual(out_schema, str(torch.ops.aten.mul.Scalar_out._schema)) 26 27 def test_gen_out_variant_schema_from_inplace(self) -> None: 28 func_schema = str(torch.ops.aten.add_.Scalar._schema) 29 30 out_schema = gen_out_variant_schema(func_schema) 31 self.assertEqual(out_schema, str(torch.ops.aten.add.Scalar_out._schema)) 32 33 def test_gen_out_variant_schema_for_custom_ops(self) -> None: 34 func_schema = "custom::foo(Tensor a, Tensor b) -> (Tensor c, Tensor d)" 35 36 out_schema = gen_out_variant_schema(func_schema) 37 self.assertEqual( 38 out_schema, 39 "custom::foo.out(Tensor a, Tensor b, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", 40 ) 41 42 def test_to_out_variant_mutable(self) -> None: 43 44 with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib: 45 46 lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor") 47 lib.define( 48 "custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)" 49 ) 50 51 @impl(lib, "custom_mutator", "Meta") 52 def custom_mutator_meta( 53 x: torch.Tensor, 54 y: torch.Tensor, 55 ) -> torch.Tensor: 56 return torch.empty_like(x) 57 58 @impl(lib, "custom_mutator", "CompositeExplicitAutograd") 59 def custom_mutator( 60 x: torch.Tensor, 61 y: torch.Tensor, 62 ) -> torch.Tensor: 63 return x + y.add_(1) 64 65 @impl_abstract("DO_NOT_USE_TEST_ONLY::custom_mutator.out") 66 def custom_mutator_out( 67 x: torch.Tensor, 68 y: torch.Tensor, 69 out: torch.Tensor, 70 ) -> torch.Tensor: 71 out = custom_mutator_meta( 72 x, 73 y, 74 ) 75 return out 76 77 out, _ = to_out_variant( 78 torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator.default 79 ) 80 schema = _get_overload_schema(out) 81 self.assertEqual( 82 schema.__str__(), 83 "DO_NOT_USE_TEST_ONLY::custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)", 84 ) 85