xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/mean_dim.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestMeanDim(unittest.TestCase):
14    class MeanDim(torch.nn.Module):
15        def __init__(self, dims):
16            super().__init__()
17            self.dims = dims
18
19        def forward(self, x):
20            y = x + x
21            z = torch.mean(y, self.dims, keepdim=True)
22            return z
23
24    def _test_mean_dim(self, inputs):
25        (
26            Tester(self.MeanDim((-1, -2)), inputs)
27            .export()
28            .check_count({"torch.ops.aten.mean.dim": 1})
29            .to_edge_transform_and_lower()
30            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
31            .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
32            .to_executorch()
33            .serialize()
34            .run_method_and_compare_outputs()
35        )
36
37    def test_fp16_mean_dim(self):
38        inputs = (torch.randn(1, 5, 4, 4).to(torch.float16),)
39        self._test_mean_dim(inputs)
40
41    def test_fp32_mean_dim(self):
42        inputs = (torch.randn(1, 5, 4, 4),)
43        self._test_mean_dim(inputs)
44
45    def test_fp32_mean_dim_unsupported(self):
46        """
47        XNNPack mean.dim implementation only supports innermost two dimensions. As such,
48        we expect it to fail to partition when dim=(3).
49        """
50        inputs = (torch.randn(1, 5, 4, 4),)
51        (
52            Tester(self.MeanDim((3)), inputs)
53            .export()
54            .check_count({"torch.ops.aten.mean.dim": 1})
55            .to_edge_transform_and_lower()
56            .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
57        )
58
59    def test_fp32_mean_dim_unsupported_3d(self):
60        """
61        XNNPack mean.dim implementation only supports 4D tensors.
62        """
63        inputs = (torch.randn(1, 5, 4),)
64        (
65            Tester(self.MeanDim((-1, -2)), inputs)
66            .export()
67            .check_count({"torch.ops.aten.mean.dim": 1})
68            .to_edge_transform_and_lower()
69            .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1})
70        )
71
72    def test_qs8_mean_dim(self):
73        inputs = (torch.randn(1, 5, 4, 4),)
74        (
75            Tester(self.MeanDim((-1, -2)), inputs)
76            .quantize()
77            .export()
78            .check_node_count(
79                {
80                    torch.ops.aten.mean.dim: 1,
81                    torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
82                }
83            )
84            .to_edge_transform_and_lower()
85            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86            .check_not(
87                [
88                    "executorch_exir_dialects_edge__ops_aten_mean_dim",
89                    "torch.ops.quantized_decomposed",
90                ]
91            )
92            .to_executorch()
93            .serialize()
94            .run_method_and_compare_outputs(qtol=1)
95        )
96