xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/elu.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestElu(unittest.TestCase):
14    class ELU(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17            self.elu = torch.nn.ELU(alpha=0.5)
18
19        def forward(self, x):
20            return self.elu(x)
21
22    class ELUFunctional(torch.nn.Module):
23        def forward(self, x):
24            return torch.nn.functional.elu(x, alpha=1.2)
25
26    def _test_elu(self, inputs):
27        (
28            Tester(self.ELU(), inputs)
29            .export()
30            .check_count({"torch.ops.aten.elu.default": 1})
31            .to_edge_transform_and_lower()
32            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
33            .check_not(
34                [
35                    "executorch_exir_dialects_edge__ops_aten_elu_default",
36                ]
37            )
38            .to_executorch()
39            .serialize()
40            .run_method_and_compare_outputs()
41        )
42
43    @unittest.skip("PyTorch Pin Update Required")
44    def _test_fp16_elu(self):
45        inputs = (torch.randn(1, 3, 3).to(torch.float16),)
46        self._test_elu(inputs)
47
48    @unittest.skip("PyTorch Pin Update Required")
49    def _test_fp32_elu(self):
50        inputs = (torch.randn(1, 3, 3),)
51        self._test_elu(inputs)
52
53    @unittest.skip("Update Quantizer to quantize Elu")
54    def _test_qs8_elu(self):
55        inputs = (torch.randn(1, 3, 4, 4),)
56        (
57            Tester(self.ELU(), inputs)
58            .quantize()
59            .export()
60            .check_count({"torch.ops.aten.elu.default": 1})
61            .check(["torch.ops.quantized_decomposed"])
62            .to_edge_transform_and_lower()
63            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
64            .check_not(
65                [
66                    "executorch_exir_dialects_edge__ops_aten_elu_default",
67                    "torch.ops.quantized_decomposed",
68                ]
69            )
70            .to_executorch()
71            .serialize()
72            .run_method_and_compare_outputs()
73        )
74
75    @unittest.skip("Update Quantizer to quantize Elu")
76    def _test_qs8_elu_functional(self):
77        inputs = (torch.randn(1, 3, 4, 4),)
78        (
79            Tester(self.ELU(), inputs)
80            .quantize()
81            .export()
82            .check_count({"torch.ops.aten.elu.default": 1})
83            .check(["torch.ops.quantized_decomposed"])
84            .to_edge_transform_and_lower()
85            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86            .check_not(
87                [
88                    "executorch_exir_dialects_edge__ops_aten_elu_default",
89                    "torch.ops.quantized_decomposed",
90                ]
91            )
92            .to_executorch()
93            .serialize()
94            .run_method_and_compare_outputs()
95        )
96