1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9 10import torch 11from executorch.exir.dialects._ops import ops as exir_ops 12from executorch.exir.pass_base import ExportPass 13 14 15def get_var_decomposition(op) -> tuple: 16 if op == exir_ops.edge.aten.var.correction: 17 return ( 18 exir_ops.edge.aten.mean.dim, 19 exir_ops.edge.aten.sub.Tensor, 20 exir_ops.edge.aten.mul.Tensor, 21 exir_ops.edge.aten.sum.dim_IntList, 22 exir_ops.edge.aten.full.default, 23 ) 24 if op in (torch.ops.aten.var.correction, torch.ops.aten.var.dim): 25 return ( 26 torch.ops.aten.mean.dim, 27 torch.ops.aten.sub.Tensor, 28 torch.ops.aten.mul.Tensor, 29 torch.ops.aten.sum.dim_IntList, 30 torch.ops.aten.full, 31 ) 32 raise RuntimeError(f"Can't get var decomposition for op {op}") 33 34 35class DecomposeVarPass(ExportPass): 36 """ 37 This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html) 38 39 Example: 40 y = var_correction(x, dim, keepdim, correction) 41 Becomes: 42 mean = mean(x, dim) 43 diff = sub(x, mean) 44 squared_diff = mul(diff, diff) 45 sum = sum(squared_diff, dim) 46 y = div(sum, max(0, N-correction)) 47 """ 48 49 def call_operator(self, op, args, kwargs, meta): 50 if op not in ( 51 exir_ops.edge.aten.var.correction, 52 torch.ops.aten.var.correction, 53 torch.ops.aten.var.dim, 54 ): 55 return super().call_operator(op, args, kwargs, meta) 56 shape = meta["val"].size() 57 dtype = meta["val"].dtype 58 dim = args[1] if len(args) > 1 else list(range(len(shape))) 59 if op == torch.ops.aten.var.dim: 60 correction = args[-2] 61 keepdim = args[-1] 62 else: 63 correction = kwargs["correction"] 64 keepdim = kwargs.get("keepdim", False) 65 if not keepdim: 66 return super().call_operator(op, args, kwargs, meta) 67 68 x = args[0] 69 input_shape = x.data.size() 70 N = 1 71 for d in dim: 72 N *= input_shape[d] 73 74 mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op) 75 mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta) 76 diff = super().call_operator(diff_op, (x, mean), {}, meta) 77 squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta) 78 sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta) 79 full = super().call_operator( 80 full_op, 81 ([1 for _ in shape], 1 / max(0, N - correction)), 82 {"dtype": dtype}, 83 meta, 84 ) 85 return super().call_operator(mul_op, (sum, full), {}, meta) 86