1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import operator 10 11import torch 12from executorch.backends.arm._passes.arm_pass_utils import create_node 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.pass_base import ExportPass, PassResult 15 16 17def get_layer_norm_decomposition(op) -> tuple: 18 if op == exir_ops.edge.aten.native_layer_norm.default: 19 return ( 20 exir_ops.edge.aten.mean.dim, 21 exir_ops.edge.aten.sub.Tensor, 22 exir_ops.edge.aten.var.correction, 23 exir_ops.edge.aten.full.default, 24 exir_ops.edge.aten.add.Tensor, 25 exir_ops.edge.aten.rsqrt.default, 26 exir_ops.edge.aten.mul.Tensor, 27 exir_ops.edge.aten.view_copy.default, 28 ) 29 if op == torch.ops.aten.layer_norm.default: 30 return ( 31 torch.ops.aten.mean.dim, 32 torch.ops.aten.sub.Tensor, 33 torch.ops.aten.var.correction, 34 torch.ops.aten.full.default, 35 torch.ops.aten.add.Tensor, 36 torch.ops.aten.rsqrt.default, 37 torch.ops.aten.mul.Tensor, 38 torch.ops.aten.view_copy.default, 39 ) 40 raise RuntimeError(f"Can't get layer_norm composition for op {op}") 41 42 43class DecomposeLayerNormPass(ExportPass): 44 """ 45 layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias 46 Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: 47 mean = op_mean(x, dims) # E[x] 48 var = op_var(x, dims) # Var[x] 49 denominator = op_sub(x, mean) # (x - E[x]) 50 add = op_add(var, eps) # Var[x] + eps 51 rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps) 52 mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths 53 bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias 54 55 Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html 56 """ 57 58 def call(self, graph_module: torch.fx.GraphModule): 59 for node in graph_module.graph.nodes: 60 if node.op != "call_function" or node.target not in ( 61 exir_ops.edge.aten.native_layer_norm.default, 62 torch.ops.aten.layer_norm.default, 63 ): 64 continue 65 66 # epsilon default value 67 epsilon = torch.finfo().eps 68 weights = None 69 bias = None 70 args = node.args 71 meta = node.meta 72 match len(args): 73 case 5: 74 x, normalized_shape, weights, bias, epsilon = args 75 case 4: 76 x, normalized_shape, weights, bias = args 77 case 3: 78 x, normalized_shape, weights = args 79 case _: 80 x, normalized_shape = args 81 82 n_dims = len(normalized_shape) 83 if isinstance(meta["val"], tuple): 84 shape = meta["val"][0].size() 85 else: 86 shape = meta["val"].size() 87 dtype = meta["val"][0].dtype 88 rank = len(shape) 89 dims = list(range(-1, -1 * (n_dims + 1), -1)) 90 dims = [dim % rank for dim in dims] 91 weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)] 92 epsilon_reshaped_shape = [1] * rank 93 94 ( 95 mean_op, 96 sub_op, 97 var_op, 98 full_op, 99 add_op, 100 rsqrt_op, 101 mul_op, 102 view_op, 103 ) = get_layer_norm_decomposition(node.target) 104 with graph_module.graph.inserting_before(node): 105 keepdim = True 106 mean = create_node(graph_module.graph, mean_op, args=(x, dims, keepdim)) 107 sub = create_node(graph_module.graph, sub_op, args=(x, mean)) 108 var = create_node( 109 graph_module.graph, 110 var_op, 111 args=(x, dims), 112 kwargs={"correction": 0, "keepdim": keepdim}, 113 ) 114 full = create_node( 115 graph_module.graph, 116 full_op, 117 args=(epsilon_reshaped_shape, epsilon), 118 kwargs={"dtype": dtype}, 119 ) 120 add0 = create_node(graph_module.graph, add_op, args=(var, full)) 121 rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,)) 122 mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt)) 123 if weights is not None: 124 weights_reshaped = create_node( 125 graph_module.graph, 126 view_op, 127 args=(weights, weights_reshaped_shape), 128 ) 129 mul1 = create_node( 130 graph_module.graph, mul_op, args=(mul0, weights_reshaped) 131 ) 132 else: 133 mul1 = mul0 134 output = mul1 135 if bias is not None: 136 bias_reshaped_shape = weights_reshaped_shape 137 bias_reshaped = create_node( 138 graph_module.graph, view_op, args=(bias, bias_reshaped_shape) 139 ) 140 output = create_node( 141 graph_module.graph, add_op, args=(mul1, bias_reshaped) 142 ) 143 144 users = [user for user in node.users if node != user] 145 node.replace_all_uses_with(output) 146 for user in users: 147 if user.target == operator.getitem: 148 user.replace_all_uses_with(output) 149 graph_module.graph.erase_node(node) 150 graph_module.graph.eliminate_dead_code() 151 graph_module.recompile() 152 graph_module = super().call(graph_module).graph_module 153 154 return PassResult(graph_module, True) 155