1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass 12 13 14def get_meandim_decomposition(op) -> tuple: 15 if op == exir_ops.edge.aten.mean.dim: 16 return ( 17 exir_ops.edge.aten.sum.dim_IntList, 18 exir_ops.edge.aten.full.default, 19 exir_ops.edge.aten.mul.Tensor, 20 ) 21 if op == torch.ops.aten.mean.dim: 22 return ( 23 torch.ops.aten.sum.dim_IntList, 24 torch.ops.aten.full.default, 25 torch.ops.aten.mul.Tensor, 26 ) 27 raise RuntimeError(f"Can't get meandim decomposition for op {op}") 28 29 30class DecomposeMeanDimPass(ExportPass): 31 """ 32 This pass decomposes meandim into a sum and mul node. 33 34 Example: 35 y = mean_dim(x, dim, keepdim) 36 Becomes: 37 sum = sum.dim_IntList(x, dim, keepdim) 38 y = mul(sum, 1/N) 39 """ 40 41 def call_operator(self, op, args, kwargs, meta): 42 if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim): 43 return super().call_operator(op, args, kwargs, meta) 44 45 x = args[0] 46 dim = args[1] 47 keepdim = args[2] if len(args) > 2 else False 48 if not keepdim: 49 return super().call_operator(op, args, kwargs, meta) 50 # if keepdim == True and dim == [-1, -2], mean.dim can be 51 # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool. 52 if dim == [-1, -2]: 53 # Simply return the mean.dim operator for future decomposition. 54 return super().call_operator(op, args, kwargs, meta) 55 shape = meta["val"].size() 56 dtype = meta["val"].dtype 57 input_shape = x.data.size() 58 N = 1 59 for d in dim: 60 N *= input_shape[d] 61 62 sum_op, full_op, mul_op = get_meandim_decomposition(op) 63 64 sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta) 65 full = super().call_operator( 66 full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta 67 ) 68 return super().call_operator(mul_op, (sum, full), {}, meta) 69