xref: /aosp_15_r20/external/executorch/backends/arm/_passes/decompose_layernorm_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import operator
10
11import torch
12from executorch.backends.arm._passes.arm_pass_utils import create_node
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.pass_base import ExportPass, PassResult
15
16
17def get_layer_norm_decomposition(op) -> tuple:
18    if op == exir_ops.edge.aten.native_layer_norm.default:
19        return (
20            exir_ops.edge.aten.mean.dim,
21            exir_ops.edge.aten.sub.Tensor,
22            exir_ops.edge.aten.var.correction,
23            exir_ops.edge.aten.full.default,
24            exir_ops.edge.aten.add.Tensor,
25            exir_ops.edge.aten.rsqrt.default,
26            exir_ops.edge.aten.mul.Tensor,
27            exir_ops.edge.aten.view_copy.default,
28        )
29    if op == torch.ops.aten.layer_norm.default:
30        return (
31            torch.ops.aten.mean.dim,
32            torch.ops.aten.sub.Tensor,
33            torch.ops.aten.var.correction,
34            torch.ops.aten.full.default,
35            torch.ops.aten.add.Tensor,
36            torch.ops.aten.rsqrt.default,
37            torch.ops.aten.mul.Tensor,
38            torch.ops.aten.view_copy.default,
39        )
40    raise RuntimeError(f"Can't get layer_norm composition for op {op}")
41
42
43class DecomposeLayerNormPass(ExportPass):
44    """
45    layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
46    Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
47    mean        = op_mean(x, dims)           # E[x]
48    var         = op_var(x, dims)            # Var[x]
49    denominator = op_sub(x, mean)            # (x - E[x])
50    add         = op_add(var, eps)           # Var[x] + eps
51    rsqrt       = op_rsqrt(add)              # 1 / sqrt(Var[x] + eps)
52    mul         = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
53    bias        = op_add(mul, bias)          # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
54
55    Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
56    """
57
58    def call(self, graph_module: torch.fx.GraphModule):
59        for node in graph_module.graph.nodes:
60            if node.op != "call_function" or node.target not in (
61                exir_ops.edge.aten.native_layer_norm.default,
62                torch.ops.aten.layer_norm.default,
63            ):
64                continue
65
66            # epsilon default value
67            epsilon = torch.finfo().eps
68            weights = None
69            bias = None
70            args = node.args
71            meta = node.meta
72            match len(args):
73                case 5:
74                    x, normalized_shape, weights, bias, epsilon = args
75                case 4:
76                    x, normalized_shape, weights, bias = args
77                case 3:
78                    x, normalized_shape, weights = args
79                case _:
80                    x, normalized_shape = args
81
82            n_dims = len(normalized_shape)
83            if isinstance(meta["val"], tuple):
84                shape = meta["val"][0].size()
85            else:
86                shape = meta["val"].size()
87            dtype = meta["val"][0].dtype
88            rank = len(shape)
89            dims = list(range(-1, -1 * (n_dims + 1), -1))
90            dims = [dim % rank for dim in dims]
91            weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)]
92            epsilon_reshaped_shape = [1] * rank
93
94            (
95                mean_op,
96                sub_op,
97                var_op,
98                full_op,
99                add_op,
100                rsqrt_op,
101                mul_op,
102                view_op,
103            ) = get_layer_norm_decomposition(node.target)
104            with graph_module.graph.inserting_before(node):
105                keepdim = True
106                mean = create_node(graph_module.graph, mean_op, args=(x, dims, keepdim))
107                sub = create_node(graph_module.graph, sub_op, args=(x, mean))
108                var = create_node(
109                    graph_module.graph,
110                    var_op,
111                    args=(x, dims),
112                    kwargs={"correction": 0, "keepdim": keepdim},
113                )
114                full = create_node(
115                    graph_module.graph,
116                    full_op,
117                    args=(epsilon_reshaped_shape, epsilon),
118                    kwargs={"dtype": dtype},
119                )
120                add0 = create_node(graph_module.graph, add_op, args=(var, full))
121                rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
122                mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
123                if weights is not None:
124                    weights_reshaped = create_node(
125                        graph_module.graph,
126                        view_op,
127                        args=(weights, weights_reshaped_shape),
128                    )
129                    mul1 = create_node(
130                        graph_module.graph, mul_op, args=(mul0, weights_reshaped)
131                    )
132                else:
133                    mul1 = mul0
134                output = mul1
135                if bias is not None:
136                    bias_reshaped_shape = weights_reshaped_shape
137                    bias_reshaped = create_node(
138                        graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
139                    )
140                    output = create_node(
141                        graph_module.graph, add_op, args=(mul1, bias_reshaped)
142                    )
143
144                users = [user for user in node.users if node != user]
145                node.replace_all_uses_with(output)
146                for user in users:
147                    if user.target == operator.getitem:
148                        user.replace_all_uses_with(output)
149                graph_module.graph.erase_node(node)
150                graph_module.graph.eliminate_dead_code()
151        graph_module.recompile()
152        graph_module = super().call(graph_module).graph_module
153
154        return PassResult(graph_module, True)
155