1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestElu(unittest.TestCase): 14 class ELU(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 self.elu = torch.nn.ELU(alpha=0.5) 18 19 def forward(self, x): 20 return self.elu(x) 21 22 class ELUFunctional(torch.nn.Module): 23 def forward(self, x): 24 return torch.nn.functional.elu(x, alpha=1.2) 25 26 def _test_elu(self, inputs): 27 ( 28 Tester(self.ELU(), inputs) 29 .export() 30 .check_count({"torch.ops.aten.elu.default": 1}) 31 .to_edge_transform_and_lower() 32 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 33 .check_not( 34 [ 35 "executorch_exir_dialects_edge__ops_aten_elu_default", 36 ] 37 ) 38 .to_executorch() 39 .serialize() 40 .run_method_and_compare_outputs() 41 ) 42 43 @unittest.skip("PyTorch Pin Update Required") 44 def _test_fp16_elu(self): 45 inputs = (torch.randn(1, 3, 3).to(torch.float16),) 46 self._test_elu(inputs) 47 48 @unittest.skip("PyTorch Pin Update Required") 49 def _test_fp32_elu(self): 50 inputs = (torch.randn(1, 3, 3),) 51 self._test_elu(inputs) 52 53 @unittest.skip("Update Quantizer to quantize Elu") 54 def _test_qs8_elu(self): 55 inputs = (torch.randn(1, 3, 4, 4),) 56 ( 57 Tester(self.ELU(), inputs) 58 .quantize() 59 .export() 60 .check_count({"torch.ops.aten.elu.default": 1}) 61 .check(["torch.ops.quantized_decomposed"]) 62 .to_edge_transform_and_lower() 63 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 64 .check_not( 65 [ 66 "executorch_exir_dialects_edge__ops_aten_elu_default", 67 "torch.ops.quantized_decomposed", 68 ] 69 ) 70 .to_executorch() 71 .serialize() 72 .run_method_and_compare_outputs() 73 ) 74 75 @unittest.skip("Update Quantizer to quantize Elu") 76 def _test_qs8_elu_functional(self): 77 inputs = (torch.randn(1, 3, 4, 4),) 78 ( 79 Tester(self.ELU(), inputs) 80 .quantize() 81 .export() 82 .check_count({"torch.ops.aten.elu.default": 1}) 83 .check(["torch.ops.quantized_decomposed"]) 84 .to_edge_transform_and_lower() 85 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 86 .check_not( 87 [ 88 "executorch_exir_dialects_edge__ops_aten_elu_default", 89 "torch.ops.quantized_decomposed", 90 ] 91 ) 92 .to_executorch() 93 .serialize() 94 .run_method_and_compare_outputs() 95 ) 96