xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/clamp.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestClamp(unittest.TestCase):
14    class Clamp(torch.nn.Module):
15        def __init__(self, min_val=None, max_val=None):
16            super().__init__()
17            self.min_val = min_val
18            self.max_val = max_val
19
20        def forward(self, x):
21            z = torch.clamp(x, min=self.min_val, max=self.max_val)
22            return z + z
23
24    def _test_clamp(self, module, inputs):
25        (
26            Tester(module, inputs)
27            .export()
28            .check_count({"torch.ops.aten.clamp.default": 1})
29            .to_edge_transform_and_lower()
30            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
31            .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
32            .to_executorch()
33            .serialize()
34            .run_method_and_compare_outputs()
35        )
36
37    def test_fp16_clamp(self):
38        inputs = (torch.randn(1, 4, 122, 122).to(torch.float16) * 2,)
39        module = self.Clamp(-0.5, 0.5)
40        self._test_clamp(module, inputs)
41
42    def test_fp32_clamp(self):
43        inputs = (torch.randn(1, 4, 122, 122) * 2,)
44        module = self.Clamp(-0.5, 0.5)
45        self._test_clamp(module, inputs)
46
47    def test_fp32_clamp_lower(self):
48        inputs = (torch.randn(1, 4, 122, 122) * 2,)
49        module = self.Clamp(min_val=-0.5)
50        self._test_clamp(module, inputs)
51
52    def test_fp32_clamp_upper(self):
53        inputs = (torch.randn(1, 4, 122, 122) * 2,)
54        module = self.Clamp(max_val=0.5)
55        self._test_clamp(module, inputs)
56
57    def test_qs8_clamp(self):
58        inputs = (torch.randn(1, 4, 122, 122),)
59        (
60            Tester(self.Clamp(min_val=-1, max_val=1), inputs)
61            .quantize()
62            .export()
63            .check_count({"torch.ops.aten.clamp.default": 1})
64            .check(["torch.ops.quantized_decomposed"])
65            .to_edge_transform_and_lower()
66            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
67            .check_not(
68                [
69                    "executorch_exir_dialects_edge__ops_aten_clamp_default",
70                    "torch.ops.quantized_decomposed",
71                ]
72            )
73            .to_executorch()
74            .serialize()
75            .run_method_and_compare_outputs()
76        )
77