xref: /aosp_15_r20/external/executorch/backends/arm/_passes/decompose_meandim_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch
10from executorch.exir.dialects._ops import ops as exir_ops
11from executorch.exir.pass_base import ExportPass
12
13
14def get_meandim_decomposition(op) -> tuple:
15    if op == exir_ops.edge.aten.mean.dim:
16        return (
17            exir_ops.edge.aten.sum.dim_IntList,
18            exir_ops.edge.aten.full.default,
19            exir_ops.edge.aten.mul.Tensor,
20        )
21    if op == torch.ops.aten.mean.dim:
22        return (
23            torch.ops.aten.sum.dim_IntList,
24            torch.ops.aten.full.default,
25            torch.ops.aten.mul.Tensor,
26        )
27    raise RuntimeError(f"Can't get meandim decomposition for op {op}")
28
29
30class DecomposeMeanDimPass(ExportPass):
31    """
32    This pass decomposes meandim into a sum and mul node.
33
34    Example:
35        y = mean_dim(x, dim, keepdim)
36    Becomes:
37        sum = sum.dim_IntList(x, dim, keepdim)
38        y = mul(sum, 1/N)
39    """
40
41    def call_operator(self, op, args, kwargs, meta):
42        if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
43            return super().call_operator(op, args, kwargs, meta)
44
45        x = args[0]
46        dim = args[1]
47        keepdim = args[2] if len(args) > 2 else False
48        if not keepdim:
49            return super().call_operator(op, args, kwargs, meta)
50        # if keepdim == True and dim == [-1, -2], mean.dim can be
51        # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
52        if dim == [-1, -2]:
53            # Simply return the mean.dim operator for future decomposition.
54            return super().call_operator(op, args, kwargs, meta)
55        shape = meta["val"].size()
56        dtype = meta["val"].dtype
57        input_shape = x.data.size()
58        N = 1
59        for d in dim:
60            N *= input_shape[d]
61
62        sum_op, full_op, mul_op = get_meandim_decomposition(op)
63
64        sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
65        full = super().call_operator(
66            full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
67        )
68        return super().call_operator(mul_op, (sum, full), {}, meta)
69