xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/sub.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestSub(unittest.TestCase):
14    class Sub(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17
18        def forward(self, x, y):
19            z = x - y
20            return z
21
22    class Sub2(torch.nn.Module):
23        def __init__(self):
24            super().__init__()
25
26        def forward(self, x):
27            z = x - x
28            return z
29
30    def _test_sub(self, inputs):
31        (
32            Tester(self.Sub(), inputs)
33            .export()
34            .check_count({"torch.ops.aten.sub.Tensor": 1})
35            .to_edge_transform_and_lower()
36            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
37            .check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"])
38            .to_executorch()
39            .serialize()
40            .run_method_and_compare_outputs()
41        )
42
43    def test_fp16_sub(self):
44        inputs = (
45            torch.randn((1, 3)).to(torch.float16),
46            torch.randn((4, 3)).to(torch.float16),
47        )
48        self._test_sub(inputs)
49
50    def test_fp32_sub(self):
51        inputs = (torch.randn((1, 3)), torch.randn((4, 3)))
52        self._test_sub(inputs)
53
54    @unittest.skip("T171957656 - Quantized sub not implemented.")
55    def _test_qs8_sub(self):
56        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
57        (
58            Tester(self.Sub(), inputs)
59            .quantize()
60            .export()
61            .check_count({"torch.ops.aten.sub.Tensor": 1})
62            .check(["torch.ops.quantized_decomposed"])
63            .to_edge_transform_and_lower()
64            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
65            .check_not(
66                [
67                    "executorch_exir_dialects_edge__ops_aten_sub_Tensor",
68                    "torch.ops.quantized_decomposed",
69                ]
70            )
71            .to_executorch()
72            .serialize()
73            .run_method_and_compare_outputs()
74        )
75
76    @unittest.skip("T171957656 - Quantized sub not implemented.")
77    def _test_qs8_sub2(self):
78        inputs = (torch.randn(1, 1, 4, 4),)
79        (
80            Tester(self.Sub2(), inputs)
81            .quantize()
82            .export()
83            .check_count({"torch.ops.aten.sub.Tensor": 1})
84            .check(["torch.ops.quantized_decomposed"])
85            .to_edge_transform_and_lower()
86            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
87            .check_not(
88                [
89                    "executorch_exir_dialects_edge__ops_aten_sub_Tensor",
90                    "torch.ops.quantized_decomposed",
91                ]
92            )
93            .to_executorch()
94            .serialize()
95            .run_method_and_compare_outputs()
96        )
97
98    @unittest.skip("T171957656 - Quantized sub not implemented.")
99    def _test_qs8_sub3(self):
100        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
101        (
102            Tester(self.Sub(), inputs)
103            .quantize()
104            .export()
105            .check_count({"torch.ops.aten.sub.Tensor": 1})
106            .check(["torch.ops.quantized_decomposed"])
107            .to_edge_transform_and_lower()
108            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
109            .check_not(
110                [
111                    "executorch_exir_dialects_edge__ops_aten_sub_Tensor",
112                    "torch.ops.quantized_decomposed",
113                ]
114            )
115            .to_executorch()
116            .serialize()
117            .run_method_and_compare_outputs()
118        )
119
120    @unittest.skip("T171957656 - Quantized sub not implemented.")
121    def _test_qs8_sub_relu(self):
122        class Sub(torch.nn.Module):
123            def forward(self, x, y):
124                z = x - y
125                return torch.nn.functional.relu(z)
126
127        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
128        (
129            Tester(self.Sub(), inputs)
130            .quantize()
131            .export()
132            .check_count(
133                {
134                    "torch.ops.aten.sub.Tensor": 1,
135                    "torch.ops.aten.relu.default": 1,
136                }
137            )
138            .check(["torch.ops.quantized_decomposed"])
139            .to_edge_transform_and_lower()
140            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
141            .check_not(
142                [
143                    "executorch_exir_dialects_edge__ops_aten_sub_Tensor",
144                    "executorch_exir_dialects_edge__ops_aten_relu_default",
145                    "torch.ops.quantized_decomposed",
146                ]
147            )
148            .to_executorch()
149            .serialize()
150            .run_method_and_compare_outputs()
151        )
152