1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestClamp(unittest.TestCase): 14 class Clamp(torch.nn.Module): 15 def __init__(self, min_val=None, max_val=None): 16 super().__init__() 17 self.min_val = min_val 18 self.max_val = max_val 19 20 def forward(self, x): 21 z = torch.clamp(x, min=self.min_val, max=self.max_val) 22 return z + z 23 24 def _test_clamp(self, module, inputs): 25 ( 26 Tester(module, inputs) 27 .export() 28 .check_count({"torch.ops.aten.clamp.default": 1}) 29 .to_edge_transform_and_lower() 30 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 31 .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) 32 .to_executorch() 33 .serialize() 34 .run_method_and_compare_outputs() 35 ) 36 37 def test_fp16_clamp(self): 38 inputs = (torch.randn(1, 4, 122, 122).to(torch.float16) * 2,) 39 module = self.Clamp(-0.5, 0.5) 40 self._test_clamp(module, inputs) 41 42 def test_fp32_clamp(self): 43 inputs = (torch.randn(1, 4, 122, 122) * 2,) 44 module = self.Clamp(-0.5, 0.5) 45 self._test_clamp(module, inputs) 46 47 def test_fp32_clamp_lower(self): 48 inputs = (torch.randn(1, 4, 122, 122) * 2,) 49 module = self.Clamp(min_val=-0.5) 50 self._test_clamp(module, inputs) 51 52 def test_fp32_clamp_upper(self): 53 inputs = (torch.randn(1, 4, 122, 122) * 2,) 54 module = self.Clamp(max_val=0.5) 55 self._test_clamp(module, inputs) 56 57 def test_qs8_clamp(self): 58 inputs = (torch.randn(1, 4, 122, 122),) 59 ( 60 Tester(self.Clamp(min_val=-1, max_val=1), inputs) 61 .quantize() 62 .export() 63 .check_count({"torch.ops.aten.clamp.default": 1}) 64 .check(["torch.ops.quantized_decomposed"]) 65 .to_edge_transform_and_lower() 66 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 67 .check_not( 68 [ 69 "executorch_exir_dialects_edge__ops_aten_clamp_default", 70 "torch.ops.quantized_decomposed", 71 ] 72 ) 73 .to_executorch() 74 .serialize() 75 .run_method_and_compare_outputs() 76 ) 77