xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_full.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7#
8# Tests the full op which creates a tensor of a given shape filled with a given value.
9# The shape and value are set at compile time, i.e. can't be set by a tensor input.
10#
11
12import unittest
13from typing import Tuple
14
15import torch
16from executorch.backends.arm.test import common
17from executorch.backends.arm.test.tester.arm_tester import ArmTester
18from executorch.exir.backend.compile_spec_schema import CompileSpec
19from parameterized import parameterized
20
21
22class TestFull(unittest.TestCase):
23    """Tests the full op which creates a tensor of a given shape filled with a given value."""
24
25    class Full(torch.nn.Module):
26        # A single full op
27        def forward(self):
28            return torch.full((3, 3), 4.5)
29
30    class AddConstFull(torch.nn.Module):
31        # Input + a full with constant value.
32        def forward(self, x: torch.Tensor):
33            return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x
34
35    class AddVariableFull(torch.nn.Module):
36        sizes = [
37            (5),
38            (5, 5),
39            (5, 5, 5),
40            (1, 5, 5, 5),
41        ]
42        test_parameters = [((torch.randn(n) * 10 - 5, 3.2),) for n in sizes]
43
44        def forward(self, x: torch.Tensor, y):
45            # Input + a full with the shape from the input and a given value 'y'.
46            return x + torch.full(x.shape, y)
47
48    def _test_full_tosa_MI_pipeline(
49        self,
50        module: torch.nn.Module,
51        example_data: Tuple,
52        test_data: Tuple | None = None,
53    ):
54        if test_data is None:
55            test_data = example_data
56        (
57            ArmTester(
58                module,
59                example_inputs=example_data,
60                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
61            )
62            .export()
63            .check_count({"torch.ops.aten.full.default": 1})
64            .to_edge()
65            .partition()
66            .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
67            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
68            .to_executorch()
69            .run_method_and_compare_outputs(inputs=test_data)
70        )
71
72    def _test_full_tosa_BI_pipeline(
73        self,
74        module: torch.nn.Module,
75        test_data: Tuple,
76        permute_memory_to_nhwc: bool,
77    ):
78        (
79            ArmTester(
80                module,
81                example_inputs=test_data,
82                compile_spec=common.get_tosa_compile_spec(
83                    "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute_memory_to_nhwc
84                ),
85            )
86            .quantize()
87            .export()
88            .check_count({"torch.ops.aten.full.default": 1})
89            .to_edge()
90            .partition()
91            .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
92            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
93            .to_executorch()
94            .run_method_and_compare_outputs(inputs=test_data)
95        )
96
97    def _test_full_tosa_ethos_pipeline(
98        self, compile_spec: list[CompileSpec], module: torch.nn.Module, test_data: Tuple
99    ):
100        (
101            ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
102            .quantize()
103            .export()
104            .check_count({"torch.ops.aten.full.default": 1})
105            .to_edge()
106            .partition()
107            .check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
108            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
109            .to_executorch()
110        )
111
112    def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple):
113        self._test_full_tosa_ethos_pipeline(
114            common.get_u55_compile_spec(), module, test_data
115        )
116
117    def _test_full_tosa_u85_pipeline(self, module: torch.nn.Module, test_data: Tuple):
118        self._test_full_tosa_ethos_pipeline(
119            common.get_u85_compile_spec(), module, test_data
120        )
121
122    def test_only_full_tosa_MI(self):
123        self._test_full_tosa_MI_pipeline(self.Full(), ())
124
125    def test_const_full_tosa_MI(self):
126        _input = torch.rand((2, 2, 3, 3)) * 10
127        self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,))
128
129    def test_const_full_nhwc_tosa_BI(self):
130        _input = torch.rand((2, 2, 3, 3)) * 10
131        self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,), True)
132
133    @parameterized.expand(AddVariableFull.test_parameters)
134    def test_full_tosa_MI(self, test_tensor: Tuple):
135        self._test_full_tosa_MI_pipeline(
136            self.AddVariableFull(), example_data=test_tensor
137        )
138
139    @parameterized.expand(AddVariableFull.test_parameters)
140    def test_full_tosa_BI(self, test_tensor: Tuple):
141        self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor, False)
142
143    @parameterized.expand(AddVariableFull.test_parameters)
144    def test_full_u55_BI(self, test_tensor: Tuple):
145        self._test_full_tosa_u55_pipeline(
146            self.AddVariableFull(),
147            test_tensor,
148        )
149
150    @parameterized.expand(AddVariableFull.test_parameters)
151    def test_full_u85_BI(self, test_tensor: Tuple):
152        self._test_full_tosa_u85_pipeline(
153            self.AddVariableFull(),
154            test_tensor,
155        )
156
157    # This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support.
158    @unittest.expectedFailure
159    def test_integer_value(self):
160        _input = torch.ones((2, 2))
161        integer_fill_value = 1
162        self._test_full_tosa_MI_pipeline(
163            self.AddVariableFull(), example_data=(_input, integer_fill_value)
164        )
165
166    # This fails since the fill value in the full tensor is set at compile time by the example data (1.).
167    # Test data tries to set it again at runtime (to 2.) but it doesn't do anything.
168    # In eager mode, the fill value can be set at runtime, causing the outputs to not match.
169    @unittest.expectedFailure
170    def test_set_value_at_runtime(self):
171        _input = torch.ones((2, 2))
172        example_fill_value = 1.0
173        test_fill_value = 2.0
174        self._test_full_tosa_MI_pipeline(
175            self.AddVariableFull(),
176            example_data=(_input, example_fill_value),
177            test_data=(_input, test_fill_value),
178        )
179