xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/prelu.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestPrelu(unittest.TestCase):
14    class PReLU(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17            self.prelu = torch.nn.PReLU(num_parameters=5, init=0.2)
18
19        def forward(self, x):
20            a = self.prelu(x)
21            return a
22
23    def _test_prelu(self, module, inputs):
24        (
25            Tester(module, inputs)
26            .export()
27            .check_count({"torch.ops.aten.prelu.default": 1})
28            .to_edge_transform_and_lower()
29            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
30            .check_not(
31                ["executorch_exir_dialects_edge__ops_aten__prelu_kernel_default"]
32            )
33            .to_executorch()
34            .serialize()
35            .run_method_and_compare_outputs()
36        )
37
38    @unittest.skip("XNNPACK Expects FP16 inputs but FP32 weights")
39    def _test_fp16_prelu(self):
40        module = self.PReLU().to(torch.float16)
41        inputs = (torch.randn(1, 5, 3, 2).to(torch.float16),)
42        self._test_prelu(module, inputs)
43
44    def test_fp32_prelu(self):
45        module = self.PReLU()
46        inputs = (torch.randn(1, 5, 3, 2),)
47        self._test_prelu(module, inputs)
48