xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/softmax.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestSoftmax(unittest.TestCase):
14    class Softmax(torch.nn.Module):
15        def __init__(self, dim):
16            super().__init__()
17            self.dim = dim
18
19        def forward(self, x):
20            return torch.nn.Softmax(dim=self.dim)(x)
21
22    def _test_softmax(self, inputs):
23        # Dim can be either the last dimension index or -1 (last dimension),
24        # as xnnpack only supports softmax on the last dimension.
25        valid_dims = [len(inputs[0]) - 1, -1]
26
27        for dim in valid_dims:
28            (
29                Tester(self.Softmax(dim), inputs)
30                .export()
31                .check_count({"torch.ops.aten.softmax": 1})
32                .to_edge_transform_and_lower()
33                .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
34                .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"])
35                .to_executorch()
36                .serialize()
37                .run_method_and_compare_outputs()
38            )
39
40    def test_fp16_softmax(self):
41        inputs = (torch.rand((3, 5, 7)).to(torch.float16),)
42        self._test_softmax(inputs)
43
44    def test_fp32_softmax(self):
45        inputs = (torch.rand((3, 5, 7)),)
46        self._test_softmax(inputs)
47
48    def test_fp32_softmax_unsupported(self):
49        inputs = (torch.rand((3, 5, 7)),)
50
51        # Dim can be either the last dimension index or -1 (last dimension),
52        # as xnnpack only supports softmax on the last dimension.
53        # This test validates the delegate does not attempt to delegate softmax
54        # on any other dimension.
55        invalid_dims = range(len(inputs) - 1)
56
57        for dim in invalid_dims:
58            (
59                Tester(self.Softmax(dim), inputs)
60                .export()
61                .check_count({"torch.ops.aten.softmax": 1})
62                .to_edge_transform_and_lower()
63                # Should not be delegated
64                .check(["executorch_exir_dialects_edge__ops_aten__softmax_default"])
65            )
66