1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestSqrt(unittest.TestCase): 14 class Sqrt(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 18 def forward(self, x): 19 x = torch.abs(x) 20 z = torch.sqrt(x) 21 return z 22 23 def _test_sqrt(self, inputs): 24 ( 25 Tester(self.Sqrt(), inputs) 26 .export() 27 .check_count({"torch.ops.aten.sqrt.default": 1}) 28 .to_edge_transform_and_lower() 29 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 30 .check_not(["executorch_exir_dialects_edge__ops_aten_sqrt_default"]) 31 .to_executorch() 32 .serialize() 33 .run_method_and_compare_outputs() 34 ) 35 36 def test_fp16_sqrt(self): 37 inputs = (torch.randn(20).to(torch.float16),) 38 self._test_sqrt(inputs) 39 40 def test_fp32_sqrt(self): 41 inputs = (torch.randn(20),) 42 self._test_sqrt(inputs) 43