xref: /aosp_15_r20/external/executorch/backends/arm/_passes/decompose_var_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9
10import torch
11from executorch.exir.dialects._ops import ops as exir_ops
12from executorch.exir.pass_base import ExportPass
13
14
15def get_var_decomposition(op) -> tuple:
16    if op == exir_ops.edge.aten.var.correction:
17        return (
18            exir_ops.edge.aten.mean.dim,
19            exir_ops.edge.aten.sub.Tensor,
20            exir_ops.edge.aten.mul.Tensor,
21            exir_ops.edge.aten.sum.dim_IntList,
22            exir_ops.edge.aten.full.default,
23        )
24    if op in (torch.ops.aten.var.correction, torch.ops.aten.var.dim):
25        return (
26            torch.ops.aten.mean.dim,
27            torch.ops.aten.sub.Tensor,
28            torch.ops.aten.mul.Tensor,
29            torch.ops.aten.sum.dim_IntList,
30            torch.ops.aten.full,
31        )
32    raise RuntimeError(f"Can't get var decomposition for op {op}")
33
34
35class DecomposeVarPass(ExportPass):
36    """
37    This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)
38
39    Example:
40        y = var_correction(x, dim, keepdim, correction)
41    Becomes:
42        mean = mean(x, dim)
43        diff = sub(x, mean)
44        squared_diff = mul(diff, diff)
45        sum = sum(squared_diff, dim)
46        y = div(sum, max(0, N-correction))
47    """
48
49    def call_operator(self, op, args, kwargs, meta):
50        if op not in (
51            exir_ops.edge.aten.var.correction,
52            torch.ops.aten.var.correction,
53            torch.ops.aten.var.dim,
54        ):
55            return super().call_operator(op, args, kwargs, meta)
56        shape = meta["val"].size()
57        dtype = meta["val"].dtype
58        dim = args[1] if len(args) > 1 else list(range(len(shape)))
59        if op == torch.ops.aten.var.dim:
60            correction = args[-2]
61            keepdim = args[-1]
62        else:
63            correction = kwargs["correction"]
64            keepdim = kwargs.get("keepdim", False)
65        if not keepdim:
66            return super().call_operator(op, args, kwargs, meta)
67
68        x = args[0]
69        input_shape = x.data.size()
70        N = 1
71        for d in dim:
72            N *= input_shape[d]
73
74        mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
75        mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
76        diff = super().call_operator(diff_op, (x, mean), {}, meta)
77        squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
78        sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
79        full = super().call_operator(
80            full_op,
81            ([1 for _ in shape], 1 / max(0, N - correction)),
82            {"dtype": dtype},
83            meta,
84        )
85        return super().call_operator(mul_op, (sum, full), {}, meta)
86