xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_clone.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7#
8# Tests the clone op which copies the data of the input tensor (possibly with new data format)
9#
10
11import unittest
12from typing import Tuple
13
14import torch
15
16from executorch.backends.arm.quantizer.arm_quantizer import (
17    ArmQuantizer,
18    get_symmetric_quantization_config,
19)
20from executorch.backends.arm.test import common
21from executorch.backends.arm.test.tester.arm_tester import ArmTester
22
23from executorch.backends.xnnpack.test.tester.tester import Quantize
24
25from executorch.exir.backend.compile_spec_schema import CompileSpec
26from parameterized import parameterized
27
28
29class TestSimpleClone(unittest.TestCase):
30    """Tests clone."""
31
32    class Clone(torch.nn.Module):
33        sizes = [10, 15, 50, 100]
34        test_parameters = [(torch.ones(n),) for n in sizes]
35
36        def __init__(self):
37            super().__init__()
38
39        def forward(self, x: torch.Tensor):
40            x = x.clone()
41            return x
42
43    def _test_clone_tosa_MI_pipeline(
44        self, module: torch.nn.Module, test_data: torch.Tensor
45    ):
46        (
47            ArmTester(
48                module,
49                example_inputs=test_data,
50                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
51            )
52            .export()
53            .check_count({"torch.ops.aten.clone.default": 1})
54            .to_edge()
55            .partition()
56            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
57            .to_executorch()
58            .run_method_and_compare_outputs(inputs=test_data)
59        )
60
61    def _test_clone_tosa_BI_pipeline(
62        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
63    ):
64        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
65        (
66            ArmTester(
67                module,
68                example_inputs=test_data,
69                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
70            )
71            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
72            .export()
73            .check_count({"torch.ops.aten.clone.default": 1})
74            .to_edge()
75            .partition()
76            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
77            .to_executorch()
78            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
79        )
80
81    def _test_clone_tosa_ethos_pipeline(
82        self,
83        compile_spec: list[CompileSpec],
84        module: torch.nn.Module,
85        test_data: Tuple[torch.Tensor],
86    ):
87        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
88        (
89            ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
90            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
91            .export()
92            .check_count({"torch.ops.aten.clone.default": 1})
93            .to_edge()
94            .partition()
95            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
96            .to_executorch()
97        )
98
99    def _test_clone_tosa_u55_pipeline(
100        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
101    ):
102        self._test_clone_tosa_ethos_pipeline(
103            common.get_u55_compile_spec(), module, test_data
104        )
105
106    def _test_clone_tosa_u85_pipeline(
107        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
108    ):
109        self._test_clone_tosa_ethos_pipeline(
110            common.get_u85_compile_spec(), module, test_data
111        )
112
113    @parameterized.expand(Clone.test_parameters)
114    def test_clone_tosa_MI(self, test_tensor: torch.Tensor):
115        self._test_clone_tosa_MI_pipeline(self.Clone(), (test_tensor,))
116
117    @parameterized.expand(Clone.test_parameters)
118    def test_clone_tosa_BI(self, test_tensor: torch.Tensor):
119        self._test_clone_tosa_BI_pipeline(self.Clone(), (test_tensor,))
120
121    @parameterized.expand(Clone.test_parameters)
122    def test_clone_u55_BI(self, test_tensor: torch.Tensor):
123        self._test_clone_tosa_u55_pipeline(self.Clone(), (test_tensor,))
124
125    @parameterized.expand(Clone.test_parameters)
126    def test_clone_u85_BI(self, test_tensor: torch.Tensor):
127        self._test_clone_tosa_u85_pipeline(self.Clone(), (test_tensor,))
128