1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestSoftmax(unittest.TestCase): 14 class Softmax(torch.nn.Module): 15 def __init__(self, dim): 16 super().__init__() 17 self.dim = dim 18 19 def forward(self, x): 20 return torch.nn.Softmax(dim=self.dim)(x) 21 22 def _test_softmax(self, inputs): 23 # Dim can be either the last dimension index or -1 (last dimension), 24 # as xnnpack only supports softmax on the last dimension. 25 valid_dims = [len(inputs[0]) - 1, -1] 26 27 for dim in valid_dims: 28 ( 29 Tester(self.Softmax(dim), inputs) 30 .export() 31 .check_count({"torch.ops.aten.softmax": 1}) 32 .to_edge_transform_and_lower() 33 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 34 .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) 35 .to_executorch() 36 .serialize() 37 .run_method_and_compare_outputs() 38 ) 39 40 def test_fp16_softmax(self): 41 inputs = (torch.rand((3, 5, 7)).to(torch.float16),) 42 self._test_softmax(inputs) 43 44 def test_fp32_softmax(self): 45 inputs = (torch.rand((3, 5, 7)),) 46 self._test_softmax(inputs) 47 48 def test_fp32_softmax_unsupported(self): 49 inputs = (torch.rand((3, 5, 7)),) 50 51 # Dim can be either the last dimension index or -1 (last dimension), 52 # as xnnpack only supports softmax on the last dimension. 53 # This test validates the delegate does not attempt to delegate softmax 54 # on any other dimension. 55 invalid_dims = range(len(inputs) - 1) 56 57 for dim in invalid_dims: 58 ( 59 Tester(self.Softmax(dim), inputs) 60 .export() 61 .check_count({"torch.ops.aten.softmax": 1}) 62 .to_edge_transform_and_lower() 63 # Should not be delegated 64 .check(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) 65 ) 66