1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestPrelu(unittest.TestCase): 14 class PReLU(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 self.prelu = torch.nn.PReLU(num_parameters=5, init=0.2) 18 19 def forward(self, x): 20 a = self.prelu(x) 21 return a 22 23 def _test_prelu(self, module, inputs): 24 ( 25 Tester(module, inputs) 26 .export() 27 .check_count({"torch.ops.aten.prelu.default": 1}) 28 .to_edge_transform_and_lower() 29 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 30 .check_not( 31 ["executorch_exir_dialects_edge__ops_aten__prelu_kernel_default"] 32 ) 33 .to_executorch() 34 .serialize() 35 .run_method_and_compare_outputs() 36 ) 37 38 @unittest.skip("XNNPACK Expects FP16 inputs but FP32 weights") 39 def _test_fp16_prelu(self): 40 module = self.PReLU().to(torch.float16) 41 inputs = (torch.randn(1, 5, 3, 2).to(torch.float16),) 42 self._test_prelu(module, inputs) 43 44 def test_fp32_prelu(self): 45 module = self.PReLU() 46 inputs = (torch.randn(1, 5, 3, 2),) 47 self._test_prelu(module, inputs) 48